package us.ihmc.commonWalkingControlModules.capturePoint.lqrControl;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import org.ejml.data.DMatrixRMaj;
import org.ejml.dense.row.CommonOps_DDRM;
import us.ihmc.commonWalkingControlModules.controllerCore.command.inverseDynamics.LinearMomentumRateCostCommand;
import us.ihmc.commonWalkingControlModules.dynamicPlanning.comPlanning.ContactStateProvider;
import us.ihmc.commonWalkingControlModules.dynamicPlanning.comPlanning.SettableContactStateProvider;
import us.ihmc.commonWalkingControlModules.momentumBasedController.optimization.JointAccelerationIntegrationCalculator;
import us.ihmc.commons.lists.RecyclingArrayList;
import us.ihmc.euclid.referenceFrame.ReferenceFrame;
import us.ihmc.euclid.referenceFrame.interfaces.FramePoint3DReadOnly;
import us.ihmc.robotics.math.trajectories.core.Polynomial3D;
import us.ihmc.robotics.math.trajectories.interfaces.Polynomial3DBasics;
import us.ihmc.robotics.math.trajectories.interfaces.Polynomial3DReadOnly;
import us.ihmc.yoVariables.euclid.referenceFrame.YoFramePoint3D;
import us.ihmc.yoVariables.euclid.referenceFrame.YoFrameVector3D;
import us.ihmc.yoVariables.providers.DoubleProvider;
import us.ihmc.yoVariables.registry.YoRegistry;

/* loaded from: input_file:us/ihmc/commonWalkingControlModules/capturePoint/lqrControl/LQRJumpMomentumController.class */
public class LQRJumpMomentumController {
    private static final double discreteDt = 0.005d;
    private static final double gravityZ = -9.81d;
    private final double totalMass;
    private final YoRegistry registry;
    private final YoFrameVector3D yoK2;
    private final YoFrameVector3D feedbackForce;
    private final YoFramePoint3D relativeCoMPosition;
    private final YoFrameVector3D relativeCoMVelocity;
    private final YoFramePoint3D finalVRPPosition;
    private final YoFramePoint3D referenceVRPPosition;
    private final YoFramePoint3D feedbackVRPPosition;
    static final double defaultVrpTrackingWeight = 1000000.0d;
    static final double defaultMomentumRateWeight = 1.0E-5d;
    private double vrpTrackingWeight;
    private double momentumRateWeight;
    private final LQRCommonValues lqrCommonValues;
    private final AlgebraicS1Function finalS1Function;
    private final DMatrixRMaj Nb;
    private final DMatrixRMaj S1;
    private final DMatrixRMaj s2;
    private final DMatrixRMaj K1;
    private final DMatrixRMaj k2;
    private final DMatrixRMaj u;
    private final DMatrixRMaj R1InverseDQ;
    private final DMatrixRMaj R1InverseBTranspose;
    private final DMatrixRMaj finalVRPState;
    private final DMatrixRMaj relativeState;
    private final DMatrixRMaj relativeDesiredVRP;
    private final DMatrixRMaj linearMomentumRateGradient;
    private final DMatrixRMaj linearMomentumRateHessian;
    final RecyclingArrayList<Polynomial3D> relativeVRPTrajectories;
    final RecyclingArrayList<SettableContactStateProvider> contactStateProviders;
    private boolean shouldUpdateP;
    private boolean shouldUpdateCosts;
    private final HashMap<Polynomial3D, S1Function> s1Functions;
    private final List<S1Function> reversedS1FunctionList;
    private final List<S1Function> s1FunctionList;
    private final List<S2Segment> reversedS2FunctionList;
    private final List<S2Segment> s2FunctionList;
    private final LinearMomentumRateCostCommand momentumRateCostCommand;

    public LQRJumpMomentumController(DoubleProvider doubleProvider, double d) {
        this(doubleProvider, d, null);
    }

