package tl.lin.data.cfd;

import com.google.common.collect.Maps;
import java.lang.Comparable;
import java.util.Iterator;
import java.util.Map;
import tl.lin.data.fd.Object2IntFrequencyDistribution;
import tl.lin.data.fd.Object2IntFrequencyDistributionEntry;
import tl.lin.data.map.HMapKL;
import tl.lin.data.pair.PairOfObjectInt;

/* loaded from: input_file:tl/lin/data/cfd/Object2IntConditionalFrequencyDistributionEntry.class */
public class Object2IntConditionalFrequencyDistributionEntry<K extends Comparable<K>> implements Object2IntConditionalFrequencyDistribution<K> {
    private final Map<K, Object2IntFrequencyDistribution<K>> distributions = Maps.newHashMap();
    private final HMapKL<K> marginals = new HMapKL<>();
    private long sumOfAllCounts = 0;

    @Override // tl.lin.data.cfd.Object2IntConditionalFrequencyDistribution
    public void set(K k, K k2, int i) {
        if (!this.distributions.containsKey(k2)) {
            Object2IntFrequencyDistributionEntry object2IntFrequencyDistributionEntry = new Object2IntFrequencyDistributionEntry();
            object2IntFrequencyDistributionEntry.set(k, i);
            this.distributions.put(k2, object2IntFrequencyDistributionEntry);
            this.marginals.increment(k, i);
            this.sumOfAllCounts += i;
            return;
        }
        Object2IntFrequencyDistribution<K> object2IntFrequencyDistribution = this.distributions.get(k2);
        int i2 = object2IntFrequencyDistribution.get(k);
        object2IntFrequencyDistribution.set(k, i);
        this.distributions.put(k2, object2IntFrequencyDistribution);
        this.marginals.increment(k, (-i2) + i);
        this.sumOfAllCounts = (this.sumOfAllCounts - i2) + i;
    }

    @Override // tl.lin.data.cfd.Object2IntConditionalFrequencyDistribution
    public void increment(K k, K k2) {
        increment(k, k2, 1);
    }

    @Override // tl.lin.data.cfd.Object2IntConditionalFrequencyDistribution
    public void increment(K k, K k2, int i) {
        int i2 = get(k, k2);
        if (i2 == 0) {
            set(k, k2, i);
        } else {
            set(k, k2, i2 + i);
        }
    }

    @Override // tl.lin.data.cfd.Object2IntConditionalFrequencyDistribution
    public int get(K k, K k2) {
        if (this.distributions.containsKey(k2)) {
            return this.distributions.get(k2).get(k);
        }
        return 0;
    }

    @Override // tl.lin.data.cfd.Object2IntConditionalFrequencyDistribution
    public long getMarginalCount(K k) {
        return this.marginals.get(k);
    }

    @Override // tl.lin.data.cfd.Object2IntConditionalFrequencyDistribution
    public Object2IntFrequencyDistribution<K> getConditionalDistribution(K k) {
        return this.distributions.containsKey(k) ? this.distributions.get(k) : new Object2IntFrequencyDistributionEntry();
    }

    @Override // tl.lin.data.cfd.Object2IntConditionalFrequencyDistribution
    public long getSumOfAllCounts() {
        return this.sumOfAllCounts;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // tl.lin.data.cfd.Object2IntConditionalFrequencyDistribution
    public void check() {
        Object2IntFrequencyDistributionEntry<PairOfObjectInt> object2IntFrequencyDistributionEntry = new Object2IntFrequencyDistributionEntry();
        long j = 0;
        for (Object2IntFrequencyDistribution<K> object2IntFrequencyDistribution : this.distributions.values()) {
            long j2 = 0;
            for (K k : object2IntFrequencyDistribution) {
                j2 += k.getRightElement();
                object2IntFrequencyDistributionEntry.increment(k.getLeftElement(), k.getRightElement());
            }
            if (j2 != object2IntFrequencyDistribution.getSumOfCounts()) {
                throw new RuntimeException("Internal Error!");
            }
            j += object2IntFrequencyDistribution.getSumOfCounts();
        }
        if (j != getSumOfAllCounts()) {
            throw new RuntimeException("Internal Error! Got " + j + ", Expected " + getSumOfAllCounts());
        }
        Iterator it2 = object2IntFrequencyDistributionEntry.iterator();
        while (it2.hasNext()) {
            if (r0.getRightElement() != this.marginals.get(((PairOfObjectInt) it2.next()).getLeftElement())) {
                throw new RuntimeException("Internal Error!");
            }
        }
        for (PairOfObjectInt pairOfObjectInt : object2IntFrequencyDistributionEntry) {
            if (pairOfObjectInt.getRightElement() != object2IntFrequencyDistributionEntry.get(pairOfObjectInt.getLeftElement())) {
                throw new RuntimeException("Internal Error!");
            }
        }
    }
}
