package us.ihmc.commonWalkingControlModules.dynamicPlanning.lipm;

import org.ejml.data.DMatrixRMaj;
import us.ihmc.commonWalkingControlModules.momentumBasedController.optimization.JointAccelerationIntegrationCalculator;
import us.ihmc.euclid.referenceFrame.FramePoint3D;
import us.ihmc.euclid.referenceFrame.FrameVector3D;
import us.ihmc.robotics.math.trajectories.generators.MultipleWaypointsPositionTrajectoryGenerator;
import us.ihmc.trajectoryOptimization.DefaultDiscreteState;
import us.ihmc.trajectoryOptimization.DiscreteHybridDynamics;
import us.ihmc.trajectoryOptimization.DiscreteOptimizationTrajectory;
import us.ihmc.trajectoryOptimization.DiscreteSequence;
import us.ihmc.trajectoryOptimization.DiscreteTimeVaryingTrackingLQRSolver;
import us.ihmc.trajectoryOptimization.LQRSolverInterface;
import us.ihmc.trajectoryOptimization.LQTrackingCostFunction;
import us.ihmc.trajectoryOptimization.SimpleDDPSolver;

/* loaded from: input_file:us/ihmc/commonWalkingControlModules/dynamicPlanning/lipm/LIPMDDPCalculator.class */
public class LIPMDDPCalculator {
    private final DiscreteHybridDynamics<DefaultDiscreteState> dynamics;
    private double deltaT;
    private double modifiedDeltaT;
    private final DiscreteOptimizationTrajectory desiredTrajectory;
    private final DiscreteOptimizationTrajectory optimalTrajectory;
    private final DiscreteSequence constantSequence;
    private int numberOfTimeSteps;
    private final SimpleDDPSolver<DefaultDiscreteState> ddpSolver;
    private final LQRSolverInterface<DefaultDiscreteState> lqrSolver;
    private double mass;
    private double gravityZ;
    private final LQTrackingCostFunction<DefaultDiscreteState> costFunction = new LIPMSimpleCostFunction();
    private final LQTrackingCostFunction<DefaultDiscreteState> terminalCostFunction = new LIPMTerminalCostFunction();
    private final FramePoint3D tempPoint = new FramePoint3D();
    private final FrameVector3D tempVector = new FrameVector3D();

    public LIPMDDPCalculator(double d, double d2, double d3) {
        this.dynamics = new LIPMDynamics(d, d2, d3);
        this.deltaT = d;
        this.mass = d2;
        this.gravityZ = d3;
        this.ddpSolver = new SimpleDDPSolver<>(this.dynamics, true);
        this.lqrSolver = new DiscreteTimeVaryingTrackingLQRSolver(this.dynamics, this.costFunction, this.terminalCostFunction);
        int stateVectorSize = this.dynamics.getStateVectorSize();
        int controlVectorSize = this.dynamics.getControlVectorSize();
        int constantVectorSize = this.dynamics.getConstantVectorSize();
        this.desiredTrajectory = new DiscreteOptimizationTrajectory(stateVectorSize, controlVectorSize);
        this.optimalTrajectory = new DiscreteOptimizationTrajectory(stateVectorSize, controlVectorSize);
        this.constantSequence = new DiscreteSequence(constantVectorSize);
    }

    public void setDeltaT(double d) {
        this.dynamics.setTimeStepSize(d);
        this.deltaT = d;
        this.modifiedDeltaT = d;
    }

