package org.nd4j.evaluation.regression;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import org.apache.camel.util.URISupport;
import org.nd4j.evaluation.BaseEvaluation;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.IMetric;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.ReduceOp;
import org.nd4j.linalg.api.ops.impl.reduce.same.ASum;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Triple;
import org.nd4j.serde.jackson.shaded.NDArrayTextDeSerializer;
import org.nd4j.serde.jackson.shaded.NDArrayTextSerializer;
import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize;
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;

/* loaded from: input_file:org/nd4j/evaluation/regression/RegressionEvaluation.class */
public class RegressionEvaluation extends BaseEvaluation<RegressionEvaluation> {
    public static final int DEFAULT_PRECISION = 5;
    protected int axis;
    private boolean initialized;
    private List<String> columnNames;
    private long precision;

    @JsonDeserialize(using = NDArrayTextDeSerializer.class)
    @JsonSerialize(using = NDArrayTextSerializer.class)
    private INDArray exampleCountPerColumn;

    @JsonDeserialize(using = NDArrayTextDeSerializer.class)
    @JsonSerialize(using = NDArrayTextSerializer.class)
    private INDArray labelsSumPerColumn;

    @JsonDeserialize(using = NDArrayTextDeSerializer.class)
    @JsonSerialize(using = NDArrayTextSerializer.class)
    private INDArray sumSquaredErrorsPerColumn;

    @JsonDeserialize(using = NDArrayTextDeSerializer.class)
    @JsonSerialize(using = NDArrayTextSerializer.class)
    private INDArray sumAbsErrorsPerColumn;

    @JsonDeserialize(using = NDArrayTextDeSerializer.class)
    @JsonSerialize(using = NDArrayTextSerializer.class)
    private INDArray currentMean;

    @JsonDeserialize(using = NDArrayTextDeSerializer.class)
    @JsonSerialize(using = NDArrayTextSerializer.class)
    private INDArray currentPredictionMean;

    @JsonDeserialize(using = NDArrayTextDeSerializer.class)
    @JsonSerialize(using = NDArrayTextSerializer.class)
    private INDArray sumOfProducts;

    @JsonDeserialize(using = NDArrayTextDeSerializer.class)
    @JsonSerialize(using = NDArrayTextSerializer.class)
    private INDArray sumSquaredLabels;

    @JsonDeserialize(using = NDArrayTextDeSerializer.class)
    @JsonSerialize(using = NDArrayTextSerializer.class)
    private INDArray sumSquaredPredicted;

    @JsonDeserialize(using = NDArrayTextDeSerializer.class)
    @JsonSerialize(using = NDArrayTextSerializer.class)
    private INDArray sumLabels;

    /* loaded from: input_file:org/nd4j/evaluation/regression/RegressionEvaluation$Metric.class */
    public enum Metric implements IMetric {
        MSE,
        MAE,
        RMSE,
        RSE,
        PC,
        R2;

        @Override // org.nd4j.evaluation.IMetric
        public Class<? extends IEvaluation> getEvaluationClass() {
            return RegressionEvaluation.class;
        }

        @Override // org.nd4j.evaluation.IMetric
        public boolean minimize() {
            return (this == R2 || this == PC) ? false : true;
        }
    }

    protected RegressionEvaluation(int i, List<String> list, long j) {
        this.axis = 1;
        this.axis = i;
        this.columnNames = list;
        this.precision = j;
    }

    public RegressionEvaluation() {
        this((List<String>) null, 5L);
    }

    public RegressionEvaluation(long j) {
        this(createDefaultColumnNames(j), 5L);
    }

    public RegressionEvaluation(long j, long j2) {
        this(createDefaultColumnNames(j), j2);
    }

    public RegressionEvaluation(String... strArr) {
        this((List<String>) ((strArr == null || strArr.length == 0) ? null : Arrays.asList(strArr)), 5L);
    }

    public RegressionEvaluation(List<String> list) {
        this(list, 5L);
    }

    public RegressionEvaluation(List<String> list, long j) {
        this.axis = 1;
        this.precision = j;
        if (list == null || list.isEmpty()) {
            this.initialized = false;
        } else {
            this.columnNames = list;
            initialize(list.size());
        }
    }

