package ai.libs.jaicore.ml.core.dataset.util;

import ai.libs.jaicore.basic.algorithm.AlgorithmExecutionCanceledException;
import ai.libs.jaicore.basic.algorithm.exceptions.AlgorithmException;
import ai.libs.jaicore.ml.core.dataset.DatasetCreationException;
import ai.libs.jaicore.ml.core.dataset.INumericLabeledAttributeArrayInstance;
import ai.libs.jaicore.ml.core.dataset.IOrderedLabeledAttributeArrayDataset;
import ai.libs.jaicore.ml.core.dataset.sampling.inmemory.stratified.sampling.AttributeBasedStratiAmountSelectorAndAssigner;
import ai.libs.jaicore.ml.core.dataset.sampling.inmemory.stratified.sampling.DiscretizationHelper;
import ai.libs.jaicore.ml.core.dataset.sampling.inmemory.stratified.sampling.StratifiedSampling;
import java.util.Collections;
import java.util.Random;

/* loaded from: input_file:ai/libs/jaicore/ml/core/dataset/util/StratifiedSplit.class */
public class StratifiedSplit<I extends INumericLabeledAttributeArrayInstance<L>, L, D extends IOrderedLabeledAttributeArrayDataset<I, L>> {
    private final D dataset;
    private D trainingData;
    private D testData;
    private final long seed;

    public StratifiedSplit(D d, long j) {
        this.dataset = d;
        this.seed = j;
    }

    public void doSplit(double d) throws AlgorithmException {
        Random random = new Random(this.seed);
        AttributeBasedStratiAmountSelectorAndAssigner attributeBasedStratiAmountSelectorAndAssigner = new AttributeBasedStratiAmountSelectorAndAssigner(Collections.singletonList(Integer.valueOf(this.dataset.getNumberOfAttributes())), DiscretizationHelper.DiscretizationStrategy.EQUAL_SIZE, 10);
        StratifiedSampling stratifiedSampling = new StratifiedSampling(attributeBasedStratiAmountSelectorAndAssigner, attributeBasedStratiAmountSelectorAndAssigner, random, this.dataset);
        stratifiedSampling.setSampleSize((int) (d * this.dataset.size()));
        try {
            this.trainingData = (D) stratifiedSampling.m24call();
            this.testData = (D) this.dataset.createEmpty();
            this.testData.addAll(this.dataset);
            this.testData.removeAll(this.trainingData);
        } catch (AlgorithmExecutionCanceledException e) {
            throw new AlgorithmException("Stratified split has been cancelled");
        } catch (DatasetCreationException e2) {
            throw new AlgorithmException("Could not create an empty copy of the given dataset.");
        } catch (InterruptedException e3) {
            Thread.currentThread().interrupt();
        }
    }

    public D getTrainingData() {
        return this.trainingData;
    }

    public D getTestData() {
        return this.testData;
    }
}
