package io.trino.operator.aggregation;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import io.airlift.bytecode.DynamicClassLoader;
import io.airlift.slice.Slice;
import io.trino.metadata.FunctionArgumentDefinition;
import io.trino.metadata.FunctionBinding;
import io.trino.metadata.FunctionKind;
import io.trino.metadata.FunctionMetadata;
import io.trino.metadata.Signature;
import io.trino.metadata.SqlAggregationFunction;
import io.trino.operator.aggregation.AggregationMetadata;
import io.trino.operator.aggregation.state.LongDecimalWithOverflowAndLongState;
import io.trino.operator.aggregation.state.LongDecimalWithOverflowAndLongStateFactory;
import io.trino.operator.aggregation.state.LongDecimalWithOverflowAndLongStateSerializer;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.Decimals;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeSignature;
import io.trino.spi.type.TypeSignatureParameter;
import io.trino.spi.type.UnscaledDecimal128Arithmetic;
import io.trino.util.Reflection;
import java.lang.invoke.MethodHandle;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.util.List;
import java.util.Optional;

/* loaded from: input_file:io/trino/operator/aggregation/DecimalAverageAggregation.class */
public class DecimalAverageAggregation extends SqlAggregationFunction {
    private static final String NAME = "avg";
    public static final DecimalAverageAggregation DECIMAL_AVERAGE_AGGREGATION = new DecimalAverageAggregation();
    private static final MethodHandle SHORT_DECIMAL_INPUT_FUNCTION = Reflection.methodHandle(DecimalAverageAggregation.class, "inputShortDecimal", Type.class, LongDecimalWithOverflowAndLongState.class, Block.class, Integer.TYPE);
    private static final MethodHandle LONG_DECIMAL_INPUT_FUNCTION = Reflection.methodHandle(DecimalAverageAggregation.class, "inputLongDecimal", Type.class, LongDecimalWithOverflowAndLongState.class, Block.class, Integer.TYPE);
    private static final MethodHandle SHORT_DECIMAL_OUTPUT_FUNCTION = Reflection.methodHandle(DecimalAverageAggregation.class, "outputShortDecimal", DecimalType.class, LongDecimalWithOverflowAndLongState.class, BlockBuilder.class);
    private static final MethodHandle LONG_DECIMAL_OUTPUT_FUNCTION = Reflection.methodHandle(DecimalAverageAggregation.class, "outputLongDecimal", DecimalType.class, LongDecimalWithOverflowAndLongState.class, BlockBuilder.class);
    private static final MethodHandle COMBINE_FUNCTION = Reflection.methodHandle(DecimalAverageAggregation.class, "combine", LongDecimalWithOverflowAndLongState.class, LongDecimalWithOverflowAndLongState.class);
    private static final BigInteger TWO = new BigInteger("2");
    private static final BigInteger OVERFLOW_MULTIPLIER = TWO.shiftLeft(126);

    public DecimalAverageAggregation() {
        super(new FunctionMetadata(new Signature(NAME, new TypeSignature("decimal", new TypeSignatureParameter[]{TypeSignatureParameter.typeVariable("p"), TypeSignatureParameter.typeVariable("s")}), (List<TypeSignature>) ImmutableList.of(new TypeSignature("decimal", new TypeSignatureParameter[]{TypeSignatureParameter.typeVariable("p"), TypeSignatureParameter.typeVariable("s")}))), true, ImmutableList.of(new FunctionArgumentDefinition(false)), false, true, "Calculates the average value", FunctionKind.AGGREGATE), true, false);
    }

    @Override // io.trino.metadata.SqlAggregationFunction
    public List<TypeSignature> getIntermediateTypes(FunctionBinding functionBinding) {
        return ImmutableList.of(new LongDecimalWithOverflowAndLongStateSerializer().getSerializedType().getTypeSignature());
    }

    @Override // io.trino.metadata.SqlAggregationFunction
    public InternalAggregationFunction specialize(FunctionBinding functionBinding) {
        return generateAggregation((Type) Iterables.getOnlyElement(functionBinding.getBoundSignature().getArgumentTypes()));
    }

    private static InternalAggregationFunction generateAggregation(Type type) {
        MethodHandle methodHandle;
        MethodHandle methodHandle2;
        Preconditions.checkArgument(type instanceof DecimalType, "type must be Decimal");
        DynamicClassLoader dynamicClassLoader = new DynamicClassLoader(DecimalAverageAggregation.class.getClassLoader());
        ImmutableList of = ImmutableList.of(type);
        LongDecimalWithOverflowAndLongStateSerializer longDecimalWithOverflowAndLongStateSerializer = new LongDecimalWithOverflowAndLongStateSerializer();
        if (((DecimalType) type).isShort()) {
            methodHandle = SHORT_DECIMAL_INPUT_FUNCTION;
            methodHandle2 = SHORT_DECIMAL_OUTPUT_FUNCTION;
        } else {
            methodHandle = LONG_DECIMAL_INPUT_FUNCTION;
            methodHandle2 = LONG_DECIMAL_OUTPUT_FUNCTION;
        }
        AggregationMetadata aggregationMetadata = new AggregationMetadata(AggregationUtils.generateAggregationName(NAME, type.getTypeSignature(), (List) of.stream().map((v0) -> {
            return v0.getTypeSignature();
        }).collect(ImmutableList.toImmutableList())), createInputParameterMetadata(type), methodHandle.bindTo(type), Optional.empty(), COMBINE_FUNCTION, methodHandle2.bindTo(type), ImmutableList.of(new AggregationMetadata.AccumulatorStateDescriptor(LongDecimalWithOverflowAndLongState.class, longDecimalWithOverflowAndLongStateSerializer, new LongDecimalWithOverflowAndLongStateFactory())), type);
        Type serializedType = longDecimalWithOverflowAndLongStateSerializer.getSerializedType();
        return new InternalAggregationFunction(NAME, of, ImmutableList.of(serializedType), type, AccumulatorCompiler.generateAccumulatorFactoryBinder(aggregationMetadata, dynamicClassLoader));
    }