    public LQRJumpMomentumController(DoubleProvider doubleProvider, double d, YoRegistry yoRegistry) {
        this.registry = new YoRegistry(getClass().getSimpleName());
        this.yoK2 = new YoFrameVector3D("k2", ReferenceFrame.getWorldFrame(), this.registry);
        this.feedbackForce = new YoFrameVector3D("feedbackForce", ReferenceFrame.getWorldFrame(), this.registry);
        this.relativeCoMPosition = new YoFramePoint3D("relativeCoMPosition", ReferenceFrame.getWorldFrame(), this.registry);
        this.relativeCoMVelocity = new YoFrameVector3D("relativeCoMVelocity", ReferenceFrame.getWorldFrame(), this.registry);
        this.finalVRPPosition = new YoFramePoint3D("finalVRPPosition", ReferenceFrame.getWorldFrame(), this.registry);
        this.referenceVRPPosition = new YoFramePoint3D("referenceVRPPosition", ReferenceFrame.getWorldFrame(), this.registry);
        this.feedbackVRPPosition = new YoFramePoint3D("feedbackVRPPosition", ReferenceFrame.getWorldFrame(), this.registry);
        this.vrpTrackingWeight = defaultVrpTrackingWeight;
        this.momentumRateWeight = 1.0E-5d;
        this.lqrCommonValues = new LQRCommonValues();
        this.finalS1Function = new AlgebraicS1Function();
        this.Nb = new DMatrixRMaj(3, 6);
        this.S1 = new DMatrixRMaj(6, 6);
        this.s2 = new DMatrixRMaj(6, 1);
        this.K1 = new DMatrixRMaj(3, 6);
        this.k2 = new DMatrixRMaj(3, 1);
        this.u = new DMatrixRMaj(3, 1);
        this.R1InverseDQ = new DMatrixRMaj(3, 3);
        this.R1InverseBTranspose = new DMatrixRMaj(3, 6);
        this.finalVRPState = new DMatrixRMaj(6, 1);
        this.relativeState = new DMatrixRMaj(6, 1);
        this.relativeDesiredVRP = new DMatrixRMaj(3, 1);
        this.linearMomentumRateGradient = new DMatrixRMaj(1, 3);
        this.linearMomentumRateHessian = new DMatrixRMaj(3, 3);
        this.relativeVRPTrajectories = new RecyclingArrayList<>(() -> {
            return new Polynomial3D(4);
        });
        this.contactStateProviders = new RecyclingArrayList<>(SettableContactStateProvider::new);
        this.shouldUpdateP = true;
        this.shouldUpdateCosts = true;
        this.s1Functions = new HashMap<>();
        this.reversedS1FunctionList = new ArrayList();
        this.s1FunctionList = new ArrayList();
        this.reversedS2FunctionList = new ArrayList();
        this.s2FunctionList = new ArrayList();
        this.momentumRateCostCommand = new LinearMomentumRateCostCommand();
        this.totalMass = d;
        computeDynamicsMatrix(doubleProvider.getValue());
        computeP();
        if (yoRegistry != null) {
            yoRegistry.addChild(this.registry);
        }
    }

    public void setVRPTrackingWeight(double d) {
        this.vrpTrackingWeight = d;
        this.shouldUpdateP = true;
    }

    public void setMomentumRateWeight(double d) {
        this.momentumRateWeight = d;
        this.shouldUpdateP = true;
    }

    private void computeDynamicsMatrix(double d) {
        this.lqrCommonValues.computeDynamicsMatrix(d);
        this.shouldUpdateP = true;
    }

    public void setVRPTrajectory(List<? extends Polynomial3DReadOnly> list, List<? extends ContactStateProvider> list2) {
        if (list.size() != list2.size()) {
            throw new IllegalArgumentException("The contacts don't match the trajectory.");
        }
        this.relativeVRPTrajectories.clear();
        this.contactStateProviders.clear();
        Polynomial3DReadOnly polynomial3DReadOnly = list.get(list.size() - 1);
        polynomial3DReadOnly.compute(Math.min(10.0d, polynomial3DReadOnly.getTimeInterval().getEndTime()));
        this.finalVRPPosition.set(polynomial3DReadOnly.getPosition());
        this.finalVRPPosition.get(this.finalVRPState);
        for (int i = 0; i < list.size(); i++) {
            Polynomial3DReadOnly polynomial3DReadOnly2 = list.get(i);
            Polynomial3DBasics polynomial3DBasics = (Polynomial3DBasics) this.relativeVRPTrajectories.add();
            polynomial3DBasics.set(polynomial3DReadOnly2);
            polynomial3DBasics.shiftTrajectory(-this.finalVRPPosition.getX(), -this.finalVRPPosition.getY(), -this.finalVRPPosition.getZ());
            ((SettableContactStateProvider) this.contactStateProviders.add()).set(list2.get(i));
        }
        if (this.shouldUpdateP) {
            computeP();
        }
        computeS1Segments();
        computeS2Segments();
    }

