package org.tribuo.interop.tensorflow;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.PrimitiveProvenance;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil;
import com.oracle.labs.mlrg.olcut.provenance.impl.SkeletalConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.DateTimeProvenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.HashProvenance;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.attribute.FileAttribute;
import java.time.Instant;
import java.time.OffsetDateTime;
import java.time.ZoneId;
import java.util.ArrayList;
import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.TensorFlowException;
import org.tensorflow.Tensors;
import org.tribuo.Dataset;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Output;
import org.tribuo.Trainer;
import org.tribuo.interop.tensorflow.TensorflowTrainer;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.provenance.SkeletalTrainerProvenance;
import org.tribuo.provenance.TrainerProvenance;

/* loaded from: input_file:org/tribuo/interop/tensorflow/TensorflowCheckpointTrainer.class */
public final class TensorflowCheckpointTrainer<T extends Output<T>> implements Trainer<T> {
    private static final Logger logger = Logger.getLogger(TensorflowCheckpointTrainer.class.getName());
    public static final String MODEL_FILENAME = "model";

    @Config(mandatory = true, description = "Path to the protobuf containing the graph.")
    private Path graphPath;
    private byte[] graphDef;

    @Config(mandatory = true, description = "Feature extractor.")
    private ExampleTransformer<T> exampleTransformer;

    @Config(mandatory = true, description = "Response extractor.")
    private OutputTransformer<T> outputTransformer;

    @Config(description = "Minibatch size.")
    private int minibatchSize;

    @Config(description = "Number of SGD epochs to run.")
    private int epochs;

    @Config(description = "Logging interval to print out the loss.")
    private int loggingInterval;

    @Config(description = "Path to write out the checkpoints.")
    private Path checkpointRootPath;
    private int trainInvocationCounter;

    /* loaded from: input_file:org/tribuo/interop/tensorflow/TensorflowCheckpointTrainer$TensorflowCheckpointTrainerProvenance.class */
    public static final class TensorflowCheckpointTrainerProvenance extends SkeletalTrainerProvenance {
        private static final long serialVersionUID = 1;
        public static final String GRAPH_HASH = "graph-hash";
        public static final String GRAPH_LAST_MOD = "graph-last-modified";
        private final HashProvenance graphHash;
        private final DateTimeProvenance graphLastModified;

        <T extends Output<T>> TensorflowCheckpointTrainerProvenance(TensorflowCheckpointTrainer<T> tensorflowCheckpointTrainer) {
            super(tensorflowCheckpointTrainer);
            this.graphHash = new HashProvenance(DEFAULT_HASH_TYPE, "graph-hash", ProvenanceUtil.hashResource(DEFAULT_HASH_TYPE, ((TensorflowCheckpointTrainer) tensorflowCheckpointTrainer).graphPath));
            this.graphLastModified = new DateTimeProvenance("graph-last-modified", OffsetDateTime.ofInstant(Instant.ofEpochMilli(((TensorflowCheckpointTrainer) tensorflowCheckpointTrainer).graphPath.toFile().lastModified()), ZoneId.systemDefault()));
        }

        public TensorflowCheckpointTrainerProvenance(Map<String, Provenance> map) {
            this(extractTFProvenanceInfo(map));
        }

        private TensorflowCheckpointTrainerProvenance(SkeletalConfiguredObjectProvenance.ExtractedInfo extractedInfo) {
            super(extractedInfo);
            this.graphHash = (HashProvenance) extractedInfo.instanceValues.get("graph-hash");
            this.graphLastModified = (DateTimeProvenance) extractedInfo.instanceValues.get("graph-last-modified");
        }

        public Map<String, PrimitiveProvenance<?>> getInstanceValues() {
            Map<String, PrimitiveProvenance<?>> instanceValues = super.getInstanceValues();
            instanceValues.put(this.graphHash.getKey(), this.graphHash);
            instanceValues.put(this.graphLastModified.getKey(), this.graphLastModified);
            return instanceValues;
        }

