package org.nd4j.linalg.dataset.api.preprocessor;

import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.indexing.conditions.Conditions;

/* loaded from: input_file:org/nd4j/linalg/dataset/api/preprocessor/LabelLastTimeStepPreProcessor.class */
public class LabelLastTimeStepPreProcessor implements DataSetPreProcessor {
    @Override // org.nd4j.linalg.dataset.api.DataSetPreProcessor
    public void preProcess(DataSet dataSet) {
        INDArray create;
        INDArray labels = dataSet.getLabels();
        Preconditions.checkState(labels.rank() == 3, "LabelLastTimeStepPreProcessor expects rank 3 labels, got rank %s labels with shape %ndShape", Integer.valueOf(labels.rank()), labels);
        INDArray labelsMaskArray = dataSet.getLabelsMaskArray();
        if (labelsMaskArray == null) {
            create = labels.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(labels.size(2) - 1)).dup();
        } else {
            long[] asLong = BooleanIndexing.lastIndex(labelsMaskArray, Conditions.greaterThan(0), 1).data().asLong();
            create = Nd4j.create(DataType.FLOAT, labels.size(0), labels.size(1));
            for (int i = 0; i < asLong.length; i++) {
                long j = asLong[i];
                Preconditions.checkState(j >= 0, "Invalid last time step index: example %s in minibatch is entirely masked out (label mask is all 0s, meaning no label data is present for this example)", i);
                create.putRow(i, labels.get(NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.point(j)));
            }
        }
        dataSet.setLabels(create);
        dataSet.setLabelsMaskArray(null);
    }
}
