package org.nd4j.evaluation.custom;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import lombok.NonNull;
import org.apache.camel.util.URISupport;
import org.nd4j.evaluation.BaseEvaluation;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.IMetric;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.shade.guava.collect.Lists;

/* loaded from: input_file:org/nd4j/evaluation/custom/CustomEvaluation.class */
public class CustomEvaluation<T> extends BaseEvaluation<CustomEvaluation> {

    @NonNull
    private EvaluationLambda<T> evaluationLambda;

    @NonNull
    private MergeLambda<T> mergeLambda;
    private List<T> evaluations = new ArrayList();

    /* loaded from: input_file:org/nd4j/evaluation/custom/CustomEvaluation$Metric.class */
    public static class Metric<T> implements IMetric {

        @NonNull
        private ResultLambda<T> getResult;
        private boolean minimize;

        @Override // org.nd4j.evaluation.IMetric
        public Class<? extends IEvaluation> getEvaluationClass() {
            return CustomEvaluation.class;
        }

        @Override // org.nd4j.evaluation.IMetric
        public boolean minimize() {
            return this.minimize;
        }

        public static Metric<Double> doubleAverage(boolean z) {
            return new Metric<>(new ResultLambda<Double>() { // from class: org.nd4j.evaluation.custom.CustomEvaluation.Metric.1
                @Override // org.nd4j.evaluation.custom.ResultLambda
                public double toResult(List<Double> list) {
                    int i = 0;
                    double d = 0.0d;
                    Iterator<Double> it = list.iterator();
                    while (it.hasNext()) {
                        i++;
                        d += it.next().doubleValue();
                    }
                    return d / i;
                }
            }, z);
        }

        public static Metric<Double> doubleMax(boolean z) {
            return new Metric<>(new ResultLambda<Double>() { // from class: org.nd4j.evaluation.custom.CustomEvaluation.Metric.2
                @Override // org.nd4j.evaluation.custom.ResultLambda
                public double toResult(List<Double> list) {
                    double d = 0.0d;
                    for (Double d2 : list) {
                        if (d2.doubleValue() > d) {
                            d = d2.doubleValue();
                        }
                    }
                    return d;
                }
            }, z);
        }

        public static Metric<Double> doubleMin(boolean z) {
            return new Metric<>(new ResultLambda<Double>() { // from class: org.nd4j.evaluation.custom.CustomEvaluation.Metric.3
                @Override // org.nd4j.evaluation.custom.ResultLambda
                public double toResult(List<Double> list) {
                    double d = 0.0d;
                    for (Double d2 : list) {
                        if (d2.doubleValue() < d) {
                            d = d2.doubleValue();
                        }
                    }
                    return d;
                }
            }, z);
        }

        public Metric(@NonNull ResultLambda<T> resultLambda, boolean z) {
            this.minimize = false;
            if (resultLambda == null) {
                throw new NullPointerException("getResult is marked @NonNull but is null");
            }
            this.getResult = resultLambda;
            this.minimize = z;
        }

        public Metric(@NonNull ResultLambda<T> resultLambda) {
            this.minimize = false;
            if (resultLambda == null) {
                throw new NullPointerException("getResult is marked @NonNull but is null");
            }
            this.getResult = resultLambda;
        }

        @NonNull
        public ResultLambda<T> getGetResult() {
            return this.getResult;
        }
    }

    public static <R> MergeLambda<R> mergeConcatenate() {
        return new MergeLambda<R>() { // from class: org.nd4j.evaluation.custom.CustomEvaluation.1
            @Override // org.nd4j.evaluation.custom.MergeLambda
            public List<R> merge(List<R> list, List<R> list2) {
                ArrayList newArrayList = Lists.newArrayList(list);
                newArrayList.addAll(list2);
                return newArrayList;
            }
        };
    }

    @Override // org.nd4j.evaluation.IEvaluation
    public void eval(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, List<? extends Serializable> list) {
        this.evaluations.add(this.evaluationLambda.eval(iNDArray, iNDArray2, iNDArray3, list));
    }

    @Override // org.nd4j.evaluation.IEvaluation
    public void merge(CustomEvaluation customEvaluation) {
        this.evaluations = this.mergeLambda.merge(this.evaluations, customEvaluation.evaluations);
    }

    @Override // org.nd4j.evaluation.IEvaluation
    public void reset() {
        this.evaluations = new ArrayList();
    }

