package org.nd4j.parameterserver.updater;

import java.util.HashMap;
import java.util.Map;
import org.nd4j.aeron.ipc.NDArrayHolder;
import org.nd4j.aeron.ipc.NDArrayMessage;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.parameterserver.updater.storage.UpdateStorage;
import org.nd4j.shade.jackson.core.JsonProcessingException;
import org.nd4j.shade.jackson.databind.ObjectMapper;

/* loaded from: input_file:org/nd4j/parameterserver/updater/SynchronousParameterUpdater.class */
public class SynchronousParameterUpdater extends BaseParameterUpdater {
    private int workers;
    private static ObjectMapper objectMapper = new ObjectMapper();

    @Override // org.nd4j.parameterserver.updater.ParameterServerUpdater
    public int requiredUpdatesForPass() {
        return this.workers;
    }

    @Override // org.nd4j.parameterserver.updater.BaseParameterUpdater, org.nd4j.parameterserver.updater.ParameterServerUpdater
    public boolean isAsync() {
        return false;
    }

    public SynchronousParameterUpdater(UpdateStorage updateStorage, NDArrayHolder nDArrayHolder, int i) {
        super(updateStorage, nDArrayHolder);
        this.workers = Runtime.getRuntime().availableProcessors();
        this.workers = i;
    }

    public SynchronousParameterUpdater(UpdateStorage updateStorage, int i) {
        super(updateStorage);
        this.workers = Runtime.getRuntime().availableProcessors();
        this.workers = i;
    }

    public SynchronousParameterUpdater(int i) {
        this.workers = Runtime.getRuntime().availableProcessors();
        this.workers = i;
    }

    @Override // org.nd4j.parameterserver.updater.ParameterServerUpdater
    public Map<String, Number> status() {
        HashMap hashMap = new HashMap();
        hashMap.put("workers", Integer.valueOf(this.workers));
        hashMap.put("accumulatedUpdates", Integer.valueOf(numUpdates()));
        return hashMap;
    }

    @Override // org.nd4j.parameterserver.updater.ParameterServerUpdater
    public String toJson() {
        try {
            return objectMapper.writeValueAsString(status());
        } catch (JsonProcessingException e) {
            throw new RuntimeException(e);
        }
    }

    @Override // org.nd4j.parameterserver.updater.ParameterServerUpdater
    public boolean shouldReplicate() {
        return numUpdates() == this.workers;
    }

    @Override // org.nd4j.parameterserver.updater.ParameterServerUpdater
    public void update(NDArrayMessage nDArrayMessage) {
        this.updateStorage.addUpdate(nDArrayMessage);
        INDArray arr = nDArrayMessage.getArr();
        int[] dimensions = nDArrayMessage.getDimensions();
        if (dimensions.length == 1 && dimensions[0] == -1) {
            update(arr, this.ndArrayHolder.get());
        } else {
            partialUpdate(arr, this.ndArrayHolder.get(), nDArrayMessage.getIndex(), dimensions);
        }
    }

    @Override // org.nd4j.parameterserver.updater.ParameterServerUpdater
    public void partialUpdate(INDArray iNDArray, INDArray iNDArray2, long j, int... iArr) {
        iNDArray2.tensorAlongDimension((int) j, iArr).addi(iNDArray);
    }

    @Override // org.nd4j.parameterserver.updater.ParameterServerUpdater
    public void update(INDArray iNDArray, INDArray iNDArray2) {
        iNDArray2.addi(iNDArray);
    }
}
