package org.tribuo.regression.baseline;

import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Random;
import org.tribuo.Example;
import org.tribuo.Excuse;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Prediction;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.regression.Regressor;
import org.tribuo.regression.baseline.DummyRegressionTrainer;

/* loaded from: input_file:org/tribuo/regression/baseline/DummyRegressionModel.class */
public class DummyRegressionModel extends Model<Regressor> {
    private static final long serialVersionUID = 2;
    private final DummyRegressionTrainer.DummyType dummyType;
    private final Regressor output;
    private final long seed;
    private final Random rng;
    private final double[] means;
    private final double[] variances;
    private final String[] dimensionNames;

    /* JADX INFO: Access modifiers changed from: package-private */
    public DummyRegressionModel(ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<Regressor> immutableOutputInfo, long j, double[] dArr, double[] dArr2, String[] strArr) {
        super("dummy-GAUSSIAN-regression", modelProvenance, immutableFeatureMap, immutableOutputInfo, false);
        this.dummyType = DummyRegressionTrainer.DummyType.GAUSSIAN;
        this.output = null;
        this.seed = j;
        this.rng = new Random(j);
        this.means = Arrays.copyOf(dArr, dArr.length);
        this.variances = Arrays.copyOf(dArr2, dArr2.length);
        this.dimensionNames = (String[]) Arrays.copyOf(strArr, strArr.length);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public DummyRegressionModel(ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<Regressor> immutableOutputInfo, DummyRegressionTrainer.DummyType dummyType, Regressor regressor) {
        super("dummy-" + dummyType + "-regression", modelProvenance, immutableFeatureMap, immutableOutputInfo, false);
        this.dummyType = dummyType;
        this.output = regressor;
        this.seed = 12345L;
        this.rng = null;
        this.means = new double[0];
        this.variances = new double[0];
        this.dimensionNames = new String[0];
    }

    public Prediction<Regressor> predict(Example<Regressor> example) {
        switch (this.dummyType) {
            case CONSTANT:
            case MEAN:
            case MEDIAN:
            case QUARTILE:
                return new Prediction<>(this.output, 0, example);
            case GAUSSIAN:
                Regressor.DimensionTuple[] dimensionTupleArr = new Regressor.DimensionTuple[this.dimensionNames.length];
                for (int i = 0; i < this.dimensionNames.length; i++) {
                    dimensionTupleArr[i] = new Regressor.DimensionTuple(this.dimensionNames[i], (this.rng.nextGaussian() * this.variances[i]) + this.means[i]);
                }
                return new Prediction<>(new Regressor(dimensionTupleArr), 0, example);
            default:
                throw new IllegalStateException("Unknown dummyType " + this.dummyType);
        }
    }

    public Map<String, List<Pair<String, Double>>> getTopFeatures(int i) {
        return i != 0 ? Collections.singletonMap("ALL_OUTPUTS", Collections.singletonList(new Pair("BIAS", Double.valueOf(1.0d)))) : Collections.emptyMap();
    }

    public Optional<Excuse<Regressor>> getExcuse(Example<Regressor> example) {
        return Optional.of(new Excuse(example, predict(example), getTopFeatures(1)));
    }

    protected Model<Regressor> copy(String str, ModelProvenance modelProvenance) {
        switch (this.dummyType) {
            case CONSTANT:
            case MEAN:
            case MEDIAN:
            case QUARTILE:
                return new DummyRegressionModel(modelProvenance, this.featureIDMap, this.outputIDInfo, this.dummyType, this.output.mo8copy());
            case GAUSSIAN:
                return new DummyRegressionModel(modelProvenance, this.featureIDMap, this.outputIDInfo, this.seed, this.means, this.variances, this.dimensionNames);
            default:
                throw new IllegalStateException("Unknown dummyType " + this.dummyType);
        }
    }
}
