package org.tribuo.interop.tensorflow;

import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import org.tensorflow.Tensor;
import org.tribuo.Example;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Prediction;
import org.tribuo.regression.Regressor;

/* loaded from: input_file:org/tribuo/interop/tensorflow/RegressorTransformer.class */
public class RegressorTransformer implements OutputTransformer<Regressor> {
    private static final long serialVersionUID = 1;

    @Override // org.tribuo.interop.tensorflow.OutputTransformer
    public Prediction<Regressor> transformToPrediction(Tensor<?> tensor, ImmutableOutputInfo<Regressor> immutableOutputInfo, int i, Example<Regressor> example) {
        return new Prediction<>(transformToOutput2(tensor, immutableOutputInfo), i, example);
    }

    /* renamed from: transformToOutput, reason: avoid collision after fix types in other method */
    public Regressor transformToOutput2(Tensor<?> tensor, ImmutableOutputInfo<Regressor> immutableOutputInfo) {
        float[][] batchPredictions = getBatchPredictions(tensor);
        if (batchPredictions.length != 1) {
            throw new IllegalArgumentException("Supplied tensor has too many results, predictions.length = " + batchPredictions.length);
        }
        if (batchPredictions[0].length != immutableOutputInfo.size()) {
            throw new IllegalArgumentException("Supplied tensor has an incorrect number of dimensions, predictions[0].length = " + batchPredictions[0].length + ", expected " + immutableOutputInfo.size());
        }
        String[] strArr = new String[immutableOutputInfo.size()];
        double[] dArr = new double[immutableOutputInfo.size()];
        Iterator it = immutableOutputInfo.iterator();
        while (it.hasNext()) {
            Pair pair = (Pair) it.next();
            int intValue = ((Integer) pair.getA()).intValue();
            strArr[intValue] = ((Regressor) pair.getB()).getNames()[0];
            dArr[intValue] = batchPredictions[0][intValue];
        }
        return new Regressor(strArr, dArr);
    }

    float[][] getBatchPredictions(Tensor<?> tensor) {
        long[] shape = tensor.shape();
        if (shape.length != 2) {
            throw new IllegalArgumentException("Supplied tensor has the wrong number of dimensions, shape = " + Arrays.toString(shape));
        }
        int i = (int) shape[1];
        if (i != 1) {
            throw new IllegalArgumentException("Supplied tensor has too many elements, tensor.length = " + i);
        }
        return (float[][]) tensor.expect(Float.class).copyTo(new float[(int) shape[0]][i]);
    }

    @Override // org.tribuo.interop.tensorflow.OutputTransformer
    public List<Prediction<Regressor>> transformToBatchPrediction(Tensor<?> tensor, ImmutableOutputInfo<Regressor> immutableOutputInfo, int[] iArr, List<Example<Regressor>> list) {
        List<Regressor> transformToBatchOutput = transformToBatchOutput(tensor, immutableOutputInfo);
        ArrayList arrayList = new ArrayList();
        if (transformToBatchOutput.size() != list.size() || transformToBatchOutput.size() != iArr.length) {
            throw new IllegalArgumentException("Invalid number of predictions received from Tensorflow, expected " + iArr.length + ", received " + transformToBatchOutput.size());
        }
        for (int i = 0; i < transformToBatchOutput.size(); i++) {
            arrayList.add(new Prediction(transformToBatchOutput.get(i), iArr[i], list.get(i)));
        }
        return arrayList;
    }

    @Override // org.tribuo.interop.tensorflow.OutputTransformer
    public List<Regressor> transformToBatchOutput(Tensor<?> tensor, ImmutableOutputInfo<Regressor> immutableOutputInfo) {
        float[][] batchPredictions = getBatchPredictions(tensor);
        ArrayList arrayList = new ArrayList();
        String[] strArr = new String[immutableOutputInfo.size()];
        Iterator it = immutableOutputInfo.iterator();
        while (it.hasNext()) {
            Pair pair = (Pair) it.next();
            strArr[((Integer) pair.getA()).intValue()] = ((Regressor) pair.getB()).getNames()[0];
        }
        for (float[] fArr : batchPredictions) {
            double[] dArr = new double[strArr.length];
            for (int i = 0; i < strArr.length; i++) {
                dArr[i] = fArr[i];
            }
            arrayList.add(new Regressor(strArr, dArr));
        }
        return arrayList;
    }

    @Override // org.tribuo.interop.tensorflow.OutputTransformer
    public Tensor<?> transform(Regressor regressor, ImmutableOutputInfo<Regressor> immutableOutputInfo) {
        float[] fArr = new float[regressor.size()];
        double[] values = regressor.getValues();
        for (int i = 0; i < fArr.length; i++) {
            fArr[i] = (float) values[i];
        }
        return Tensor.create(fArr);
    }

    @Override // org.tribuo.interop.tensorflow.OutputTransformer
    public Tensor<?> transform(List<Example<Regressor>> list, ImmutableOutputInfo<Regressor> immutableOutputInfo) {
        float[][] fArr = new float[list.size()][immutableOutputInfo.size()];
        int i = 0;
        Iterator<Example<Regressor>> it = list.iterator();
        while (it.hasNext()) {
            double[] values = it.next().getOutput().getValues();
            for (int i2 = 0; i2 < fArr.length; i2++) {
                fArr[i][i2] = (float) values[i2];
            }
            i++;
        }
        return Tensor.create(fArr);
    }

    @Override // org.tribuo.interop.tensorflow.OutputTransformer
    public boolean generatesProbabilities() {
        return false;
    }

    public String toString() {
        return "RegressorTransformer()";
    }

    /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
    public ConfiguredObjectProvenance m5getProvenance() {
        return new ConfiguredObjectProvenanceImpl(this, "OutputTransformer");
    }

    @Override // org.tribuo.interop.tensorflow.OutputTransformer
    public /* bridge */ /* synthetic */ Regressor transformToOutput(Tensor tensor, ImmutableOutputInfo<Regressor> immutableOutputInfo) {
        return transformToOutput2((Tensor<?>) tensor, immutableOutputInfo);
    }
}