    public void setAxis(int i) {
        this.axis = i;
    }

    public int getAxis() {
        return this.axis;
    }

    @Override // org.nd4j.evaluation.IEvaluation
    public void reset() {
        this.initialized = false;
    }

    private void initialize(int i) {
        if (this.columnNames == null || this.columnNames.size() != i) {
            this.columnNames = createDefaultColumnNames(i);
        }
        this.exampleCountPerColumn = Nd4j.zeros(DataType.DOUBLE, i);
        this.labelsSumPerColumn = Nd4j.zeros(DataType.DOUBLE, i);
        this.sumSquaredErrorsPerColumn = Nd4j.zeros(DataType.DOUBLE, i);
        this.sumAbsErrorsPerColumn = Nd4j.zeros(DataType.DOUBLE, i);
        this.currentMean = Nd4j.zeros(DataType.DOUBLE, i);
        this.currentPredictionMean = Nd4j.zeros(DataType.DOUBLE, i);
        this.sumOfProducts = Nd4j.zeros(DataType.DOUBLE, i);
        this.sumSquaredLabels = Nd4j.zeros(DataType.DOUBLE, i);
        this.sumSquaredPredicted = Nd4j.zeros(DataType.DOUBLE, i);
        this.sumLabels = Nd4j.zeros(DataType.DOUBLE, i);
        this.initialized = true;
    }

    private static List<String> createDefaultColumnNames(long j) {
        ArrayList arrayList = new ArrayList((int) j);
        for (int i = 0; i < j; i++) {
            arrayList.add("col_" + i);
        }
        return arrayList;
    }

    @Override // org.nd4j.evaluation.BaseEvaluation, org.nd4j.evaluation.IEvaluation
    public void eval(INDArray iNDArray, INDArray iNDArray2) {
        eval(iNDArray, iNDArray2, (INDArray) null);
    }

