package us.ihmc.commonWalkingControlModules.modelPredictiveController.core;

import gnu.trove.list.TIntList;
import org.ejml.data.DMatrixRMaj;
import us.ihmc.commonWalkingControlModules.controllerCore.command.ConstraintType;
import us.ihmc.commonWalkingControlModules.modelPredictiveController.commands.ForceObjectiveCommand;
import us.ihmc.commonWalkingControlModules.modelPredictiveController.commands.ForceRateTrackingCommand;
import us.ihmc.commonWalkingControlModules.modelPredictiveController.commands.ForceTrackingCommand;
import us.ihmc.commonWalkingControlModules.modelPredictiveController.commands.MPCCommand;
import us.ihmc.commonWalkingControlModules.modelPredictiveController.commands.MPCCommandList;
import us.ihmc.commonWalkingControlModules.modelPredictiveController.commands.MPCCommandType;
import us.ihmc.commonWalkingControlModules.modelPredictiveController.commands.MPCContinuityCommand;
import us.ihmc.commonWalkingControlModules.modelPredictiveController.commands.MPCValueCommand;
import us.ihmc.commonWalkingControlModules.modelPredictiveController.commands.NormalForceBoundCommand;
import us.ihmc.commonWalkingControlModules.modelPredictiveController.commands.RhoBoundCommand;
import us.ihmc.commonWalkingControlModules.modelPredictiveController.commands.RhoObjectiveCommand;
import us.ihmc.commonWalkingControlModules.modelPredictiveController.commands.RhoRateTrackingCommand;
import us.ihmc.commonWalkingControlModules.modelPredictiveController.commands.RhoTrackingCommand;
import us.ihmc.commonWalkingControlModules.modelPredictiveController.commands.VRPTrackingCommand;
import us.ihmc.commonWalkingControlModules.momentumBasedController.optimization.NativeQPInputTypeA;
import us.ihmc.commonWalkingControlModules.momentumBasedController.optimization.NativeQPInputTypeC;
import us.ihmc.convexOptimization.quadraticProgram.InverseMatrixCalculator;
import us.ihmc.matrixlib.NativeMatrix;
import us.ihmc.robotics.time.ExecutionTimer;
import us.ihmc.yoVariables.registry.YoRegistry;
import us.ihmc.yoVariables.variable.YoBoolean;
import us.ihmc.yoVariables.variable.YoDouble;
import us.ihmc.yoVariables.variable.YoInteger;