    public void initialize(DMatrixRMaj dMatrixRMaj, MultipleWaypointsPositionTrajectoryGenerator multipleWaypointsPositionTrajectoryGenerator) {
        this.modifiedDeltaT = computeDeltaT(multipleWaypointsPositionTrajectoryGenerator.getLastWaypointTime());
        this.dynamics.setTimeStepSize(this.modifiedDeltaT);
        this.desiredTrajectory.setTrajectoryDuration(JointAccelerationIntegrationCalculator.DEFAULT_VELOCITY_REFERENCE_ALPHA, multipleWaypointsPositionTrajectoryGenerator.getLastWaypointTime(), this.deltaT);
        this.optimalTrajectory.setTrajectoryDuration(JointAccelerationIntegrationCalculator.DEFAULT_VELOCITY_REFERENCE_ALPHA, multipleWaypointsPositionTrajectoryGenerator.getLastWaypointTime(), this.deltaT);
        this.constantSequence.setLength(this.desiredTrajectory.size());
        double d = dMatrixRMaj.get(2);
        multipleWaypointsPositionTrajectoryGenerator.compute(JointAccelerationIntegrationCalculator.DEFAULT_VELOCITY_REFERENCE_ALPHA);
        this.tempPoint.setIncludingFrame(multipleWaypointsPositionTrajectoryGenerator.getPosition());
        this.tempVector.setIncludingFrame(multipleWaypointsPositionTrajectoryGenerator.getVelocity());
        DMatrixRMaj state = this.desiredTrajectory.getState(0);
        state.set(0, this.tempPoint.getX());
        state.set(1, this.tempPoint.getY());
        state.set(2, this.tempPoint.getZ() + d);
        state.set(3, this.tempVector.getX());
        state.set(4, this.tempVector.getY());
        state.set(5, this.tempVector.getZ());
        DMatrixRMaj control = this.desiredTrajectory.getControl(0);
        control.set(0, this.tempPoint.getX());
        control.set(1, this.tempPoint.getY());
        control.set(2, this.mass * this.gravityZ);
        double d2 = JointAccelerationIntegrationCalculator.DEFAULT_VELOCITY_REFERENCE_ALPHA + this.modifiedDeltaT;
        for (int i = 1; i < this.numberOfTimeSteps; i++) {
            multipleWaypointsPositionTrajectoryGenerator.compute(d2);
            this.tempPoint.setIncludingFrame(multipleWaypointsPositionTrajectoryGenerator.getPosition());
            this.tempVector.setIncludingFrame(multipleWaypointsPositionTrajectoryGenerator.getVelocity());
            DMatrixRMaj state2 = this.desiredTrajectory.getState(i);
            state2.set(0, this.tempPoint.getX());
            state2.set(1, this.tempPoint.getY());
            state2.set(2, this.tempPoint.getZ() + d);
            state2.set(3, this.tempVector.getX());
            state2.set(4, this.tempVector.getY());
            state2.set(5, this.tempVector.getZ());
            DMatrixRMaj control2 = this.desiredTrajectory.getControl(i);
            control2.set(0, this.tempPoint.getX());
            control2.set(1, this.tempPoint.getY());
            control2.set(2, this.mass * this.gravityZ);
            d2 += this.modifiedDeltaT;
        }
        this.lqrSolver.setDesiredSequence(this.desiredTrajectory, this.constantSequence, dMatrixRMaj);
        this.lqrSolver.solveRiccatiEquation(DefaultDiscreteState.DEFAULT, 0, this.desiredTrajectory.size() - 1);
        this.lqrSolver.computeOptimalSequences(DefaultDiscreteState.DEFAULT, 0, this.desiredTrajectory.size() - 1);
        this.lqrSolver.getOptimalSequence(this.optimalTrajectory);
        this.ddpSolver.initializeFromLQRSolution(DefaultDiscreteState.DEFAULT, this.costFunction, this.optimalTrajectory, this.desiredTrajectory, this.constantSequence, this.lqrSolver.getOptimalFeedbackGainSequence(), this.lqrSolver.getOptimalFeedForwardControlSequence());
    }

    public int solve() {
        int computeSequence = this.ddpSolver.computeSequence(DefaultDiscreteState.DEFAULT, this.costFunction, this.terminalCostFunction);
        this.optimalTrajectory.set(this.ddpSolver.getOptimalSequence());
        return computeSequence;
    }

    private double computeDeltaT(double d) {
        this.numberOfTimeSteps = (int) Math.floor(d / this.deltaT);
        return d / this.numberOfTimeSteps;
    }

    public double getDT() {
        return this.modifiedDeltaT;
    }

    public DiscreteOptimizationTrajectory getOptimalTrajectory() {
        return this.optimalTrajectory;
    }

    public double getValue() {
        return JointAccelerationIntegrationCalculator.DEFAULT_VELOCITY_REFERENCE_ALPHA;
    }
}