    void computeP() {
        this.lqrCommonValues.computeEquivalentCostValues(this.momentumRateWeight, this.vrpTrackingWeight);
        this.finalS1Function.set(this.lqrCommonValues);
        this.shouldUpdateP = false;
    }

    /* JADX WARN: Multi-variable type inference failed */
    void computeS1Segments() {
        AlgebraicS1Function algebraicS1Function;
        this.s1Functions.clear();
        this.reversedS1FunctionList.clear();
        this.s1FunctionList.clear();
        int size = this.relativeVRPTrajectories.size() - 1;
        if (size < 0) {
            this.reversedS1FunctionList.add(this.finalS1Function);
        } else {
            Polynomial3D polynomial3D = (Polynomial3D) this.relativeVRPTrajectories.get(size);
            this.finalS1Function.compute(JointAccelerationIntegrationCalculator.DEFAULT_VELOCITY_REFERENCE_ALPHA, this.S1);
            this.s1Functions.put(polynomial3D, this.finalS1Function);
            this.reversedS1FunctionList.add(this.finalS1Function);
            boolean z = false;
            for (int i = size - 1; i >= 0; i--) {
                Polynomial3D polynomial3D2 = (Polynomial3D) this.relativeVRPTrajectories.get(i);
                if (((SettableContactStateProvider) this.contactStateProviders.get(i)).getContactState().isLoadBearing()) {
                    if (z) {
                        DifferentialS1Segment differentialS1Segment = new DifferentialS1Segment(discreteDt);
                        differentialS1Segment.set(this.lqrCommonValues, this.S1, polynomial3D2.getTimeInterval().getDuration());
                        algebraicS1Function = differentialS1Segment;
                    } else {
                        algebraicS1Function = this.finalS1Function;
                    }
                    algebraicS1Function.compute(JointAccelerationIntegrationCalculator.DEFAULT_VELOCITY_REFERENCE_ALPHA, this.S1);
                    this.s1Functions.put(polynomial3D2, algebraicS1Function);
                    this.reversedS1FunctionList.add(algebraicS1Function);
                } else {
                    z = true;
                    FlightS1Function flightS1Function = new FlightS1Function();
                    flightS1Function.set(this.S1, polynomial3D2.getDuration());
                    flightS1Function.compute(JointAccelerationIntegrationCalculator.DEFAULT_VELOCITY_REFERENCE_ALPHA, this.S1);
                    this.s1Functions.put(polynomial3D2, flightS1Function);
                    this.reversedS1FunctionList.add(flightS1Function);
                }
            }
        }
        for (int size2 = this.reversedS1FunctionList.size() - 1; size2 >= 0; size2--) {
            this.s1FunctionList.add(this.reversedS1FunctionList.get(size2));
        }
    }