/* loaded from: input_file:us/ihmc/commonWalkingControlModules/modelPredictiveController/core/LinearMPCQPSolver.class */
public class LinearMPCQPSolver {
    private static final boolean debug = false;
    protected final YoRegistry registry;
    private final ExecutionTimer qpSolverTimer;
    public final MPCQPSolver qpSolver;
    private final YoBoolean addRateRegularization;
    private final YoBoolean foundSolution;
    protected final NativeMatrix previousSolution;
    public final NativeQPInputTypeA qpInputTypeA;
    public final NativeQPInputTypeC qpInputTypeC;
    protected final NativeMatrix solverOutput;
    private final YoInteger numberOfActiveVariables;
    private final YoInteger numberOfIterations;
    private final YoInteger numberOfEqualityConstraints;
    private final YoInteger numberOfInequalityConstraints;
    private final YoInteger numberOfConstraints;
    private final YoDouble comCoefficientRegularization;
    private final YoDouble rhoCoefficientRegularization;
    private final YoDouble comRateCoefficientRegularization;
    private final YoDouble rhoRateCoefficientRegularization;
    private int problemSize;
    private boolean resetActiveSet;
    private boolean useWarmStart;
    private int maxNumberOfIterations;
    private final LinearMPCIndexHandler indexHandler;
    private final MPCQPInputCalculator inputCalculator;
    protected final double dt;
    protected final double dt2;
    private final RowMajorNativeMatrixGrower nativeMatrixGrower;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: us.ihmc.commonWalkingControlModules.modelPredictiveController.core.LinearMPCQPSolver$1, reason: invalid class name */
    /* loaded from: input_file:us/ihmc/commonWalkingControlModules/modelPredictiveController/core/LinearMPCQPSolver$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$us$ihmc$commonWalkingControlModules$modelPredictiveController$commands$MPCCommandType;

        static {
            try {
                $SwitchMap$us$ihmc$commonWalkingControlModules$controllerCore$command$ConstraintType[ConstraintType.OBJECTIVE.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$us$ihmc$commonWalkingControlModules$controllerCore$command$ConstraintType[ConstraintType.EQUALITY.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$us$ihmc$commonWalkingControlModules$controllerCore$command$ConstraintType[ConstraintType.LEQ_INEQUALITY.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$us$ihmc$commonWalkingControlModules$controllerCore$command$ConstraintType[ConstraintType.GEQ_INEQUALITY.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            $SwitchMap$us$ihmc$commonWalkingControlModules$modelPredictiveController$commands$MPCCommandType = new int[MPCCommandType.values().length];
            try {
                $SwitchMap$us$ihmc$commonWalkingControlModules$modelPredictiveController$commands$MPCCommandType[MPCCommandType.VALUE.ordinal()] = 1;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$us$ihmc$commonWalkingControlModules$modelPredictiveController$commands$MPCCommandType[MPCCommandType.CONTINUITY.ordinal()] = 2;
            } catch (NoSuchFieldError e6) {
            }
            try {
                $SwitchMap$us$ihmc$commonWalkingControlModules$modelPredictiveController$commands$MPCCommandType[MPCCommandType.LIST.ordinal()] = 3;
            } catch (NoSuchFieldError e7) {
            }
            try {
                $SwitchMap$us$ihmc$commonWalkingControlModules$modelPredictiveController$commands$MPCCommandType[MPCCommandType.RHO_VALUE.ordinal()] = 4;
            } catch (NoSuchFieldError e8) {
            }
            try {
                $SwitchMap$us$ihmc$commonWalkingControlModules$modelPredictiveController$commands$MPCCommandType[MPCCommandType.VRP_TRACKING.ordinal()] = 5;
            } catch (NoSuchFieldError e9) {
            }
            try {
                $SwitchMap$us$ihmc$commonWalkingControlModules$modelPredictiveController$commands$MPCCommandType[MPCCommandType.RHO_BOUND.ordinal()] = 6;
            } catch (NoSuchFieldError e10) {
            }
            try {
                $SwitchMap$us$ihmc$commonWalkingControlModules$modelPredictiveController$commands$MPCCommandType[MPCCommandType.NORMAL_FORCE_BOUND.ordinal()] = 7;
            } catch (NoSuchFieldError e11) {
            }
            try {
                $SwitchMap$us$ihmc$commonWalkingControlModules$modelPredictiveController$commands$MPCCommandType[MPCCommandType.FORCE_VALUE.ordinal()] = 8;
            } catch (NoSuchFieldError e12) {
            }
            try {
                $SwitchMap$us$ihmc$commonWalkingControlModules$modelPredictiveController$commands$MPCCommandType[MPCCommandType.FORCE_TRACKING.ordinal()] = 9;
            } catch (NoSuchFieldError e13) {
            }
            try {
                $SwitchMap$us$ihmc$commonWalkingControlModules$modelPredictiveController$commands$MPCCommandType[MPCCommandType.FORCE_RATE_TRACKING.ordinal()] = 10;
            } catch (NoSuchFieldError e14) {
            }
            try {
                $SwitchMap$us$ihmc$commonWalkingControlModules$modelPredictiveController$commands$MPCCommandType[MPCCommandType.RHO_TRACKING.ordinal()] = 11;
            } catch (NoSuchFieldError e15) {
            }
            try {
                $SwitchMap$us$ihmc$commonWalkingControlModules$modelPredictiveController$commands$MPCCommandType[MPCCommandType.RHO_RATE_TRACKING.ordinal()] = 12;
            } catch (NoSuchFieldError e16) {
            }
        }
    }

    /* JADX WARN: 'this' call moved to the top of the method (can break code semantics) */
    public LinearMPCQPSolver(LinearMPCIndexHandler linearMPCIndexHandler, double d, double d2, YoRegistry yoRegistry) {
        this(linearMPCIndexHandler, d, d2, new BlockInverseCalculator(linearMPCIndexHandler, linearMPCIndexHandler::getComCoefficientStartIndex, i -> {
            return linearMPCIndexHandler.getRhoCoefficientsInSegment(i) + 6;
        }), yoRegistry);
        linearMPCIndexHandler.getClass();
    }

