package tl.lin.data.cfd;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Iterator;
import tl.lin.data.fd.Int2IntFrequencyDistribution;
import tl.lin.data.fd.Int2IntFrequencyDistributionEntry;
import tl.lin.data.fd.Int2LongFrequencyDistributionEntry;
import tl.lin.data.map.HMapIVW;
import tl.lin.data.pair.PairOfInts;

/* loaded from: input_file:tl/lin/data/cfd/Int2IntConditionalFrequencyDistributionEntry.class */
public class Int2IntConditionalFrequencyDistributionEntry implements Int2IntConditionalFrequencyDistribution {
    private final HMapIVW<Int2IntFrequencyDistribution> distributions = new HMapIVW<>();
    private final Int2LongFrequencyDistributionEntry marginals = new Int2LongFrequencyDistributionEntry();
    private long sumOfAllCounts = 0;

    @Override // tl.lin.data.cfd.Int2IntConditionalFrequencyDistribution
    public void set(int i, int i2, int i3) {
        if (!this.distributions.containsKey(i2)) {
            Int2IntFrequencyDistributionEntry int2IntFrequencyDistributionEntry = new Int2IntFrequencyDistributionEntry();
            int2IntFrequencyDistributionEntry.set(i, i3);
            this.distributions.put(i2, int2IntFrequencyDistributionEntry);
            this.marginals.increment(i, i3);
            this.sumOfAllCounts += i3;
            return;
        }
        Int2IntFrequencyDistribution int2IntFrequencyDistribution = (Int2IntFrequencyDistribution) this.distributions.get(i2);
        int i4 = int2IntFrequencyDistribution.get(i);
        int2IntFrequencyDistribution.set(i, i3);
        this.distributions.put(i2, int2IntFrequencyDistribution);
        this.marginals.increment(i, (-i4) + i3);
        this.sumOfAllCounts = (this.sumOfAllCounts - i4) + i3;
    }

    @Override // tl.lin.data.cfd.Int2IntConditionalFrequencyDistribution
    public void increment(int i, int i2) {
        increment(i, i2, 1);
    }

    @Override // tl.lin.data.cfd.Int2IntConditionalFrequencyDistribution
    public void increment(int i, int i2, int i3) {
        int i4 = get(i, i2);
        if (i4 == 0) {
            set(i, i2, i3);
        } else {
            set(i, i2, i4 + i3);
        }
    }

    @Override // tl.lin.data.cfd.Int2IntConditionalFrequencyDistribution
    public int get(int i, int i2) {
        if (this.distributions.containsKey(i2)) {
            return ((Int2IntFrequencyDistribution) this.distributions.get(i2)).get(i);
        }
        return 0;
    }

    @Override // tl.lin.data.cfd.Int2IntConditionalFrequencyDistribution
    public long getMarginalCount(int i) {
        return this.marginals.get(i);
    }

    @Override // tl.lin.data.cfd.Int2IntConditionalFrequencyDistribution
    public Int2IntFrequencyDistribution getConditionalDistribution(int i) {
        return this.distributions.containsKey(i) ? (Int2IntFrequencyDistribution) this.distributions.get(i) : new Int2IntFrequencyDistributionEntry();
    }

    @Override // tl.lin.data.cfd.Int2IntConditionalFrequencyDistribution
    public long getSumOfAllCounts() {
        return this.sumOfAllCounts;
    }

    @Override // tl.lin.data.cfd.Int2IntConditionalFrequencyDistribution
    public void check() {
        Int2IntFrequencyDistributionEntry<PairOfInts> int2IntFrequencyDistributionEntry = new Int2IntFrequencyDistributionEntry();
        long j = 0;
        for (V<PairOfInts> v : this.distributions.values()) {
            long j2 = 0;
            for (PairOfInts pairOfInts : v) {
                j2 += pairOfInts.getRightElement();
                int2IntFrequencyDistributionEntry.increment(pairOfInts.getLeftElement(), pairOfInts.getRightElement());
            }
            if (j2 != v.getSumOfCounts()) {
                throw new RuntimeException("Internal Error!");
            }
            j += v.getSumOfCounts();
        }
        if (j != getSumOfAllCounts()) {
            throw new RuntimeException("Internal Error! Got " + j + ", Expected " + getSumOfAllCounts());
        }
        Iterator it2 = int2IntFrequencyDistributionEntry.iterator();
        while (it2.hasNext()) {
            if (r0.getRightElement() != this.marginals.get(((PairOfInts) it2.next()).getLeftElement())) {
                throw new RuntimeException("Internal Error!");
            }
        }
        for (PairOfInts pairOfInts2 : int2IntFrequencyDistributionEntry) {
            if (pairOfInts2.getRightElement() != int2IntFrequencyDistributionEntry.get(pairOfInts2.getLeftElement())) {
                throw new RuntimeException("Internal Error!");
            }
        }
    }

    public void readFields(DataInput dataInput) throws IOException {
        this.marginals.readFields(dataInput);
        this.distributions.readFields(dataInput);
        this.sumOfAllCounts = dataInput.readLong();
    }

    public void write(DataOutput dataOutput) throws IOException {
        this.marginals.write(dataOutput);
        this.distributions.write(dataOutput);
        dataOutput.writeLong(this.sumOfAllCounts);
    }
}
