package us.ihmc.commonWalkingControlModules.dynamicPlanning;

import java.lang.Enum;
import java.util.Random;
import org.ejml.data.DMatrixRMaj;
import org.ejml.dense.row.CommonOps_DDRM;
import us.ihmc.matrixlib.MatrixTestTools;
import us.ihmc.robotics.random.RandomGeometry;
import us.ihmc.trajectoryOptimization.LQTrackingCostFunction;

/* loaded from: input_file:us/ihmc/commonWalkingControlModules/dynamicPlanning/TrackingCostFunctionTest.class */
public abstract class TrackingCostFunctionTest<E extends Enum<E>> {
    public abstract int getNumberOfStates();

    public abstract int getStateVectorSize();

    public abstract int getControlVectorSize();

    public abstract int getConstantVectorSize();

    public abstract E getHybridState(int i);

    public abstract LQTrackingCostFunction<E> getCostFunction();

    public abstract void testCost();

    public void testCostStateGradientNumerically() {
        LQTrackingCostFunction<E> costFunction = getCostFunction();
        for (int i = 0; i < getNumberOfStates(); i++) {
            E hybridState = getHybridState(i);
            Random random = new Random(1738L);
            DMatrixRMaj nextDenseMatrix64F = RandomGeometry.nextDenseMatrix64F(random, getStateVectorSize(), 1);
            DMatrixRMaj nextDenseMatrix64F2 = RandomGeometry.nextDenseMatrix64F(random, getControlVectorSize(), 1);
            DMatrixRMaj nextDenseMatrix64F3 = RandomGeometry.nextDenseMatrix64F(random, getStateVectorSize(), 1);
            DMatrixRMaj nextDenseMatrix64F4 = RandomGeometry.nextDenseMatrix64F(random, getControlVectorSize(), 1);
            DMatrixRMaj nextDenseMatrix64F5 = RandomGeometry.nextDenseMatrix64F(random, getConstantVectorSize(), 1);
            DMatrixRMaj dMatrixRMaj = new DMatrixRMaj(getStateVectorSize(), 1);
            DMatrixRMaj dMatrixRMaj2 = new DMatrixRMaj(getStateVectorSize(), 1);
            double cost = costFunction.getCost(hybridState, nextDenseMatrix64F2, nextDenseMatrix64F, nextDenseMatrix64F4, nextDenseMatrix64F3, nextDenseMatrix64F5);
            double max = Math.max(1.0d * Math.pow(10.0d, (int) (Math.log10(cost) + 1.0d)) * 1.0E-18d, 1.0E-7d);
            costFunction.getCostStateGradient(hybridState, nextDenseMatrix64F2, nextDenseMatrix64F, nextDenseMatrix64F4, nextDenseMatrix64F3, nextDenseMatrix64F5, dMatrixRMaj);
            for (int i2 = 0; i2 < getStateVectorSize(); i2++) {
                DMatrixRMaj dMatrixRMaj3 = new DMatrixRMaj(nextDenseMatrix64F);
                dMatrixRMaj3.add(i2, 0, max);
                dMatrixRMaj2.set(i2, 0, (costFunction.getCost(hybridState, nextDenseMatrix64F2, dMatrixRMaj3, nextDenseMatrix64F4, nextDenseMatrix64F3, nextDenseMatrix64F5) - cost) / max);
            }
            MatrixTestTools.assertMatrixEquals(dMatrixRMaj2, dMatrixRMaj, 0.001d * Math.abs(CommonOps_DDRM.elementSum(dMatrixRMaj2)));
        }
    }

    public void testCostControlGradientNumerically() {
        LQTrackingCostFunction<E> costFunction = getCostFunction();
        for (int i = 0; i < getNumberOfStates(); i++) {
            E hybridState = getHybridState(i);
            Random random = new Random(1738L);
            DMatrixRMaj nextDenseMatrix64F = RandomGeometry.nextDenseMatrix64F(random, getStateVectorSize(), 1);
            DMatrixRMaj nextDenseMatrix64F2 = RandomGeometry.nextDenseMatrix64F(random, getControlVectorSize(), 1);
            DMatrixRMaj nextDenseMatrix64F3 = RandomGeometry.nextDenseMatrix64F(random, getStateVectorSize(), 1);
            DMatrixRMaj nextDenseMatrix64F4 = RandomGeometry.nextDenseMatrix64F(random, getControlVectorSize(), 1);
            DMatrixRMaj nextDenseMatrix64F5 = RandomGeometry.nextDenseMatrix64F(random, getConstantVectorSize(), 1);
            DMatrixRMaj dMatrixRMaj = new DMatrixRMaj(getControlVectorSize(), 1);
            DMatrixRMaj dMatrixRMaj2 = new DMatrixRMaj(getControlVectorSize(), 1);
            double cost = costFunction.getCost(hybridState, nextDenseMatrix64F2, nextDenseMatrix64F, nextDenseMatrix64F4, nextDenseMatrix64F3, nextDenseMatrix64F5);
            double max = Math.max(1.0d * Math.pow(10.0d, (int) (Math.log10(cost) + 1.0d)) * 1.0E-18d, 1.0E-9d);
            costFunction.getCostControlGradient(hybridState, nextDenseMatrix64F2, nextDenseMatrix64F, nextDenseMatrix64F4, nextDenseMatrix64F3, nextDenseMatrix64F5, dMatrixRMaj);
            for (int i2 = 0; i2 < getControlVectorSize(); i2++) {
                DMatrixRMaj dMatrixRMaj3 = new DMatrixRMaj(nextDenseMatrix64F2);
                dMatrixRMaj3.add(i2, 0, max);
                dMatrixRMaj2.set(i2, 0, (costFunction.getCost(hybridState, dMatrixRMaj3, nextDenseMatrix64F, nextDenseMatrix64F4, nextDenseMatrix64F3, nextDenseMatrix64F5) - cost) / max);
            }
            MatrixTestTools.assertMatrixEquals(dMatrixRMaj2, dMatrixRMaj, Math.max(1.0E-10d, 0.001d * Math.abs(CommonOps_DDRM.elementSum(dMatrixRMaj2))));
        }
    }

