package ai.libs.jaicore.ml.tsc.classifier.trees;

import ai.libs.jaicore.basic.TimeOut;
import ai.libs.jaicore.basic.algorithm.IRandomAlgorithmConfig;
import ai.libs.jaicore.basic.algorithm.events.AlgorithmEvent;
import ai.libs.jaicore.basic.sets.Pair;
import ai.libs.jaicore.graph.TreeNode;
import ai.libs.jaicore.ml.tsc.classifier.ASimplifiedTSCLearningAlgorithm;
import ai.libs.jaicore.ml.tsc.classifier.trees.TimeSeriesTreeClassifier;
import ai.libs.jaicore.ml.tsc.dataset.TimeSeriesDataset;
import ai.libs.jaicore.ml.tsc.features.TimeSeriesFeature;
import ai.libs.jaicore.ml.tsc.util.TimeSeriesUtil;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.Random;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.aeonbits.owner.Config;
import org.apache.commons.lang3.ArrayUtils;

/* loaded from: input_file:ai/libs/jaicore/ml/tsc/classifier/trees/TimeSeriesTreeLearningAlgorithm.class */
public class TimeSeriesTreeLearningAlgorithm extends ASimplifiedTSCLearningAlgorithm<Integer, TimeSeriesTreeClassifier> {
    public static final int NUM_THRESH_CANDIDATES = 20;
    public static final double ENTROPY_APLHA = 1.0E-22d;
    private static final double PRECISION_DELTA = 1.0E-9d;
    private HashMap<Long, double[]> transformedFeaturesCache;
    public static final boolean USE_BIAS_CORRECTION = true;

    /* loaded from: input_file:ai/libs/jaicore/ml/tsc/classifier/trees/TimeSeriesTreeLearningAlgorithm$ITimeSeriesTreeConfig.class */
    public interface ITimeSeriesTreeConfig extends IRandomAlgorithmConfig {
        public static final String K_MAXDEPTH = "maxdepth";
        public static final String K_FEATURECACHING = "featurecaching";

        @Config.DefaultValue("-1")
        @Config.Key("maxdepth")
        int maxDepth();

        @Config.DefaultValue("false")
        @Config.Key("featurecaching")
        boolean useFeatureCaching();
    }

    public TimeSeriesTreeLearningAlgorithm(ITimeSeriesTreeConfig iTimeSeriesTreeConfig, TimeSeriesTreeClassifier timeSeriesTreeClassifier, TimeSeriesDataset timeSeriesDataset) {
        super(iTimeSeriesTreeConfig, timeSeriesTreeClassifier, timeSeriesDataset);
        this.transformedFeaturesCache = null;
    }

    @Override // ai.libs.jaicore.ml.tsc.classifier.ASimplifiedTSCLearningAlgorithm
    public void registerListener(Object obj) {
        throw new UnsupportedOperationException();
    }

    public int getNumCPUs() {
        throw new UnsupportedOperationException();
    }

    public void setNumCPUs(int i) {
        throw new UnsupportedOperationException();
    }

    public void setTimeout(long j, TimeUnit timeUnit) {
        throw new UnsupportedOperationException();
    }

    public void setTimeout(TimeOut timeOut) {
        throw new UnsupportedOperationException();
    }

    public TimeOut getTimeout() {
        throw new UnsupportedOperationException();
    }

    @Override // ai.libs.jaicore.ml.tsc.classifier.ASimplifiedTSCLearningAlgorithm
    public AlgorithmEvent nextWithException() {
        throw new UnsupportedOperationException();
    }

    /* renamed from: call, reason: merged with bridge method [inline-methods] */
    public TimeSeriesTreeClassifier m100call() {
        TimeSeriesDataset timeSeriesDataset = (TimeSeriesDataset) getInput();
        if (timeSeriesDataset.isEmpty()) {
            throw new IllegalArgumentException("The dataset used for training must not be null!");
        }
        if (timeSeriesDataset.isMultivariate()) {
            throw new UnsupportedOperationException("Multivariate instances are not supported yet.");
        }
        double[][] valuesOrNull = timeSeriesDataset.getValuesOrNull(0);
        int length = valuesOrNull.length;
        if (length <= 0) {
            throw new IllegalArgumentException("The traning data's matrix must contain at least one instance!");
        }
        if (getConfig().useFeatureCaching()) {
            int length2 = valuesOrNull[0].length;
            this.transformedFeaturesCache = new HashMap<>(length2 * length2 * length);
        }
        tree(valuesOrNull, timeSeriesDataset.getTargets(), 2.0d, getClassifier().getRootNode(), 0);
        return getClassifier();
    }

    @Override // ai.libs.jaicore.ml.tsc.classifier.ASimplifiedTSCLearningAlgorithm
    public Iterator<AlgorithmEvent> iterator() {
        throw new UnsupportedOperationException();
    }