    void computeS2Segments() {
        this.reversedS2FunctionList.clear();
        this.s2FunctionList.clear();
        int size = this.relativeVRPTrajectories.size();
        int i = 0;
        for (int i2 = size - 1; i2 >= 0 && ((SettableContactStateProvider) this.contactStateProviders.get(i2)).getContactState().isLoadBearing(); i2--) {
            i++;
        }
        ArrayList arrayList = new ArrayList();
        for (int i3 = size - i; i3 < size; i3++) {
            arrayList.add(this.relativeVRPTrajectories.get(i3));
        }
        this.s2.zero();
        AlgebraicS2Function algebraicS2Function = new AlgebraicS2Function();
        this.finalS1Function.compute(JointAccelerationIntegrationCalculator.DEFAULT_VELOCITY_REFERENCE_ALPHA, this.S1);
        this.lqrCommonValues.computeS2ConstantStateMatrices(this.S1);
        algebraicS2Function.set(this.s2, arrayList, this.lqrCommonValues);
        algebraicS2Function.compute(0, JointAccelerationIntegrationCalculator.DEFAULT_VELOCITY_REFERENCE_ALPHA, this.s2);
        for (int i4 = i - 1; i4 >= 0; i4--) {
            this.reversedS2FunctionList.add(algebraicS2Function.getSegment(i4));
        }
        for (int i5 = (size - i) - 1; i5 >= 0; i5--) {
            Polynomial3D polynomial3D = (Polynomial3D) this.relativeVRPTrajectories.get(i5);
            if (((SettableContactStateProvider) this.contactStateProviders.get(i5)).getContactState().isLoadBearing()) {
                DifferentialS2Segment differentialS2Segment = new DifferentialS2Segment(discreteDt);
                differentialS2Segment.set(this.s1Functions.get(polynomial3D), polynomial3D, this.lqrCommonValues, this.s2);
                differentialS2Segment.compute(JointAccelerationIntegrationCalculator.DEFAULT_VELOCITY_REFERENCE_ALPHA, this.s2);
                this.reversedS2FunctionList.add(differentialS2Segment);
            } else {
                this.s1Functions.get(polynomial3D).compute(polynomial3D.getTimeInterval().getDuration(), this.S1);
                FlightS2Function flightS2Function = new FlightS2Function(gravityZ);
                flightS2Function.set(this.S1, this.s2, polynomial3D.getTimeInterval().getDuration());
                flightS2Function.compute(JointAccelerationIntegrationCalculator.DEFAULT_VELOCITY_REFERENCE_ALPHA, this.s2);
                this.reversedS2FunctionList.add(flightS2Function);
            }
        }
        for (int size2 = this.reversedS2FunctionList.size() - 1; size2 >= 0; size2--) {
            this.s2FunctionList.add(this.reversedS2FunctionList.get(size2));
        }
    }

    void computeS1AndK1(double d) {
        int segmentNumber = getSegmentNumber(d);
        this.s1FunctionList.get(segmentNumber).compute(computeTimeInSegment(d, segmentNumber), this.S1);
        this.Nb.set(this.lqrCommonValues.getNTranspose());
        CommonOps_DDRM.multAddTransA(this.lqrCommonValues.getB(), this.S1, this.Nb);
        CommonOps_DDRM.mult(-1.0d, this.lqrCommonValues.getR1Inverse(), this.Nb, this.K1);
    }

    void computeS2AndK2(double d) {
        int segmentNumber = getSegmentNumber(d);
        double min = Math.min(10.0d, computeTimeInSegment(d, segmentNumber));
        ((Polynomial3D) this.relativeVRPTrajectories.get(segmentNumber)).compute(min);
        this.referenceVRPPosition.set(((Polynomial3D) this.relativeVRPTrajectories.get(segmentNumber)).getPosition());
        this.referenceVRPPosition.get(this.relativeDesiredVRP);
        this.referenceVRPPosition.add(this.finalVRPPosition);
        this.s2FunctionList.get(segmentNumber).compute(min, this.s2);
        CommonOps_DDRM.mult(this.lqrCommonValues.getR1Inverse(), this.lqrCommonValues.getDQ(), this.R1InverseDQ);
        CommonOps_DDRM.multTransB(-0.5d, this.lqrCommonValues.getR1Inverse(), this.lqrCommonValues.getB(), this.R1InverseBTranspose);
        CommonOps_DDRM.mult(this.R1InverseDQ, this.relativeDesiredVRP, this.k2);
        CommonOps_DDRM.multAdd(this.R1InverseBTranspose, this.s2, this.k2);
        this.yoK2.set(this.k2);
    }