    public void testCostStateHessianNumerically() {
        LQTrackingCostFunction<E> costFunction = getCostFunction();
        for (int i = 0; i < getNumberOfStates(); i++) {
            E hybridState = getHybridState(i);
            Random random = new Random(1738L);
            DMatrixRMaj nextDenseMatrix64F = RandomGeometry.nextDenseMatrix64F(random, getStateVectorSize(), 1);
            DMatrixRMaj nextDenseMatrix64F2 = RandomGeometry.nextDenseMatrix64F(random, getControlVectorSize(), 1);
            DMatrixRMaj nextDenseMatrix64F3 = RandomGeometry.nextDenseMatrix64F(random, getStateVectorSize(), 1);
            DMatrixRMaj nextDenseMatrix64F4 = RandomGeometry.nextDenseMatrix64F(random, getControlVectorSize(), 1);
            DMatrixRMaj nextDenseMatrix64F5 = RandomGeometry.nextDenseMatrix64F(random, getConstantVectorSize(), 1);
            DMatrixRMaj dMatrixRMaj = new DMatrixRMaj(getStateVectorSize(), getStateVectorSize());
            DMatrixRMaj dMatrixRMaj2 = new DMatrixRMaj(getStateVectorSize(), getStateVectorSize());
            DMatrixRMaj dMatrixRMaj3 = new DMatrixRMaj(getStateVectorSize(), 1);
            DMatrixRMaj dMatrixRMaj4 = new DMatrixRMaj(getStateVectorSize(), 1);
            costFunction.getCostStateHessian(hybridState, nextDenseMatrix64F2, nextDenseMatrix64F, nextDenseMatrix64F5, dMatrixRMaj);
            costFunction.getCostStateGradient(hybridState, nextDenseMatrix64F2, nextDenseMatrix64F, nextDenseMatrix64F4, nextDenseMatrix64F3, nextDenseMatrix64F5, dMatrixRMaj3);
            for (int i2 = 0; i2 < getStateVectorSize(); i2++) {
                DMatrixRMaj dMatrixRMaj5 = new DMatrixRMaj(nextDenseMatrix64F);
                dMatrixRMaj5.add(i2, 0, 1.0E-9d);
                costFunction.getCostStateGradient(hybridState, nextDenseMatrix64F2, dMatrixRMaj5, nextDenseMatrix64F4, nextDenseMatrix64F3, nextDenseMatrix64F5, dMatrixRMaj4);
                for (int i3 = 0; i3 < getStateVectorSize(); i3++) {
                    dMatrixRMaj2.set(i3, i2, (dMatrixRMaj4.get(i3) - dMatrixRMaj3.get(i3)) / 1.0E-9d);
                }
            }
            MatrixTestTools.assertMatrixEquals(dMatrixRMaj2, dMatrixRMaj, 1.0E-5d * Math.abs(CommonOps_DDRM.trace(dMatrixRMaj2)));
        }
    }

