package us.ihmc.trajectoryOptimization;

import java.lang.Enum;
import org.ejml.data.DMatrixRMaj;
import org.ejml.dense.row.CommonOps_DDRM;
import org.ejml.dense.row.factory.LinearSolverFactory_DDRM;
import org.ejml.interfaces.linsol.LinearSolverDense;
import us.ihmc.commons.lists.RecyclingArrayList;

/* loaded from: input_file:us/ihmc/trajectoryOptimization/DiscreteTimeVaryingTrackingLQRSolver.class */
public class DiscreteTimeVaryingTrackingLQRSolver<E extends Enum> implements LQRSolverInterface<E> {
    private final DiscreteOptimizationData optimalSequence;
    private final DiscreteOptimizationData desiredSequence;
    private final DiscreteSequence feedbackGainSequence;
    private final DiscreteSequence feedforwardSequence;
    private final DiscreteSequence constantsSequence;
    private final RecyclingArrayList<DMatrixRMaj> s1Sequence;
    private final RecyclingArrayList<DMatrixRMaj> s2Sequence;
    private final LinearSolverDense<DMatrixRMaj> linearSolver;
    private final DMatrixRMaj Q;
    private final DMatrixRMaj R;
    private final DMatrixRMaj Qf;
    private final DMatrixRMaj A;
    private final DMatrixRMaj B;
    private final DMatrixRMaj G;
    private final DMatrixRMaj G_inv;
    private final DMatrixRMaj H;
    private final DiscreteHybridDynamics<E> dynamics;
    private final LQTrackingCostFunction<E> costFunction;
    private final LQTrackingCostFunction<E> terminalCostFunction;
    private final DMatrixRMaj tempMatrix;
    private final DMatrixRMaj tempMatrix2;
    private final boolean debug;
    private final DMatrixRMaj tempMatrix3;

    public DiscreteTimeVaryingTrackingLQRSolver(DiscreteHybridDynamics<E> discreteHybridDynamics, LQTrackingCostFunction<E> lQTrackingCostFunction, LQTrackingCostFunction<E> lQTrackingCostFunction2) {
        this(discreteHybridDynamics, lQTrackingCostFunction, lQTrackingCostFunction2, false);
    }

    public DiscreteTimeVaryingTrackingLQRSolver(DiscreteHybridDynamics<E> discreteHybridDynamics, LQTrackingCostFunction<E> lQTrackingCostFunction, LQTrackingCostFunction<E> lQTrackingCostFunction2, boolean z) {
        this.linearSolver = LinearSolverFactory_DDRM.linear(0);
        this.tempMatrix = new DMatrixRMaj(0, 0);
        this.tempMatrix2 = new DMatrixRMaj(0, 0);
        this.tempMatrix3 = new DMatrixRMaj(0, 0);
        this.dynamics = discreteHybridDynamics;
        this.costFunction = lQTrackingCostFunction;
        this.terminalCostFunction = lQTrackingCostFunction2;
        this.debug = z;
        int stateVectorSize = discreteHybridDynamics.getStateVectorSize();
        int controlVectorSize = discreteHybridDynamics.getControlVectorSize();
        int constantVectorSize = discreteHybridDynamics.getConstantVectorSize();
        this.Q = new DMatrixRMaj(stateVectorSize, stateVectorSize);
        this.Qf = new DMatrixRMaj(stateVectorSize, stateVectorSize);
        this.R = new DMatrixRMaj(controlVectorSize, controlVectorSize);
        this.G_inv = new DMatrixRMaj(controlVectorSize, controlVectorSize);
        this.G = new DMatrixRMaj(controlVectorSize, controlVectorSize);
        this.A = new DMatrixRMaj(stateVectorSize, stateVectorSize);
        this.B = new DMatrixRMaj(stateVectorSize, controlVectorSize);
        this.H = new DMatrixRMaj(stateVectorSize, stateVectorSize);
        this.optimalSequence = new DiscreteOptimizationSequence(stateVectorSize, controlVectorSize);
        this.desiredSequence = new DiscreteOptimizationSequence(stateVectorSize, controlVectorSize);
        this.feedbackGainSequence = new DiscreteSequence(controlVectorSize, stateVectorSize);
        this.feedforwardSequence = new DiscreteSequence(controlVectorSize);
        this.constantsSequence = new DiscreteSequence(constantVectorSize);
        this.s1Sequence = new RecyclingArrayList<>(1000, new VariableVectorBuilder(stateVectorSize, stateVectorSize));
        this.s2Sequence = new RecyclingArrayList<>(1000, new VariableVectorBuilder(1, stateVectorSize));
        this.feedbackGainSequence.clear();
        this.feedforwardSequence.clear();
        this.s1Sequence.clear();
        this.s2Sequence.clear();
    }