    @Override // ai.libs.jaicore.ml.tsc.classifier.ASimplifiedTSCLearningAlgorithm
    public boolean hasNext() {
        throw new UnsupportedOperationException();
    }

    @Override // ai.libs.jaicore.ml.tsc.classifier.ASimplifiedTSCLearningAlgorithm
    /* renamed from: next */
    public AlgorithmEvent mo71next() {
        throw new NoSuchElementException("Cannot enumerate this algorithm!");
    }

    @Override // ai.libs.jaicore.ml.tsc.classifier.ASimplifiedTSCLearningAlgorithm
    public void cancel() {
        throw new UnsupportedOperationException();
    }

    public void tree(double[][] dArr, int[] iArr, double d, TreeNode<TimeSeriesTreeClassifier.TimeSeriesTreeNodeDecisionFunction> treeNode, int i) {
        int length = iArr.length;
        ITimeSeriesTreeConfig config = getConfig();
        Pair<List<Integer>, List<Integer>> sampleIntervals = sampleIntervals(dArr[0].length, config.seed());
        double[][][] transformInstances = transformInstances(dArr, sampleIntervals);
        List<List<Double>> generateThresholdCandidates = generateThresholdCandidates(sampleIntervals, 20, transformInstances);
        ArrayList arrayList = new ArrayList(new HashSet(Arrays.asList(ArrayUtils.toObject(iArr))));
        double[] dArr2 = new double[TimeSeriesFeature.NUM_FEATURE_TYPES];
        for (int i2 = 0; i2 < dArr2.length; i2++) {
            dArr2[i2] = -2.147483648E9d;
        }
        double[] dArr3 = new double[TimeSeriesFeature.NUM_FEATURE_TYPES];
        int[] iArr2 = new int[TimeSeriesFeature.NUM_FEATURE_TYPES];
        double[] dArr4 = new double[TimeSeriesFeature.NUM_FEATURE_TYPES];
        List list = (List) sampleIntervals.getX();
        List list2 = (List) sampleIntervals.getY();
        for (int i3 = 0; i3 < list.size(); i3++) {
            for (int i4 = 0; i4 < TimeSeriesFeature.NUM_FEATURE_TYPES; i4++) {
                Iterator<Double> it = generateThresholdCandidates.get(i4).iterator();
                while (it.hasNext()) {
                    double doubleValue = it.next().doubleValue();
                    double calculateDeltaEntropy = calculateDeltaEntropy(transformInstances[i4][i3], iArr, doubleValue, arrayList, d);
                    double calculateEntrance = calculateEntrance(calculateDeltaEntropy, calculateMargin(transformInstances[i4][i3], doubleValue));
                    if (calculateEntrance > dArr2[i4]) {
                        dArr2[i4] = calculateEntrance;
                        dArr3[i4] = calculateDeltaEntropy;
                        iArr2[i4] = i3;
                        dArr4[i4] = doubleValue;
                    }
                }
            }
        }
        int bestSplitIndex = getBestSplitIndex(dArr3);
        double d2 = dArr3[bestSplitIndex];
        int i5 = iArr2[bestSplitIndex];
        double d3 = dArr4[bestSplitIndex];
        if (Math.abs(d2) <= PRECISION_DELTA || i == config.maxDepth() - 1 || (i != 0 && Math.abs(d2 - d) <= PRECISION_DELTA)) {
            ((TimeSeriesTreeClassifier.TimeSeriesTreeNodeDecisionFunction) treeNode.getValue()).classPrediction = TimeSeriesUtil.getMode(iArr);
            return;
        }
        ((TimeSeriesTreeClassifier.TimeSeriesTreeNodeDecisionFunction) treeNode.getValue()).f = TimeSeriesFeature.FeatureType.values()[bestSplitIndex];
        ((TimeSeriesTreeClassifier.TimeSeriesTreeNodeDecisionFunction) treeNode.getValue()).t1 = ((Integer) list.get(i5)).intValue();
        ((TimeSeriesTreeClassifier.TimeSeriesTreeNodeDecisionFunction) treeNode.getValue()).t2 = ((Integer) list2.get(i5)).intValue();
        ((TimeSeriesTreeClassifier.TimeSeriesTreeNodeDecisionFunction) treeNode.getValue()).threshold = d3;
        Pair<List<Integer>, List<Integer>> childDataIndices = getChildDataIndices(transformInstances, length, bestSplitIndex, i5, d3);
        double[][] dArr5 = new double[((List) childDataIndices.getX()).size()][dArr[0].length];
        int[] iArr3 = new int[((List) childDataIndices.getX()).size()];
        double[][] dArr6 = new double[((List) childDataIndices.getY()).size()][dArr[0].length];
        int[] iArr4 = new int[((List) childDataIndices.getY()).size()];
        for (int i6 = 0; i6 < ((List) childDataIndices.getX()).size(); i6++) {
            dArr5[i6] = dArr[((Integer) ((List) childDataIndices.getX()).get(i6)).intValue()];
            iArr3[i6] = iArr[((Integer) ((List) childDataIndices.getX()).get(i6)).intValue()];
        }
        for (int i7 = 0; i7 < ((List) childDataIndices.getY()).size(); i7++) {
            dArr6[i7] = dArr[((Integer) ((List) childDataIndices.getY()).get(i7)).intValue()];
            iArr4[i7] = iArr[((Integer) ((List) childDataIndices.getY()).get(i7)).intValue()];
        }
        TreeNode<TimeSeriesTreeClassifier.TimeSeriesTreeNodeDecisionFunction> addChild = treeNode.addChild(new TimeSeriesTreeClassifier.TimeSeriesTreeNodeDecisionFunction());
        TreeNode<TimeSeriesTreeClassifier.TimeSeriesTreeNodeDecisionFunction> addChild2 = treeNode.addChild(new TimeSeriesTreeClassifier.TimeSeriesTreeNodeDecisionFunction());
        tree(dArr5, iArr3, d2, addChild, i + 1);
        tree(dArr6, iArr4, d2, addChild2, i + 1);
    }

