package ai.libs.jaicore.ml.classification.multiclass.reduction.splitters;

import ai.libs.jaicore.basic.sets.SetUtil;
import ai.libs.jaicore.logging.LoggerUtil;
import ai.libs.jaicore.ml.WekaUtil;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Random;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import weka.classifiers.Classifier;
import weka.core.Instance;
import weka.core.Instances;

/* loaded from: input_file:ai/libs/jaicore/ml/classification/multiclass/reduction/splitters/RPNDSplitter.class */
public class RPNDSplitter implements ISplitter {
    private static final Logger logger = LoggerFactory.getLogger(RPNDSplitter.class);
    private final Random rand;
    private final Classifier rpndClassifier;

    public RPNDSplitter(Random random, Classifier classifier) {
        this.rand = random;
        this.rpndClassifier = classifier;
    }

    @Override // ai.libs.jaicore.ml.classification.multiclass.reduction.splitters.ISplitter
    public Collection<Collection<String>> split(Instances instances) throws Exception {
        Collection<String> classesActuallyContainedInDataset = WekaUtil.getClassesActuallyContainedInDataset(instances);
        if (classesActuallyContainedInDataset.size() == 1) {
            ArrayList arrayList = new ArrayList();
            arrayList.add(classesActuallyContainedInDataset);
            return arrayList;
        }
        ArrayList arrayList2 = new ArrayList(classesActuallyContainedInDataset);
        Collections.shuffle(arrayList2, this.rand);
        String str = (String) arrayList2.get(0);
        String str2 = (String) arrayList2.get(1);
        Collection<String> hashSet = new HashSet<>();
        hashSet.add(str);
        Collection<String> hashSet2 = new HashSet<>();
        hashSet2.add(str2);
        return split(arrayList2, hashSet, hashSet2, instances);
    }

    public Collection<Collection<String>> split(Collection<String> collection, Collection<String> collection2, Collection<String> collection3, Instances instances) throws Exception {
        logger.info("Start creation of RPND split with basis {}/{} for classes {}", new Object[]{collection2, collection3, collection});
        Instances mergeClassesOfInstances = WekaUtil.mergeClassesOfInstances(instances, collection2, collection3);
        logger.debug("Building classifier for separating the two class sets {} and {}", collection2, collection3);
        this.rpndClassifier.buildClassifier(mergeClassesOfInstances);
        logger.info("Now classifying the items of the other classes");
        ArrayList arrayList = new ArrayList(SetUtil.difference(SetUtil.difference(collection, collection2), collection3));
        for (int i = 0; i < arrayList.size(); i++) {
            String str = (String) arrayList.get(i);
            Instances instancesOfClass = WekaUtil.getInstancesOfClass(instances, str);
            logger.debug("Classify {} instances of class {}", Integer.valueOf(instancesOfClass.size()), str);
            int i2 = 0;
            int i3 = 0;
            Iterator it = instancesOfClass.iterator();
            while (it.hasNext()) {
                Instance instance = (Instance) it.next();
                if (Thread.interrupted()) {
                    throw new InterruptedException();
                }
                try {
                    if (this.rpndClassifier.classifyInstance(WekaUtil.getRefactoredInstance(instance)) == 0.0d) {
                        i2++;
                    } else {
                        i3++;
                    }
                } catch (Exception e) {
                    logger.error(LoggerUtil.getExceptionInfo(e));
                }
            }
            if (i2 > i3) {
                collection2.add(str);
            } else {
                collection3.add(str);
            }
        }
        ArrayList arrayList2 = new ArrayList();
        arrayList2.add(collection2);
        arrayList2.add(collection3);
        return arrayList2;
    }
}
