package io.trino.operator.aggregation;

import io.trino.operator.aggregation.state.LongDecimalWithOverflowState;
import io.trino.operator.aggregation.state.LongDecimalWithOverflowStateFactory;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.block.BlockBuilderStatus;
import io.trino.spi.block.VariableWidthBlockBuilder;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.Int128;
import java.math.BigInteger;
import org.assertj.core.api.Assertions;
import org.testng.Assert;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;

@Test(singleThreaded = true)
/* loaded from: input_file:io/trino/operator/aggregation/TestDecimalSumAggregation.class */
public class TestDecimalSumAggregation {
    private static final BigInteger TWO = new BigInteger("2");
    private static final DecimalType TYPE = DecimalType.createDecimalType(38, 0);
    private LongDecimalWithOverflowState state;

    @BeforeMethod
    public void setUp() {
        this.state = new LongDecimalWithOverflowStateFactory().createSingleState();
    }

    @Test
    public void testOverflow() {
        addToState(this.state, TWO.pow(126));
        Assert.assertEquals(this.state.getOverflow(), 0L);
        Assert.assertEquals(getDecimal(this.state), Int128.valueOf(TWO.pow(126)));
        addToState(this.state, TWO.pow(126));
        Assert.assertEquals(this.state.getOverflow(), 1L);
        Assert.assertEquals(getDecimal(this.state), Int128.valueOf(Long.MIN_VALUE, 0L));
    }

    @Test
    public void testUnderflow() {
        addToState(this.state, TWO.pow(126).negate());
        Assert.assertEquals(this.state.getOverflow(), 0L);
        Assert.assertEquals(getDecimal(this.state), Int128.valueOf(TWO.pow(126).negate()));
        addToState(this.state, TWO.pow(126).negate());
        Assert.assertEquals(this.state.getOverflow(), 0L);
        Assert.assertEquals(getDecimal(this.state), Int128.valueOf(Long.MIN_VALUE, 0L));
    }

    @Test
    public void testUnderflowAfterOverflow() {
        addToState(this.state, TWO.pow(126));
        addToState(this.state, TWO.pow(126));
        addToState(this.state, TWO.pow(125));
        Assert.assertEquals(this.state.getOverflow(), 1L);
        Assert.assertEquals(getDecimal(this.state), Int128.valueOf(-6917529027641081856L, 0L));
        addToState(this.state, TWO.pow(126).negate());
        addToState(this.state, TWO.pow(126).negate());
        addToState(this.state, TWO.pow(126).negate());
        Assert.assertEquals(this.state.getOverflow(), 0L);
        Assert.assertEquals(getDecimal(this.state), Int128.valueOf(TWO.pow(125).negate()));
    }

    @Test
    public void testCombineOverflow() {
        addToState(this.state, TWO.pow(125));
        addToState(this.state, TWO.pow(126));
        LongDecimalWithOverflowState createSingleState = new LongDecimalWithOverflowStateFactory().createSingleState();
        addToState(createSingleState, TWO.pow(125));
        addToState(createSingleState, TWO.pow(126));
        DecimalSumAggregation.combine(this.state, createSingleState);
        Assert.assertEquals(this.state.getOverflow(), 1L);
        Assert.assertEquals(getDecimal(this.state), Int128.valueOf(-4611686018427387904L, 0L));
    }

    @Test
    public void testCombineUnderflow() {
        addToState(this.state, TWO.pow(125).negate());
        addToState(this.state, TWO.pow(126).negate());
        LongDecimalWithOverflowState createSingleState = new LongDecimalWithOverflowStateFactory().createSingleState();
        addToState(createSingleState, TWO.pow(125).negate());
        addToState(createSingleState, TWO.pow(126).negate());
        DecimalSumAggregation.combine(this.state, createSingleState);
        Assert.assertEquals(this.state.getOverflow(), -1L);
        Assert.assertEquals(getDecimal(this.state), Int128.valueOf(4611686018427387904L, 0L));
    }

    @Test
    public void testOverflowOnOutput() {
        addToState(this.state, TWO.pow(126));
        addToState(this.state, TWO.pow(126));
        Assert.assertEquals(this.state.getOverflow(), 1L);
        Assertions.assertThatThrownBy(() -> {
            DecimalSumAggregation.outputLongDecimal(this.state, new VariableWidthBlockBuilder((BlockBuilderStatus) null, 10, 100));
        }).isInstanceOf(ArithmeticException.class).hasMessage("Decimal overflow");
    }

    private static void addToState(LongDecimalWithOverflowState longDecimalWithOverflowState, BigInteger bigInteger) {
        BlockBuilder createFixedSizeBlockBuilder = TYPE.createFixedSizeBlockBuilder(1);
        TYPE.writeObject(createFixedSizeBlockBuilder, Int128.valueOf(bigInteger));
        if (TYPE.isShort()) {
            DecimalSumAggregation.inputShortDecimal(longDecimalWithOverflowState, createFixedSizeBlockBuilder.build(), 0);
        } else {
            DecimalSumAggregation.inputLongDecimal(longDecimalWithOverflowState, createFixedSizeBlockBuilder.build(), 0);
        }
    }

    private Int128 getDecimal(LongDecimalWithOverflowState longDecimalWithOverflowState) {
        long[] decimalArray = longDecimalWithOverflowState.getDecimalArray();
        int decimalArrayOffset = longDecimalWithOverflowState.getDecimalArrayOffset();
        return Int128.valueOf(decimalArray[decimalArrayOffset], decimalArray[decimalArrayOffset + 1]);
    }
}
