package org.tribuo.interop.tensorflow;

import com.oracle.labs.mlrg.olcut.util.Pair;
import java.io.Closeable;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.net.URL;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.time.OffsetDateTime;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Output;
import org.tribuo.OutputFactory;
import org.tribuo.Prediction;
import org.tribuo.interop.ExternalDatasetProvenance;
import org.tribuo.interop.ExternalModel;
import org.tribuo.interop.ExternalTrainerProvenance;
import org.tribuo.math.la.SparseVector;
import org.tribuo.provenance.ModelProvenance;

/* loaded from: input_file:org/tribuo/interop/tensorflow/TensorflowExternalModel.class */
public final class TensorflowExternalModel<T extends Output<T>> extends ExternalModel<T, Tensor<?>, Tensor<?>> implements Closeable {
    private static final long serialVersionUID = 100;
    private transient Graph model;
    private transient Session session;
    private final ExampleTransformer<T> featureTransformer;
    private final OutputTransformer<T> outputTransformer;
    private final String inputName;
    private final String outputName;

    private TensorflowExternalModel(String str, ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<T> immutableOutputInfo, Map<String, Integer> map, Graph graph, String str2, String str3, ExampleTransformer<T> exampleTransformer, OutputTransformer<T> outputTransformer) {
        super(str, modelProvenance, immutableFeatureMap, immutableOutputInfo, outputTransformer.generatesProbabilities(), map);
        this.model = graph;
        this.session = new Session(graph);
        this.inputName = str2;
        this.outputName = str3;
        this.featureTransformer = exampleTransformer;
        this.outputTransformer = outputTransformer;
    }

    private TensorflowExternalModel(String str, ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<T> immutableOutputInfo, int[] iArr, int[] iArr2, Graph graph, String str2, String str3, ExampleTransformer<T> exampleTransformer, OutputTransformer<T> outputTransformer) {
        super(str, modelProvenance, immutableFeatureMap, immutableOutputInfo, iArr, iArr2, outputTransformer.generatesProbabilities());
        this.model = graph;
        this.session = new Session(graph);
        this.inputName = str2;
        this.outputName = str3;
        this.featureTransformer = exampleTransformer;
        this.outputTransformer = outputTransformer;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* renamed from: convertFeatures, reason: merged with bridge method [inline-methods] */
    public Tensor<?> m10convertFeatures(SparseVector sparseVector) {
        return this.featureTransformer.transform(sparseVector);
    }

    protected Tensor<?> convertFeaturesList(List<SparseVector> list) {
        return this.featureTransformer.transform(list);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Tensor<?> externalPrediction(Tensor<?> tensor) {
        Tensor<?> tensor2 = (Tensor) this.session.runner().feed(this.inputName, tensor).fetch(this.outputName).run().get(0);
        tensor.close();
        return tensor2;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Prediction<T> convertOutput(Tensor<?> tensor, int i, Example<T> example) {
        Prediction<T> transformToPrediction = this.outputTransformer.transformToPrediction(tensor, this.outputIDInfo, i, example);
        tensor.close();
        return transformToPrediction;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public List<Prediction<T>> convertOutput(Tensor<?> tensor, int[] iArr, List<Example<T>> list) {
        List<Prediction<T>> transformToBatchPrediction = this.outputTransformer.transformToBatchPrediction(tensor, this.outputIDInfo, iArr, list);
        tensor.close();
        return transformToBatchPrediction;
    }

    public Map<String, List<Pair<String, Double>>> getTopFeatures(int i) {
        return Collections.emptyMap();
    }

    protected Model<T> copy(String str, ModelProvenance modelProvenance) {
        byte[] graphDef = this.model.toGraphDef();
        Graph graph = new Graph();
        graph.importGraphDef(graphDef);
        return new TensorflowExternalModel(str, modelProvenance, this.featureIDMap, this.outputIDInfo, this.featureForwardMapping, this.featureBackwardMapping, graph, this.inputName, this.outputName, this.featureTransformer, this.outputTransformer);
    }

    @Override // java.io.Closeable, java.lang.AutoCloseable
    public void close() {
        if (this.session != null) {
            this.session.close();
        }
        if (this.model != null) {
            this.model.close();
        }
    }

    public static <T extends Output<T>> TensorflowExternalModel<T> createTensorflowModel(OutputFactory<T> outputFactory, Map<String, Integer> map, Map<T, Integer> map2, String str, String str2, ExampleTransformer<T> exampleTransformer, OutputTransformer<T> outputTransformer, String str3) {
        try {
            Path path = Paths.get(str3, new String[0]);
            byte[] readAllBytes = Files.readAllBytes(path);
            Graph graph = new Graph();
            graph.importGraphDef(readAllBytes);
            URL url = path.toUri().toURL();
            return new TensorflowExternalModel<>("external-model", new ModelProvenance(TensorflowExternalModel.class.getName(), OffsetDateTime.now(), new ExternalDatasetProvenance("unknown-external-data", outputFactory, false, map.size(), map2.size()), new ExternalTrainerProvenance(url)), ExternalModel.createFeatureMap(map.keySet()), ExternalModel.createOutputInfo(outputFactory, map2), map, graph, str, str2, exampleTransformer, outputTransformer);
        } catch (IOException e) {
            throw new IllegalArgumentException("Unable to load model from path " + str3, e);
        }
    }

    private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
        objectOutputStream.defaultWriteObject();
        objectOutputStream.writeObject(this.model.toGraphDef());
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        objectInputStream.defaultReadObject();
        byte[] bArr = (byte[]) objectInputStream.readObject();
        this.model = new Graph();
        this.model.importGraphDef(bArr);
        this.session = new Session(this.model);
    }

    /* renamed from: convertFeaturesList, reason: collision with other method in class */
    protected /* bridge */ /* synthetic */ Object m9convertFeaturesList(List list) {
        return convertFeaturesList((List<SparseVector>) list);
    }
}