    @Override // us.ihmc.trajectoryOptimization.LQRSolverInterface
    public void setDesiredSequence(DiscreteOptimizationData discreteOptimizationData, DiscreteSequence discreteSequence, DMatrixRMaj dMatrixRMaj) {
        this.desiredSequence.set(discreteOptimizationData);
        this.optimalSequence.setZero(discreteOptimizationData);
        this.s1Sequence.clear();
        this.s2Sequence.clear();
        this.constantsSequence.set(discreteSequence);
        this.feedbackGainSequence.setLength(discreteOptimizationData.size());
        this.feedforwardSequence.setLength(discreteOptimizationData.size());
        for (int i = 0; i < discreteOptimizationData.size(); i++) {
            this.s1Sequence.add();
            this.s2Sequence.add();
        }
        this.optimalSequence.setState(0, dMatrixRMaj);
    }

    @Override // us.ihmc.trajectoryOptimization.LQRSolverInterface
    public void solveRiccatiEquation(E e, int i, int i2) {
        int stateVectorSize = this.dynamics.getStateVectorSize();
        int controlVectorSize = this.dynamics.getControlVectorSize();
        int i3 = i2;
        this.terminalCostFunction.getCostStateHessian(e, this.desiredSequence.getState(i3), this.desiredSequence.getControl(i3), (DMatrixRMaj) this.constantsSequence.get(i3), this.Qf);
        DMatrixRMaj dMatrixRMaj = (DMatrixRMaj) this.s1Sequence.get(i3);
        DMatrixRMaj dMatrixRMaj2 = (DMatrixRMaj) this.s2Sequence.get(i3);
        dMatrixRMaj.set(this.Qf);
        CommonOps_DDRM.multTransA(-2.0d, this.desiredSequence.getState(i3), this.Qf, dMatrixRMaj2);
        while (true) {
            i3--;
            if (i3 < i) {
                return;
            }
            DMatrixRMaj control = this.desiredSequence.getControl(i3);
            DMatrixRMaj state = this.desiredSequence.getState(i3);
            DMatrixRMaj dMatrixRMaj3 = (DMatrixRMaj) this.constantsSequence.get(i3);
            this.dynamics.getDynamicsStateGradient(e, state, control, dMatrixRMaj3, this.A);
            this.dynamics.getDynamicsControlGradient(e, state, control, dMatrixRMaj3, this.B);
            this.costFunction.getCostStateHessian(e, state, control, dMatrixRMaj3, this.Q);
            this.costFunction.getCostControlHessian(e, state, control, dMatrixRMaj3, this.R);
            if (this.debug) {
                if (isAnyInvalid(this.A)) {
                    throw new RuntimeException("The A matrix is invalid.");
                }
                if (isAnyInvalid(this.B)) {
                    throw new RuntimeException("The B matrix is invalid.");
                }
                if (isAnyInvalid(this.Q)) {
                    throw new RuntimeException("The state Hessian is invalid.");
                }
                if (isAnyInvalid(this.R)) {
                    throw new RuntimeException("The control Hessian is invalid.");
                }
                if (isAnyInvalid(this.Qf)) {
                    throw new RuntimeException("The final state Hessian is invalid.");
                }
            }
            DMatrixRMaj dMatrixRMaj4 = (DMatrixRMaj) this.feedbackGainSequence.get(i3);
            DMatrixRMaj dMatrixRMaj5 = (DMatrixRMaj) this.feedforwardSequence.get(i3);
            DMatrixRMaj dMatrixRMaj6 = (DMatrixRMaj) this.s1Sequence.get(i3);
            DMatrixRMaj dMatrixRMaj7 = (DMatrixRMaj) this.s2Sequence.get(i3);
            DMatrixRMaj dMatrixRMaj8 = (DMatrixRMaj) this.s1Sequence.get(i3 + 1);
            DMatrixRMaj dMatrixRMaj9 = (DMatrixRMaj) this.s2Sequence.get(i3 + 1);
            this.dynamics.getDynamicsStateGradient(e, state, control, dMatrixRMaj3, this.A);
            this.dynamics.getDynamicsControlGradient(e, state, control, dMatrixRMaj3, this.B);
            this.G.set(this.R);
            addMultQuad(this.B, dMatrixRMaj8, this.B, this.G);
            this.linearSolver.setA(this.G);
            this.linearSolver.invert(this.G_inv);
            this.tempMatrix.reshape(controlVectorSize, stateVectorSize);
            CommonOps_DDRM.multTransB(this.G_inv, this.B, this.tempMatrix);
            this.tempMatrix2.reshape(stateVectorSize, stateVectorSize);
            CommonOps_DDRM.mult(dMatrixRMaj8, this.A, this.tempMatrix2);
            CommonOps_DDRM.mult(-1.0d, this.tempMatrix, this.tempMatrix2, dMatrixRMaj4);
            this.tempMatrix.reshape(controlVectorSize, 1);
            CommonOps_DDRM.multTransAB(-0.5d, this.B, dMatrixRMaj9, this.tempMatrix);
            CommonOps_DDRM.multAdd(this.R, control, this.tempMatrix);
            CommonOps_DDRM.mult(this.G_inv, this.tempMatrix, dMatrixRMaj5);
            dMatrixRMaj6.set(this.Q);
            addMultQuad(dMatrixRMaj4, this.R, dMatrixRMaj4, dMatrixRMaj6);
            this.H.set(this.A);
            CommonOps_DDRM.multAdd(this.B, dMatrixRMaj4, this.H);
            addMultQuad(this.H, dMatrixRMaj8, this.H, dMatrixRMaj6);
            this.tempMatrix.reshape(stateVectorSize, 1);
            this.tempMatrix2.set(dMatrixRMaj9);
            CommonOps_DDRM.mult(this.B, dMatrixRMaj5, this.tempMatrix);
            CommonOps_DDRM.multAddTransA(2.0d, this.tempMatrix, dMatrixRMaj8, this.tempMatrix2);
            CommonOps_DDRM.mult(this.tempMatrix2, this.H, dMatrixRMaj7);
            CommonOps_DDRM.multAddTransA(-2.0d, state, this.Q, dMatrixRMaj7);
            this.tempMatrix.reshape(controlVectorSize, stateVectorSize);
            CommonOps_DDRM.mult(this.R, dMatrixRMaj4, this.tempMatrix);
            CommonOps_DDRM.multAddTransA(-2.0d, control, this.tempMatrix, dMatrixRMaj7);
            CommonOps_DDRM.multAddTransA(2.0d, dMatrixRMaj5, this.tempMatrix, dMatrixRMaj7);
            if (!this.debug || (!isAnyInvalid(dMatrixRMaj7) && !isAnyInvalid(dMatrixRMaj6))) {
            }
        }
        throw new RuntimeException("The computed Riccati equation solutions are ill-conditioned.");
    }

