package org.tribuo.interop.tensorflow;

import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.logging.Logger;
import org.tensorflow.Tensor;
import org.tribuo.Example;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;

/* loaded from: input_file:org/tribuo/interop/tensorflow/LabelTransformer.class */
public class LabelTransformer implements OutputTransformer<Label> {
    private static final long serialVersionUID = 1;
    private static final Logger logger = Logger.getLogger(LabelTransformer.class.getName());

    @Override // org.tribuo.interop.tensorflow.OutputTransformer
    public Prediction<Label> transformToPrediction(Tensor<?> tensor, ImmutableOutputInfo<Label> immutableOutputInfo, int i, Example<Label> example) {
        float[][] batchPredictions = getBatchPredictions(tensor, immutableOutputInfo);
        if (batchPredictions.length != 1) {
            throw new IllegalArgumentException("Supplied tensor has too many results, predictions.length = " + batchPredictions.length);
        }
        return generatePrediction(batchPredictions[0], immutableOutputInfo, i, example);
    }

    private Prediction<Label> generatePrediction(float[] fArr, ImmutableOutputInfo<Label> immutableOutputInfo, int i, Example<Label> example) {
        Label label = null;
        HashMap hashMap = new HashMap();
        for (int i2 = 0; i2 < fArr.length; i2++) {
            Label label2 = new Label(immutableOutputInfo.getOutput(i2).getLabel(), fArr[i2]);
            hashMap.put(label2.getLabel(), label2);
            if (label == null || label2.getScore() > label.getScore()) {
                label = label2;
            }
        }
        return new Prediction<>(label, hashMap, i, example, true);
    }

    /* renamed from: transformToOutput, reason: avoid collision after fix types in other method */
    public Label transformToOutput2(Tensor<?> tensor, ImmutableOutputInfo<Label> immutableOutputInfo) {
        float[][] batchPredictions = getBatchPredictions(tensor, immutableOutputInfo);
        if (batchPredictions.length != 1) {
            throw new IllegalArgumentException("Supplied tensor has too many results, predictions.length = " + batchPredictions.length);
        }
        return generateLabel(batchPredictions[0], immutableOutputInfo);
    }

    private Label generateLabel(float[] fArr, ImmutableOutputInfo<Label> immutableOutputInfo) {
        int i = 0;
        float f = Float.NEGATIVE_INFINITY;
        for (int i2 = 0; i2 < fArr.length; i2++) {
            if (fArr[i2] > f) {
                i = i2;
                f = fArr[i2];
            }
        }
        return new Label(immutableOutputInfo.getOutput(i).getLabel(), f);
    }

    private float[][] getBatchPredictions(Tensor<?> tensor, ImmutableOutputInfo<Label> immutableOutputInfo) {
        long[] shape = tensor.shape();
        if (shape.length != 2) {
            throw new IllegalArgumentException("Supplied tensor has the wrong number of dimensions, shape = " + Arrays.toString(shape));
        }
        int i = (int) shape[1];
        if (i != immutableOutputInfo.size()) {
            throw new IllegalArgumentException("Supplied tensor has too many elements, tensor.length = " + i + ", outputIDInfo.size() = " + immutableOutputInfo.size());
        }
        return (float[][]) tensor.expect(Float.class).copyTo(new float[(int) shape[0]][i]);
    }

    @Override // org.tribuo.interop.tensorflow.OutputTransformer
    public List<Prediction<Label>> transformToBatchPrediction(Tensor<?> tensor, ImmutableOutputInfo<Label> immutableOutputInfo, int[] iArr, List<Example<Label>> list) {
        float[][] batchPredictions = getBatchPredictions(tensor, immutableOutputInfo);
        ArrayList arrayList = new ArrayList();
        if (batchPredictions.length != list.size() || batchPredictions.length != iArr.length) {
            throw new IllegalArgumentException("Invalid number of predictions received from Tensorflow, expected " + iArr.length + ", received " + batchPredictions.length);
        }
        for (int i = 0; i < batchPredictions.length; i++) {
            arrayList.add(generatePrediction(batchPredictions[i], immutableOutputInfo, iArr[i], list.get(i)));
        }
        return arrayList;
    }

    @Override // org.tribuo.interop.tensorflow.OutputTransformer
    public List<Label> transformToBatchOutput(Tensor<?> tensor, ImmutableOutputInfo<Label> immutableOutputInfo) {
        float[][] batchPredictions = getBatchPredictions(tensor, immutableOutputInfo);
        ArrayList arrayList = new ArrayList();
        for (float[] fArr : batchPredictions) {
            arrayList.add(generateLabel(fArr, immutableOutputInfo));
        }
        return arrayList;
    }

    private int innerTransform(Label label, ImmutableOutputInfo<Label> immutableOutputInfo) {
        int id = immutableOutputInfo.getID(label);
        if (id == -1) {
            throw new IllegalArgumentException("Label " + label + " isn't known by the supplied outputIDInfo, " + immutableOutputInfo.toString());
        }
        return id;
    }

    @Override // org.tribuo.interop.tensorflow.OutputTransformer
    public Tensor<?> transform(Label label, ImmutableOutputInfo<Label> immutableOutputInfo) {
        return Tensor.create(new int[]{innerTransform(label, immutableOutputInfo)});
    }

    @Override // org.tribuo.interop.tensorflow.OutputTransformer
    public Tensor<?> transform(List<Example<Label>> list, ImmutableOutputInfo<Label> immutableOutputInfo) {
        int[] iArr = new int[list.size()];
        int i = 0;
        Iterator<Example<Label>> it = list.iterator();
        while (it.hasNext()) {
            iArr[i] = innerTransform((Label) it.next().getOutput(), immutableOutputInfo);
            i++;
        }
        return Tensor.create(iArr);
    }

    @Override // org.tribuo.interop.tensorflow.OutputTransformer
    public boolean generatesProbabilities() {
        return true;
    }

    public String toString() {
        return "LabelTransformer()";
    }

    /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
    public ConfiguredObjectProvenance m4getProvenance() {
        return new ConfiguredObjectProvenanceImpl(this, "OutputTransformer");
    }

    @Override // org.tribuo.interop.tensorflow.OutputTransformer
    public /* bridge */ /* synthetic */ Label transformToOutput(Tensor tensor, ImmutableOutputInfo<Label> immutableOutputInfo) {
        return transformToOutput2((Tensor<?>) tensor, immutableOutputInfo);
    }
}
