package org.tribuo.interop.tensorflow.sequence;

import com.oracle.labs.mlrg.olcut.util.Pair;
import java.io.Closeable;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Output;
import org.tribuo.Prediction;
import org.tribuo.interop.tensorflow.TensorflowUtil;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.sequence.SequenceExample;
import org.tribuo.sequence.SequenceModel;

/* loaded from: input_file:org/tribuo/interop/tensorflow/sequence/TensorflowSequenceModel.class */
public class TensorflowSequenceModel<T extends Output<T>> extends SequenceModel<T> implements Closeable {
    private static final long serialVersionUID = 1;
    private transient Graph modelGraph;
    private transient Session session;
    protected final SequenceExampleTransformer<T> exampleTransformer;
    protected final SequenceOutputTransformer<T> outputTransformer;
    protected final String initOp;
    protected final String predictOp;

    /* JADX INFO: Access modifiers changed from: package-private */
    public TensorflowSequenceModel(String str, ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<T> immutableOutputInfo, byte[] bArr, SequenceExampleTransformer<T> sequenceExampleTransformer, SequenceOutputTransformer<T> sequenceOutputTransformer, String str2, String str3, Map<String, Object> map) {
        super(str, modelProvenance, immutableFeatureMap, immutableOutputInfo);
        this.modelGraph = null;
        this.session = null;
        this.exampleTransformer = sequenceExampleTransformer;
        this.outputTransformer = sequenceOutputTransformer;
        this.initOp = str2;
        this.predictOp = str3;
        this.modelGraph = new Graph();
        this.modelGraph.importGraphDef(bArr);
        this.session = new Session(this.modelGraph);
        this.session.runner().addTarget(str2).run();
        TensorflowUtil.deserialise(this.session, map);
    }

    public List<Prediction<T>> predict(SequenceExample<T> sequenceExample) {
        Map<String, Tensor<?>> encode = this.exampleTransformer.encode(sequenceExample, this.featureIDMap);
        Session.Runner runner = this.session.runner();
        for (Map.Entry<String, Tensor<?>> entry : encode.entrySet()) {
            runner.feed(entry.getKey(), entry.getValue());
        }
        Tensor<?> tensor = (Tensor) runner.fetch(this.predictOp).run().get(0);
        List<Prediction<T>> decode = this.outputTransformer.decode(tensor, sequenceExample, this.outputIDMap);
        tensor.close();
        Iterator<Tensor<?>> it = encode.values().iterator();
        while (it.hasNext()) {
            it.next().close();
        }
        return decode;
    }

    public Map<String, List<Pair<String, Double>>> getTopFeatures(int i) {
        return Collections.emptyMap();
    }

    @Override // java.io.Closeable, java.lang.AutoCloseable
    public void close() {
        if (this.session != null) {
            this.session.close();
        }
        if (this.modelGraph != null) {
            this.modelGraph.close();
        }
    }

    private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
        objectOutputStream.defaultWriteObject();
        objectOutputStream.writeObject(this.modelGraph.toGraphDef());
        objectOutputStream.writeObject(TensorflowUtil.serialise(this.modelGraph, this.session));
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        objectInputStream.defaultReadObject();
        byte[] bArr = (byte[]) objectInputStream.readObject();
        Map map = (Map) objectInputStream.readObject();
        this.modelGraph = new Graph();
        this.modelGraph.importGraphDef(bArr);
        this.session = new Session(this.modelGraph);
        this.session.runner().addTarget(this.initOp).run();
        TensorflowUtil.deserialise(this.session, map);
    }
}
