package org.broadinstitute.hellbender.utils.mcmc;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.math3.distribution.TDistribution;
import org.apache.commons.math3.primes.Primes;
import org.apache.commons.math3.random.RandomGenerator;
import org.broadinstitute.hellbender.exceptions.GATKException;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.param.ParamUtils;

/* loaded from: input_file:org/broadinstitute/hellbender/utils/mcmc/MinibatchSliceSampler.class */
public final class MinibatchSliceSampler<DATA> extends AbstractSliceSampler {
    private final List<DATA> data;
    private final Function<Double, Double> logPrior;
    private final BiFunction<DATA, Double, Double> logLikelihood;
    private final Integer minibatchSize;
    private final Double approxThreshold;
    private final int numDataPoints;
    private Double xSampleCache;
    private Double logPriorCache;
    private Map<DATA, Double> logLikelihoodsCache;

    public MinibatchSliceSampler(RandomGenerator randomGenerator, List<DATA> list, Function<Double, Double> function, BiFunction<DATA, Double, Double> biFunction, double d, double d2, double d3, int i, double d4) {
        super(randomGenerator, d, d2, d3);
        this.xSampleCache = null;
        this.logPriorCache = null;
        this.logLikelihoodsCache = null;
        Utils.nonNull(list);
        Utils.nonNull(function);
        Utils.nonNull(biFunction);
        Utils.validateArg(i > 1, "Minibatch size must be greater than 1.");
        ParamUtils.isPositiveOrZero(d4, "Minibatch approximation threshold must be non-negative.");
        this.data = Collections.unmodifiableList(new ArrayList(list));
        this.logPrior = function;
        this.logLikelihood = biFunction;
        this.minibatchSize = Integer.valueOf(i);
        this.approxThreshold = Double.valueOf(d4);
        this.numDataPoints = list.size();
    }

    public MinibatchSliceSampler(RandomGenerator randomGenerator, List<DATA> list, Function<Double, Double> function, BiFunction<DATA, Double, Double> biFunction, double d, int i, double d2) {
        this(randomGenerator, list, function, biFunction, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, d, i, d2);
    }

    @Override // org.broadinstitute.hellbender.utils.mcmc.AbstractSliceSampler
    boolean isGreaterThanSliceHeight(double d, double d2, double d3) {
        if (d < this.xMin || this.xMax < d) {
            return false;
        }
        if (this.xSampleCache == null || this.xSampleCache.doubleValue() != d2) {
            this.xSampleCache = Double.valueOf(d2);
            this.logPriorCache = this.logPrior.apply(Double.valueOf(d2));
            this.logLikelihoodsCache = new HashMap(this.numDataPoints);
        }
        if (this.xSampleCache == null || this.logPriorCache == null || this.logLikelihoodsCache == null) {
            throw new GATKException.ShouldNeverReachHereException("Cache for xSample is in an invalid state.");
        }
        if (this.numDataPoints == 0) {
            return this.logPrior.apply(Double.valueOf(d)).doubleValue() > this.logPriorCache.doubleValue() - d3;
        }
        double doubleValue = ((this.logPriorCache.doubleValue() - this.logPrior.apply(Double.valueOf(d)).doubleValue()) - d3) / this.numDataPoints;
        int max = Math.max(this.numDataPoints / this.minibatchSize.intValue(), 1);
        Iterator<DATA> lazyShuffleIterator = max > 1 ? lazyShuffleIterator(this.rng, this.data) : this.data.iterator();
        int i = 0;
        double d4 = 0.0d;
        double d5 = 0.0d;
        for (int i2 = 0; i2 < max; i2++) {
            int min = Math.min((i2 + 1) * this.minibatchSize.intValue(), this.numDataPoints) - (i2 * this.minibatchSize.intValue());
            double d6 = 0.0d;
            double d7 = 0.0d;
            for (Object obj : (List) IntStream.range(0, min).boxed().map(num -> {
                return lazyShuffleIterator.next();
            }).collect(Collectors.toList())) {
                double doubleValue2 = ((Double) this.logLikelihood.apply(obj, Double.valueOf(d))).doubleValue() - ((Double) this.logLikelihoodsCache.computeIfAbsent(obj, obj2 -> {
                    return this.logLikelihood.apply(obj2, Double.valueOf(d2));
                })).doubleValue();
                d6 += doubleValue2;
                d7 += doubleValue2 * doubleValue2;
            }
            d4 = ((i * d4) + d6) / (i + min);
            d5 = ((i * d5) + d7) / (i + min);
            i += min;
            if (max == 1) {
                break;
            }
            if (1.0d - new TDistribution((RandomGenerator) null, i - 1).cumulativeProbability(Math.abs((d4 - doubleValue) / (Math.sqrt(1.0d - (i / this.numDataPoints)) * Math.sqrt((d5 - Math.pow(d4, 2.0d)) / (i - 1))))) < this.approxThreshold.doubleValue()) {
                break;
            }
        }
        return d4 > doubleValue;
    }

    private static <T> Iterator<T> lazyShuffleIterator(final RandomGenerator randomGenerator, final List<T> list) {
        final int size = list.size();
        final int nextPrime = Primes.nextPrime(size);
        return new Iterator<T>() { // from class: org.broadinstitute.hellbender.utils.mcmc.MinibatchSliceSampler.1
            int numSeen = 0;
            int index;
            final int increment;

            {
                this.index = randomGenerator.nextInt(size) + 1;
                this.increment = this.index;
            }

            @Override // java.util.Iterator
            public boolean hasNext() {
                return this.numSeen < list.size();
            }

            @Override // java.util.Iterator
            public T next() {
                do {
                    this.index = (this.index + this.increment) % nextPrime;
                } while (this.index >= size);
                this.numSeen++;
                return (T) list.get(this.index);
            }
        };
    }

    @Override // org.broadinstitute.hellbender.utils.mcmc.AbstractSliceSampler
    public /* bridge */ /* synthetic */ List sample(double d, int i) {
        return super.sample(d, i);
    }

    @Override // org.broadinstitute.hellbender.utils.mcmc.AbstractSliceSampler
    public /* bridge */ /* synthetic */ double sample(double d) {
        return super.sample(d);
    }
}