    @Override // org.nd4j.evaluation.IEvaluation
    public String stats() {
        return "";
    }

    @Override // org.nd4j.evaluation.IEvaluation
    public double getValue(IMetric iMetric) {
        if (iMetric instanceof Metric) {
            return ((Metric) iMetric).getGetResult().toResult(this.evaluations);
        }
        throw new IllegalStateException("Can't get value for non-regression Metric " + iMetric);
    }

    @Override // org.nd4j.evaluation.IEvaluation
    public CustomEvaluation<T> newInstance() {
        return new CustomEvaluation<>(this.evaluationLambda, this.mergeLambda);
    }

    public CustomEvaluation(@NonNull EvaluationLambda<T> evaluationLambda, @NonNull MergeLambda<T> mergeLambda) {
        if (evaluationLambda == null) {
            throw new NullPointerException("evaluationLambda is marked @NonNull but is null");
        }
        if (mergeLambda == null) {
            throw new NullPointerException("mergeLambda is marked @NonNull but is null");
        }
        this.evaluationLambda = evaluationLambda;
        this.mergeLambda = mergeLambda;
    }

    @NonNull
    public EvaluationLambda<T> getEvaluationLambda() {
        return this.evaluationLambda;
    }

    @NonNull
    public MergeLambda<T> getMergeLambda() {
        return this.mergeLambda;
    }

    public List<T> getEvaluations() {
        return this.evaluations;
    }

    public void setEvaluationLambda(@NonNull EvaluationLambda<T> evaluationLambda) {
        if (evaluationLambda == null) {
            throw new NullPointerException("evaluationLambda is marked @NonNull but is null");
        }
        this.evaluationLambda = evaluationLambda;
    }

    public void setMergeLambda(@NonNull MergeLambda<T> mergeLambda) {
        if (mergeLambda == null) {
            throw new NullPointerException("mergeLambda is marked @NonNull but is null");
        }
        this.mergeLambda = mergeLambda;
    }

    public void setEvaluations(List<T> list) {
        this.evaluations = list;
    }

    @Override // org.nd4j.evaluation.BaseEvaluation
    public String toString() {
        return "CustomEvaluation(evaluationLambda=" + getEvaluationLambda() + ", mergeLambda=" + getMergeLambda() + ", evaluations=" + getEvaluations() + URISupport.RAW_TOKEN_END;
    }

    @Override // org.nd4j.evaluation.BaseEvaluation
    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof CustomEvaluation)) {
            return false;
        }
        CustomEvaluation customEvaluation = (CustomEvaluation) obj;
        if (!customEvaluation.canEqual(this) || !super.equals(obj)) {
            return false;
        }
        EvaluationLambda<T> evaluationLambda = getEvaluationLambda();
        EvaluationLambda<T> evaluationLambda2 = customEvaluation.getEvaluationLambda();
        if (evaluationLambda == null) {
            if (evaluationLambda2 != null) {
                return false;
            }
        } else if (!evaluationLambda.equals(evaluationLambda2)) {
            return false;
        }
        MergeLambda<T> mergeLambda = getMergeLambda();
        MergeLambda<T> mergeLambda2 = customEvaluation.getMergeLambda();
        if (mergeLambda == null) {
            if (mergeLambda2 != null) {
                return false;
            }
        } else if (!mergeLambda.equals(mergeLambda2)) {
            return false;
        }
        List<T> evaluations = getEvaluations();
        List<T> evaluations2 = customEvaluation.getEvaluations();
        return evaluations == null ? evaluations2 == null : evaluations.equals(evaluations2);
    }

    @Override // org.nd4j.evaluation.BaseEvaluation
    protected boolean canEqual(Object obj) {
        return obj instanceof CustomEvaluation;
    }

    @Override // org.nd4j.evaluation.BaseEvaluation
    public int hashCode() {
        int hashCode = super.hashCode();
        EvaluationLambda<T> evaluationLambda = getEvaluationLambda();
        int hashCode2 = (hashCode * 59) + (evaluationLambda == null ? 43 : evaluationLambda.hashCode());
        MergeLambda<T> mergeLambda = getMergeLambda();
        int hashCode3 = (hashCode2 * 59) + (mergeLambda == null ? 43 : mergeLambda.hashCode());
        List<T> evaluations = getEvaluations();
        return (hashCode3 * 59) + (evaluations == null ? 43 : evaluations.hashCode());
    }
}