        protected static SkeletalConfiguredObjectProvenance.ExtractedInfo extractTFProvenanceInfo(Map<String, Provenance> map) {
            SkeletalConfiguredObjectProvenance.ExtractedInfo extractProvenanceInfo = SkeletalTrainerProvenance.extractProvenanceInfo(map);
            extractProvenanceInfo.instanceValues.put("graph-hash", ObjectProvenance.checkAndExtractProvenance(map, "graph-hash", HashProvenance.class, TensorflowTrainer.TensorflowTrainerProvenance.class.getSimpleName()));
            extractProvenanceInfo.instanceValues.put("graph-last-modified", ObjectProvenance.checkAndExtractProvenance(map, "graph-last-modified", DateTimeProvenance.class, TensorflowTrainer.TensorflowTrainerProvenance.class.getSimpleName()));
            return extractProvenanceInfo;
        }
    }

    private TensorflowCheckpointTrainer() {
        this.minibatchSize = 1;
        this.epochs = 5;
        this.loggingInterval = 100;
        this.checkpointRootPath = Paths.get("/tmp/", new String[0]);
        this.trainInvocationCounter = 0;
    }

    public TensorflowCheckpointTrainer(Path path, Path path2, ExampleTransformer<T> exampleTransformer, OutputTransformer<T> outputTransformer, int i, int i2) throws IOException {
        this.minibatchSize = 1;
        this.epochs = 5;
        this.loggingInterval = 100;
        this.checkpointRootPath = Paths.get("/tmp/", new String[0]);
        this.trainInvocationCounter = 0;
        this.graphPath = path;
        this.checkpointRootPath = path2;
        this.exampleTransformer = exampleTransformer;
        this.outputTransformer = outputTransformer;
        this.minibatchSize = i;
        this.epochs = i2;
        postConfig();
    }

    public void postConfig() throws IOException {
        this.graphDef = Files.readAllBytes(this.graphPath);
    }

