package us.ihmc.trajectoryOptimization;

import java.lang.Enum;
import org.ejml.data.DMatrixRMaj;
import org.ejml.dense.row.CommonOps_DDRM;
import us.ihmc.commons.lists.RecyclingArrayList;
import us.ihmc.matrixlib.DiagonalMatrixTools;

/* loaded from: input_file:us/ihmc/trajectoryOptimization/ContinuousTrackingLQRSolver.class */
public class ContinuousTrackingLQRSolver<E extends Enum> implements LQRSolverInterface<E> {
    private final DiscreteOptimizationData optimalSequence;
    private final DiscreteOptimizationData desiredSequence;
    private final DiscreteSequence feedbackGainSequence;
    private final DiscreteSequence constantsSequence;
    private final RecyclingArrayList<DMatrixRMaj> S2Sequence;
    private final RecyclingArrayList<DMatrixRMaj> S1Sequence;
    private final DMatrixRMaj Q;
    private final DMatrixRMaj R;
    private final DMatrixRMaj R_inv;
    private final DMatrixRMaj Qf;
    private final DMatrixRMaj S2Dot;
    private final DMatrixRMaj S1Dot;
    private final DMatrixRMaj A;
    private final DMatrixRMaj B;
    private final DMatrixRMaj XDot;
    private final DMatrixRMaj S2BR_invBT;
    private final DMatrixRMaj R_invBT;
    private final DiscreteHybridDynamics<E> dynamics;
    private final LQTrackingCostFunction<E> costFunction;
    private final LQTrackingCostFunction<E> terminalCostFunction;
    private final double deltaT;
    private final boolean debug;
    private final DMatrixRMaj tempMatrix;

    public ContinuousTrackingLQRSolver(DiscreteHybridDynamics<E> discreteHybridDynamics, LQTrackingCostFunction lQTrackingCostFunction, LQTrackingCostFunction lQTrackingCostFunction2, double d) {
        this(discreteHybridDynamics, lQTrackingCostFunction, lQTrackingCostFunction2, d, false);
    }

    public ContinuousTrackingLQRSolver(DiscreteHybridDynamics<E> discreteHybridDynamics, LQTrackingCostFunction<E> lQTrackingCostFunction, LQTrackingCostFunction<E> lQTrackingCostFunction2, double d, boolean z) {
        this.tempMatrix = new DMatrixRMaj(0, 0);
        this.dynamics = discreteHybridDynamics;
        this.costFunction = lQTrackingCostFunction;
        this.terminalCostFunction = lQTrackingCostFunction2;
        this.deltaT = d;
        this.debug = z;
        int stateVectorSize = discreteHybridDynamics.getStateVectorSize();
        int controlVectorSize = discreteHybridDynamics.getControlVectorSize();
        int constantVectorSize = discreteHybridDynamics.getConstantVectorSize();
        this.Q = new DMatrixRMaj(stateVectorSize, stateVectorSize);
        this.Qf = new DMatrixRMaj(stateVectorSize, stateVectorSize);
        this.R = new DMatrixRMaj(controlVectorSize, controlVectorSize);
        this.R_inv = new DMatrixRMaj(controlVectorSize, controlVectorSize);
        this.S2Dot = new DMatrixRMaj(stateVectorSize, stateVectorSize);
        this.S1Dot = new DMatrixRMaj(stateVectorSize, 1);
        this.A = new DMatrixRMaj(stateVectorSize, stateVectorSize);
        this.B = new DMatrixRMaj(stateVectorSize, controlVectorSize);
        this.XDot = new DMatrixRMaj(stateVectorSize, 1);
        this.S2BR_invBT = new DMatrixRMaj(stateVectorSize, stateVectorSize);
        this.R_invBT = new DMatrixRMaj(controlVectorSize, stateVectorSize);
        VariableVectorBuilder variableVectorBuilder = new VariableVectorBuilder(stateVectorSize, 1);
        VariableVectorBuilder variableVectorBuilder2 = new VariableVectorBuilder(stateVectorSize, stateVectorSize);
        this.optimalSequence = new DiscreteOptimizationSequence(stateVectorSize, controlVectorSize);
        this.desiredSequence = new DiscreteOptimizationSequence(stateVectorSize, controlVectorSize);
        this.constantsSequence = new DiscreteSequence(constantVectorSize, 1);
        this.feedbackGainSequence = new DiscreteSequence(controlVectorSize, stateVectorSize);
        this.S2Sequence = new RecyclingArrayList<>(1000, variableVectorBuilder2);
        this.S1Sequence = new RecyclingArrayList<>(1000, variableVectorBuilder);
        this.feedbackGainSequence.clear();
        this.S2Sequence.clear();
        this.S1Sequence.clear();
    }

