package io.trino.operator.aggregation;

import com.google.common.collect.ImmutableList;
import io.trino.operator.aggregation.state.LongDecimalWithOverflowAndLongState;
import io.trino.operator.aggregation.state.LongDecimalWithOverflowAndLongStateFactory;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.Decimals;
import io.trino.spi.type.Int128;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.math.MathContext;
import java.math.RoundingMode;
import java.util.Iterator;
import java.util.List;
import org.testng.Assert;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;

@Test(singleThreaded = true)
/* loaded from: input_file:io/trino/operator/aggregation/TestDecimalAverageAggregation.class */
public class TestDecimalAverageAggregation {
    private static final BigInteger TWO = new BigInteger("2");
    private static final BigInteger ONE_HUNDRED = new BigInteger("100");
    private static final BigInteger TWO_HUNDRED = new BigInteger("200");
    private static final DecimalType TYPE = DecimalType.createDecimalType(38, 0);
    private LongDecimalWithOverflowAndLongState state;

    @BeforeMethod
    public void setUp() {
        this.state = new LongDecimalWithOverflowAndLongStateFactory().createSingleState();
    }

    @Test
    public void testOverflow() {
        addToState(this.state, TWO.pow(126));
        Assert.assertEquals(this.state.getLong(), 1L);
        Assert.assertEquals(this.state.getOverflow(), 0L);
        Assert.assertEquals(getDecimal(this.state), Int128.valueOf(TWO.pow(126)));
        addToState(this.state, TWO.pow(126));
        Assert.assertEquals(this.state.getLong(), 2L);
        Assert.assertEquals(this.state.getOverflow(), 1L);
        Assert.assertEquals(getDecimal(this.state), Int128.valueOf(Long.MIN_VALUE, 0L));
        assertAverageEquals(TWO.pow(126));
    }

    @Test
    public void testUnderflow() {
        addToState(this.state, Decimals.MIN_UNSCALED_DECIMAL.toBigInteger());
        Assert.assertEquals(this.state.getLong(), 1L);
        Assert.assertEquals(this.state.getOverflow(), 0L);
        Assert.assertEquals(getDecimal(this.state), Decimals.MIN_UNSCALED_DECIMAL);
        addToState(this.state, Decimals.MIN_UNSCALED_DECIMAL.toBigInteger());
        Assert.assertEquals(this.state.getLong(), 2L);
        Assert.assertEquals(this.state.getOverflow(), -1L);
        Assert.assertEquals(getDecimal(this.state), Int128.valueOf(7604722348854507275L, -1374799102801346558L));
        assertAverageEquals(Decimals.MIN_UNSCALED_DECIMAL.toBigInteger());
    }

    @Test
    public void testUnderflowAfterOverflow() {
        addToState(this.state, TWO.pow(126));
        addToState(this.state, TWO.pow(126));
        addToState(this.state, TWO.pow(125));
        Assert.assertEquals(this.state.getOverflow(), 1L);
        Assert.assertEquals(getDecimal(this.state), Int128.valueOf(-6917529027641081856L, 0L));
        addToState(this.state, TWO.pow(126).negate());
        addToState(this.state, TWO.pow(126).negate());
        addToState(this.state, TWO.pow(126).negate());
        Assert.assertEquals(this.state.getOverflow(), 0L);
        Assert.assertEquals(getDecimal(this.state), Int128.valueOf(TWO.pow(125).negate()));
        assertAverageEquals(TWO.pow(125).negate().divide(BigInteger.valueOf(6L)));
    }

    @Test
    public void testCombineOverflow() {
        addToState(this.state, TWO.pow(126));
        addToState(this.state, TWO.pow(126));
        LongDecimalWithOverflowAndLongState createSingleState = new LongDecimalWithOverflowAndLongStateFactory().createSingleState();
        addToState(createSingleState, TWO.pow(126));
        addToState(createSingleState, TWO.pow(126));
        DecimalAverageAggregation.combine(this.state, createSingleState);
        Assert.assertEquals(this.state.getLong(), 4L);
        Assert.assertEquals(this.state.getOverflow(), 1L);
        Assert.assertEquals(getDecimal(this.state), Int128.ZERO);
        assertAverageEquals(BigInteger.ZERO.add(TWO.pow(126)).add(TWO.pow(126)).add(TWO.pow(126)).add(TWO.pow(126)).divide(BigInteger.valueOf(4L)));
    }

    @Test
    public void testCombineUnderflow() {
        addToState(this.state, TWO.pow(125).negate());
        addToState(this.state, TWO.pow(126).negate());
        LongDecimalWithOverflowAndLongState createSingleState = new LongDecimalWithOverflowAndLongStateFactory().createSingleState();
        addToState(createSingleState, TWO.pow(125).negate());
        addToState(createSingleState, TWO.pow(126).negate());
        DecimalAverageAggregation.combine(this.state, createSingleState);
        Assert.assertEquals(this.state.getLong(), 4L);
        Assert.assertEquals(this.state.getOverflow(), -1L);
        Assert.assertEquals(getDecimal(this.state), Int128.valueOf(4611686018427387904L, 0L));
        assertAverageEquals(BigInteger.ZERO.add(TWO.pow(126)).add(TWO.pow(126)).add(TWO.pow(125)).add(TWO.pow(125)).negate().divide(BigInteger.valueOf(4L)));
    }

