package org.nd4j.autodiff.validation;

import java.lang.reflect.Field;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.camel.util.URISupport;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/autodiff/validation/GradCheckUtil.class */
public class GradCheckUtil {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) GradCheckUtil.class);
    private static final boolean DEFAULT_PRINT = true;
    private static final boolean DEFAULT_EXIT_FIRST_FAILURE = false;
    private static final boolean DEFAULT_DEBUG_MODE = false;
    private static final double DEFAULT_EPS = 1.0E-5d;
    private static final double DEFAULT_MAX_REL_ERROR = 1.0E-5d;
    private static final double DEFAULT_MIN_ABS_ERROR = 1.0E-6d;

    public static boolean checkGradients(SDVariable sDVariable, SDVariable sDVariable2, double d, double d2, boolean z, Map<String, INDArray> map) {
        if (d <= 0.0d || d > 0.1d) {
            throw new IllegalArgumentException("Invalid epsilon: expect epsilon in range (0,0.1], usually 1e-4 or so");
        }
        if (d2 <= 0.0d || d2 > 0.25d) {
            throw new IllegalArgumentException("Invalid maxRelativeError: " + d2);
        }
        DataBuffer.Type dtypeFromContext = DataTypeUtil.getDtypeFromContext();
        if (dtypeFromContext != DataBuffer.Type.DOUBLE) {
            throw new IllegalStateException("Cannot perform gradient check: Datatype is not set to double precision (is: " + dtypeFromContext + "). Double precision must be used for gradient checks. Set DataTypeUtil.setDTypeForContext(DataBuffer.Type.DOUBLE); before using GradientCheckUtil");
        }
        SameDiff sameDiff = sDVariable.getSameDiff();
        INDArray[] eval = SameDiff.create(sameDiff).eval(map);
        int i = 0;
        for (Map.Entry<String, INDArray> entry : map.entrySet()) {
            long length = entry.getValue().length();
            INDArray dup = entry.getValue().dup();
            for (int i2 = 0; i2 < length; i2++) {
                Nd4j.create(length).putScalar(i2, d / 2.0d);
                dup.putScalar(i2, dup.getDouble(i2) + d);
                HashMap hashMap = new HashMap();
                for (Map.Entry<String, INDArray> entry2 : map.entrySet()) {
                    if (entry2.getKey().equals(entry.getKey())) {
                        hashMap.put(entry.getKey(), dup);
                    } else {
                        hashMap.put(entry2.getKey(), entry2.getValue());
                    }
                }
                INDArray[] eval2 = sameDiff.eval(hashMap);
                INDArray[] eval3 = sameDiff.eval(hashMap);
                INDArray[] iNDArrayArr = new INDArray[eval3.length];
                for (int i3 = 0; i3 < iNDArrayArr.length; i3++) {
                    iNDArrayArr[i3] = eval2[i3].subi(eval3[i3]).divi(Double.valueOf(d));
                }
                if (Math.abs(eval[eval.length - 1].sumNumber().doubleValue() - ((eval2[eval2.length - 1].sumNumber().doubleValue() - eval3[eval3.length - 1].sumNumber().doubleValue()) / d)) > d2) {
                    i++;
                }
                if (z) {
                    log.info("GradientCheckUtil.checkGradients(): " + length + " params checked, " + (length - i) + " passed, " + i + " failed. Largest relative error = 0.0");
                }
            }
        }
        return i == 0;
    }

    public static boolean checkGradients(TestCase testCase) {
        return checkGradients(testCase.sameDiff(), testCase.gradCheckEpsilon(), testCase.gradCheckMaxRelativeError(), testCase.gradCheckMinAbsError(), testCase.gradCheckPrint(), testCase.gradCheckDefaultExitFirstFailure(), false, testCase.gradCheckDebugMode(), testCase.gradCheckSkipVariables());
    }

    public static boolean checkGradients(SameDiff sameDiff) {
        return checkGradients(sameDiff, true, false);
    }

    public static boolean checkGradients(SameDiff sameDiff, String... strArr) {
        HashSet hashSet = null;
        if (strArr != null) {
            hashSet = new HashSet();
            Collections.addAll(hashSet, strArr);
        }
        return checkGradients(sameDiff, 1.0E-5d, 1.0E-5d, 1.0E-6d, true, false, false, false, hashSet);
    }

    public static boolean checkGradients(SameDiff sameDiff, boolean z, boolean z2) {
        return checkGradients(sameDiff, 1.0E-5d, 1.0E-5d, 1.0E-6d, z, z2);
    }

    public static boolean checkGradients(SameDiff sameDiff, double d, double d2, double d3, boolean z, boolean z2) {
        return checkGradients(sameDiff, d, d2, d3, z, z2, false, false, null);
    }

    public static boolean checkGradients(SameDiff sameDiff, double d, double d2, double d3, boolean z, boolean z2, boolean z3, boolean z4, Set<String> set) {
        boolean isDebugMode = sameDiff.isDebugMode();
        if (z4) {
            sameDiff.enableDebugMode();
        }
        if (!z3) {
            validateInternalState(sameDiff, true);
        }
        if (Nd4j.dataType() != DataBuffer.Type.DOUBLE) {
            throw new IllegalStateException("Data type must be set to double");
        }
        HashSet hashSet = new HashSet();
        for (DifferentialFunction differentialFunction : sameDiff.functions()) {
            for (SDVariable sDVariable : differentialFunction.outputVariables()) {
                hashSet.add(sDVariable.getVarName());
            }
        }
        for (SDVariable sDVariable2 : sameDiff.variables()) {
            if (!hashSet.contains(sDVariable2.getVarName()) && sDVariable2.getArr() == null) {
                throw new IllegalStateException("Variable \"" + sDVariable2.getVarName() + "\" does not have array associated with it");
            }
        }
        INDArray execAndEndResult = sameDiff.execAndEndResult();
        if (execAndEndResult.length() != 1) {
            throw new IllegalStateException("Output variable is not a scalar - has shape " + Arrays.toString(execAndEndResult.shape()));
        }
        sameDiff.execBackwards();
        HashMap hashMap = new HashMap();
        for (SDVariable sDVariable3 : sameDiff.variables()) {
            if (!hashSet.contains(sDVariable3.getVarName())) {
                SDVariable grad = sameDiff.grad(sDVariable3.getVarName());
                if (grad == null) {
                    throw new IllegalStateException("Null gradient variable for \"" + sDVariable3.getVarName() + "\"");
                }
                INDArray arr = grad.getArr();
                if (arr == null) {
                    throw new IllegalStateException("Null gradient array encountered for variable: " + sDVariable3.getVarName());
                }
                if (!Arrays.equals(sDVariable3.getArr().shape(), grad.getArr().shape())) {
                    throw new IllegalStateException("Gradient shape does not match variable shape for variable \"" + sDVariable3.getVarName() + "\": shape " + Arrays.toString(sDVariable3.getArr().shape()) + " vs. gradient shape " + Arrays.toString(arr.shape()));
                }
                hashMap.put(sDVariable3.getVarName(), arr.dup());
            }
        }
        int i = 0;
        int i2 = 0;
        double d4 = 0.0d;
        for (SDVariable sDVariable4 : sameDiff.variables()) {
            if (!hashSet.contains(sDVariable4.getVarName())) {
                if (set == null || !set.contains(sDVariable4.getVarName())) {
                    String varName = sDVariable4.getVarName();
                    INDArray arr2 = sDVariable4.getArr();
                    long length = arr2.length();
                    if (z) {
                        log.info("Starting test for variable \"{}\" with {} values", sDVariable4.getVarName(), Long.valueOf(length));
                    }
                    NdIndexIterator ndIndexIterator = new NdIndexIterator('c', arr2.shape());
                    int i3 = 0;
                    while (ndIndexIterator.hasNext()) {
                        long[] next = ndIndexIterator.next();
                        String replaceAll = z ? Arrays.toString(next).replaceAll(" ", "") : null;
                        i2++;
                        double d5 = arr2.getDouble(next);
                        arr2.putScalar(next, d5 + d);
                        double d6 = sameDiff.execAndEndResult().getDouble(0L);
                        arr2.putScalar(next, d5 - d);
                        double d7 = sameDiff.execAndEndResult().getDouble(0L);
                        arr2.putScalar(next, d5);
                        double d8 = (d6 - d7) / (2.0d * d);
                        double d9 = ((INDArray) hashMap.get(sDVariable4.getVarName())).getDouble(next);
                        if (Double.isInfinite(d8) || Double.isNaN(d8)) {
                            throw new IllegalStateException("Numerical gradient was " + d8 + " for variable \"" + varName + "\", parameter " + i3 + " of " + length + " (position: " + replaceAll + URISupport.RAW_TOKEN_END);
                        }
                        if (Double.isInfinite(d9) || Double.isNaN(d9)) {
                            throw new IllegalStateException("Analytic (SameDiff) gradient was " + d9 + " for variable \"" + varName + "\", parameter " + i3 + " of " + length + " (position: " + replaceAll + URISupport.RAW_TOKEN_END);
                        }
                        double abs = (d8 == 0.0d && d9 == 0.0d) ? 0.0d : Math.abs(d9 - d8) / Math.abs(Math.abs(d9) + Math.abs(d8));
                        if (abs > d4) {
                            d4 = abs;
                        }
                        if (abs > d2 || Double.isNaN(abs)) {
                            double abs2 = Math.abs(d9 - d8);
                            if (abs2 >= d3) {
                                if (z) {
                                    log.info("Param " + i3 + " (" + varName + replaceAll + ") FAILED: grad= " + d9 + ", numericalGrad= " + d8 + ", relError= " + abs + ", absError=" + abs2 + ", scorePlus=" + d6 + ", scoreMinus= " + d7);
                                }
                                if (z2) {
                                    return false;
                                }
                                i++;
                            } else if (z) {
                                log.info("Param " + i3 + " (" + varName + replaceAll + ") passed: grad= " + d9 + ", numericalGrad= " + d8 + ", relError= " + abs + "; absolute error = " + abs2 + " < minAbsoluteError = " + d3);
                            }
                        } else if (z) {
                            log.info("Param " + i3 + " (" + varName + replaceAll + ") passed: grad= " + d9 + ", numericalGrad= " + d8 + ", relError= " + abs);
                        }
                        i3++;
                    }
                } else {
                    log.info("Grad check: skipping variable \"{}\"", sDVariable4.getVarName());
                }
            }
        }
        if (z) {
            log.info("GradCheckUtil.checkGradients(): " + i2 + " params checked, " + (i2 - i) + " passed, " + i + " failed. Largest relative error = " + d4);
        }
        if (z4 && !isDebugMode) {
            sameDiff.disableDebugging();
        }
        return i == 0;
    }

    public static void validateInternalState(SameDiff sameDiff, boolean z) {
        DifferentialFunction[] functions = sameDiff.functions();
        List<SDVariable> variables = sameDiff.variables();
        Preconditions.checkState(variables.size() == new HashSet(variables).size(), "Duplicate variables in variables() list");
        HashSet hashSet = new HashSet();
        for (SDVariable sDVariable : variables) {
            if (hashSet.contains(sDVariable.getVarName())) {
                throw new IllegalStateException("Variable with name " + sDVariable.getVarName() + " already encountered");
            }
            hashSet.add(sDVariable.getVarName());
        }
        Map map = (Map) getObject("incomingArgsReverse", sameDiff, SameDiff.class);
        Map map2 = (Map) getObject("outgoingArgsReverse", sameDiff, SameDiff.class);
        Preconditions.checkState(functions.length == map.size(), "All functions not present in incomingArgsReverse");
        Preconditions.checkState(functions.length == map2.size(), "All functions not present in outgoingArgsReverse");
        for (DifferentialFunction differentialFunction : functions) {
            Preconditions.checkState(map.containsKey(differentialFunction.getOwnName()), differentialFunction.getOwnName() + " not present in incomingArgsReverse");
            Preconditions.checkState(map2.containsKey(differentialFunction.getOwnName()), differentialFunction.getOwnName() + " not present in outgoingArgsReverse");
            for (String str : (String[]) map.get(differentialFunction.getOwnName())) {
                Preconditions.checkState(hashSet.contains(str), "Variable " + str + " in incomingArgsReverse value not a known variable name");
            }
            for (String str2 : (String[]) map2.get(differentialFunction.getOwnName())) {
                Preconditions.checkState(hashSet.contains(str2), "Variable " + str2 + " in outgoingArgsReverse value not a known variable name");
            }
        }
        HashMap hashMap = new HashMap();
        for (Map.Entry entry : map2.entrySet()) {
            for (String str3 : (String[]) entry.getValue()) {
                if (hashMap.containsKey(str3)) {
                    throw new IllegalStateException("Already saw variable \"" + str3 + "\" as output for op \"" + ((String) hashMap.get(str3)) + "\": expected variables to be present as an output only once; also seen as output for op \"" + ((String) entry.getKey()) + "\"");
                }
                hashMap.put(str3, entry.getKey());
            }
        }
        Map map3 = (Map) getObject("variableMap", sameDiff, SameDiff.class);
        Preconditions.checkState(variables.size() == map3.size(), "Variable map size check failed");
        for (Map.Entry entry2 : map3.entrySet()) {
            Preconditions.checkState(((String) entry2.getKey()).equals(((SDVariable) entry2.getValue()).getVarName()), "Name not equal");
        }
        if (z) {
            if (sameDiff.getFunction("grad") == null) {
                sameDiff.createGradFunction();
            }
            SameDiff function = sameDiff.getFunction("grad");
            validateInternalState(function, false);
            for (DifferentialFunction differentialFunction2 : functions) {
                Preconditions.checkNotNull(function.getFunctionById(differentialFunction2.getOwnName()), "DifferentialFunction " + differentialFunction2.getOwnName() + " from original SameDiff instance not present in grad fn");
            }
        }
    }

    private static <T> T getObject(String str, Object obj, Class<?> cls) {
        try {
            Field declaredField = cls.getDeclaredField(str);
            declaredField.setAccessible(true);
            return (T) declaredField.get(obj);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
}