    @Override // org.nd4j.evaluation.IEvaluation
    public void eval(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, List<? extends Serializable> list) {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override // org.nd4j.evaluation.BaseEvaluation, org.nd4j.evaluation.IEvaluation
    public void eval(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        Triple<INDArray, INDArray, INDArray> reshapeAndExtractNotMasked = BaseEvaluation.reshapeAndExtractNotMasked(iNDArray, iNDArray2, iNDArray3, this.axis);
        INDArray first = reshapeAndExtractNotMasked.getFirst();
        INDArray second = reshapeAndExtractNotMasked.getSecond();
        INDArray third = reshapeAndExtractNotMasked.getThird();
        if (first.dataType() != second.dataType()) {
            first = first.castTo(second.dataType());
        }
        if (!this.initialized) {
            initialize((int) first.size(1));
        }
        if (this.columnNames.size() != first.size(1) || this.columnNames.size() != second.size(1)) {
            throw new IllegalArgumentException("Number of the columns of labels and predictions must match specification (" + this.columnNames.size() + "). Got " + first.size(1) + " and " + second.size(1));
        }
        if (third != null) {
            first = first.mul(third);
            second = second.mul(third);
        }
        this.labelsSumPerColumn.addi(first.sum(0).castTo(this.labelsSumPerColumn.dataType()));
        INDArray sub = second.sub(first);
        INDArray exec = Nd4j.getExecutioner().exec((ReduceOp) new ASum(sub, 0));
        INDArray sum = sub.mul(sub).sum(0);
        this.sumAbsErrorsPerColumn.addi(exec.castTo(this.labelsSumPerColumn.dataType()));
        this.sumSquaredErrorsPerColumn.addi(sum.castTo(this.labelsSumPerColumn.dataType()));
        this.sumOfProducts.addi(first.mul(second).sum(0).castTo(this.labelsSumPerColumn.dataType()));
        this.sumSquaredLabels.addi(first.mul(first).sum(0).castTo(this.labelsSumPerColumn.dataType()));
        this.sumSquaredPredicted.addi(second.mul(second).sum(0).castTo(this.labelsSumPerColumn.dataType()));
        INDArray add = third == null ? this.exampleCountPerColumn.add(Long.valueOf(first.size(0))) : this.exampleCountPerColumn.add(third.sum(0).castTo(this.labelsSumPerColumn.dataType()));
        this.currentMean.muliRowVector(this.exampleCountPerColumn).addi(first.sum(0).castTo(this.labelsSumPerColumn.dataType())).diviRowVector(add);
        this.currentPredictionMean.muliRowVector(this.exampleCountPerColumn).addi(second.sum(0).castTo(this.labelsSumPerColumn.dataType())).divi(add);
        this.exampleCountPerColumn = add;
        this.sumLabels.addi(first.sum(0).castTo(this.labelsSumPerColumn.dataType()));
    }

    @Override // org.nd4j.evaluation.IEvaluation
    public void merge(RegressionEvaluation regressionEvaluation) {
        if (regressionEvaluation.labelsSumPerColumn == null) {
            return;
        }
        if (this.labelsSumPerColumn != null) {
            this.labelsSumPerColumn.addi(regressionEvaluation.labelsSumPerColumn);
            this.sumSquaredErrorsPerColumn.addi(regressionEvaluation.sumSquaredErrorsPerColumn);
            this.sumAbsErrorsPerColumn.addi(regressionEvaluation.sumAbsErrorsPerColumn);
            this.currentMean.muliRowVector(this.exampleCountPerColumn).addi(regressionEvaluation.currentMean.mulRowVector(regressionEvaluation.exampleCountPerColumn)).diviRowVector(this.exampleCountPerColumn.add(regressionEvaluation.exampleCountPerColumn));
            this.currentPredictionMean.muliRowVector(this.exampleCountPerColumn).addi(regressionEvaluation.currentPredictionMean.mulRowVector(regressionEvaluation.exampleCountPerColumn)).diviRowVector(this.exampleCountPerColumn.add(regressionEvaluation.exampleCountPerColumn));
            this.sumOfProducts.addi(regressionEvaluation.sumOfProducts);
            this.sumSquaredLabels.addi(regressionEvaluation.sumSquaredLabels);
            this.sumSquaredPredicted.addi(regressionEvaluation.sumSquaredPredicted);
            this.exampleCountPerColumn.addi(regressionEvaluation.exampleCountPerColumn);
            return;
        }
        this.columnNames = regressionEvaluation.columnNames;
        this.precision = regressionEvaluation.precision;
        this.exampleCountPerColumn = regressionEvaluation.exampleCountPerColumn;
        this.labelsSumPerColumn = regressionEvaluation.labelsSumPerColumn.dup();
        this.sumSquaredErrorsPerColumn = regressionEvaluation.sumSquaredErrorsPerColumn.dup();
        this.sumAbsErrorsPerColumn = regressionEvaluation.sumAbsErrorsPerColumn.dup();
        this.currentMean = regressionEvaluation.currentMean.dup();
        this.currentPredictionMean = regressionEvaluation.currentPredictionMean.dup();
        this.sumOfProducts = regressionEvaluation.sumOfProducts.dup();
        this.sumSquaredLabels = regressionEvaluation.sumSquaredLabels.dup();
        this.sumSquaredPredicted = regressionEvaluation.sumSquaredPredicted.dup();
    }

    @Override // org.nd4j.evaluation.IEvaluation
    public String stats() {
        if (!this.initialized) {
            return "RegressionEvaluation: No Data";
        }
        if (this.columnNames == null) {
            this.columnNames = createDefaultColumnNames(numColumns());
        }
        int i = 0;
        Iterator<String> it = this.columnNames.iterator();
        while (it.hasNext()) {
            i = Math.max(i, it.next().length());
        }
        int i2 = i + 5;
        long j = this.precision + 10;
        String str = "%-" + i2 + "s%-" + j + "." + this.precision + "e%-" + j + "." + this.precision + "e%-" + j + "." + this.precision + "e%-" + j + "." + this.precision + "e%-" + j + "." + this.precision + "e%-" + j + "." + this.precision + "e";
        StringBuilder sb = new StringBuilder();
        sb.append(String.format("%-" + i2 + "s%-" + j + "s%-" + j + "s%-" + j + "s%-" + j + "s%-" + j + "s%-" + j + "s", "Column", "MSE", "MAE", "RMSE", "RSE", "PC", "R^2"));
        sb.append("\n");
        for (int i3 = 0; i3 < this.columnNames.size(); i3++) {
            sb.append(String.format(str, this.columnNames.get(i3), Double.valueOf(meanSquaredError(i3)), Double.valueOf(meanAbsoluteError(i3)), Double.valueOf(rootMeanSquaredError(i3)), Double.valueOf(relativeSquaredError(i3)), Double.valueOf(pearsonCorrelation(i3)), Double.valueOf(rSquared(i3))));
            sb.append("\n");
        }
        return sb.toString();
    }

    public int numColumns() {
        if (this.columnNames != null) {
            return this.columnNames.size();
        }
        if (this.exampleCountPerColumn == null) {
            return 0;
        }
        return (int) this.exampleCountPerColumn.size(1);
    }

    public double meanSquaredError(int i) {
        return this.sumSquaredErrorsPerColumn.getDouble(i) / this.exampleCountPerColumn.getDouble(i);
    }

    public double meanAbsoluteError(int i) {
        return this.sumAbsErrorsPerColumn.getDouble(i) / this.exampleCountPerColumn.getDouble(i);
    }

    public double rootMeanSquaredError(int i) {
        return Math.sqrt(this.sumSquaredErrorsPerColumn.getDouble(i) / this.exampleCountPerColumn.getDouble(i));
    }

    @Deprecated
    public double correlationR2(int i) {
        return pearsonCorrelation(i);
    }

    public double pearsonCorrelation(int i) {
        double d = this.sumOfProducts.getDouble(i);
        double d2 = this.currentPredictionMean.getDouble(i);
        double d3 = this.currentMean.getDouble(i);
        double d4 = this.sumSquaredLabels.getDouble(i);
        double d5 = this.sumSquaredPredicted.getDouble(i);
        double d6 = this.exampleCountPerColumn.getDouble(i);
        return (d - ((d6 * d2) * d3)) / (Math.sqrt(d4 - ((d6 * d3) * d3)) * Math.sqrt(d5 - ((d6 * d2) * d2)));
    }

    public double rSquared(int i) {
        double d = this.sumSquaredLabels.getDouble(i);
        double d2 = this.currentMean.getDouble(i);
        double d3 = d + (d2 * ((this.exampleCountPerColumn.getDouble(i) * d2) - (2.0d * this.sumLabels.getDouble(i))));
        return (d3 - this.sumSquaredErrorsPerColumn.getDouble(i)) / d3;
    }

    public double relativeSquaredError(int i) {
        double d = (this.sumSquaredPredicted.getDouble(i) - (2.0d * this.sumOfProducts.getDouble(i))) + this.sumSquaredLabels.getDouble(i);
        double d2 = this.sumSquaredLabels.getDouble(i) - ((this.exampleCountPerColumn.getDouble(i) * this.currentMean.getDouble(i)) * this.currentMean.getDouble(i));
        if (Math.abs(d2) > Nd4j.EPS_THRESHOLD) {
            return d / d2;
        }
        return Double.POSITIVE_INFINITY;
    }

    public double averageMeanSquaredError() {
        double d = 0.0d;
        for (int i = 0; i < numColumns(); i++) {
            d += meanSquaredError(i);
        }
        return d / numColumns();
    }

    public double averageMeanAbsoluteError() {
        double d = 0.0d;
        for (int i = 0; i < numColumns(); i++) {
            d += meanAbsoluteError(i);
        }
        return d / numColumns();
    }

    public double averagerootMeanSquaredError() {
        double d = 0.0d;
        for (int i = 0; i < numColumns(); i++) {
            d += rootMeanSquaredError(i);
        }
        return d / numColumns();
    }

    public double averagerelativeSquaredError() {
        double d = 0.0d;
        for (int i = 0; i < numColumns(); i++) {
            d += relativeSquaredError(i);
        }
        return d / numColumns();
    }

    @Deprecated
    public double averagecorrelationR2() {
        return averagePearsonCorrelation();
    }

    public double averagePearsonCorrelation() {
        double d = 0.0d;
        for (int i = 0; i < numColumns(); i++) {
            d += pearsonCorrelation(i);
        }
        return d / numColumns();
    }

    public double averageRSquared() {
        double d = 0.0d;
        for (int i = 0; i < numColumns(); i++) {
            d += rSquared(i);
        }
        return d / numColumns();
    }

    @Override // org.nd4j.evaluation.IEvaluation
    public double getValue(IMetric iMetric) {
        if (iMetric instanceof Metric) {
            return scoreForMetric((Metric) iMetric);
        }
        throw new IllegalStateException("Can't get value for non-regression Metric " + iMetric);
    }

    public double scoreForMetric(Metric metric) {
        switch (metric) {
            case MSE:
                return averageMeanSquaredError();
            case MAE:
                return averageMeanAbsoluteError();
            case RMSE:
                return averagerootMeanSquaredError();
            case RSE:
                return averagerelativeSquaredError();
            case PC:
                return averagePearsonCorrelation();
            case R2:
                return averageRSquared();
            default:
                throw new IllegalStateException("Unknown metric: " + metric);
        }
    }

    public static RegressionEvaluation fromJson(String str) {
        return (RegressionEvaluation) fromJson(str, RegressionEvaluation.class);
    }

    @Override // org.nd4j.evaluation.IEvaluation
    public RegressionEvaluation newInstance() {
        return new RegressionEvaluation(this.axis, this.columnNames, this.precision);
    }

    public boolean isInitialized() {
        return this.initialized;
    }

    public List<String> getColumnNames() {
        return this.columnNames;
    }

    public long getPrecision() {
        return this.precision;
    }

    public INDArray getExampleCountPerColumn() {
        return this.exampleCountPerColumn;
    }

    public INDArray getLabelsSumPerColumn() {
        return this.labelsSumPerColumn;
    }

    public INDArray getSumSquaredErrorsPerColumn() {
        return this.sumSquaredErrorsPerColumn;
    }

    public INDArray getSumAbsErrorsPerColumn() {
        return this.sumAbsErrorsPerColumn;
    }

    public INDArray getCurrentMean() {
        return this.currentMean;
    }

    public INDArray getCurrentPredictionMean() {
        return this.currentPredictionMean;
    }

    public INDArray getSumOfProducts() {
        return this.sumOfProducts;
    }

    public INDArray getSumSquaredLabels() {
        return this.sumSquaredLabels;
    }

    public INDArray getSumSquaredPredicted() {
        return this.sumSquaredPredicted;
    }

    public INDArray getSumLabels() {
        return this.sumLabels;
    }

    public void setInitialized(boolean z) {
        this.initialized = z;
    }

    public void setColumnNames(List<String> list) {
        this.columnNames = list;
    }

    public void setPrecision(long j) {
        this.precision = j;
    }

    public void setExampleCountPerColumn(INDArray iNDArray) {
        this.exampleCountPerColumn = iNDArray;
    }

    public void setLabelsSumPerColumn(INDArray iNDArray) {
        this.labelsSumPerColumn = iNDArray;
    }

    public void setSumSquaredErrorsPerColumn(INDArray iNDArray) {
        this.sumSquaredErrorsPerColumn = iNDArray;
    }

    public void setSumAbsErrorsPerColumn(INDArray iNDArray) {
        this.sumAbsErrorsPerColumn = iNDArray;
    }

    public void setCurrentMean(INDArray iNDArray) {
        this.currentMean = iNDArray;
    }

    public void setCurrentPredictionMean(INDArray iNDArray) {
        this.currentPredictionMean = iNDArray;
    }

    public void setSumOfProducts(INDArray iNDArray) {
        this.sumOfProducts = iNDArray;
    }

    public void setSumSquaredLabels(INDArray iNDArray) {
        this.sumSquaredLabels = iNDArray;
    }

    public void setSumSquaredPredicted(INDArray iNDArray) {
        this.sumSquaredPredicted = iNDArray;
    }

    public void setSumLabels(INDArray iNDArray) {
        this.sumLabels = iNDArray;
    }

    @Override // org.nd4j.evaluation.BaseEvaluation
    public String toString() {
        return "RegressionEvaluation(axis=" + getAxis() + ", initialized=" + isInitialized() + ", columnNames=" + getColumnNames() + ", precision=" + getPrecision() + ", exampleCountPerColumn=" + getExampleCountPerColumn() + ", labelsSumPerColumn=" + getLabelsSumPerColumn() + ", sumSquaredErrorsPerColumn=" + getSumSquaredErrorsPerColumn() + ", sumAbsErrorsPerColumn=" + getSumAbsErrorsPerColumn() + ", currentMean=" + getCurrentMean() + ", currentPredictionMean=" + getCurrentPredictionMean() + ", sumOfProducts=" + getSumOfProducts() + ", sumSquaredLabels=" + getSumSquaredLabels() + ", sumSquaredPredicted=" + getSumSquaredPredicted() + ", sumLabels=" + getSumLabels() + URISupport.RAW_TOKEN_END;
    }

    @Override // org.nd4j.evaluation.BaseEvaluation
    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof RegressionEvaluation)) {
            return false;
        }
        RegressionEvaluation regressionEvaluation = (RegressionEvaluation) obj;
        if (!regressionEvaluation.canEqual(this) || !super.equals(obj) || isInitialized() != regressionEvaluation.isInitialized()) {
            return false;
        }
        List<String> columnNames = getColumnNames();
        List<String> columnNames2 = regressionEvaluation.getColumnNames();
        if (columnNames == null) {
            if (columnNames2 != null) {
                return false;
            }
        } else if (!columnNames.equals(columnNames2)) {
            return false;
        }
        if (getPrecision() != regressionEvaluation.getPrecision()) {
            return false;
        }
        INDArray exampleCountPerColumn = getExampleCountPerColumn();
        INDArray exampleCountPerColumn2 = regressionEvaluation.getExampleCountPerColumn();
        if (exampleCountPerColumn == null) {
            if (exampleCountPerColumn2 != null) {
                return false;
            }
        } else if (!exampleCountPerColumn.equals(exampleCountPerColumn2)) {
            return false;
        }
        INDArray labelsSumPerColumn = getLabelsSumPerColumn();
        INDArray labelsSumPerColumn2 = regressionEvaluation.getLabelsSumPerColumn();
        if (labelsSumPerColumn == null) {
            if (labelsSumPerColumn2 != null) {
                return false;
            }
        } else if (!labelsSumPerColumn.equals(labelsSumPerColumn2)) {
            return false;
        }
        INDArray sumSquaredErrorsPerColumn = getSumSquaredErrorsPerColumn();
        INDArray sumSquaredErrorsPerColumn2 = regressionEvaluation.getSumSquaredErrorsPerColumn();
        if (sumSquaredErrorsPerColumn == null) {
            if (sumSquaredErrorsPerColumn2 != null) {
                return false;
            }
        } else if (!sumSquaredErrorsPerColumn.equals(sumSquaredErrorsPerColumn2)) {
            return false;
        }
        INDArray sumAbsErrorsPerColumn = getSumAbsErrorsPerColumn();
        INDArray sumAbsErrorsPerColumn2 = regressionEvaluation.getSumAbsErrorsPerColumn();
        if (sumAbsErrorsPerColumn == null) {
            if (sumAbsErrorsPerColumn2 != null) {
                return false;
            }
        } else if (!sumAbsErrorsPerColumn.equals(sumAbsErrorsPerColumn2)) {
            return false;
        }
        INDArray currentMean = getCurrentMean();
        INDArray currentMean2 = regressionEvaluation.getCurrentMean();
        if (currentMean == null) {
            if (currentMean2 != null) {
                return false;
            }
        } else if (!currentMean.equals(currentMean2)) {
            return false;
        }
        INDArray currentPredictionMean = getCurrentPredictionMean();
        INDArray currentPredictionMean2 = regressionEvaluation.getCurrentPredictionMean();
        if (currentPredictionMean == null) {
            if (currentPredictionMean2 != null) {
                return false;
            }
        } else if (!currentPredictionMean.equals(currentPredictionMean2)) {
            return false;
        }
        INDArray sumOfProducts = getSumOfProducts();
        INDArray sumOfProducts2 = regressionEvaluation.getSumOfProducts();
        if (sumOfProducts == null) {
            if (sumOfProducts2 != null) {
                return false;
            }
        } else if (!sumOfProducts.equals(sumOfProducts2)) {
            return false;
        }
        INDArray sumSquaredLabels = getSumSquaredLabels();
        INDArray sumSquaredLabels2 = regressionEvaluation.getSumSquaredLabels();
        if (sumSquaredLabels == null) {
            if (sumSquaredLabels2 != null) {
                return false;
            }
        } else if (!sumSquaredLabels.equals(sumSquaredLabels2)) {
            return false;
        }
        INDArray sumSquaredPredicted = getSumSquaredPredicted();
        INDArray sumSquaredPredicted2 = regressionEvaluation.getSumSquaredPredicted();
        if (sumSquaredPredicted == null) {
            if (sumSquaredPredicted2 != null) {
                return false;
            }
        } else if (!sumSquaredPredicted.equals(sumSquaredPredicted2)) {
            return false;
        }
        INDArray sumLabels = getSumLabels();
        INDArray sumLabels2 = regressionEvaluation.getSumLabels();
        return sumLabels == null ? sumLabels2 == null : sumLabels.equals(sumLabels2);
    }

    @Override // org.nd4j.evaluation.BaseEvaluation
    protected boolean canEqual(Object obj) {
        return obj instanceof RegressionEvaluation;
    }

    @Override // org.nd4j.evaluation.BaseEvaluation
    public int hashCode() {
        int hashCode = (super.hashCode() * 59) + (isInitialized() ? 79 : 97);
        List<String> columnNames = getColumnNames();
        int hashCode2 = (hashCode * 59) + (columnNames == null ? 43 : columnNames.hashCode());
        long precision = getPrecision();
        int i = (hashCode2 * 59) + ((int) ((precision >>> 32) ^ precision));
        INDArray exampleCountPerColumn = getExampleCountPerColumn();
        int hashCode3 = (i * 59) + (exampleCountPerColumn == null ? 43 : exampleCountPerColumn.hashCode());
        INDArray labelsSumPerColumn = getLabelsSumPerColumn();
        int hashCode4 = (hashCode3 * 59) + (labelsSumPerColumn == null ? 43 : labelsSumPerColumn.hashCode());
        INDArray sumSquaredErrorsPerColumn = getSumSquaredErrorsPerColumn();
        int hashCode5 = (hashCode4 * 59) + (sumSquaredErrorsPerColumn == null ? 43 : sumSquaredErrorsPerColumn.hashCode());
        INDArray sumAbsErrorsPerColumn = getSumAbsErrorsPerColumn();
        int hashCode6 = (hashCode5 * 59) + (sumAbsErrorsPerColumn == null ? 43 : sumAbsErrorsPerColumn.hashCode());
        INDArray currentMean = getCurrentMean();
        int hashCode7 = (hashCode6 * 59) + (currentMean == null ? 43 : currentMean.hashCode());
        INDArray currentPredictionMean = getCurrentPredictionMean();
        int hashCode8 = (hashCode7 * 59) + (currentPredictionMean == null ? 43 : currentPredictionMean.hashCode());
        INDArray sumOfProducts = getSumOfProducts();
        int hashCode9 = (hashCode8 * 59) + (sumOfProducts == null ? 43 : sumOfProducts.hashCode());
        INDArray sumSquaredLabels = getSumSquaredLabels();
        int hashCode10 = (hashCode9 * 59) + (sumSquaredLabels == null ? 43 : sumSquaredLabels.hashCode());
        INDArray sumSquaredPredicted = getSumSquaredPredicted();
        int hashCode11 = (hashCode10 * 59) + (sumSquaredPredicted == null ? 43 : sumSquaredPredicted.hashCode());
        INDArray sumLabels = getSumLabels();
        return (hashCode11 * 59) + (sumLabels == null ? 43 : sumLabels.hashCode());
    }
}