    @Override // us.ihmc.trajectoryOptimization.LQRSolverInterface
    public void computeOptimalSequences(E e, int i, int i2) {
        for (int i3 = i; i3 < i2; i3++) {
            DMatrixRMaj state = this.desiredSequence.getState(i3);
            DMatrixRMaj control = this.desiredSequence.getControl(i3);
            DMatrixRMaj dMatrixRMaj = (DMatrixRMaj) this.constantsSequence.get(i3);
            this.dynamics.getDynamicsStateGradient(e, state, control, dMatrixRMaj, this.A);
            this.dynamics.getDynamicsControlGradient(e, state, control, dMatrixRMaj, this.B);
            DMatrixRMaj control2 = this.optimalSequence.getControl(i3);
            DMatrixRMaj state2 = this.optimalSequence.getState(i3);
            DMatrixRMaj state3 = this.optimalSequence.getState(i3 + 1);
            DMatrixRMaj dMatrixRMaj2 = (DMatrixRMaj) this.feedbackGainSequence.get(i3);
            control2.set((DMatrixRMaj) this.feedforwardSequence.get(i3));
            CommonOps_DDRM.multAdd(dMatrixRMaj2, state2, control2);
            CommonOps_DDRM.mult(this.A, state2, state3);
            CommonOps_DDRM.multAdd(this.B, control2, state3);
            if (this.debug) {
                if (isAnyInvalid(state3)) {
                    throw new RuntimeException("The computed optimal state is ill-conditioned.");
                }
                if (isAnyInvalid(control2)) {
                    throw new RuntimeException("The computed optimal control is ill-conditioned.");
                }
            }
        }
    }