    public LinearMPCQPSolver(LinearMPCIndexHandler linearMPCIndexHandler, double d, double d2, InverseMatrixCalculator<NativeMatrix> inverseMatrixCalculator, YoRegistry yoRegistry) {
        this.registry = new YoRegistry(getClass().getSimpleName());
        this.qpSolverTimer = new ExecutionTimer("mpcSolverTimer", 0.5d, this.registry);
        this.addRateRegularization = new YoBoolean("AddRateRegularization", this.registry);
        this.foundSolution = new YoBoolean("foundSolution", this.registry);
        this.qpInputTypeA = new NativeQPInputTypeA(0);
        this.qpInputTypeC = new NativeQPInputTypeC(0);
        this.numberOfActiveVariables = new YoInteger("numberOfActiveMPCVariables", this.registry);
        this.numberOfIterations = new YoInteger("numberOfMPCIterations", this.registry);
        this.numberOfEqualityConstraints = new YoInteger("numberOfMPCEqualityConstraints", this.registry);
        this.numberOfInequalityConstraints = new YoInteger("numberOfMPCInequalityConstraints", this.registry);
        this.numberOfConstraints = new YoInteger("numberOfMPCConstraints", this.registry);
        this.comCoefficientRegularization = new YoDouble("comCoefficientRegularization", this.registry);
        this.rhoCoefficientRegularization = new YoDouble("rhoCoefficientRegularization", this.registry);
        this.comRateCoefficientRegularization = new YoDouble("comRateCoefficientRegularization", this.registry);
        this.rhoRateCoefficientRegularization = new YoDouble("rhoRateCoefficientRegularization", this.registry);
        this.resetActiveSet = false;
        this.useWarmStart = false;
        this.maxNumberOfIterations = 100;
        this.nativeMatrixGrower = new RowMajorNativeMatrixGrower();
        this.indexHandler = linearMPCIndexHandler;
        this.dt = d;
        this.dt2 = d * d;
        this.rhoCoefficientRegularization.set(1.0E-5d);
        this.comCoefficientRegularization.set(1.0E-5d);
        this.rhoRateCoefficientRegularization.set(1.0E-10d);
        this.comRateCoefficientRegularization.set(1.0E-10d);
        this.qpSolver = new MPCQPSolver();
        this.qpSolver.setConvergenceThreshold(5.0E-6d);
        this.qpSolver.setConvergenceThresholdForLagrangeMultipliers(1.0E-4d);
        if (inverseMatrixCalculator != null) {
            this.qpSolver.setInverseHessianCalculator(inverseMatrixCalculator);
        }
        this.qpSolver.setResetActiveSetOnSizeChange(false);
        this.inputCalculator = new MPCQPInputCalculator(linearMPCIndexHandler, d2);
        this.previousSolution = new NativeMatrix(0, 0);
        this.solverOutput = new NativeMatrix(114, 1);
        yoRegistry.addChild(this.registry);
    }

    public void setComCoefficientRegularizationWeight(double d) {
        this.comCoefficientRegularization.set(d);
    }

    public void setRhoCoefficientRegularizationWeight(double d) {
        this.rhoCoefficientRegularization.set(d);
    }

    public void setComRateCoefficientRegularizationWeight(double d) {
        this.comRateCoefficientRegularization.set(d);
    }

    public void setRhoRateCoefficientRegularizationWeight(double d) {
        this.rhoRateCoefficientRegularization.set(d);
    }

    public void setUseWarmStart(boolean z) {
        this.useWarmStart = z;
    }

    public void setMaxNumberOfIterations(int i) {
        this.maxNumberOfIterations = i;
    }

    public void notifyResetActiveSet() {
        this.resetActiveSet = true;
    }

