package io.trino.operator.aggregation;

import com.google.common.annotations.VisibleForTesting;
import io.trino.operator.aggregation.state.LongDecimalWithOverflowAndLongState;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.function.AggregationFunction;
import io.trino.spi.function.AggregationState;
import io.trino.spi.function.BlockIndex;
import io.trino.spi.function.BlockPosition;
import io.trino.spi.function.CombineFunction;
import io.trino.spi.function.Description;
import io.trino.spi.function.InputFunction;
import io.trino.spi.function.LiteralParameters;
import io.trino.spi.function.OutputFunction;
import io.trino.spi.function.SqlType;
import io.trino.spi.function.TypeParameter;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.Decimals;
import io.trino.spi.type.Int128;
import io.trino.spi.type.Int128Math;
import io.trino.spi.type.Type;
import java.math.BigDecimal;
import java.math.BigInteger;

@AggregationFunction("avg")
@Description("Calculates the average value")
/* loaded from: input_file:io/trino/operator/aggregation/DecimalAverageAggregation.class */
public final class DecimalAverageAggregation {
    private static final BigInteger TWO = new BigInteger("2");
    private static final BigInteger OVERFLOW_MULTIPLIER = TWO.pow(128);

    private DecimalAverageAggregation() {
    }

    @InputFunction
    @LiteralParameters({"p", "s"})
    public static void inputShortDecimal(@AggregationState LongDecimalWithOverflowAndLongState longDecimalWithOverflowAndLongState, @BlockPosition @SqlType(value = "decimal(p, s)", nativeContainerType = long.class) Block block, @BlockIndex int i) {
        longDecimalWithOverflowAndLongState.addLong(1L);
        longDecimalWithOverflowAndLongState.setNotNull();
        long[] decimalArray = longDecimalWithOverflowAndLongState.getDecimalArray();
        int decimalArrayOffset = longDecimalWithOverflowAndLongState.getDecimalArrayOffset();
        long j = block.getLong(i, 0);
        longDecimalWithOverflowAndLongState.addOverflow(Int128Math.addWithOverflow(decimalArray[decimalArrayOffset], decimalArray[decimalArrayOffset + 1], j >> 63, j, decimalArray, decimalArrayOffset));
    }

    @InputFunction
    @LiteralParameters({"p", "s"})
    public static void inputLongDecimal(@AggregationState LongDecimalWithOverflowAndLongState longDecimalWithOverflowAndLongState, @BlockPosition @SqlType(value = "decimal(p, s)", nativeContainerType = Int128.class) Block block, @BlockIndex int i) {
        longDecimalWithOverflowAndLongState.addLong(1L);
        longDecimalWithOverflowAndLongState.setNotNull();
        long[] decimalArray = longDecimalWithOverflowAndLongState.getDecimalArray();
        int decimalArrayOffset = longDecimalWithOverflowAndLongState.getDecimalArrayOffset();
        longDecimalWithOverflowAndLongState.addOverflow(Int128Math.addWithOverflow(decimalArray[decimalArrayOffset], decimalArray[decimalArrayOffset + 1], block.getLong(i, 0), block.getLong(i, 8), decimalArray, decimalArrayOffset));
    }

    @CombineFunction
    public static void combine(@AggregationState LongDecimalWithOverflowAndLongState longDecimalWithOverflowAndLongState, @AggregationState LongDecimalWithOverflowAndLongState longDecimalWithOverflowAndLongState2) {
        longDecimalWithOverflowAndLongState.addLong(longDecimalWithOverflowAndLongState2.getLong());
        long[] decimalArray = longDecimalWithOverflowAndLongState.getDecimalArray();
        int decimalArrayOffset = longDecimalWithOverflowAndLongState.getDecimalArrayOffset();
        long[] decimalArray2 = longDecimalWithOverflowAndLongState2.getDecimalArray();
        int decimalArrayOffset2 = longDecimalWithOverflowAndLongState2.getDecimalArrayOffset();
        if (longDecimalWithOverflowAndLongState.isNotNull()) {
            longDecimalWithOverflowAndLongState.addOverflow(Int128Math.addWithOverflow(decimalArray[decimalArrayOffset], decimalArray[decimalArrayOffset + 1], decimalArray2[decimalArrayOffset2], decimalArray2[decimalArrayOffset2 + 1], decimalArray, decimalArrayOffset) + longDecimalWithOverflowAndLongState2.getOverflow());
            return;
        }
        longDecimalWithOverflowAndLongState.setNotNull();
        decimalArray[decimalArrayOffset] = decimalArray2[decimalArrayOffset2];
        decimalArray[decimalArrayOffset + 1] = decimalArray2[decimalArrayOffset2 + 1];
        longDecimalWithOverflowAndLongState.setOverflow(longDecimalWithOverflowAndLongState2.getOverflow());
    }

    @OutputFunction("decimal(p,s)")
    public static void outputShortDecimal(@TypeParameter("decimal(p,s)") Type type, @AggregationState LongDecimalWithOverflowAndLongState longDecimalWithOverflowAndLongState, BlockBuilder blockBuilder) {
        DecimalType decimalType = (DecimalType) type;
        if (longDecimalWithOverflowAndLongState.getLong() == 0) {
            blockBuilder.appendNull();
            return;
        }
        Int128 average = average(longDecimalWithOverflowAndLongState, decimalType);
        if (decimalType.isShort()) {
            Decimals.writeShortDecimal(blockBuilder, average.toLongExact());
        } else {
            type.writeObject(blockBuilder, average);
        }
    }

    @VisibleForTesting
    public static Int128 average(LongDecimalWithOverflowAndLongState longDecimalWithOverflowAndLongState, DecimalType decimalType) {
        long[] decimalArray = longDecimalWithOverflowAndLongState.getDecimalArray();
        int decimalArrayOffset = longDecimalWithOverflowAndLongState.getDecimalArrayOffset();
        long overflow = longDecimalWithOverflowAndLongState.getOverflow();
        if (overflow != 0) {
            return Decimals.encodeScaledValue(new BigDecimal(Int128.valueOf(decimalArray[decimalArrayOffset], decimalArray[decimalArrayOffset + 1]).toBigInteger(), decimalType.getScale()).add(new BigDecimal(OVERFLOW_MULTIPLIER.multiply(BigInteger.valueOf(overflow)))).divide(BigDecimal.valueOf(longDecimalWithOverflowAndLongState.getLong()), decimalType.getScale(), 4), decimalType.getScale());
        }
        Int128 divideRoundUp = Int128Math.divideRoundUp(decimalArray[decimalArrayOffset], decimalArray[decimalArrayOffset + 1], 0, 0L, longDecimalWithOverflowAndLongState.getLong(), 0);
        if (Decimals.overflows(divideRoundUp)) {
            throw new ArithmeticException("Decimal overflow");
        }
        return divideRoundUp;
    }
}
