package us.ihmc.trajectoryOptimization;

import java.lang.Enum;
import org.ejml.data.DMatrixRMaj;
import us.ihmc.commons.MathTools;
import us.ihmc.commons.PrintTools;

/* loaded from: input_file:us/ihmc/trajectoryOptimization/DDPSolver.class */
public class DDPSolver<E extends Enum> extends AbstractDDPSolver<E> implements DDPSolverInterface<E> {
    private static final double lineSearchScaling = 0.1d;
    private static final double lineSearchStartGain = 1.0d;
    private static final double lineSearchGainMinimum = 0.0d;
    private final DiscreteOptimizationData previousSequence;

    public DDPSolver(DiscreteHybridDynamics<E> discreteHybridDynamics) {
        this(discreteHybridDynamics, false);
    }

    public DDPSolver(DiscreteHybridDynamics<E> discreteHybridDynamics, boolean z) {
        super(discreteHybridDynamics, z);
        this.previousSequence = new DiscreteOptimizationSequence(discreteHybridDynamics.getStateVectorSize(), discreteHybridDynamics.getControlVectorSize());
    }

    @Override // us.ihmc.trajectoryOptimization.AbstractDDPSolver, us.ihmc.trajectoryOptimization.DDPSolverInterface
    public void initializeFromLQRSolution(E e, LQTrackingCostFunction<E> lQTrackingCostFunction, DiscreteOptimizationData discreteOptimizationData, DiscreteOptimizationData discreteOptimizationData2, DiscreteSequence discreteSequence, DiscreteSequence discreteSequence2, DiscreteSequence discreteSequence3) {
        super.initializeFromLQRSolution(e, lQTrackingCostFunction, discreteOptimizationData, discreteOptimizationData2, discreteSequence, this.feedBackGainSequence, discreteSequence3);
        this.previousSequence.setZero(discreteOptimizationData);
    }

    @Override // us.ihmc.trajectoryOptimization.AbstractDDPSolver, us.ihmc.trajectoryOptimization.DDPSolverInterface
    public void initializeSequencesFromDesireds(DMatrixRMaj dMatrixRMaj, DiscreteOptimizationData discreteOptimizationData, DiscreteSequence discreteSequence) {
        super.initializeSequencesFromDesireds(dMatrixRMaj, discreteOptimizationData, discreteSequence);
        this.previousSequence.setZero(discreteOptimizationData);
    }

    @Override // us.ihmc.trajectoryOptimization.AbstractDDPSolver, us.ihmc.trajectoryOptimization.DDPSolverInterface
    public boolean backwardPass(E e, int i, int i2, LQTrackingCostFunction<E> lQTrackingCostFunction, DiscreteOptimizationData discreteOptimizationData) {
        boolean z = true;
        DiscreteData stateSequence = discreteOptimizationData.getStateSequence();
        DiscreteData controlSequence = discreteOptimizationData.getControlSequence();
        DiscreteData stateSequence2 = this.desiredSequence.getStateSequence();
        DiscreteData controlSequence2 = this.desiredSequence.getControlSequence();
        if (lQTrackingCostFunction != null) {
            lQTrackingCostFunction.getCostStateHessian(e, controlSequence.get(i2), stateSequence.get(i2), (DMatrixRMaj) this.constantsSequence.get(i2), (DMatrixRMaj) this.valueStateHessianSequence.get(i2));
            lQTrackingCostFunction.getCostStateGradient(e, controlSequence.get(i2), stateSequence.get(i2), controlSequence2.get(i2), stateSequence2.get(i2), (DMatrixRMaj) this.constantsSequence.get(i2), (DMatrixRMaj) this.valueStateGradientSequence.get(i2));
        }
        for (int i3 = i2; i3 >= i; i3--) {
            DMatrixRMaj dMatrixRMaj = (DMatrixRMaj) this.valueStateHessianSequence.get(i3);
            int i4 = i3;
            updateHamiltonianApproximations(e, i4, (DMatrixRMaj) this.costStateGradientSequence.get(i3), (DMatrixRMaj) this.costControlGradientSequence.get(i3), (DMatrixRMaj) this.costStateHessianSequence.get(i3), (DMatrixRMaj) this.costControlHessianSequence.get(i3), (DMatrixRMaj) this.costStateControlHessianSequence.get(i3), (DMatrixRMaj) this.dynamicsStateGradientSequence.get(i3), (DMatrixRMaj) this.dynamicsControlGradientSequence.get(i3), (DMatrixRMaj) this.valueStateGradientSequence.get(i3), dMatrixRMaj, this.hamiltonianStateGradient, this.hamiltonianControlGradient, this.hamiltonianStateHessian, this.hamiltonianControlHessian, this.hamiltonianStateControlHessian, this.hamiltonianControlStateHessian);
            z = computeFeedbackGainAndFeedForwardTerms(this.hamiltonianControlGradient, this.hamiltonianControlHessian, this.hamiltonianControlStateHessian, (DMatrixRMaj) this.feedBackGainSequence.get(i3), (DMatrixRMaj) this.feedForwardSequence.get(i3));
            if (!z) {
                break;
            }
            if (i3 > 0) {
                computePreviousValueApproximation(this.hamiltonianStateGradient, this.hamiltonianControlGradient, this.hamiltonianStateHessian, this.hamiltonianStateControlHessian, (DMatrixRMaj) this.feedBackGainSequence.get(i3), (DMatrixRMaj) this.valueStateGradientSequence.get(i3 - 1), (DMatrixRMaj) this.valueStateHessianSequence.get(i3 - 1));
            }
        }
        return z;
    }