    public void setPreviousSolution(DMatrixRMaj dMatrixRMaj) {
        this.previousSolution.set(dMatrixRMaj);
        this.addRateRegularization.set(true);
    }

    private boolean pollResetActiveSet() {
        boolean z = this.resetActiveSet;
        this.resetActiveSet = false;
        return z;
    }

    public void initialize() {
        this.problemSize = this.indexHandler.getTotalProblemSize();
        this.qpInputTypeA.setNumberOfVariables(this.problemSize);
        this.qpInputTypeC.setNumberOfVariables(this.problemSize);
        this.solverOutput.reshape(this.problemSize, 1);
        resetRateRegularization();
        this.qpSolver.initialize(this.problemSize);
    }

    public void resetRateRegularization() {
        this.addRateRegularization.set(false);
    }

    private void addCoefficientRegularization() {
        addValueRegularization();
        if (this.addRateRegularization.getBooleanValue()) {
            addRateRegularization();
        }
    }

    public void addValueRegularization() {
        for (int i = 0; i < this.indexHandler.getNumberOfSegments(); i++) {
            this.qpSolver.addRegularization(this.indexHandler.getComCoefficientStartIndex(i), 6, this.comCoefficientRegularization.getValue());
            this.qpSolver.addRegularization(this.indexHandler.getRhoCoefficientStartIndex(i), this.indexHandler.getRhoCoefficientsInSegment(i), this.rhoCoefficientRegularization.getValue());
        }
    }

    public void addRateRegularization() {
        double doubleValue = this.comRateCoefficientRegularization.getDoubleValue() / this.dt2;
        double doubleValue2 = this.rhoRateCoefficientRegularization.getDoubleValue() / this.dt2;
        for (int i = 0; i < this.indexHandler.getNumberOfSegments(); i++) {
            int comCoefficientStartIndex = this.indexHandler.getComCoefficientStartIndex(i);
            this.qpSolver.addRateRegularization(comCoefficientStartIndex, 6, doubleValue, this.previousSolution);
            this.qpSolver.addRateRegularization(comCoefficientStartIndex + 6, this.indexHandler.getRhoCoefficientsInSegment(i), doubleValue2, this.previousSolution);
        }
    }

    public void submitMPCCommandList(MPCCommandList mPCCommandList) {
        for (int i = 0; i < mPCCommandList.getNumberOfCommands(); i++) {
            submitMPCCommand(mPCCommandList.getCommand(i));
        }
    }

    public void submitMPCCommand(MPCCommand<?> mPCCommand) {
        switch (AnonymousClass1.$SwitchMap$us$ihmc$commonWalkingControlModules$modelPredictiveController$commands$MPCCommandType[mPCCommand.getCommandType().ordinal()]) {
            case 1:
                submitMPCValueObjective((MPCValueCommand) mPCCommand);
                return;
            case 2:
                submitContinuityObjective((MPCContinuityCommand) mPCCommand);
                return;
            case 3:
                submitMPCCommandList((MPCCommandList) mPCCommand);
                return;
            case 4:
                submitRhoValueCommand((RhoObjectiveCommand) mPCCommand);
                return;
            case 5:
                submitVRPTrackingCommand((VRPTrackingCommand) mPCCommand);
                return;
            case 6:
                submitRhoBoundCommand((RhoBoundCommand) mPCCommand);
                return;
            case 7:
                submitNormalForceBoundCommand((NormalForceBoundCommand) mPCCommand);
                return;
            case 8:
                submitForceValueCommand((ForceObjectiveCommand) mPCCommand);
                return;
            case 9:
                submitForceTrackingCommand((ForceTrackingCommand) mPCCommand);
                return;
            case thetaYDot:
                submitForceRateTrackingCommand((ForceRateTrackingCommand) mPCCommand);
                return;
            case thetaZDot:
                submitRhoTrackingCommand((RhoTrackingCommand) mPCCommand);
                return;
            case stateVectorSize:
                submitRhoRateTrackingCommand((RhoRateTrackingCommand) mPCCommand);
                return;
            default:
                throw new RuntimeException("The command type: " + mPCCommand.getCommandType() + " is not handled.");
        }
    }