    @Override // us.ihmc.trajectoryOptimization.LQRSolverInterface
    public void setDesiredSequence(DiscreteOptimizationData discreteOptimizationData, DiscreteSequence discreteSequence, DMatrixRMaj dMatrixRMaj) {
        this.desiredSequence.set(discreteOptimizationData);
        this.optimalSequence.setZero(discreteOptimizationData);
        this.S2Sequence.clear();
        this.S1Sequence.clear();
        this.feedbackGainSequence.setLength(discreteOptimizationData.size());
        this.constantsSequence.set(discreteSequence);
        for (int i = 0; i < discreteOptimizationData.size(); i++) {
            this.S2Sequence.add();
            this.S1Sequence.add();
        }
        this.optimalSequence.setState(0, dMatrixRMaj);
    }

    @Override // us.ihmc.trajectoryOptimization.LQRSolverInterface
    public void solveRiccatiEquation(E e, int i, int i2) {
        int stateVectorSize = this.dynamics.getStateVectorSize();
        int controlVectorSize = this.dynamics.getControlVectorSize();
        int i3 = i2;
        DMatrixRMaj state = this.optimalSequence.getState(i3);
        DMatrixRMaj control = this.optimalSequence.getControl(i3);
        DMatrixRMaj state2 = this.desiredSequence.getState(i3);
        DMatrixRMaj dMatrixRMaj = (DMatrixRMaj) this.constantsSequence.get(i3);
        this.costFunction.getCostStateHessian(e, state, control, dMatrixRMaj, this.Q);
        this.costFunction.getCostControlHessian(e, state, control, dMatrixRMaj, this.R);
        this.terminalCostFunction.getCostStateHessian(e, state, control, dMatrixRMaj, this.Qf);
        this.dynamics.getContinuousAMatrix(this.A);
        this.dynamics.getContinuousBMatrix(this.B);
        DiagonalMatrixTools.invertDiagonalMatrix(this.R, this.R_inv);
        CommonOps_DDRM.multTransB(this.R_inv, this.B, this.R_invBT);
        DMatrixRMaj dMatrixRMaj2 = (DMatrixRMaj) this.S2Sequence.get(i3);
        DMatrixRMaj dMatrixRMaj3 = (DMatrixRMaj) this.S1Sequence.get(i3);
        dMatrixRMaj2.set(this.Qf);
        CommonOps_DDRM.mult(-2.0d, this.Qf, state2, dMatrixRMaj3);
        while (i3 > i) {
            DMatrixRMaj state3 = this.desiredSequence.getState(i3);
            DMatrixRMaj control2 = this.desiredSequence.getControl(i3);
            DMatrixRMaj dMatrixRMaj4 = (DMatrixRMaj) this.S2Sequence.get(i3);
            DMatrixRMaj dMatrixRMaj5 = (DMatrixRMaj) this.S1Sequence.get(i3);
            this.S2Dot.set(this.Q);
            CommonOps_DDRM.multAdd(dMatrixRMaj4, this.A, this.S2Dot);
            CommonOps_DDRM.multAddTransA(this.A, dMatrixRMaj4, this.S2Dot);
            this.tempMatrix.reshape(stateVectorSize, controlVectorSize);
            CommonOps_DDRM.multTransA(dMatrixRMaj4, this.B, this.tempMatrix);
            CommonOps_DDRM.mult(this.tempMatrix, this.R_invBT, this.S2BR_invBT);
            CommonOps_DDRM.multAdd(-1.0d, this.S2BR_invBT, dMatrixRMaj4, this.S2Dot);
            CommonOps_DDRM.mult(-2.0d, this.Q, state3, this.S1Dot);
            CommonOps_DDRM.multAddTransA(this.A, dMatrixRMaj5, this.S1Dot);
            CommonOps_DDRM.multAdd(-1.0d, this.S2BR_invBT, dMatrixRMaj5, this.S1Dot);
            this.tempMatrix.reshape(stateVectorSize, 1);
            CommonOps_DDRM.mult(this.B, control2, this.tempMatrix);
            CommonOps_DDRM.multAdd(2.0d, dMatrixRMaj4, this.tempMatrix, this.S1Dot);
            DMatrixRMaj dMatrixRMaj6 = (DMatrixRMaj) this.S2Sequence.get(i3 - 1);
            DMatrixRMaj dMatrixRMaj7 = (DMatrixRMaj) this.S1Sequence.get(i3 - 1);
            dMatrixRMaj6.set(dMatrixRMaj4);
            dMatrixRMaj7.set(dMatrixRMaj5);
            CommonOps_DDRM.addEquals(dMatrixRMaj6, this.deltaT, this.S2Dot);
            CommonOps_DDRM.addEquals(dMatrixRMaj7, this.deltaT, this.S1Dot);
            if (this.debug && (isAnyInvalid(dMatrixRMaj4) || isAnyInvalid(dMatrixRMaj5))) {
                throw new RuntimeException("The computed Riccati equation solutions are ill-conditioned.");
            }
            i3--;
        }
    }