    @Override // us.ihmc.trajectoryOptimization.LQRSolverInterface
    public void getOptimalSequence(DiscreteOptimizationData discreteOptimizationData) {
        discreteOptimizationData.set(this.optimalSequence);
    }

    @Override // us.ihmc.trajectoryOptimization.LQRSolverInterface
    public DiscreteOptimizationData getOptimalSequence() {
        return this.optimalSequence;
    }

    @Override // us.ihmc.trajectoryOptimization.LQRSolverInterface
    public DiscreteData getOptimalStateSequence() {
        return this.optimalSequence.getStateSequence();
    }

    @Override // us.ihmc.trajectoryOptimization.LQRSolverInterface
    public DiscreteData getOptimalControlSequence() {
        return this.optimalSequence.getControlSequence();
    }

    @Override // us.ihmc.trajectoryOptimization.LQRSolverInterface
    public DiscreteSequence getOptimalFeedbackGainSequence() {
        return this.feedbackGainSequence;
    }

    @Override // us.ihmc.trajectoryOptimization.LQRSolverInterface
    public DiscreteSequence getOptimalFeedForwardControlSequence() {
        return this.feedforwardSequence;
    }

    @Override // us.ihmc.trajectoryOptimization.LQRSolverInterface
    public DMatrixRMaj getValueHessian() {
        return (DMatrixRMaj) this.s1Sequence.get(0);
    }

    private void addMultQuad(DMatrixRMaj dMatrixRMaj, DMatrixRMaj dMatrixRMaj2, DMatrixRMaj dMatrixRMaj3, DMatrixRMaj dMatrixRMaj4) {
        this.tempMatrix3.reshape(dMatrixRMaj.numCols, dMatrixRMaj2.numCols);
        CommonOps_DDRM.multTransA(dMatrixRMaj, dMatrixRMaj2, this.tempMatrix3);
        CommonOps_DDRM.multAdd(this.tempMatrix3, dMatrixRMaj3, dMatrixRMaj4);
    }

    private boolean isAnyInvalid(DMatrixRMaj dMatrixRMaj) {
        for (int i = 0; i < dMatrixRMaj.getNumElements(); i++) {
            if (!Double.isFinite(dMatrixRMaj.get(i))) {
                return true;
            }
        }
        return false;
    }
}