    @Test(dataProvider = "testNoOverflowDataProvider")
    public void testNoOverflow(List<BigInteger> list) {
        testNoOverflow(DecimalType.createDecimalType(38, 0), list);
        testNoOverflow(DecimalType.createDecimalType(38, 2), list);
    }

    private void testNoOverflow(DecimalType decimalType, List<BigInteger> list) {
        LongDecimalWithOverflowAndLongState createSingleState = new LongDecimalWithOverflowAndLongStateFactory().createSingleState();
        Iterator<BigInteger> it = list.iterator();
        while (it.hasNext()) {
            addToState(decimalType, createSingleState, it.next());
        }
        Assert.assertEquals(createSingleState.getOverflow(), 0L);
        BigInteger reduce = list.stream().reduce(BigInteger.ZERO, (v0, v1) -> {
            return v0.add(v1);
        });
        Assert.assertEquals(getDecimal(createSingleState), Int128.valueOf(reduce));
        Assert.assertEquals(decodeBigDecimal(decimalType, DecimalAverageAggregation.average(createSingleState, decimalType)), new BigDecimal(reduce, decimalType.getScale()).divide(BigDecimal.valueOf(list.size()), decimalType.getScale(), RoundingMode.HALF_UP));
    }

    /* JADX WARN: Type inference failed for: r0v1, types: [java.lang.Object[], java.lang.Object[][]] */
    @DataProvider
    public static Object[][] testNoOverflowDataProvider() {
        return new Object[]{new Object[]{ImmutableList.of(BigInteger.TEN.pow(37), BigInteger.ZERO)}, new Object[]{ImmutableList.of(BigInteger.TEN.pow(37).negate(), BigInteger.ZERO)}, new Object[]{ImmutableList.of(TWO, BigInteger.ONE)}, new Object[]{ImmutableList.of(BigInteger.ZERO, BigInteger.ONE)}, new Object[]{ImmutableList.of(TWO.negate(), BigInteger.ONE.negate())}, new Object[]{ImmutableList.of(BigInteger.ONE.negate(), BigInteger.ZERO)}, new Object[]{ImmutableList.of(BigInteger.ONE.negate(), BigInteger.ZERO, BigInteger.ZERO)}, new Object[]{ImmutableList.of(TWO.negate(), BigInteger.ZERO, BigInteger.ZERO)}, new Object[]{ImmutableList.of(TWO.negate(), BigInteger.ZERO)}, new Object[]{ImmutableList.of(TWO_HUNDRED, ONE_HUNDRED)}, new Object[]{ImmutableList.of(BigInteger.ZERO, ONE_HUNDRED)}, new Object[]{ImmutableList.of(TWO_HUNDRED.negate(), ONE_HUNDRED.negate())}, new Object[]{ImmutableList.of(ONE_HUNDRED.negate(), BigInteger.ZERO)}};
    }

    private static BigDecimal decodeBigDecimal(DecimalType decimalType, Int128 int128) {
        return new BigDecimal(int128.toBigInteger(), decimalType.getScale(), new MathContext(decimalType.getPrecision()));
    }

    private void assertAverageEquals(BigInteger bigInteger) {
        assertAverageEquals(bigInteger, TYPE);
    }

    private void assertAverageEquals(BigInteger bigInteger, DecimalType decimalType) {
        Assert.assertEquals(DecimalAverageAggregation.average(this.state, decimalType).toBigInteger(), bigInteger);
    }

    private static void addToState(LongDecimalWithOverflowAndLongState longDecimalWithOverflowAndLongState, BigInteger bigInteger) {
        addToState(TYPE, longDecimalWithOverflowAndLongState, bigInteger);
    }

    private static void addToState(DecimalType decimalType, LongDecimalWithOverflowAndLongState longDecimalWithOverflowAndLongState, BigInteger bigInteger) {
        if (decimalType.isShort()) {
            DecimalAverageAggregation.inputShortDecimal(longDecimalWithOverflowAndLongState, Int128.valueOf(bigInteger).toLongExact());
            return;
        }
        BlockBuilder createFixedSizeBlockBuilder = decimalType.createFixedSizeBlockBuilder(1);
        decimalType.writeObject(createFixedSizeBlockBuilder, Int128.valueOf(bigInteger));
        DecimalAverageAggregation.inputLongDecimal(longDecimalWithOverflowAndLongState, createFixedSizeBlockBuilder.build(), 0);
    }

    private Int128 getDecimal(LongDecimalWithOverflowAndLongState longDecimalWithOverflowAndLongState) {
        long[] decimalArray = longDecimalWithOverflowAndLongState.getDecimalArray();
        int decimalArrayOffset = longDecimalWithOverflowAndLongState.getDecimalArrayOffset();
        return Int128.valueOf(decimalArray[decimalArrayOffset], decimalArray[decimalArrayOffset + 1]);
    }
}
