package io.trino.operator.aggregation;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import io.airlift.bytecode.DynamicClassLoader;
import io.trino.metadata.BoundSignature;
import io.trino.metadata.FunctionBinding;
import io.trino.metadata.FunctionDependencies;
import io.trino.metadata.FunctionDependencyDeclaration;
import io.trino.metadata.FunctionKind;
import io.trino.metadata.FunctionMetadata;
import io.trino.metadata.Signature;
import io.trino.metadata.SqlAggregationFunction;
import io.trino.operator.ParametricFunctionHelpers;
import io.trino.operator.ParametricImplementationsGroup;
import io.trino.operator.aggregation.AggregationMetadata;
import io.trino.operator.aggregation.state.StateCompiler;
import io.trino.operator.annotations.ImplementationDependency;
import io.trino.spi.StandardErrorCode;
import io.trino.spi.TrinoException;
import io.trino.spi.function.AccumulatorStateFactory;
import io.trino.spi.function.AccumulatorStateSerializer;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeSignature;
import java.lang.invoke.MethodHandle;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.Optional;

/* loaded from: input_file:io/trino/operator/aggregation/ParametricAggregation.class */
public class ParametricAggregation extends SqlAggregationFunction {
    private final ParametricImplementationsGroup<AggregationImplementation> implementations;

    public ParametricAggregation(Signature signature, AggregationHeader aggregationHeader, ParametricImplementationsGroup<AggregationImplementation> parametricImplementationsGroup, boolean z) {
        super(new FunctionMetadata(signature, aggregationHeader.getName(), true, parametricImplementationsGroup.getArgumentDefinitions(), aggregationHeader.isHidden(), true, aggregationHeader.getDescription().orElse(""), FunctionKind.AGGREGATE, z), aggregationHeader.isDecomposable(), aggregationHeader.isOrderSensitive());
        Objects.requireNonNull(aggregationHeader, "details is null");
        Preconditions.checkArgument(parametricImplementationsGroup.isNullable(), "currently aggregates are required to be nullable");
        this.implementations = (ParametricImplementationsGroup) Objects.requireNonNull(parametricImplementationsGroup, "implementations is null");
    }

    @Override // io.trino.metadata.SqlFunction
    public FunctionDependencyDeclaration getFunctionDependencies() {
        FunctionDependencyDeclaration.FunctionDependencyDeclarationBuilder builder = FunctionDependencyDeclaration.builder();
        declareDependencies(builder, this.implementations.getExactImplementations().values());
        declareDependencies(builder, this.implementations.getSpecializedImplementations());
        declareDependencies(builder, this.implementations.getGenericImplementations());
        return builder.build();
    }

    private static void declareDependencies(FunctionDependencyDeclaration.FunctionDependencyDeclarationBuilder functionDependencyDeclarationBuilder, Collection<AggregationImplementation> collection) {
        for (AggregationImplementation aggregationImplementation : collection) {
            Iterator<ImplementationDependency> it = aggregationImplementation.getInputDependencies().iterator();
            while (it.hasNext()) {
                it.next().declareDependencies(functionDependencyDeclarationBuilder);
            }
            Iterator<ImplementationDependency> it2 = aggregationImplementation.getCombineDependencies().iterator();
            while (it2.hasNext()) {
                it2.next().declareDependencies(functionDependencyDeclarationBuilder);
            }
            Iterator<ImplementationDependency> it3 = aggregationImplementation.getOutputDependencies().iterator();
            while (it3.hasNext()) {
                it3.next().declareDependencies(functionDependencyDeclarationBuilder);
            }
        }
    }

    @Override // io.trino.metadata.SqlAggregationFunction
    public List<TypeSignature> getIntermediateTypes(FunctionBinding functionBinding) {
        return ImmutableList.of(StateCompiler.getSerializedType(findMatchingImplementation(functionBinding.getBoundSignature()).getStateClass()).getTypeSignature());
    }

