package org.nd4j.linalg.util;

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/linalg/util/DeviceLocalNDArray.class */
public class DeviceLocalNDArray extends DeviceLocal<INDArray> {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) DeviceLocalNDArray.class);

    public DeviceLocalNDArray() {
    }

    public DeviceLocalNDArray(INDArray iNDArray) {
        broadcast(iNDArray);
    }

    public void broadcast(INDArray iNDArray) {
        if (iNDArray == null) {
            return;
        }
        Nd4j.getExecutioner().commit();
        int numberOfDevices = Nd4j.getAffinityManager().getNumberOfDevices();
        for (int i = 0; i < numberOfDevices; i++) {
            if (Nd4j.getAffinityManager().getDeviceForCurrentThread().intValue() == i) {
                set(i, iNDArray);
            } else {
                set(i, Nd4j.getAffinityManager().replicateToDevice(Integer.valueOf(i), iNDArray));
            }
        }
    }
}