    public void computeControlInput(DMatrixRMaj dMatrixRMaj, double d) {
        this.shouldUpdateCosts = true;
        computeS1AndK1(d);
        computeS2AndK2(d);
        this.relativeState.set(dMatrixRMaj);
        for (int i = 0; i < 3; i++) {
            this.relativeState.add(i, 0, -this.finalVRPState.get(i));
        }
        this.relativeCoMPosition.set(this.relativeState);
        this.relativeCoMVelocity.set(3, this.relativeState);
        CommonOps_DDRM.mult(this.K1, this.relativeState, this.u);
        this.feedbackForce.set(this.u);
        CommonOps_DDRM.addEquals(this.u, this.k2);
        CommonOps_DDRM.mult(this.lqrCommonValues.getC(), this.relativeState, this.relativeDesiredVRP);
        CommonOps_DDRM.multAdd(this.lqrCommonValues.getD(), this.u, this.relativeDesiredVRP);
        this.feedbackVRPPosition.set(this.relativeDesiredVRP);
        this.feedbackVRPPosition.add(this.finalVRPPosition);
        double d2 = 1.0d / this.totalMass;
        CommonOps_DDRM.multTransA((-2.0d) * d2, this.k2, this.lqrCommonValues.getR1(), this.linearMomentumRateGradient);
        CommonOps_DDRM.multAddTransAB(2.0d * d2, this.relativeState, this.Nb, this.linearMomentumRateGradient);
        CommonOps_DDRM.scale(2.0d * d2 * d2, this.lqrCommonValues.getR1(), this.linearMomentumRateHessian);
        this.momentumRateCostCommand.setLinearMomentumRateGradient(this.linearMomentumRateGradient);
        this.momentumRateCostCommand.setLinearMomentumRateHessian(this.linearMomentumRateHessian);
    }

    public DMatrixRMaj getU() {
        return this.u;
    }

    public DMatrixRMaj getCostHessian() {
        return this.S1;
    }

    public DMatrixRMaj getCostJacobian() {
        return this.s2;
    }

    private int getSegmentNumber(double d) {
        double d2 = 0.0d;
        for (int i = 0; i < this.relativeVRPTrajectories.size(); i++) {
            double duration = ((Polynomial3D) this.relativeVRPTrajectories.get(i)).getDuration();
            if (d - d2 <= duration) {
                return i;
            }
            d2 += duration;
        }
        return -1;
    }

    private double computeTimeInSegment(double d, int i) {
        double d2 = 0.0d;
        for (int i2 = 0; i2 < i; i2++) {
            d2 += ((Polynomial3D) this.relativeVRPTrajectories.get(i2)).getDuration();
        }
        return d - d2;
    }

    DMatrixRMaj getA() {
        return this.lqrCommonValues.getA();
    }

    DMatrixRMaj getB() {
        return this.lqrCommonValues.getB();
    }

    DMatrixRMaj getC() {
        return this.lqrCommonValues.getC();
    }

    DMatrixRMaj getD() {
        return this.lqrCommonValues.getD();
    }

    DMatrixRMaj getQ() {
        return this.lqrCommonValues.getQ();
    }

    DMatrixRMaj getR() {
        return this.lqrCommonValues.getR();
    }

    DMatrixRMaj getK1() {
        return this.K1;
    }

    DMatrixRMaj getK2() {
        return this.k2;
    }

    S1Function getS1Segment(int i) {
        return this.s1FunctionList.get(i);
    }

    S2Segment getS2Segment(int i) {
        return this.s2FunctionList.get(i);
    }

    public FramePoint3DReadOnly getFeedbackVRPPosition() {
        return this.feedbackVRPPosition;
    }

    private void computeMomentumRateCostCommand() {
        double d = 1.0d / this.totalMass;
        CommonOps_DDRM.multTransA((-2.0d) * d, this.k2, this.lqrCommonValues.getR1(), this.linearMomentumRateGradient);
        CommonOps_DDRM.multAddTransAB(2.0d * d, this.relativeState, this.Nb, this.linearMomentumRateGradient);
        CommonOps_DDRM.scale(2.0d * d * d, this.lqrCommonValues.getR1(), this.linearMomentumRateHessian);
        this.momentumRateCostCommand.setLinearMomentumRateGradient(this.linearMomentumRateGradient);
        this.momentumRateCostCommand.setLinearMomentumRateHessian(this.linearMomentumRateHessian);
    }

    public LinearMomentumRateCostCommand getMomentumRateCostCommand() {
        if (this.shouldUpdateCosts) {
            this.shouldUpdateCosts = false;
            computeMomentumRateCostCommand();
        }
        return this.momentumRateCostCommand;
    }
}