    public void submitRhoValueCommand(RhoObjectiveCommand rhoObjectiveCommand) {
        int calculateCompactRhoValueCommand = this.inputCalculator.calculateCompactRhoValueCommand(this.qpInputTypeA, rhoObjectiveCommand);
        if (calculateCompactRhoValueCommand != -1) {
            addInput(this.qpInputTypeA, calculateCompactRhoValueCommand);
        }
    }

    public void submitMPCValueObjective(MPCValueCommand mPCValueCommand) {
        int calculateCompactValueObjective = this.inputCalculator.calculateCompactValueObjective(this.qpInputTypeA, mPCValueCommand);
        if (calculateCompactValueObjective != -1) {
            addInput(this.qpInputTypeA, calculateCompactValueObjective);
        }
    }

    public void submitContinuityObjective(MPCContinuityCommand mPCContinuityCommand) {
        int calculateContinuityObjective = this.inputCalculator.calculateContinuityObjective(this.qpInputTypeA, mPCContinuityCommand);
        if (calculateContinuityObjective != -1) {
            addInput(this.qpInputTypeA, calculateContinuityObjective);
        }
    }

    public void submitVRPTrackingCommand(VRPTrackingCommand vRPTrackingCommand) {
        int calculateCompactVRPTrackingObjective = this.inputCalculator.calculateCompactVRPTrackingObjective(this.qpInputTypeC, vRPTrackingCommand);
        if (calculateCompactVRPTrackingObjective != -1) {
            addInput(this.qpInputTypeC, calculateCompactVRPTrackingObjective);
        }
    }

    public void submitRhoBoundCommand(RhoBoundCommand rhoBoundCommand) {
        int calculateRhoBoundCommandCompact = this.inputCalculator.calculateRhoBoundCommandCompact(this.qpInputTypeA, rhoBoundCommand);
        if (calculateRhoBoundCommandCompact != -1) {
            addInput(this.qpInputTypeA, calculateRhoBoundCommandCompact, rhoBoundCommand.getSlackVariableWeight());
        }
    }

    public void submitNormalForceBoundCommand(NormalForceBoundCommand normalForceBoundCommand) {
        int calculateNormalForceBoundCommandCompact = this.inputCalculator.calculateNormalForceBoundCommandCompact(this.qpInputTypeA, normalForceBoundCommand);
        if (calculateNormalForceBoundCommandCompact != -1) {
            addInput(this.qpInputTypeA, calculateNormalForceBoundCommandCompact);
        }
    }

    public void submitForceValueCommand(ForceObjectiveCommand forceObjectiveCommand) {
        if (this.inputCalculator.calculateForceMinimizationObjective(this.qpInputTypeC, forceObjectiveCommand)) {
            addInput(this.qpInputTypeC);
        }
    }

    public void submitForceTrackingCommand(ForceTrackingCommand forceTrackingCommand) {
        if (this.inputCalculator.calculateForceTrackingObjective(this.qpInputTypeC, forceTrackingCommand) != -1) {
            addInput(this.qpInputTypeC);
        }
    }

    public void submitForceRateTrackingCommand(ForceRateTrackingCommand forceRateTrackingCommand) {
        if (this.inputCalculator.calculateForceRateTrackingObjective(this.qpInputTypeC, forceRateTrackingCommand) != -1) {
            addInput(this.qpInputTypeC);
        }
    }

    public void submitRhoTrackingCommand(RhoTrackingCommand rhoTrackingCommand) {
        int calculateRhoTrackingObjective = this.inputCalculator.calculateRhoTrackingObjective(this.qpInputTypeC, rhoTrackingCommand);
        if (calculateRhoTrackingObjective != -1) {
            addInput(this.qpInputTypeC, calculateRhoTrackingObjective);
        }
    }

    public void submitRhoRateTrackingCommand(RhoRateTrackingCommand rhoRateTrackingCommand) {
        int calculateRhoRateTrackingObjective = this.inputCalculator.calculateRhoRateTrackingObjective(this.qpInputTypeC, rhoRateTrackingCommand);
        if (calculateRhoRateTrackingObjective != -1) {
            addInput(this.qpInputTypeC, calculateRhoRateTrackingObjective);
        }
    }

