package org.tribuo.regression.impl;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import java.time.OffsetDateTime;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Set;
import java.util.SplittableRandom;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.SparseModel;
import org.tribuo.SparseTrainer;
import org.tribuo.math.la.SparseVector;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.provenance.TrainerProvenance;
import org.tribuo.regression.Regressor;

/* loaded from: input_file:org/tribuo/regression/impl/SkeletalIndependentRegressionSparseTrainer.class */
public abstract class SkeletalIndependentRegressionSparseTrainer<T> implements SparseTrainer<Regressor> {
    private SplittableRandom rng;

    @Config(description = "Seed for the RNG, may be unused.")
    private long seed = 1;
    private int trainInvocationCounter = 0;

    protected SkeletalIndependentRegressionSparseTrainer() {
    }

    public synchronized void postConfig() {
        this.rng = new SplittableRandom(this.seed);
    }

    public SkeletalIndependentRegressionSparseModel train(Dataset<Regressor> dataset) {
        return train(dataset, Collections.emptyMap());
    }

    public SkeletalIndependentRegressionSparseModel train(Dataset<Regressor> dataset, Map<String, Provenance> map) {
        SplittableRandom split;
        TrainerProvenance provenance;
        if (dataset.getOutputInfo().getUnknownCount() > 0) {
            throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
        }
        synchronized (this) {
            split = this.rng.split();
            provenance = getProvenance();
            this.trainInvocationCounter++;
        }
        ImmutableOutputInfo<Regressor> outputIDInfo = dataset.getOutputIDInfo();
        ImmutableFeatureMap featureIDMap = dataset.getFeatureIDMap();
        Set<Regressor> domain = outputIDInfo.getDomain();
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        int size = dataset.size();
        boolean useBias = useBias();
        float[] fArr = new float[size];
        double[][] dArr = new double[outputIDInfo.size()][size];
        SparseVector[] sparseVectorArr = new SparseVector[size];
        int i = 0;
        Iterator it = dataset.iterator();
        while (it.hasNext()) {
            Example example = (Example) it.next();
            sparseVectorArr[i] = SparseVector.createSparseVector(example, featureIDMap, useBias);
            fArr[i] = example.getWeight();
            Iterator<Regressor.DimensionTuple> it2 = ((Regressor) example.getOutput()).iterator();
            while (it2.hasNext()) {
                Regressor.DimensionTuple next = it2.next();
                dArr[outputIDInfo.getID(next)][i] = next.getValue();
            }
            i++;
        }
        for (Regressor regressor : domain) {
            linkedHashMap.put(regressor.getNames()[0], trainDimension(dArr[outputIDInfo.getID(regressor)], sparseVectorArr, fArr, split));
        }
        return createModel(linkedHashMap, new ModelProvenance(getModelClassName(), OffsetDateTime.now(), dataset.getProvenance(), provenance, map), featureIDMap, outputIDInfo);
    }

    public int getInvocationCount() {
        return this.trainInvocationCounter;
    }

    protected abstract SkeletalIndependentRegressionSparseModel createModel(Map<String, T> map, ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<Regressor> immutableOutputInfo);

    protected abstract T trainDimension(double[] dArr, SparseVector[] sparseVectorArr, float[] fArr, SplittableRandom splittableRandom);

    protected abstract boolean useBias();

    protected abstract String getModelClassName();

    /* renamed from: train, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ SparseModel m24train(Dataset dataset, Map map) {
        return train((Dataset<Regressor>) dataset, (Map<String, Provenance>) map);
    }

    /* renamed from: train, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ SparseModel m25train(Dataset dataset) {
        return train((Dataset<Regressor>) dataset);
    }

    /* renamed from: train, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ Model m26train(Dataset dataset, Map map) {
        return train((Dataset<Regressor>) dataset, (Map<String, Provenance>) map);
    }

    /* renamed from: train, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ Model m27train(Dataset dataset) {
        return train((Dataset<Regressor>) dataset);
    }
}
