package org.tribuo.regression.example;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.config.PropertyException;
import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.SkeletalConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import org.tribuo.ConfigurableDataSource;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.MutableDataset;
import org.tribuo.OutputFactory;
import org.tribuo.impl.ArrayExample;
import org.tribuo.provenance.ConfiguredDataSourceProvenance;
import org.tribuo.provenance.DataSourceProvenance;
import org.tribuo.regression.RegressionFactory;
import org.tribuo.regression.Regressor;

/* loaded from: input_file:org/tribuo/regression/example/NonlinearGaussianDataSource.class */
public class NonlinearGaussianDataSource implements ConfigurableDataSource<Regressor> {

    @Config(mandatory = true, description = "The number of samples to draw.")
    private int numSamples;

    @Config(description = "The feature weights. Must be a 4 element array.")
    private float[] weights;

    @Config(description = "The y-intercept of the line.")
    private float intercept;

    @Config(description = "The variance of the noise gaussian.")
    private float variance;

    @Config(description = "The minimum value of x_0.")
    private float xZeroMin;

    @Config(description = "The maximum value of x_0.")
    private float xZeroMax;

    @Config(description = "The minimum value of x_1.")
    private float xOneMin;

    @Config(description = "The maximum value of x_1.")
    private float xOneMax;

    @Config(description = "The RNG seed.")
    private long seed;
    private List<Example<Regressor>> examples;
    private final RegressionFactory factory;
    private static final String[] featureNames = {"X_0", "X_1"};

    /* loaded from: input_file:org/tribuo/regression/example/NonlinearGaussianDataSource$NonlinearGaussianDataSourceProvenance.class */
    public static class NonlinearGaussianDataSourceProvenance extends SkeletalConfiguredObjectProvenance implements ConfiguredDataSourceProvenance {
        private static final long serialVersionUID = 1;

        NonlinearGaussianDataSourceProvenance(NonlinearGaussianDataSource nonlinearGaussianDataSource) {
            super(nonlinearGaussianDataSource, "DataSource");
        }

        public NonlinearGaussianDataSourceProvenance(Map<String, Provenance> map) {
            this(extractProvenanceInfo(map));
        }

        private NonlinearGaussianDataSourceProvenance(SkeletalConfiguredObjectProvenance.ExtractedInfo extractedInfo) {
            super(extractedInfo);
        }

        protected static SkeletalConfiguredObjectProvenance.ExtractedInfo extractProvenanceInfo(Map<String, Provenance> map) {
            HashMap hashMap = new HashMap(map);
            return new SkeletalConfiguredObjectProvenance.ExtractedInfo(ObjectProvenance.checkAndExtractProvenance(hashMap, "class-name", StringProvenance.class, NonlinearGaussianDataSourceProvenance.class.getSimpleName()).getValue(), ObjectProvenance.checkAndExtractProvenance(hashMap, "host-short-name", StringProvenance.class, NonlinearGaussianDataSourceProvenance.class.getSimpleName()).getValue(), hashMap, Collections.emptyMap());
        }
    }

    private NonlinearGaussianDataSource() {
        this.weights = new float[]{1.0f, 1.0f, 1.0f, 1.0f};
        this.intercept = 0.0f;
        this.variance = 1.0f;
        this.xZeroMin = -2.0f;
        this.xZeroMax = 2.0f;
        this.xOneMin = -2.0f;
        this.xOneMax = 2.0f;
        this.seed = 12345L;
        this.factory = new RegressionFactory();
    }

    public NonlinearGaussianDataSource(int i, float[] fArr, float f, float f2, float f3, float f4, float f5, float f6, long j) {
        this.weights = new float[]{1.0f, 1.0f, 1.0f, 1.0f};
        this.intercept = 0.0f;
        this.variance = 1.0f;
        this.xZeroMin = -2.0f;
        this.xZeroMax = 2.0f;
        this.xOneMin = -2.0f;
        this.xOneMax = 2.0f;
        this.seed = 12345L;
        this.factory = new RegressionFactory();
        this.numSamples = i;
        this.weights = fArr;
        this.intercept = f;
        this.variance = f2;
        this.xZeroMin = f3;
        this.xZeroMax = f4;
        this.xOneMin = f5;
        this.xOneMax = f6;
        this.seed = j;
        postConfig();
    }

    public void postConfig() {
        Random random = new Random(this.seed);
        if (this.weights.length != 4) {
            throw new PropertyException("", "weights", "Must supply 4 weights, found " + this.weights.length);
        }
        if (this.xZeroMax <= this.xZeroMin) {
            throw new PropertyException("", "xZeroMax", "xZeroMax must be greater than xZeroMin, found xZeroMax = " + this.xZeroMax + ", xZeroMin = " + this.xZeroMin);
        }
        if (this.xOneMax <= this.xOneMin) {
            throw new PropertyException("", "xOneMax", "xOneMax must be greater than xOneMin, found xOneMax = " + this.xOneMax + ", xOneMin = " + this.xOneMin);
        }
        if (this.variance <= 0.0d) {
            throw new PropertyException("", "variance", "Variance must be positive, found variance = " + this.variance);
        }
        ArrayList arrayList = new ArrayList(this.numSamples);
        double d = this.xZeroMax - this.xZeroMin;
        double d2 = this.xOneMax - this.xOneMin;
        for (int i = 0; i < this.numSamples; i++) {
            double nextDouble = (random.nextDouble() * d) + this.xZeroMin;
            double nextDouble2 = (random.nextDouble() * d2) + this.xOneMin;
            arrayList.add(new ArrayExample(new Regressor("Y", (random.nextGaussian() * this.variance) + (this.weights[0] * nextDouble) + (this.weights[1] * nextDouble2) + (this.weights[2] * nextDouble * nextDouble2) + (this.weights[3] * Math.pow(nextDouble2, 3.0d)) + this.intercept), featureNames, new double[]{nextDouble, nextDouble2}));
        }
        this.examples = Collections.unmodifiableList(arrayList);
    }

    public OutputFactory<Regressor> getOutputFactory() {
        return this.factory;
    }

    /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
    public DataSourceProvenance m22getProvenance() {
        return new NonlinearGaussianDataSourceProvenance(this);
    }

    public Iterator<Example<Regressor>> iterator() {
        return this.examples.iterator();
    }

    public static Dataset<Regressor> generateDataset(int i, float[] fArr, float f, float f2, float f3, float f4, float f5, float f6, long j) {
        return new MutableDataset(new NonlinearGaussianDataSource(i, fArr, f, f2, f3, f4, f5, f6, j));
    }
}
