package org.nd4j.parameterserver.distributed.messages.intercom;

import java.util.Arrays;
import lombok.NonNull;
import org.apache.camel.util.URISupport;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.parameterserver.distributed.enums.ExecutionMode;
import org.nd4j.parameterserver.distributed.logic.storage.WordVectorStorage;
import org.nd4j.parameterserver.distributed.messages.BaseVoidMessage;
import org.nd4j.parameterserver.distributed.messages.DistributedMessage;
import org.nd4j.parameterserver.distributed.messages.aggregations.DotAggregation;
import org.nd4j.parameterserver.distributed.messages.requests.CbowRequestMessage;
import org.nd4j.parameterserver.distributed.training.impl.CbowTrainer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/parameterserver/distributed/messages/intercom/DistributedCbowDotMessage.class */
public class DistributedCbowDotMessage extends BaseVoidMessage implements DistributedMessage {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) DistributedCbowDotMessage.class);
    protected int[] rowsA;
    protected int[] rowsB;
    protected int w1;
    protected boolean useHS;
    protected short negSamples;
    protected float alpha;
    protected byte[] codes;

    public DistributedCbowDotMessage() {
        this.messageType = 22;
    }

    @Deprecated
    public DistributedCbowDotMessage(long j, int i, int i2) {
        this(j, new int[]{i}, new int[]{i2}, i, new byte[0], false, (short) 0, 0.001f);
    }

    public DistributedCbowDotMessage(long j, @NonNull int[] iArr, @NonNull int[] iArr2, int i, @NonNull byte[] bArr, boolean z, short s, float f) {
        this();
        if (iArr == null) {
            throw new NullPointerException("rowsA is marked @NonNull but is null");
        }
        if (iArr2 == null) {
            throw new NullPointerException("rowsB is marked @NonNull but is null");
        }
        if (bArr == null) {
            throw new NullPointerException("codes is marked @NonNull but is null");
        }
        this.rowsA = iArr;
        this.rowsB = iArr2;
        this.taskId = j;
        this.w1 = i;
        this.useHS = z;
        this.negSamples = s;
        this.alpha = f;
        this.codes = bArr;
    }

    @Override // org.nd4j.parameterserver.distributed.messages.VoidMessage
    public void processMessage() {
        CbowRequestMessage cbowRequestMessage = new CbowRequestMessage(this.rowsA, this.rowsB, this.w1, this.codes, this.negSamples, this.alpha, 119L);
        if (this.negSamples > 0) {
            cbowRequestMessage.setNegatives(Arrays.copyOfRange(this.rowsB, this.codes.length, this.rowsB.length));
        }
        cbowRequestMessage.setFrameId(-119L);
        cbowRequestMessage.setTaskId(this.taskId);
        cbowRequestMessage.setOriginatorId(getOriginatorId());
        ((CbowTrainer) this.trainer).pickTraining(cbowRequestMessage);
        INDArray mean = Nd4j.pullRows(this.storage.getArray(WordVectorStorage.SYN_0), 1, this.rowsA, 'c').mean(0);
        int length = this.codes.length + (this.negSamples > 0 ? this.negSamples + 1 : 0);
        INDArray createUninitialized = Nd4j.createUninitialized(length, 1L);
        int i = 0;
        while (i < this.codes.length) {
            createUninitialized.putScalar(i, Nd4j.getBlasWrapper().dot(mean, this.storage.getArray(WordVectorStorage.SYN_1).getRow(this.rowsB[i])));
            i++;
        }
        while (i < length) {
            createUninitialized.putScalar(i, Nd4j.getBlasWrapper().dot(mean, this.storage.getArray(WordVectorStorage.SYN_1_NEGATIVE).getRow(this.rowsB[i])));
            i++;
        }
        if (this.voidConfiguration.getExecutionMode() == ExecutionMode.AVERAGING) {
            DotAggregation dotAggregation = new DotAggregation(this.taskId, (short) 1, this.shardIndex, createUninitialized);
            dotAggregation.setTargetId((short) -1);
            dotAggregation.setOriginatorId(getOriginatorId());
            this.transport.putMessage(dotAggregation);
            return;
        }
        if (this.voidConfiguration.getExecutionMode() == ExecutionMode.SHARDED) {
            DotAggregation dotAggregation2 = new DotAggregation(this.taskId, (short) this.voidConfiguration.getNumberOfShards(), this.shardIndex, createUninitialized);
            dotAggregation2.setTargetId((short) -1);
            dotAggregation2.setOriginatorId(getOriginatorId());
            this.transport.sendMessage(dotAggregation2);
        }
    }

    public int[] getRowsA() {
        return this.rowsA;
    }

    public int[] getRowsB() {
        return this.rowsB;
    }

    public int getW1() {
        return this.w1;
    }

    public boolean isUseHS() {
        return this.useHS;
    }

    public short getNegSamples() {
        return this.negSamples;
    }

    public float getAlpha() {
        return this.alpha;
    }

    public byte[] getCodes() {
        return this.codes;
    }

    public void setRowsA(int[] iArr) {
        this.rowsA = iArr;
    }

    public void setRowsB(int[] iArr) {
        this.rowsB = iArr;
    }

    public void setW1(int i) {
        this.w1 = i;
    }

    public void setUseHS(boolean z) {
        this.useHS = z;
    }

    public void setNegSamples(short s) {
        this.negSamples = s;
    }

    public void setAlpha(float f) {
        this.alpha = f;
    }

    public void setCodes(byte[] bArr) {
        this.codes = bArr;
    }

    @Override // org.nd4j.parameterserver.distributed.messages.BaseVoidMessage
    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof DistributedCbowDotMessage)) {
            return false;
        }
        DistributedCbowDotMessage distributedCbowDotMessage = (DistributedCbowDotMessage) obj;
        return distributedCbowDotMessage.canEqual(this) && Arrays.equals(getRowsA(), distributedCbowDotMessage.getRowsA()) && Arrays.equals(getRowsB(), distributedCbowDotMessage.getRowsB()) && getW1() == distributedCbowDotMessage.getW1() && isUseHS() == distributedCbowDotMessage.isUseHS() && getNegSamples() == distributedCbowDotMessage.getNegSamples() && Float.compare(getAlpha(), distributedCbowDotMessage.getAlpha()) == 0 && Arrays.equals(getCodes(), distributedCbowDotMessage.getCodes());
    }

    @Override // org.nd4j.parameterserver.distributed.messages.BaseVoidMessage
    protected boolean canEqual(Object obj) {
        return obj instanceof DistributedCbowDotMessage;
    }

    @Override // org.nd4j.parameterserver.distributed.messages.BaseVoidMessage
    public int hashCode() {
        return (((((((((((((1 * 59) + Arrays.hashCode(getRowsA())) * 59) + Arrays.hashCode(getRowsB())) * 59) + getW1()) * 59) + (isUseHS() ? 79 : 97)) * 59) + getNegSamples()) * 59) + Float.floatToIntBits(getAlpha())) * 59) + Arrays.hashCode(getCodes());
    }

    @Override // org.nd4j.parameterserver.distributed.messages.BaseVoidMessage
    public String toString() {
        return "DistributedCbowDotMessage(rowsA=" + Arrays.toString(getRowsA()) + ", rowsB=" + Arrays.toString(getRowsB()) + ", w1=" + getW1() + ", useHS=" + isUseHS() + ", negSamples=" + ((int) getNegSamples()) + ", alpha=" + getAlpha() + ", codes=" + Arrays.toString(getCodes()) + URISupport.RAW_TOKEN_END;
    }
}