    @Override // us.ihmc.trajectoryOptimization.LQRSolverInterface
    public void computeOptimalSequences(E e, int i, int i2) {
        int stateVectorSize = this.dynamics.getStateVectorSize();
        for (int i3 = i; i3 < i2; i3++) {
            DMatrixRMaj control = this.desiredSequence.getControl(i3);
            DMatrixRMaj state = this.optimalSequence.getState(i3);
            DMatrixRMaj control2 = this.optimalSequence.getControl(i3);
            DMatrixRMaj dMatrixRMaj = (DMatrixRMaj) this.feedbackGainSequence.get(i3);
            DMatrixRMaj state2 = this.optimalSequence.getState(i3 + 1);
            DMatrixRMaj dMatrixRMaj2 = (DMatrixRMaj) this.S2Sequence.get(i3);
            DMatrixRMaj dMatrixRMaj3 = (DMatrixRMaj) this.S1Sequence.get(i3);
            control2.set(control);
            this.tempMatrix.reshape(stateVectorSize, 1);
            CommonOps_DDRM.mult(dMatrixRMaj2, state, this.tempMatrix);
            CommonOps_DDRM.addEquals(this.tempMatrix, 0.5d, dMatrixRMaj3);
            CommonOps_DDRM.mult(-1.0d, this.R_invBT, this.tempMatrix, dMatrixRMaj);
            CommonOps_DDRM.multAdd(-1.0d, this.R_invBT, this.tempMatrix, control2);
            CommonOps_DDRM.mult(this.A, state, this.XDot);
            CommonOps_DDRM.multAdd(this.B, control2, this.XDot);
            state2.set(state);
            CommonOps_DDRM.addEquals(state2, this.deltaT, this.XDot);
            if (this.debug) {
                if (isAnyInvalid(state)) {
                    throw new RuntimeException("The computed optimal state is ill-conditioned.");
                }
                if (isAnyInvalid(control2)) {
                    throw new RuntimeException("The computed optimal control is ill-conditioned.");
                }
            }
        }
    }

    @Override // us.ihmc.trajectoryOptimization.LQRSolverInterface
    public void getOptimalSequence(DiscreteOptimizationData discreteOptimizationData) {
        discreteOptimizationData.set(this.optimalSequence);
    }

    @Override // us.ihmc.trajectoryOptimization.LQRSolverInterface
    public DiscreteOptimizationData getOptimalSequence() {
        return this.optimalSequence;
    }

    @Override // us.ihmc.trajectoryOptimization.LQRSolverInterface
    public DiscreteData getOptimalStateSequence() {
        return this.optimalSequence.getStateSequence();
    }

    @Override // us.ihmc.trajectoryOptimization.LQRSolverInterface
    public DiscreteData getOptimalControlSequence() {
        return this.optimalSequence.getControlSequence();
    }

    @Override // us.ihmc.trajectoryOptimization.LQRSolverInterface
    public DiscreteSequence getOptimalFeedbackGainSequence() {
        throw new RuntimeException("this isn't implemented correctly.");
    }

    @Override // us.ihmc.trajectoryOptimization.LQRSolverInterface
    public DMatrixRMaj getValueHessian() {
        return (DMatrixRMaj) this.S1Sequence.get(0);
    }

    @Override // us.ihmc.trajectoryOptimization.LQRSolverInterface
    public DiscreteSequence getOptimalFeedForwardControlSequence() {
        throw new RuntimeException("this isn't implemented correctly.");
    }

    private boolean isAnyInvalid(DMatrixRMaj dMatrixRMaj) {
        for (int i = 0; i < dMatrixRMaj.getNumElements(); i++) {
            if (!Double.isFinite(dMatrixRMaj.get(i))) {
                return true;
            }
        }
        return false;
    }
}