    @Override // us.ihmc.trajectoryOptimization.AbstractDDPSolver, us.ihmc.trajectoryOptimization.DDPSolverInterface
    public double forwardPass(E e, int i, int i2, LQTrackingCostFunction<E> lQTrackingCostFunction, DMatrixRMaj dMatrixRMaj, DiscreteOptimizationData discreteOptimizationData) {
        this.lineSearchGain = lineSearchStartGain;
        boolean z = true;
        boolean z2 = false;
        double d = 0.0d;
        while (z) {
            d = solveForwardDDPPassInternal(e, i, i2, lQTrackingCostFunction, dMatrixRMaj, this.previousSequence);
            if (!Double.isInfinite(d)) {
                discreteOptimizationData.set(this.previousSequence);
                z = false;
            } else {
                if (z2) {
                    break;
                }
                this.lineSearchGain = Math.max(this.lineSearchGain - lineSearchScaling, lineSearchGainMinimum);
                PrintTools.info("Solution diverged, decrease line search gain to " + this.lineSearchGain + " and trying again.");
            }
            if (this.lineSearchGain == lineSearchGainMinimum) {
                z2 = true;
            }
        }
        return d;
    }

    private double solveForwardDDPPassInternal(E e, int i, int i2, LQTrackingCostFunction<E> lQTrackingCostFunction, DMatrixRMaj dMatrixRMaj, DiscreteOptimizationData discreteOptimizationData) {
        discreteOptimizationData.setState(i, dMatrixRMaj);
        double d = 0.0d;
        for (int i3 = i; i3 < i2; i3++) {
            DMatrixRMaj state = this.optimalSequence.getState(i3);
            DMatrixRMaj state2 = discreteOptimizationData.getState(i3);
            DMatrixRMaj control = discreteOptimizationData.getControl(i3);
            DMatrixRMaj dMatrixRMaj2 = (DMatrixRMaj) this.constantsSequence.get(i3);
            if (isStateDiverging(state2, state)) {
                return Double.POSITIVE_INFINITY;
            }
            computeUpdatedControl(state, state2, (DMatrixRMaj) this.feedBackGainSequence.get(i3), (DMatrixRMaj) this.feedForwardSequence.get(i3), this.optimalSequence.getControl(i3), control);
            if (i3 < this.desiredSequence.size() - 1) {
                this.dynamics.getNextState(e, state2, control, dMatrixRMaj2, discreteOptimizationData.getState(i3 + 1));
            }
            d += lQTrackingCostFunction.getCost(e, control, state2, this.desiredSequence.getControl(i3), this.desiredSequence.getState(i3), dMatrixRMaj2);
        }
        return d;
    }

    private boolean isStateDiverging(DMatrixRMaj dMatrixRMaj, DMatrixRMaj dMatrixRMaj2) {
        for (int i = 0; i < dMatrixRMaj.getNumElements(); i++) {
            if (!MathTools.epsilonEquals(dMatrixRMaj.get(i), dMatrixRMaj2.get(i), 1.0E20d)) {
                return true;
            }
        }
        return false;
    }
}