    @Override // io.trino.metadata.SqlAggregationFunction
    public InternalAggregationFunction specialize(FunctionBinding functionBinding, FunctionDependencies functionDependencies) {
        Signature signature = getFunctionMetadata().getSignature();
        AggregationImplementation findMatchingImplementation = findMatchingImplementation(functionBinding.getBoundSignature());
        List<Type> argumentTypes = functionBinding.getBoundSignature().getArgumentTypes();
        Type returnType = functionBinding.getBoundSignature().getReturnType();
        DynamicClassLoader dynamicClassLoader = new DynamicClassLoader(findMatchingImplementation.getDefinitionClass().getClassLoader(), getClass().getClassLoader());
        Class<?> stateClass = findMatchingImplementation.getStateClass();
        AccumulatorStateSerializer generateStateSerializer = StateCompiler.generateStateSerializer(stateClass, dynamicClassLoader);
        AccumulatorStateFactory generateStateFactory = StateCompiler.generateStateFactory(stateClass, dynamicClassLoader);
        MethodHandle bindDependencies = ParametricFunctionHelpers.bindDependencies(findMatchingImplementation.getInputFunction(), findMatchingImplementation.getInputDependencies(), functionBinding, functionDependencies);
        Optional<U> map = findMatchingImplementation.getRemoveInputFunction().map(methodHandle -> {
            return ParametricFunctionHelpers.bindDependencies(methodHandle, findMatchingImplementation.getRemoveInputDependencies(), functionBinding, functionDependencies);
        });
        MethodHandle bindDependencies2 = ParametricFunctionHelpers.bindDependencies(findMatchingImplementation.getCombineFunction(), findMatchingImplementation.getCombineDependencies(), functionBinding, functionDependencies);
        MethodHandle bindDependencies3 = ParametricFunctionHelpers.bindDependencies(findMatchingImplementation.getOutputFunction(), findMatchingImplementation.getOutputDependencies(), functionBinding, functionDependencies);
        return new InternalAggregationFunction(signature.getName(), argumentTypes, ImmutableList.of(generateStateSerializer.getSerializedType()), returnType, new LazyAccumulatorFactoryBinder(new AggregationMetadata(AggregationUtils.generateAggregationName(signature.getName(), returnType.getTypeSignature(), signaturesFromTypes(argumentTypes)), buildParameterMetadata(findMatchingImplementation.getInputParameterMetadataTypes(), argumentTypes), bindDependencies, map, bindDependencies2, bindDependencies3, ImmutableList.of(new AggregationMetadata.AccumulatorStateDescriptor(stateClass, generateStateSerializer, generateStateFactory)), returnType), dynamicClassLoader));
    }

    @VisibleForTesting
    public ParametricImplementationsGroup<AggregationImplementation> getImplementations() {
        return this.implementations;
    }

    private AggregationImplementation findMatchingImplementation(BoundSignature boundSignature) {
        Signature signature = boundSignature.toSignature();
        Optional empty = Optional.empty();
        if (this.implementations.getExactImplementations().containsKey(signature)) {
            empty = Optional.of(this.implementations.getExactImplementations().get(signature));
        } else {
            for (AggregationImplementation aggregationImplementation : this.implementations.getGenericImplementations()) {
                if (aggregationImplementation.areTypesAssignable(boundSignature)) {
                    if (empty.isPresent()) {
                        throw new TrinoException(StandardErrorCode.AMBIGUOUS_FUNCTION_CALL, String.format("Ambiguous function call (%s) for %s", boundSignature, getFunctionMetadata().getSignature()));
                    }
                    empty = Optional.of(aggregationImplementation);
                }
            }
        }
        if (empty.isEmpty()) {
            throw new TrinoException(StandardErrorCode.FUNCTION_IMPLEMENTATION_MISSING, String.format("Unsupported type parameters (%s) for %s", boundSignature, getFunctionMetadata().getSignature()));
        }
        return (AggregationImplementation) empty.get();
    }

    private static List<TypeSignature> signaturesFromTypes(List<Type> list) {
        return (List) list.stream().map((v0) -> {
            return v0.getTypeSignature();
        }).collect(ImmutableList.toImmutableList());
    }

    private static List<AggregationMetadata.ParameterMetadata> buildParameterMetadata(List<AggregationMetadata.ParameterMetadata.ParameterType> list, List<Type> list2) {
        ImmutableList.Builder builder = ImmutableList.builder();
        int i = 0;
        for (AggregationMetadata.ParameterMetadata.ParameterType parameterType : list) {
            switch (parameterType) {
                case STATE:
                case BLOCK_INDEX:
                    builder.add(new AggregationMetadata.ParameterMetadata(parameterType));
                    break;
                case INPUT_CHANNEL:
                case BLOCK_INPUT_CHANNEL:
                case NULLABLE_BLOCK_INPUT_CHANNEL:
                    int i2 = i;
                    i++;
                    builder.add(new AggregationMetadata.ParameterMetadata(parameterType, list2.get(i2)));
                    break;
            }
        }
        return builder.build();
    }
}