    /* JADX WARN: Failed to calculate best type for var: r20v1 ??
    java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.InsnArg.getType()" because "changeArg" is null
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.moveListener(TypeUpdate.java:439)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.runListeners(TypeUpdate.java:232)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.requestUpdate(TypeUpdate.java:212)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeForSsaVar(TypeUpdate.java:183)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeChecked(TypeUpdate.java:112)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:83)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:56)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.calculateFromBounds(FixTypesVisitor.java:156)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.setBestType(FixTypesVisitor.java:133)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.deduceType(FixTypesVisitor.java:238)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.tryDeduceTypes(FixTypesVisitor.java:221)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.visit(FixTypesVisitor.java:91)
     */
    /* JADX WARN: Failed to calculate best type for var: r20v1 ??
    java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.InsnArg.getType()" because "changeArg" is null
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.moveListener(TypeUpdate.java:439)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.runListeners(TypeUpdate.java:232)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.requestUpdate(TypeUpdate.java:212)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeForSsaVar(TypeUpdate.java:183)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeChecked(TypeUpdate.java:112)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:83)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:56)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.calculateFromBounds(TypeInferenceVisitor.java:145)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.setBestType(TypeInferenceVisitor.java:123)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.lambda$runTypePropagation$2(TypeInferenceVisitor.java:101)
    	at java.base/java.util.ArrayList.forEach(ArrayList.java:1596)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.runTypePropagation(TypeInferenceVisitor.java:101)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.visit(TypeInferenceVisitor.java:75)
     */
    /* JADX WARN: Failed to calculate best type for var: r21v0 ??
    java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.InsnArg.getType()" because "changeArg" is null
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.moveListener(TypeUpdate.java:439)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.runListeners(TypeUpdate.java:232)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.requestUpdate(TypeUpdate.java:212)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeForSsaVar(TypeUpdate.java:183)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeChecked(TypeUpdate.java:112)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:83)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:56)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.calculateFromBounds(FixTypesVisitor.java:156)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.setBestType(FixTypesVisitor.java:133)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.deduceType(FixTypesVisitor.java:238)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.tryDeduceTypes(FixTypesVisitor.java:221)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.visit(FixTypesVisitor.java:91)
     */
    /* JADX WARN: Failed to calculate best type for var: r21v0 ??
    java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.InsnArg.getType()" because "changeArg" is null
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.moveListener(TypeUpdate.java:439)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.runListeners(TypeUpdate.java:232)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.requestUpdate(TypeUpdate.java:212)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeForSsaVar(TypeUpdate.java:183)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.updateTypeChecked(TypeUpdate.java:112)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:83)
    	at jadx.core.dex.visitors.typeinference.TypeUpdate.apply(TypeUpdate.java:56)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.calculateFromBounds(TypeInferenceVisitor.java:145)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.setBestType(TypeInferenceVisitor.java:123)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.lambda$runTypePropagation$2(TypeInferenceVisitor.java:101)
    	at java.base/java.util.ArrayList.forEach(ArrayList.java:1596)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.runTypePropagation(TypeInferenceVisitor.java:101)
    	at jadx.core.dex.visitors.typeinference.TypeInferenceVisitor.visit(TypeInferenceVisitor.java:75)
     */
    /* JADX WARN: Multi-variable type inference failed. Error: java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.RegisterArg.getSVar()" because the return value of "jadx.core.dex.nodes.InsnNode.getResult()" is null
    	at jadx.core.dex.visitors.typeinference.AbstractTypeConstraint.collectRelatedVars(AbstractTypeConstraint.java:31)
    	at jadx.core.dex.visitors.typeinference.AbstractTypeConstraint.<init>(AbstractTypeConstraint.java:19)
    	at jadx.core.dex.visitors.typeinference.TypeSearch$1.<init>(TypeSearch.java:376)
    	at jadx.core.dex.visitors.typeinference.TypeSearch.makeMoveConstraint(TypeSearch.java:376)
    	at jadx.core.dex.visitors.typeinference.TypeSearch.makeConstraint(TypeSearch.java:361)
    	at jadx.core.dex.visitors.typeinference.TypeSearch.collectConstraints(TypeSearch.java:341)
    	at java.base/java.util.ArrayList.forEach(ArrayList.java:1596)
    	at jadx.core.dex.visitors.typeinference.TypeSearch.run(TypeSearch.java:60)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.runMultiVariableSearch(FixTypesVisitor.java:116)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.visit(FixTypesVisitor.java:91)
     */
    /* JADX WARN: Not initialized variable reg: 20, insn: 0x0353: MOVE (r0 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]) = (r20 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]) A[TRY_LEAVE], block:B:130:0x0353 */
    /* JADX WARN: Not initialized variable reg: 21, insn: 0x0358: MOVE (r0 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]) = (r21 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]), block:B:132:0x0358 */
    /* JADX WARN: Type inference failed for: r20v1, types: [org.tensorflow.Session] */
    /* JADX WARN: Type inference failed for: r21v0, types: [java.lang.Throwable] */
    public Model<T> train(Dataset<T> dataset, Map<String, Provenance> map) {
        ?? r20;
        ?? r21;
        try {
            Path createTempDirectory = Files.createTempDirectory(this.checkpointRootPath, "tensorflow-checkpoint", new FileAttribute[0]);
            ImmutableFeatureMap featureIDMap = dataset.getFeatureIDMap();
            ImmutableOutputInfo<T> outputIDInfo = dataset.getOutputIDInfo();
            ArrayList arrayList = new ArrayList();
            this.trainInvocationCounter++;
            try {
                Graph graph = new Graph();
                Throwable th = null;
                try {
                    try {
                        Session session = new Session(graph);
                        Throwable th2 = null;
                        Tensor create = Tensor.create(true);
                        Throwable th3 = null;
                        try {
                            Tensor create2 = Tensors.create(createTempDirectory.toString() + "/" + MODEL_FILENAME);
                            Throwable th4 = null;
                            try {
                                try {
                                    graph.importGraphDef(this.graphDef);
                                    session.runner().addTarget(TensorflowTrainer.INIT).run();
                                    logger.info("Initialised the model parameters");
                                    int i = 0;
                                    for (int i2 = 0; i2 < this.epochs; i2++) {
                                        logger.log(Level.INFO, "Starting epoch " + i2);
                                        Tensor create3 = Tensor.create(Integer.valueOf(i2));
                                        int i3 = 0;
                                        while (i3 < dataset.size()) {
                                            arrayList.clear();
                                            for (int i4 = i3; i4 < i3 + this.minibatchSize && i4 < dataset.size(); i4++) {
                                                arrayList.add(dataset.getExample(i4));
                                            }
                                            Tensor<?> transform = this.exampleTransformer.transform(arrayList, featureIDMap);
                                            Tensor<?> transform2 = this.outputTransformer.transform(arrayList, outputIDInfo);
                                            Tensor tensor = (Tensor) session.runner().feed(TensorflowModel.INPUT_NAME, transform).feed(TensorflowTrainer.TARGET, transform2).feed(TensorflowTrainer.EPOCH, create3).feed(TensorflowTrainer.IS_TRAINING, create).addTarget(TensorflowTrainer.TRAIN).fetch(TensorflowTrainer.TRAINING_LOSS).run().get(0);
                                            if (i % this.loggingInterval == 0) {
                                                logger.log(Level.INFO, "Training loss = " + tensor.floatValue());
                                            }
                                            transform.close();
                                            transform2.close();
                                            tensor.close();
                                            i++;
                                            i3 += this.minibatchSize;
                                        }
                                        create3.close();
                                    }
                                    session.runner().feed("save/Const", create2).addTarget("save/control_dependency").run();
                                    TensorflowCheckpointModel tensorflowCheckpointModel = new TensorflowCheckpointModel("tf-model", new ModelProvenance(TensorflowCheckpointModel.class.getName(), OffsetDateTime.now(), dataset.getProvenance(), m8getProvenance(), map), featureIDMap, outputIDInfo, graph.toGraphDef(), createTempDirectory.toString(), this.exampleTransformer, this.outputTransformer);
                                    if (create2 != null) {
                                        if (0 != 0) {
                                            try {
                                                create2.close();
                                            } catch (Throwable th5) {
                                                th4.addSuppressed(th5);
                                            }
                                        } else {
                                            create2.close();
                                        }
                                    }
                                    if (session != null) {
                                        if (0 != 0) {
                                            try {
                                                session.close();
                                            } catch (Throwable th6) {
                                                th2.addSuppressed(th6);
                                            }
                                        } else {
                                            session.close();
                                        }
                                    }
                                    return tensorflowCheckpointModel;
                                } finally {
                                }
                            } catch (Throwable th7) {
                                if (create2 != null) {
                                    if (th4 != null) {
                                        try {
                                            create2.close();
                                        } catch (Throwable th8) {
                                            th4.addSuppressed(th8);
                                        }
                                    } else {
                                        create2.close();
                                    }
                                }
                                throw th7;
                            }
                        } finally {
                            if (create != null) {
                                if (0 != 0) {
                                    try {
                                        create.close();
                                    } catch (Throwable th9) {
                                        th3.addSuppressed(th9);
                                    }
                                } else {
                                    create.close();
                                }
                            }
                        }
                    } finally {
                        if (graph != null) {
                            if (0 != 0) {
                                try {
                                    graph.close();
                                } catch (Throwable th10) {
                                    th.addSuppressed(th10);
                                }
                            } else {
                                graph.close();
                            }
                        }
                    }
                } catch (Throwable th11) {
                    if (r20 != 0) {
                        if (r21 != 0) {
                            try {
                                r20.close();
                            } catch (Throwable th12) {
                                r21.addSuppressed(th12);
                            }
                        } else {
                            r20.close();
                        }
                    }
                    throw th11;
                }
            } catch (TensorFlowException e) {
                logger.log(Level.SEVERE, "TensorFlow threw an error", e);
                throw new IllegalStateException(e);
            }
        } catch (IOException e2) {
            logger.log(Level.SEVERE, "Failed to create checkpoint directory at path " + this.checkpointRootPath, (Throwable) e2);
            throw new IllegalStateException("Failed to create checkpoint directory at path " + this.checkpointRootPath, e2);
        }
    }

    public int getInvocationCount() {
        return this.trainInvocationCounter;
    }

    public String toString() {
        return "TensorflowCheckpointTrainer(graphPath=" + this.graphPath.toString() + ",checkpointRootPath=" + this.checkpointRootPath.toString() + ",exampleTransformer=" + this.exampleTransformer.toString() + ",outputTransformer" + this.outputTransformer.toString() + ",minibatchSize=" + this.minibatchSize + ",epochs=" + this.epochs + ")";
    }

    /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
    public TrainerProvenance m8getProvenance() {
        return new TensorflowCheckpointTrainerProvenance(this);
    }
}
