package org.nd4j.linalg.dataset.api;

import java.util.Arrays;
import lombok.NonNull;
import org.apache.camel.util.URISupport;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BroadcastOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.primitives.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/linalg/dataset/api/DataSetUtil.class */
public class DataSetUtil {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) DataSetUtil.class);

    public static INDArray tailor2d(@NonNull DataSet dataSet, boolean z) {
        if (dataSet == null) {
            throw new NullPointerException("dataSet is marked @NonNull but is null");
        }
        return tailor2d(z ? dataSet.getFeatures() : dataSet.getLabels(), z ? dataSet.getFeaturesMaskArray() : dataSet.getLabelsMaskArray());
    }

    public static INDArray tailor2d(@NonNull INDArray iNDArray, INDArray iNDArray2) {
        if (iNDArray == null) {
            throw new NullPointerException("data is marked @NonNull but is null");
        }
        switch (iNDArray.rank()) {
            case 1:
            case 2:
                return iNDArray;
            case 3:
                return tailor3d2d(iNDArray, iNDArray2);
            case 4:
                return tailor4d2d(iNDArray);
            default:
                throw new RuntimeException("Unsupported data rank");
        }
    }

    public static INDArray tailor3d2d(DataSet dataSet, boolean z) {
        return tailor3d2d(z ? dataSet.getFeatures() : dataSet.getLabels(), z ? dataSet.getFeaturesMaskArray() : dataSet.getLabelsMaskArray());
    }

    public static INDArray tailor3d2d(@NonNull INDArray iNDArray, INDArray iNDArray2) {
        if (iNDArray == null) {
            throw new NullPointerException("data is marked @NonNull but is null");
        }
        if (iNDArray2 != null && (iNDArray.size(0) != iNDArray2.size(0) || iNDArray.size(2) != iNDArray2.size(1))) {
            throw new IllegalArgumentException("Invalid mask array/data combination: got data with shape [minibatch, vectorSize, timeSeriesLength] = " + Arrays.toString(iNDArray.shape()) + "; got mask with shape [minibatch,timeSeriesLength] = " + Arrays.toString(iNDArray2.shape()) + "; minibatch and timeSeriesLength dimensions must match");
        }
        if (iNDArray.ordering() != 'f' || iNDArray.isView() || !Shape.strideDescendingCAscendingF(iNDArray)) {
            iNDArray = iNDArray.dup('f');
        }
        long[] shape = iNDArray.shape();
        INDArray permutei = shape[0] == 1 ? iNDArray.tensorAlongDimension(0L, 1, 2).permutei(1, 0) : shape[2] == 1 ? iNDArray.tensorAlongDimension(0L, 1, 0) : iNDArray.permute(0, 2, 1).reshape('f', shape[0] * shape[2], shape[1]);
        if (iNDArray2 == null) {
            return permutei;
        }
        if (iNDArray2.ordering() != 'f' || iNDArray2.isView() || !Shape.strideDescendingCAscendingF(iNDArray2)) {
            iNDArray2 = iNDArray2.dup('f');
        }
        INDArray reshape = iNDArray2.reshape('f', iNDArray2.length(), 1);
        int intValue = iNDArray2.sumNumber().intValue();
        if (intValue == iNDArray2.length()) {
            return permutei;
        }
        if (intValue == 0) {
            return null;
        }
        int[] iArr = new int[intValue];
        float[] asFloat = reshape.data().asFloat();
        int i = 0;
        for (int i2 = 0; i2 < asFloat.length; i2++) {
            if (asFloat[i2] != 0.0f) {
                int i3 = i;
                i++;
                iArr[i3] = i2;
            }
        }
        return Nd4j.pullRows(permutei, 1, iArr);
    }

    public static INDArray tailor4d2d(DataSet dataSet, boolean z) {
        return tailor4d2d(z ? dataSet.getFeatures() : dataSet.getLabels());
    }

    public static INDArray tailor4d2d(@NonNull INDArray iNDArray) {
        if (iNDArray == null) {
            throw new NullPointerException("data is marked @NonNull but is null");
        }
        INDArray create = Nd4j.create(iNDArray.size(1), iNDArray.size(2) * iNDArray.size(3) * iNDArray.size(0));
        long tensorsAlongDimension = iNDArray.tensorsAlongDimension(3, 2, 0);
        for (int i = 0; i < tensorsAlongDimension; i++) {
            create.putRow(i, Nd4j.toFlattened(iNDArray.tensorAlongDimension(i, 3, 2, 0)));
        }
        return create.transposei();
    }

    public static void setMaskedValuesToZero(INDArray iNDArray, INDArray iNDArray2) {
        if (iNDArray2 == null || iNDArray.rank() != 3) {
            return;
        }
        Nd4j.getExecutioner().exec((BroadcastOp) new BroadcastMulOp(iNDArray, iNDArray2, iNDArray, 0, 2));
    }

    public static Pair<INDArray[], INDArray[]> mergeFeatures(@NonNull INDArray[][] iNDArrayArr, INDArray[][] iNDArrayArr2) {
        if (iNDArrayArr == null) {
            throw new NullPointerException("featuresToMerge is marked @NonNull but is null");
        }
        int length = iNDArrayArr[0].length;
        INDArray[] iNDArrayArr3 = new INDArray[length];
        INDArray[] iNDArrayArr4 = null;
        for (int i = 0; i < length; i++) {
            Pair<INDArray, INDArray> mergeFeatures = mergeFeatures(iNDArrayArr, iNDArrayArr2, i);
            iNDArrayArr3[i] = mergeFeatures.getFirst();
            if (mergeFeatures.getSecond() != null) {
                if (iNDArrayArr4 == null) {
                    iNDArrayArr4 = new INDArray[length];
                }
                iNDArrayArr4[i] = mergeFeatures.getSecond();
            }
        }
        return new Pair<>(iNDArrayArr3, iNDArrayArr4);
    }

    public static Pair<INDArray, INDArray> mergeFeatures(@NonNull INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2) {
        if (iNDArrayArr == null) {
            throw new NullPointerException("featuresToMerge is marked @NonNull but is null");
        }
        Preconditions.checkNotNull(iNDArrayArr[0], "Encountered null feature array when merging");
        switch (iNDArrayArr[0].rank()) {
            case 2:
                return merge2d(iNDArrayArr, iNDArrayArr2);
            case 3:
                return mergeTimeSeries(iNDArrayArr, iNDArrayArr2);
            case 4:
                return merge4d(iNDArrayArr, iNDArrayArr2);
            default:
                throw new IllegalStateException("Cannot merge examples: features rank must be in range 2 to 4 inclusive. First example features shape: " + Arrays.toString(iNDArrayArr[0].shape()));
        }
    }

    public static Pair<INDArray, INDArray> mergeFeatures(INDArray[][] iNDArrayArr, INDArray[][] iNDArrayArr2, int i) {
        Pair<INDArray[], INDArray[]> selectColumnFromMDSData = selectColumnFromMDSData(iNDArrayArr, iNDArrayArr2, i);
        return mergeFeatures(selectColumnFromMDSData.getFirst(), selectColumnFromMDSData.getSecond());
    }

    public static Pair<INDArray, INDArray> mergeLabels(INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2) {
        Preconditions.checkNotNull(iNDArrayArr[0], "Cannot merge data: Encountered null labels array");
        switch (iNDArrayArr[0].rank()) {
            case 2:
                return merge2d(iNDArrayArr, iNDArrayArr2);
            case 3:
                return mergeTimeSeries(iNDArrayArr, iNDArrayArr2);
            case 4:
                return merge4d(iNDArrayArr, iNDArrayArr2);
            default:
                throw new ND4JIllegalStateException("Cannot merge examples: labels rank must be in range 2 to 4 inclusive. First example features shape: " + Arrays.toString(iNDArrayArr[0].shape()));
        }
    }

    public static Pair<INDArray, INDArray> mergeLabels(@NonNull INDArray[][] iNDArrayArr, INDArray[][] iNDArrayArr2, int i) {
        if (iNDArrayArr == null) {
            throw new NullPointerException("labelsToMerge is marked @NonNull but is null");
        }
        Pair<INDArray[], INDArray[]> selectColumnFromMDSData = selectColumnFromMDSData(iNDArrayArr, iNDArrayArr2, i);
        return mergeLabels(selectColumnFromMDSData.getFirst(), selectColumnFromMDSData.getSecond());
    }

    private static Pair<INDArray[], INDArray[]> selectColumnFromMDSData(@NonNull INDArray[][] iNDArrayArr, INDArray[][] iNDArrayArr2, int i) {
        if (iNDArrayArr == null) {
            throw new NullPointerException("arrays is marked @NonNull but is null");
        }
        INDArray[] iNDArrayArr3 = new INDArray[iNDArrayArr.length];
        INDArray[] iNDArrayArr4 = new INDArray[iNDArrayArr3.length];
        for (int i2 = 0; i2 < iNDArrayArr3.length; i2++) {
            iNDArrayArr3[i2] = iNDArrayArr[i2][i];
            if (iNDArrayArr2 != null && iNDArrayArr2[i2] != null) {
                iNDArrayArr4[i2] = iNDArrayArr2[i2][i];
            }
        }
        return new Pair<>(iNDArrayArr3, iNDArrayArr4);
    }

    public static Pair<INDArray, INDArray> merge2d(@NonNull INDArray[][] iNDArrayArr, INDArray[][] iNDArrayArr2, int i) {
        if (iNDArrayArr == null) {
            throw new NullPointerException("arrays is marked @NonNull but is null");
        }
        Pair<INDArray[], INDArray[]> selectColumnFromMDSData = selectColumnFromMDSData(iNDArrayArr, iNDArrayArr2, i);
        return merge2d(selectColumnFromMDSData.getFirst(), selectColumnFromMDSData.getSecond());
    }

    public static Pair<INDArray, INDArray> merge2d(INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2) {
        long columns = iNDArrayArr[0].columns();
        INDArray[] iNDArrayArr3 = new INDArray[iNDArrayArr.length];
        boolean z = false;
        for (int i = 0; i < iNDArrayArr.length; i++) {
            Preconditions.checkNotNull((Object) iNDArrayArr[i], "Encountered null array at position %s when merging data", i);
            if (iNDArrayArr[i].columns() != columns) {
                throw new IllegalStateException("Cannot merge 2d arrays with different numbers of columns (firstNCols=" + columns + ", ithNCols=" + iNDArrayArr[i].columns() + URISupport.RAW_TOKEN_END);
            }
            iNDArrayArr3[i] = iNDArrayArr[i];
            if (iNDArrayArr2 != null && iNDArrayArr2[i] != null && iNDArrayArr2[i] != null) {
                z = true;
            }
        }
        INDArray specialConcat = Nd4j.specialConcat(0, iNDArrayArr3);
        return new Pair<>(specialConcat, z ? mergePerOutputMasks2d(specialConcat.shape(), iNDArrayArr, iNDArrayArr2) : null);
    }

    public static INDArray mergePerOutputMasks2d(long[] jArr, INDArray[][] iNDArrayArr, INDArray[][] iNDArrayArr2, int i) {
        Pair<INDArray[], INDArray[]> selectColumnFromMDSData = selectColumnFromMDSData(iNDArrayArr, iNDArrayArr2, i);
        return mergePerOutputMasks2d(jArr, selectColumnFromMDSData.getFirst(), selectColumnFromMDSData.getSecond());
    }

    @Deprecated
    public static INDArray mergePerOutputMasks2d(long[] jArr, INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2) {
        return mergeMasks2d(jArr, iNDArrayArr, iNDArrayArr2);
    }

    public static INDArray mergeMasks2d(long[] jArr, INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2) {
        long[] jArr2 = new long[iNDArrayArr.length];
        for (int i = 0; i < jArr2.length; i++) {
            jArr2[i] = iNDArrayArr[i].size(0);
        }
        INDArray ones = Nd4j.ones(iNDArrayArr[0].dataType(), jArr);
        int i2 = 0;
        for (int i3 = 0; i3 < iNDArrayArr2.length; i3++) {
            long j = jArr2[i3];
            if (iNDArrayArr2[i3] != null) {
                ones.put(new INDArrayIndex[]{NDArrayIndex.interval(i2, i2 + j), NDArrayIndex.all()}, iNDArrayArr2[i3]);
                i2 = (int) (i2 + j);
            }
        }
        return ones;
    }

    public static INDArray mergeMasks4d(INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2) {
        long j;
        long j2;
        long[] jArr = null;
        long j3 = 0;
        for (int i = 0; i < iNDArrayArr2.length; i++) {
            if (iNDArrayArr2[i] == null) {
                j3 += iNDArrayArr[i].size(0);
            } else {
                if (iNDArrayArr2[i].rank() != 4) {
                    throw new IllegalStateException("Cannot merge mask arrays: expected mask array of rank 4. Got mask array of rank " + iNDArrayArr2[i].rank() + " with shape " + Arrays.toString(iNDArrayArr2[i].shape()));
                }
                if (jArr == null) {
                    jArr = (long[]) iNDArrayArr2[i].shape().clone();
                } else {
                    INDArray iNDArray = iNDArrayArr2[i];
                    if (iNDArray.size(1) != jArr[1] || iNDArray.size(2) != jArr[2] || iNDArray.size(3) != jArr[3]) {
                        throw new IllegalStateException("Mismatched mask shapes: masks should have same depth/height/width for all examples. Prior examples had shape [mb," + iNDArrayArr2[1] + "," + iNDArrayArr2[2] + "," + iNDArrayArr2[3] + "], next example has shape " + Arrays.toString(iNDArray.shape()));
                    }
                    long[] jArr2 = jArr;
                    jArr2[0] = jArr2[0] + iNDArray.size(0);
                }
            }
        }
        if (jArr == null) {
            return null;
        }
        long[] jArr3 = jArr;
        jArr3[0] = jArr3[0] + j3;
        INDArray ones = Nd4j.ones(jArr);
        int i2 = 0;
        for (int i3 = 0; i3 < iNDArrayArr2.length; i3++) {
            if (iNDArrayArr2[i3] == null) {
                j = i2;
                j2 = iNDArrayArr[i3].size(0);
            } else {
                long size = iNDArrayArr2[i3].size(0);
                ones.put(new INDArrayIndex[]{NDArrayIndex.interval(i2, i2 + size), NDArrayIndex.all()}, iNDArrayArr2[i3]);
                j = i2;
                j2 = size;
            }
            i2 = (int) (j + j2);
        }
        return ones;
    }

    public static Pair<INDArray, INDArray> mergeTimeSeries(INDArray[][] iNDArrayArr, INDArray[][] iNDArrayArr2, int i) {
        Pair<INDArray[], INDArray[]> selectColumnFromMDSData = selectColumnFromMDSData(iNDArrayArr, iNDArrayArr2, i);
        return mergeTimeSeries(selectColumnFromMDSData.getFirst(), selectColumnFromMDSData.getSecond());
    }

    public static Pair<INDArray, INDArray> mergeTimeSeries(INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2) {
        long size = iNDArrayArr[0].size(2);
        long size2 = iNDArrayArr[0].size(1);
        long j = size;
        boolean z = false;
        int i = -1;
        boolean z2 = false;
        int i2 = 0;
        for (int i3 = 0; i3 < iNDArrayArr.length; i3++) {
            i2 = (int) (i2 + iNDArrayArr[i3].size(0));
            long size3 = iNDArrayArr[i3].size(2);
            j = Math.max(j, size3);
            if (size3 != size) {
                z2 = true;
            }
            if (iNDArrayArr2 != null && iNDArrayArr2[i3] != null && iNDArrayArr2[i3] != null) {
                i = iNDArrayArr2[i3].rank();
                z = true;
            }
            if (iNDArrayArr[i3].size(1) != size2) {
                throw new IllegalStateException("Cannot merge time series with different size for dimension 1 (first shape: " + Arrays.toString(iNDArrayArr[0].shape()) + ", " + i3 + "th shape: " + Arrays.toString(iNDArrayArr[i3].shape()));
            }
        }
        boolean z3 = z || z2;
        INDArray create = Nd4j.create(iNDArrayArr[0].dataType(), i2, size2, j);
        INDArray ones = (!z3 || i == 3) ? null : Nd4j.ones(iNDArrayArr[0].dataType(), i2, j);
        int i4 = 0;
        if (!z2 && !z3) {
            for (int i5 = 0; i5 < iNDArrayArr.length; i5++) {
                long size4 = iNDArrayArr[i5].size(0);
                create.put(new INDArrayIndex[]{NDArrayIndex.interval(i4, i4 + size4), NDArrayIndex.all(), NDArrayIndex.all()}, iNDArrayArr[i5]);
                i4 = (int) (i4 + size4);
            }
            return new Pair<>(create, null);
        }
        if ((z2 && !z) || i == 2) {
            for (int i6 = 0; i6 < iNDArrayArr.length; i6++) {
                INDArray iNDArray = iNDArrayArr[i6];
                long size5 = iNDArray.size(0);
                long size6 = iNDArray.size(2);
                create.put(new INDArrayIndex[]{NDArrayIndex.interval(i4, i4 + size5), NDArrayIndex.all(), NDArrayIndex.interval(0L, size6)}, iNDArray);
                if (iNDArrayArr2 != null && iNDArrayArr2[i6] != null && iNDArrayArr2[i6] != null) {
                    INDArray iNDArray2 = iNDArrayArr2[i6];
                    long size7 = iNDArray2.size(1);
                    ones.put(new INDArrayIndex[]{NDArrayIndex.interval(i4, i4 + size5), NDArrayIndex.interval(0L, size7)}, iNDArray2);
                    if (size7 < j) {
                        ones.put(new INDArrayIndex[]{NDArrayIndex.interval(i4, i4 + size5), NDArrayIndex.interval(size7, j)}, Nd4j.zeros(size5, j - size7));
                    }
                } else if (size6 < j) {
                    ones.put(new INDArrayIndex[]{NDArrayIndex.interval(i4, i4 + size5), NDArrayIndex.interval(size6, j)}, Nd4j.zeros(size5, j - size6));
                }
                i4 = (int) (i4 + size5);
            }
        } else {
            if (i != 3) {
                throw new UnsupportedOperationException("Cannot merge time series with mask rank " + i);
            }
            ones = Nd4j.create(create.dataType(), create.shape());
            for (int i7 = 0; i7 < iNDArrayArr.length; i7++) {
                INDArray iNDArray3 = iNDArrayArr2[i7];
                INDArray iNDArray4 = iNDArrayArr[i7];
                long size8 = iNDArray4.size(0);
                long size9 = iNDArray4.size(2);
                create.put(new INDArrayIndex[]{NDArrayIndex.interval(i4, i4 + size8), NDArrayIndex.all(), NDArrayIndex.interval(0L, size9)}, iNDArray4);
                if (iNDArray3 == null) {
                    ones.get(NDArrayIndex.interval(i4, i4 + size8), NDArrayIndex.all(), NDArrayIndex.interval(0L, size9)).assign((Number) 1);
                } else {
                    ones.put(new INDArrayIndex[]{NDArrayIndex.interval(i4, i4 + size8), NDArrayIndex.all(), NDArrayIndex.interval(0L, size9)}, iNDArray3);
                }
                i4 = (int) (i4 + size8);
            }
        }
        return new Pair<>(create, ones);
    }

    public static Pair<INDArray, INDArray> merge4d(INDArray[][] iNDArrayArr, INDArray[][] iNDArrayArr2, int i) {
        Pair<INDArray[], INDArray[]> selectColumnFromMDSData = selectColumnFromMDSData(iNDArrayArr, iNDArrayArr2, i);
        return merge4d(selectColumnFromMDSData.getFirst(), selectColumnFromMDSData.getSecond());
    }

    public static Pair<INDArray, INDArray> merge4d(INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2) {
        int i = 0;
        long[] shape = iNDArrayArr[0].shape();
        INDArray[] iNDArrayArr3 = new INDArray[iNDArrayArr.length];
        boolean z = false;
        int i2 = -1;
        for (int i3 = 0; i3 < iNDArrayArr.length; i3++) {
            Preconditions.checkNotNull((Object) iNDArrayArr[i3], "Encountered null array when merging data at position %s", i3);
            i = (int) (i + iNDArrayArr[i3].size(0));
            long[] shape2 = iNDArrayArr[i3].shape();
            if (shape2.length != 4) {
                throw new IllegalStateException("Cannot merge 4d arrays with non 4d arrays");
            }
            for (int i4 = 1; i4 < 4; i4++) {
                if (shape2[i4] != shape[i4]) {
                    throw new IllegalStateException("Cannot merge 4d arrays with different shape (other than # examples):  data[0].shape = " + Arrays.toString(shape) + ", data[" + i3 + "].shape = " + Arrays.toString(shape2));
                }
            }
            iNDArrayArr3[i3] = iNDArrayArr[i3];
            if (iNDArrayArr2 != null && iNDArrayArr2[i3] != null) {
                z = true;
                i2 = iNDArrayArr2[i3].rank();
            }
        }
        INDArray specialConcat = Nd4j.specialConcat(0, iNDArrayArr3);
        INDArray iNDArray = null;
        if (z) {
            if (i2 == 2) {
                iNDArray = mergeMasks2d(specialConcat.shape(), iNDArrayArr, iNDArrayArr2);
            } else if (i2 == 4) {
                iNDArray = mergeMasks4d(iNDArrayArr, iNDArrayArr2);
            }
        }
        return new Pair<>(specialConcat, iNDArray);
    }
}