    public void testCostControlHessianNumerically() {
        LQTrackingCostFunction<E> costFunction = getCostFunction();
        for (int i = 0; i < getNumberOfStates(); i++) {
            E hybridState = getHybridState(i);
            Random random = new Random(1738L);
            DMatrixRMaj nextDenseMatrix64F = RandomGeometry.nextDenseMatrix64F(random, getStateVectorSize(), 1);
            DMatrixRMaj nextDenseMatrix64F2 = RandomGeometry.nextDenseMatrix64F(random, getControlVectorSize(), 1);
            DMatrixRMaj nextDenseMatrix64F3 = RandomGeometry.nextDenseMatrix64F(random, getStateVectorSize(), 1);
            DMatrixRMaj nextDenseMatrix64F4 = RandomGeometry.nextDenseMatrix64F(random, getControlVectorSize(), 1);
            DMatrixRMaj nextDenseMatrix64F5 = RandomGeometry.nextDenseMatrix64F(random, getConstantVectorSize(), 1);
            DMatrixRMaj dMatrixRMaj = new DMatrixRMaj(getControlVectorSize(), getControlVectorSize());
            DMatrixRMaj dMatrixRMaj2 = new DMatrixRMaj(getControlVectorSize(), getControlVectorSize());
            DMatrixRMaj dMatrixRMaj3 = new DMatrixRMaj(getControlVectorSize(), 1);
            DMatrixRMaj dMatrixRMaj4 = new DMatrixRMaj(getControlVectorSize(), 1);
            costFunction.getCostControlHessian(hybridState, nextDenseMatrix64F2, nextDenseMatrix64F, nextDenseMatrix64F5, dMatrixRMaj);
            costFunction.getCostControlGradient(hybridState, nextDenseMatrix64F2, nextDenseMatrix64F, nextDenseMatrix64F4, nextDenseMatrix64F3, nextDenseMatrix64F5, dMatrixRMaj3);
            for (int i2 = 0; i2 < getControlVectorSize(); i2++) {
                DMatrixRMaj dMatrixRMaj5 = new DMatrixRMaj(nextDenseMatrix64F2);
                dMatrixRMaj5.add(i2, 0, 1.0E-9d);
                costFunction.getCostControlGradient(hybridState, dMatrixRMaj5, nextDenseMatrix64F, nextDenseMatrix64F4, nextDenseMatrix64F3, nextDenseMatrix64F5, dMatrixRMaj4);
                for (int i3 = 0; i3 < getControlVectorSize(); i3++) {
                    dMatrixRMaj2.set(i3, i2, (dMatrixRMaj4.get(i3) - dMatrixRMaj3.get(i3)) / 1.0E-9d);
                }
            }
            MatrixTestTools.assertMatrixEquals(dMatrixRMaj2, dMatrixRMaj, 1.0E-5d * Math.abs(CommonOps_DDRM.trace(dMatrixRMaj2)));
        }
    }

    public void testCostStateControlHessianNumerically() {
        LQTrackingCostFunction<E> costFunction = getCostFunction();
        for (int i = 0; i < getNumberOfStates(); i++) {
            E hybridState = getHybridState(i);
            Random random = new Random(1738L);
            DMatrixRMaj nextDenseMatrix64F = RandomGeometry.nextDenseMatrix64F(random, getStateVectorSize(), 1);
            DMatrixRMaj nextDenseMatrix64F2 = RandomGeometry.nextDenseMatrix64F(random, getControlVectorSize(), 1);
            DMatrixRMaj nextDenseMatrix64F3 = RandomGeometry.nextDenseMatrix64F(random, getStateVectorSize(), 1);
            DMatrixRMaj nextDenseMatrix64F4 = RandomGeometry.nextDenseMatrix64F(random, getControlVectorSize(), 1);
            DMatrixRMaj nextDenseMatrix64F5 = RandomGeometry.nextDenseMatrix64F(random, getConstantVectorSize(), 1);
            DMatrixRMaj dMatrixRMaj = new DMatrixRMaj(getStateVectorSize(), getControlVectorSize());
            DMatrixRMaj dMatrixRMaj2 = new DMatrixRMaj(getStateVectorSize(), getControlVectorSize());
            DMatrixRMaj dMatrixRMaj3 = new DMatrixRMaj(getStateVectorSize(), 1);
            DMatrixRMaj dMatrixRMaj4 = new DMatrixRMaj(getStateVectorSize(), 1);
            costFunction.getCostControlGradientOfStateGradient(hybridState, nextDenseMatrix64F2, nextDenseMatrix64F, nextDenseMatrix64F5, dMatrixRMaj);
            costFunction.getCostStateGradient(hybridState, nextDenseMatrix64F2, nextDenseMatrix64F, nextDenseMatrix64F4, nextDenseMatrix64F3, nextDenseMatrix64F5, dMatrixRMaj3);
            for (int i2 = 0; i2 < getControlVectorSize(); i2++) {
                DMatrixRMaj dMatrixRMaj5 = new DMatrixRMaj(nextDenseMatrix64F2);
                dMatrixRMaj5.add(i2, 0, 1.0E-9d);
                costFunction.getCostStateGradient(hybridState, dMatrixRMaj5, nextDenseMatrix64F, nextDenseMatrix64F4, nextDenseMatrix64F3, nextDenseMatrix64F5, dMatrixRMaj4);
                for (int i3 = 0; i3 < getStateVectorSize(); i3++) {
                    dMatrixRMaj2.set(i3, i2, (dMatrixRMaj4.get(i3) - dMatrixRMaj3.get(i3)) / 1.0E-9d);
                }
            }
            MatrixTestTools.assertMatrixEquals(dMatrixRMaj2, dMatrixRMaj, 1.0E-5d * Math.abs(CommonOps_DDRM.trace(dMatrixRMaj2)));
        }
    }
}