    public void addInput(NativeQPInputTypeA nativeQPInputTypeA) {
        addInput(nativeQPInputTypeA, 0);
    }

    public void addInput(NativeQPInputTypeA nativeQPInputTypeA, int i) {
        addInput(nativeQPInputTypeA, i, Double.NaN);
    }

    public void addInput(NativeQPInputTypeA nativeQPInputTypeA, int i, double d) {
        switch (nativeQPInputTypeA.getConstraintType()) {
            case OBJECTIVE:
                if (nativeQPInputTypeA.useWeightScalar()) {
                    this.qpSolver.addObjective(nativeQPInputTypeA.taskJacobian, nativeQPInputTypeA.taskObjective, nativeQPInputTypeA.getWeightScalar(), i);
                    return;
                } else {
                    this.qpSolver.addObjective(nativeQPInputTypeA.taskJacobian, nativeQPInputTypeA.taskObjective, nativeQPInputTypeA.getTaskWeightMatrix(), i);
                    return;
                }
            case EQUALITY:
                this.qpSolver.addEqualityConstraint(nativeQPInputTypeA.taskJacobian, nativeQPInputTypeA.taskObjective, this.problemSize, i);
                return;
            case LEQ_INEQUALITY:
                this.qpSolver.addMotionLesserOrEqualInequalityConstraint(nativeQPInputTypeA.taskJacobian, nativeQPInputTypeA.taskObjective, d, this.problemSize, i);
                return;
            case GEQ_INEQUALITY:
                this.qpSolver.addMotionGreaterOrEqualInequalityConstraint(nativeQPInputTypeA.taskJacobian, nativeQPInputTypeA.taskObjective, d, this.problemSize, i);
                return;
            default:
                throw new RuntimeException("Unexpected constraint type: " + nativeQPInputTypeA.getConstraintType());
        }
    }

    public void addInput(NativeQPInputTypeC nativeQPInputTypeC) {
        addInput(nativeQPInputTypeC, 0);
    }

    public void addInput(NativeQPInputTypeC nativeQPInputTypeC, int i) {
        if (!nativeQPInputTypeC.useWeightScalar()) {
            throw new IllegalArgumentException("Not yet implemented.");
        }
        this.qpSolver.addDirectObjective(nativeQPInputTypeC.directCostHessian, nativeQPInputTypeC.directCostGradient, nativeQPInputTypeC.getWeightScalar(), i);
    }

    public boolean solve() {
        addCoefficientRegularization();
        this.numberOfEqualityConstraints.set(this.qpSolver.getNumberOfEqualityConstraints());
        this.numberOfInequalityConstraints.set(this.qpSolver.getNumberOfInequalityConstraints());
        this.numberOfConstraints.set(this.numberOfEqualityConstraints.getIntegerValue() + this.numberOfInequalityConstraints.getIntegerValue());
        this.qpSolverTimer.startMeasurement();
        this.qpSolver.setUseWarmStart(this.useWarmStart);
        this.qpSolver.setMaxNumberOfIterations(this.maxNumberOfIterations);
        if (this.useWarmStart && pollResetActiveSet()) {
            this.qpSolver.resetActiveSet();
        }
        this.numberOfActiveVariables.set(this.problemSize);
        this.numberOfIterations.set(this.qpSolver.solve(this.solverOutput));
        this.qpSolverTimer.stopMeasurement();
        if (!this.solverOutput.containsNaN()) {
            this.foundSolution.set(true);
            this.addRateRegularization.set(true);
            return true;
        }
        this.addRateRegularization.set(false);
        this.numberOfIterations.set(-1);
        this.foundSolution.set(false);
        return false;
    }

    public NativeMatrix getSolution() {
        return this.solverOutput;
    }

    public void setActiveInequalityIndices(TIntList tIntList) {
        this.qpSolver.setActiveInequalityIndices(tIntList);
    }

    public TIntList getActiveInequalityIndices() {
        return this.qpSolver.getActiveInequalityIndices();
    }
}