    public static Pair<List<Integer>, List<Integer>> getChildDataIndices(double[][][] dArr, int i, int i2, int i3, double d) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (int i4 = 0; i4 < i; i4++) {
            if (dArr[i2][i3][i4] <= d) {
                arrayList.add(Integer.valueOf(i4));
            } else {
                arrayList2.add(Integer.valueOf(i4));
            }
        }
        return new Pair<>(arrayList, arrayList2);
    }

    public int getBestSplitIndex(double[] dArr) {
        if (dArr.length != TimeSeriesFeature.NUM_FEATURE_TYPES) {
            throw new IllegalArgumentException("A delta entropy star value has to be given for each feature type!");
        }
        double d = -2.147483648E9d;
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < dArr.length; i++) {
            if (dArr[i] > d) {
                d = dArr[i];
                arrayList.clear();
                arrayList.add(Integer.valueOf(i));
            } else if (dArr[i] == d) {
                arrayList.add(Integer.valueOf(i));
            }
        }
        if (arrayList.isEmpty()) {
            throw new IllegalArgumentException("Could not find any maximum delta entropy star for any feature type for the given array " + Arrays.toString(dArr) + ".");
        }
        if (arrayList.size() > 1) {
            Collections.shuffle(arrayList, new Random(getConfig().seed()));
        }
        return ((Integer) arrayList.get(0)).intValue();
    }

    public static double calculateDeltaEntropy(double[] dArr, int[] iArr, double d, List<Integer> list, double d2) {
        if (dArr.length != iArr.length) {
            throw new IllegalArgumentException("The number of data values must be the same as the number of target values!");
        }
        double[] dArr2 = new double[2];
        int size = list.size();
        int[][] iArr2 = new int[2][size];
        int[] iArr3 = new int[2];
        for (int i = 0; i < dArr.length; i++) {
            if (dArr[i] <= d) {
                int[] iArr4 = iArr2[0];
                int indexOf = list.indexOf(Integer.valueOf(iArr[i]));
                iArr4[indexOf] = iArr4[indexOf] + 1;
                iArr3[0] = iArr3[0] + 1;
            } else {
                int[] iArr5 = iArr2[1];
                int indexOf2 = list.indexOf(Integer.valueOf(iArr[i]));
                iArr5[indexOf2] = iArr5[indexOf2] + 1;
                iArr3[1] = iArr3[1] + 1;
            }
        }
        for (int i2 = 0; i2 < dArr2.length; i2++) {
            double d3 = 0.0d;
            for (int i3 = 0; i3 < size; i3++) {
                double d4 = iArr3[i2] != 0 ? iArr2[i2][i3] / iArr3[i2] : 0.0d;
                d3 += d4 < PRECISION_DELTA ? 0.0d : d4 * Math.log(d4);
            }
            dArr2[i2] = (-1.0d) * d3;
        }
        double d5 = 0.0d;
        for (int i4 = 0; i4 < dArr2.length; i4++) {
            d5 += (iArr3[i4] / dArr.length) * dArr2[i4];
        }
        return d2 - d5;
    }

    public static double calculateEntrance(double d, double d2) {
        return d + (1.0E-22d * d2);
    }

    public static double calculateMargin(double[] dArr, double d) {
        double d2 = Double.MAX_VALUE;
        for (double d3 : dArr) {
            double abs = Math.abs(d3 - d);
            if (abs < d2) {
                d2 = abs;
            }
        }
        return d2;
    }

    public double[][][] transformInstances(double[][] dArr, Pair<List<Integer>, List<Integer>> pair) {
        double[] features;
        double[][][] dArr2 = new double[TimeSeriesFeature.NUM_FEATURE_TYPES][((List) pair.getX()).size()][dArr.length];
        int length = dArr.length;
        boolean useFeatureCaching = getConfig().useFeatureCaching();
        for (int i = 0; i < length; i++) {
            for (int i2 = 0; i2 < ((List) pair.getX()).size(); i2++) {
                int intValue = ((Integer) ((List) pair.getX()).get(i2)).intValue();
                int intValue2 = ((Integer) ((List) pair.getY()).get(i2)).intValue();
                if (useFeatureCaching) {
                    long length2 = i + (dArr[i].length * intValue) + (dArr[i].length * dArr[i].length * intValue2);
                    if (this.transformedFeaturesCache.containsKey(Long.valueOf(length2))) {
                        features = this.transformedFeaturesCache.get(Long.valueOf(length2));
                    } else {
                        features = TimeSeriesFeature.getFeatures(dArr[i], intValue, intValue2, true);
                        this.transformedFeaturesCache.put(Long.valueOf(length2), features);
                    }
                } else {
                    features = TimeSeriesFeature.getFeatures(dArr[i], intValue, intValue2, true);
                }
                dArr2[0][i2][i] = features[0];
                dArr2[1][i2][i] = features[1];
                dArr2[2][i2][i] = features[2];
            }
        }
        return dArr2;
    }

    public static List<List<Double>> generateThresholdCandidates(Pair<List<Integer>, List<Integer>> pair, int i, double[][][] dArr) {
        if (i < 1) {
            throw new IllegalArgumentException("At least one candidate must be calculated!");
        }
        ArrayList arrayList = new ArrayList();
        int length = dArr[0][0].length;
        double[] dArr2 = new double[TimeSeriesFeature.NUM_FEATURE_TYPES];
        double[] dArr3 = new double[TimeSeriesFeature.NUM_FEATURE_TYPES];
        for (int i2 = 0; i2 < TimeSeriesFeature.NUM_FEATURE_TYPES; i2++) {
            arrayList.add(new ArrayList());
            dArr2[i2] = Double.MAX_VALUE;
            dArr3[i2] = -2.147483648E9d;
        }
        for (int i3 = 0; i3 < TimeSeriesFeature.NUM_FEATURE_TYPES; i3++) {
            for (int i4 = 0; i4 < length; i4++) {
                for (int i5 = 0; i5 < ((List) pair.getX()).size(); i5++) {
                    if (dArr[i3][i5][i4] < dArr2[i3]) {
                        dArr2[i3] = dArr[i3][i5][i4];
                    }
                    if (dArr[i3][i5][i4] > dArr3[i3]) {
                        dArr3[i3] = dArr[i3][i5][i4];
                    }
                }
            }
        }
        for (int i6 = 0; i6 < TimeSeriesFeature.NUM_FEATURE_TYPES; i6++) {
            double d = (dArr3[i6] - dArr2[i6]) / (i + 1);
            for (int i7 = 0; i7 < i; i7++) {
                ((List) arrayList.get(i6)).add(Double.valueOf(dArr2[i6] + ((i7 + 1) * d)));
            }
        }
        return arrayList;
    }

    public static Pair<List<Integer>, List<Integer>> sampleIntervals(int i, int i2) {
        if (i < 1) {
            throw new IllegalArgumentException("The series' length m must be greater than zero.");
        }
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        Iterator<Integer> it = randomlySampleNoReplacement((List) IntStream.rangeClosed(1, i).boxed().collect(Collectors.toList()), (int) Math.sqrt(i), i2).iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            List<Integer> randomlySampleNoReplacement = randomlySampleNoReplacement((List) IntStream.rangeClosed(0, i - intValue).boxed().collect(Collectors.toList()), (int) Math.sqrt((i - intValue) + 1.0d), i2);
            arrayList.addAll(randomlySampleNoReplacement);
            Iterator<Integer> it2 = randomlySampleNoReplacement.iterator();
            while (it2.hasNext()) {
                arrayList2.add(Integer.valueOf((it2.next().intValue() + intValue) - 1));
            }
        }
        return new Pair<>(arrayList, arrayList2);
    }

    public static List<Integer> randomlySampleNoReplacement(List<Integer> list, int i, int i2) {
        if (list == null) {
            throw new IllegalArgumentException("The list to be sampled from must not be null!");
        }
        if (i < 1 || i > list.size()) {
            throw new IllegalArgumentException("Sample size must lower equals the size of the list to be sampled from without replacement and greater zero.");
        }
        ArrayList arrayList = new ArrayList(list);
        Collections.shuffle(arrayList, new Random(i2));
        return arrayList.subList(0, i);
    }
}
