package org.deeplearning4j.rl4j.network.ac;

import java.io.IOException;
import java.io.OutputStream;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/deeplearning4j/rl4j/network/ac/ActorCriticCompGraph.class */
public class ActorCriticCompGraph implements IActorCritic {
    protected final ComputationGraph cg;

    public ActorCriticCompGraph(ComputationGraph computationGraph) {
        this.cg = computationGraph;
    }

    @Override // org.deeplearning4j.rl4j.network.ac.IActorCritic, org.deeplearning4j.rl4j.network.NeuralNet
    public void fit(INDArray iNDArray, INDArray[] iNDArrayArr) {
        this.cg.fit(new INDArray[]{iNDArray}, iNDArrayArr);
    }

    @Override // org.deeplearning4j.rl4j.network.ac.IActorCritic, org.deeplearning4j.rl4j.network.NeuralNet
    public INDArray[] outputAll(INDArray iNDArray) {
        return this.cg.output(new INDArray[]{iNDArray});
    }

    @Override // org.deeplearning4j.rl4j.network.NeuralNet
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public ActorCriticCompGraph m8clone() {
        return new ActorCriticCompGraph(this.cg.clone());
    }

    @Override // org.deeplearning4j.rl4j.network.ac.IActorCritic, org.deeplearning4j.rl4j.network.NeuralNet
    public Gradient[] gradient(INDArray iNDArray, INDArray[] iNDArrayArr) {
        this.cg.setInput(0, iNDArray);
        this.cg.setLabels(iNDArrayArr);
        this.cg.computeGradientAndScore();
        return new Gradient[]{this.cg.gradient()};
    }

    @Override // org.deeplearning4j.rl4j.network.ac.IActorCritic, org.deeplearning4j.rl4j.network.NeuralNet
    public void applyGradient(Gradient[] gradientArr, int i) {
        this.cg.getUpdater().update(this.cg, gradientArr[0], 1, i);
        this.cg.params().subi(gradientArr[0].gradient());
    }

    @Override // org.deeplearning4j.rl4j.network.ac.IActorCritic, org.deeplearning4j.rl4j.network.NeuralNet
    public double getLatestScore() {
        return this.cg.score();
    }

    @Override // org.deeplearning4j.rl4j.network.ac.IActorCritic, org.deeplearning4j.rl4j.network.NeuralNet
    public void save(OutputStream outputStream) {
        try {
            ModelSerializer.writeModel(this.cg, outputStream, true);
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    @Override // org.deeplearning4j.rl4j.network.ac.IActorCritic, org.deeplearning4j.rl4j.network.NeuralNet
    public void save(String str) {
        try {
            ModelSerializer.writeModel(this.cg, str, true);
        } catch (IOException e) {
            e.printStackTrace();
        }
    }
}