    private static List<AggregationMetadata.ParameterMetadata> createInputParameterMetadata(Type type) {
        return ImmutableList.of(new AggregationMetadata.ParameterMetadata(AggregationMetadata.ParameterMetadata.ParameterType.STATE), new AggregationMetadata.ParameterMetadata(AggregationMetadata.ParameterMetadata.ParameterType.BLOCK_INPUT_CHANNEL, type), new AggregationMetadata.ParameterMetadata(AggregationMetadata.ParameterMetadata.ParameterType.BLOCK_INDEX));
    }

    public static void inputShortDecimal(Type type, LongDecimalWithOverflowAndLongState longDecimalWithOverflowAndLongState, Block block, int i) {
        longDecimalWithOverflowAndLongState.addLong(1L);
        Slice longDecimal = longDecimalWithOverflowAndLongState.getLongDecimal();
        if (longDecimal == null) {
            longDecimal = UnscaledDecimal128Arithmetic.unscaledDecimal();
            longDecimalWithOverflowAndLongState.setLongDecimal(longDecimal);
        }
        longDecimalWithOverflowAndLongState.addOverflow(UnscaledDecimal128Arithmetic.addWithOverflow(longDecimal, UnscaledDecimal128Arithmetic.unscaledDecimal(type.getLong(block, i)), longDecimal));
    }

    public static void inputLongDecimal(Type type, LongDecimalWithOverflowAndLongState longDecimalWithOverflowAndLongState, Block block, int i) {
        longDecimalWithOverflowAndLongState.addLong(1L);
        Slice longDecimal = longDecimalWithOverflowAndLongState.getLongDecimal();
        if (longDecimal == null) {
            longDecimal = UnscaledDecimal128Arithmetic.unscaledDecimal();
            longDecimalWithOverflowAndLongState.setLongDecimal(longDecimal);
        }
        longDecimalWithOverflowAndLongState.addOverflow(UnscaledDecimal128Arithmetic.addWithOverflow(longDecimal, type.getSlice(block, i), longDecimal));
    }

    public static void combine(LongDecimalWithOverflowAndLongState longDecimalWithOverflowAndLongState, LongDecimalWithOverflowAndLongState longDecimalWithOverflowAndLongState2) {
        longDecimalWithOverflowAndLongState.addLong(longDecimalWithOverflowAndLongState2.getLong());
        long overflow = longDecimalWithOverflowAndLongState2.getOverflow();
        Slice longDecimal = longDecimalWithOverflowAndLongState.getLongDecimal();
        if (longDecimal == null) {
            longDecimalWithOverflowAndLongState.setLongDecimal(longDecimalWithOverflowAndLongState2.getLongDecimal());
        } else {
            overflow += UnscaledDecimal128Arithmetic.addWithOverflow(longDecimal, longDecimalWithOverflowAndLongState2.getLongDecimal(), longDecimal);
        }
        longDecimalWithOverflowAndLongState.addOverflow(overflow);
    }

    public static void outputShortDecimal(DecimalType decimalType, LongDecimalWithOverflowAndLongState longDecimalWithOverflowAndLongState, BlockBuilder blockBuilder) {
        if (longDecimalWithOverflowAndLongState.getLong() == 0) {
            blockBuilder.appendNull();
        } else {
            Decimals.writeShortDecimal(blockBuilder, average(longDecimalWithOverflowAndLongState, decimalType).unscaledValue().longValueExact());
        }
    }

    public static void outputLongDecimal(DecimalType decimalType, LongDecimalWithOverflowAndLongState longDecimalWithOverflowAndLongState, BlockBuilder blockBuilder) {
        if (longDecimalWithOverflowAndLongState.getLong() == 0) {
            blockBuilder.appendNull();
        } else {
            Decimals.writeBigDecimal(decimalType, blockBuilder, average(longDecimalWithOverflowAndLongState, decimalType));
        }
    }

    @VisibleForTesting
    public static BigDecimal average(LongDecimalWithOverflowAndLongState longDecimalWithOverflowAndLongState, DecimalType decimalType) {
        BigDecimal bigDecimal = new BigDecimal(Decimals.decodeUnscaledValue(longDecimalWithOverflowAndLongState.getLongDecimal()), decimalType.getScale());
        BigDecimal valueOf = BigDecimal.valueOf(longDecimalWithOverflowAndLongState.getLong());
        long overflow = longDecimalWithOverflowAndLongState.getOverflow();
        if (overflow != 0) {
            bigDecimal = bigDecimal.add(new BigDecimal(OVERFLOW_MULTIPLIER.multiply(BigInteger.valueOf(overflow))));
        }
        return bigDecimal.divide(valueOf, decimalType.getScale(), 4);
    }
}
