/*
 * Decompiled with CFR 0.152.
 */
package org.apache.calcite.interpreter;

import com.google.common.collect.ImmutableList;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Type;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Supplier;
import org.apache.calcite.DataContext;
import org.apache.calcite.adapter.enumerable.AggImpState;
import org.apache.calcite.adapter.enumerable.JavaRowFormat;
import org.apache.calcite.adapter.enumerable.PhysType;
import org.apache.calcite.adapter.enumerable.PhysTypeImpl;
import org.apache.calcite.adapter.enumerable.RexToLixTranslator;
import org.apache.calcite.adapter.enumerable.impl.AggAddContextImpl;
import org.apache.calcite.adapter.java.JavaTypeFactory;
import org.apache.calcite.interpreter.AbstractSingleNode;
import org.apache.calcite.interpreter.Compiler;
import org.apache.calcite.interpreter.Context;
import org.apache.calcite.interpreter.JaninoRexCompiler;
import org.apache.calcite.interpreter.Row;
import org.apache.calcite.interpreter.Scalar;
import org.apache.calcite.interpreter.Sink;
import org.apache.calcite.linq4j.tree.BlockBuilder;
import org.apache.calcite.linq4j.tree.Expression;
import org.apache.calcite.linq4j.tree.Expressions;
import org.apache.calcite.linq4j.tree.ParameterExpression;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.schema.impl.AggregateFunctionImpl;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.validate.SqlConformanceEnum;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.Pair;

public class AggregateNode
extends AbstractSingleNode<Aggregate> {
    private final List<Grouping> groups = new ArrayList<Grouping>();
    private final ImmutableBitSet unionGroups;
    private final int outputRowLength;
    private final ImmutableList<AccumulatorFactory> accumulatorFactories;
    private final DataContext dataContext;

    public AggregateNode(Compiler compiler, Aggregate rel) {
        super(compiler, rel);
        this.dataContext = compiler.getDataContext();
        ImmutableBitSet union = ImmutableBitSet.of();
        if (rel.getGroupSets() != null) {
            for (ImmutableBitSet group : rel.getGroupSets()) {
                union = union.union(group);
                this.groups.add(new Grouping(group));
            }
        }
        this.unionGroups = union;
        this.outputRowLength = this.unionGroups.cardinality() + rel.getAggCallList().size();
        ImmutableList.Builder builder = ImmutableList.builder();
        for (AggregateCall aggregateCall : rel.getAggCallList()) {
            builder.add((Object)this.getAccumulator(aggregateCall, false));
        }
        this.accumulatorFactories = builder.build();
    }

    @Override
    public void run() throws InterruptedException {
        Row r;
        while ((r = this.source.receive()) != null) {
            for (Grouping group : this.groups) {
                group.send(r);
            }
        }
        for (Grouping group : this.groups) {
            group.end(this.sink);
        }
    }

    private AccumulatorFactory getAccumulator(AggregateCall call, boolean ignoreFilter) {
        if (call.filterArg >= 0 && !ignoreFilter) {
            AccumulatorFactory factory = this.getAccumulator(call, true);
            return () -> {
                Accumulator accumulator = (Accumulator)factory.get();
                return new FilterAccumulator(accumulator, call.filterArg);
            };
        }
        if (call.getAggregation() == SqlStdOperatorTable.COUNT) {
            return () -> new CountAccumulator(call);
        }
        if (call.getAggregation() == SqlStdOperatorTable.SUM || call.getAggregation() == SqlStdOperatorTable.SUM0) {
            Class clazz;
            switch (call.type.getSqlTypeName()) {
                case DOUBLE: 
                case REAL: 
                case FLOAT: {
                    clazz = DoubleSum.class;
                    break;
                }
                case DECIMAL: {
                    clazz = BigDecimalSum.class;
                    break;
                }
                case INTEGER: {
                    clazz = IntSum.class;
                    break;
                }
                default: {
                    clazz = LongSum.class;
                }
            }
            if (call.getAggregation() == SqlStdOperatorTable.SUM) {
                return new UdaAccumulatorFactory(AggregateFunctionImpl.create(clazz), call, true);
            }
            return new UdaAccumulatorFactory(AggregateFunctionImpl.create(clazz), call, false);
        }
        if (call.getAggregation() == SqlStdOperatorTable.MIN) {
            Class clazz;
            switch (call.getType().getSqlTypeName()) {
                case INTEGER: {
                    clazz = MinInt.class;
                    break;
                }
                case FLOAT: {
                    clazz = MinFloat.class;
                    break;
                }
                case DOUBLE: 
                case REAL: {
                    clazz = MinDouble.class;
                    break;
                }
                case DECIMAL: {
                    clazz = MinBigDecimal.class;
                    break;
                }
                case BOOLEAN: {
                    clazz = MinBoolean.class;
                    break;
                }
                default: {
                    clazz = MinLong.class;
                }
            }
            return new UdaAccumulatorFactory(AggregateFunctionImpl.create(clazz), call, true);
        }
        if (call.getAggregation() == SqlStdOperatorTable.MAX) {
            Class clazz;
            switch (call.getType().getSqlTypeName()) {
                case INTEGER: {
                    clazz = MaxInt.class;
                    break;
                }
                case FLOAT: {
                    clazz = MaxFloat.class;
                    break;
                }
                case DOUBLE: 
                case REAL: {
                    clazz = MaxDouble.class;
                    break;
                }
                case DECIMAL: {
                    clazz = MaxBigDecimal.class;
                    break;
                }
                default: {
                    clazz = MaxLong.class;
                }
            }
            return new UdaAccumulatorFactory(AggregateFunctionImpl.create(clazz), call, true);
        }
        final JavaTypeFactory typeFactory = (JavaTypeFactory)((Aggregate)this.rel).getCluster().getTypeFactory();
        int stateOffset = 0;
        final AggImpState agg = new AggImpState(0, call, false);
        int stateSize = agg.state.size();
        BlockBuilder builder2 = new BlockBuilder();
        final PhysType inputPhysType = PhysTypeImpl.of(typeFactory, ((Aggregate)this.rel).getInput().getRowType(), JavaRowFormat.ARRAY);
        RelDataTypeFactory.FieldInfoBuilder builder = typeFactory.builder();
        for (Expression expression : agg.state) {
            ((RelDataTypeFactory.Builder)builder).add("a", typeFactory.createJavaType((Class)expression.getType()));
        }
        PhysType accPhysType = PhysTypeImpl.of(typeFactory, builder.build(), JavaRowFormat.ARRAY);
        final ParameterExpression inParameter = Expressions.parameter((Type)inputPhysType.getJavaRowType(), (String)"in");
        ParameterExpression acc_ = Expressions.parameter((Type)accPhysType.getJavaRowType(), (String)"acc");
        ArrayList<Expression> accumulator = new ArrayList<Expression>(stateSize);
        for (int j = 0; j < stateSize; ++j) {
            accumulator.add(accPhysType.fieldReference((Expression)acc_, j + stateOffset));
        }
        agg.state = accumulator;
        AggAddContextImpl addContext = new AggAddContextImpl(builder2, accumulator){

            @Override
            public List<RexNode> rexArguments() {
                ArrayList<RexNode> args = new ArrayList<RexNode>();
                for (int index : agg.call.getArgList()) {
                    args.add(RexInputRef.of(index, inputPhysType.getRowType()));
                }
                return args;
            }

            @Override
            public RexNode rexFilterArgument() {
                return agg.call.filterArg < 0 ? null : RexInputRef.of(agg.call.filterArg, inputPhysType.getRowType());
            }

            @Override
            public RexToLixTranslator rowTranslator() {
                SqlConformanceEnum conformance = SqlConformanceEnum.DEFAULT;
                return RexToLixTranslator.forAggregation(typeFactory, this.currentBlock(), new RexToLixTranslator.InputGetterImpl(Collections.singletonList(Pair.of(inParameter, inputPhysType))), conformance);
            }
        };
        agg.implementor.implementAdd(agg.context, addContext);
        ParameterExpression context_ = Expressions.parameter(Context.class, (String)"context");
        ParameterExpression outputValues_ = Expressions.parameter(Object[].class, (String)"outputValues");
        Scalar addScalar = JaninoRexCompiler.baz(context_, outputValues_, builder2.toBlock());
        return new ScalarAccumulatorDef(null, addScalar, null, ((Aggregate)this.rel).getInput().getRowType().getFieldCount(), stateSize, this.dataContext);
    }

    private static class FilterAccumulator
    implements Accumulator {
        private final Accumulator accumulator;
        private final int filterArg;

        FilterAccumulator(Accumulator accumulator, int filterArg) {
            this.accumulator = accumulator;
            this.filterArg = filterArg;
        }

        @Override
        public void send(Row row) {
            if (row.getValues()[this.filterArg] == Boolean.TRUE) {
                this.accumulator.send(row);
            }
        }

        @Override
        public Object end() {
            return this.accumulator.end();
        }
    }

    private static class UdaAccumulator
    implements Accumulator {
        private final UdaAccumulatorFactory factory;
        private Object value;
        private boolean empty;

        UdaAccumulator(UdaAccumulatorFactory factory) {
            this.factory = factory;
            try {
                this.value = factory.aggFunction.initMethod.invoke(factory.instance, new Object[0]);
            }
            catch (IllegalAccessException | InvocationTargetException e) {
                throw new RuntimeException(e);
            }
            this.empty = true;
        }

        @Override
        public void send(Row row) {
            Object[] args = new Object[]{this.value, row.getValues()[this.factory.argOrdinal]};
            for (int i = 1; i < args.length; ++i) {
                if (args[i] != null) continue;
                return;
            }
            try {
                this.value = this.factory.aggFunction.addMethod.invoke(this.factory.instance, args);
            }
            catch (IllegalAccessException | InvocationTargetException e) {
                throw new RuntimeException(e);
            }
            this.empty = false;
        }

        @Override
        public Object end() {
            if (this.factory.nullIfEmpty && this.empty) {
                return null;
            }
            Object[] args = new Object[]{this.value};
            try {
                return this.factory.aggFunction.resultMethod.invoke(this.factory.instance, args);
            }
            catch (IllegalAccessException | InvocationTargetException e) {
                throw new RuntimeException(e);
            }
        }
    }

    private static class UdaAccumulatorFactory
    implements AccumulatorFactory {
        final AggregateFunctionImpl aggFunction;
        final int argOrdinal;
        public final Object instance;
        public final boolean nullIfEmpty;

        UdaAccumulatorFactory(AggregateFunctionImpl aggFunction, AggregateCall call, boolean nullIfEmpty) {
            this.aggFunction = aggFunction;
            if (call.getArgList().size() != 1) {
                throw new UnsupportedOperationException("in current implementation, aggregate must have precisely one argument");
            }
            this.argOrdinal = call.getArgList().get(0);
            if (aggFunction.isStatic) {
                this.instance = null;
            } else {
                try {
                    Constructor<?> constructor = aggFunction.declaringClass.getConstructor(new Class[0]);
                    this.instance = constructor.newInstance(new Object[0]);
                }
                catch (IllegalAccessException | InstantiationException | NoSuchMethodException | InvocationTargetException e) {
                    throw new RuntimeException(e);
                }
            }
            this.nullIfEmpty = nullIfEmpty;
        }

        @Override
        public Accumulator get() {
            return new UdaAccumulator(this);
        }
    }

    public static class MaxBigDecimal
    extends NumericComparison<BigDecimal> {
        public MaxBigDecimal() {
            super(new BigDecimal(Double.MIN_VALUE), MaxBigDecimal::max);
        }

        public static BigDecimal max(BigDecimal a, BigDecimal b) {
            return a.max(b);
        }
    }

    public static class MaxDouble
    extends NumericComparison<Double> {
        public MaxDouble() {
            super(Double.MIN_VALUE, Math::max);
        }
    }

    public static class MaxFloat
    extends NumericComparison<Float> {
        public MaxFloat() {
            super(Float.valueOf(Float.MIN_VALUE), Math::max);
        }
    }

    public static class MaxLong
    extends NumericComparison<Long> {
        public MaxLong() {
            super(Long.MIN_VALUE, Math::max);
        }
    }

    public static class MaxInt
    extends NumericComparison<Integer> {
        public MaxInt() {
            super(Integer.MIN_VALUE, Math::max);
        }
    }

    public static class MinBoolean {
        public Boolean init() {
            return Boolean.TRUE;
        }

        public Boolean add(Boolean accumulator, Boolean value) {
            return accumulator.compareTo(value) < 0 ? accumulator : value;
        }

        public Boolean merge(Boolean accumulator0, Boolean accumulator1) {
            return this.add(accumulator0, accumulator1);
        }

        public Boolean result(Boolean accumulator) {
            return accumulator;
        }
    }

    public static class MinBigDecimal
    extends NumericComparison<BigDecimal> {
        public MinBigDecimal() {
            super(new BigDecimal(Double.MAX_VALUE), MinBigDecimal::min);
        }

        public static BigDecimal min(BigDecimal a, BigDecimal b) {
            return a.min(b);
        }
    }

    public static class MinDouble
    extends NumericComparison<Double> {
        public MinDouble() {
            super(Double.MAX_VALUE, Math::min);
        }
    }

    public static class MinFloat
    extends NumericComparison<Float> {
        public MinFloat() {
            super(Float.valueOf(Float.MAX_VALUE), Math::min);
        }
    }

    public static class MinLong
    extends NumericComparison<Long> {
        public MinLong() {
            super(Long.MAX_VALUE, Math::min);
        }
    }

    public static class MinInt
    extends NumericComparison<Integer> {
        public MinInt() {
            super(Integer.MAX_VALUE, Math::min);
        }
    }

    public static class NumericComparison<T> {
        private final T initialValue;
        private final BiFunction<T, T, T> comparisonFunction;

        public NumericComparison(T initialValue, BiFunction<T, T, T> comparisonFunction) {
            this.initialValue = initialValue;
            this.comparisonFunction = comparisonFunction;
        }

        public T init() {
            return this.initialValue;
        }

        public T add(T accumulator, T value) {
            return this.comparisonFunction.apply(accumulator, value);
        }

        public T merge(T accumulator0, T accumulator1) {
            return this.add(accumulator0, accumulator1);
        }

        public T result(T accumulator) {
            return accumulator;
        }
    }

    public static class BigDecimalSum {
        public BigDecimal init() {
            return new BigDecimal("0");
        }

        public BigDecimal add(BigDecimal accumulator, BigDecimal v) {
            return accumulator.add(v);
        }

        public BigDecimal merge(BigDecimal accumulator0, BigDecimal accumulator01) {
            return this.add(accumulator0, accumulator01);
        }

        public BigDecimal result(BigDecimal accumulator) {
            return accumulator;
        }
    }

    public static class DoubleSum {
        public double init() {
            return 0.0;
        }

        public double add(double accumulator, double v) {
            return accumulator + v;
        }

        public double merge(double accumulator0, double accumulator1) {
            return accumulator0 + accumulator1;
        }

        public double result(double accumulator) {
            return accumulator;
        }
    }

    public static class LongSum {
        public long init() {
            return 0L;
        }

        public long add(long accumulator, long v) {
            return accumulator + v;
        }

        public long merge(long accumulator0, long accumulator1) {
            return accumulator0 + accumulator1;
        }

        public long result(long accumulator) {
            return accumulator;
        }
    }

    public static class IntSum {
        public int init() {
            return 0;
        }

        public int add(int accumulator, int v) {
            return accumulator + v;
        }

        public int merge(int accumulator0, int accumulator1) {
            return accumulator0 + accumulator1;
        }

        public int result(int accumulator) {
            return accumulator;
        }
    }

    private static interface Accumulator {
        public void send(Row var1);

        public Object end();
    }

    private static class AccumulatorList
    extends ArrayList<Accumulator> {
        private AccumulatorList() {
        }

        public void send(Row row) {
            for (Accumulator a : this) {
                a.send(row);
            }
        }

        public void end(Row.RowBuilder r) {
            int accIndex = 0;
            int rowIndex = r.size() - this.size();
            while (rowIndex < r.size()) {
                r.set(rowIndex, ((Accumulator)this.get(accIndex)).end());
                ++rowIndex;
                ++accIndex;
            }
        }
    }

    private class Grouping {
        private final ImmutableBitSet grouping;
        private final Map<Row, AccumulatorList> accumulators = new HashMap<Row, AccumulatorList>();

        private Grouping(ImmutableBitSet grouping) {
            this.grouping = grouping;
        }

        public void send(Row row) {
            Row.RowBuilder builder = Row.newBuilder(this.grouping.cardinality());
            int j = 0;
            for (Integer i : this.grouping) {
                builder.set(j++, row.getObject(i));
            }
            Row key = builder.build();
            if (!this.accumulators.containsKey(key)) {
                AccumulatorList list = new AccumulatorList();
                for (AccumulatorFactory factory : AggregateNode.this.accumulatorFactories) {
                    list.add(factory.get());
                }
                this.accumulators.put(key, list);
            }
            this.accumulators.get(key).send(row);
        }

        public void end(Sink sink) throws InterruptedException {
            for (Map.Entry<Row, AccumulatorList> e : this.accumulators.entrySet()) {
                Row key = e.getKey();
                AccumulatorList list = e.getValue();
                Row.RowBuilder rb = Row.newBuilder(AggregateNode.this.outputRowLength);
                int index = 0;
                for (Integer groupPos : AggregateNode.this.unionGroups) {
                    if (this.grouping.get(groupPos)) {
                        rb.set(index, key.getObject(index));
                    }
                    ++index;
                }
                list.end(rb);
                sink.send(rb.build());
            }
        }
    }

    private static class ScalarAccumulator
    implements Accumulator {
        final ScalarAccumulatorDef def;
        final Object[] values;

        private ScalarAccumulator(ScalarAccumulatorDef def, Object[] values) {
            this.def = def;
            this.values = values;
        }

        @Override
        public void send(Row row) {
            System.arraycopy(row.getValues(), 0, this.def.sendContext.values, 0, this.def.rowLength);
            System.arraycopy(this.values, 0, this.def.sendContext.values, this.def.rowLength, this.values.length);
            this.def.addScalar.execute(this.def.sendContext, this.values);
        }

        @Override
        public Object end() {
            System.arraycopy(this.values, 0, this.def.endContext.values, 0, this.values.length);
            return this.def.endScalar.execute(this.def.endContext);
        }
    }

    private static class ScalarAccumulatorDef
    implements AccumulatorFactory {
        final Scalar initScalar;
        final Scalar addScalar;
        final Scalar endScalar;
        final Context sendContext;
        final Context endContext;
        final int rowLength;
        final int accumulatorLength;

        private ScalarAccumulatorDef(Scalar initScalar, Scalar addScalar, Scalar endScalar, int rowLength, int accumulatorLength, DataContext root) {
            this.initScalar = initScalar;
            this.addScalar = addScalar;
            this.endScalar = endScalar;
            this.accumulatorLength = accumulatorLength;
            this.rowLength = rowLength;
            this.sendContext = new Context(root);
            this.sendContext.values = new Object[rowLength + accumulatorLength];
            this.endContext = new Context(root);
            this.endContext.values = new Object[accumulatorLength];
        }

        @Override
        public Accumulator get() {
            return new ScalarAccumulator(this, new Object[this.accumulatorLength]);
        }
    }

    private static interface AccumulatorFactory
    extends Supplier<Accumulator> {
    }

    private static class CountAccumulator
    implements Accumulator {
        private final AggregateCall call;
        long cnt;

        CountAccumulator(AggregateCall call) {
            this.call = call;
            this.cnt = 0L;
        }

        @Override
        public void send(Row row) {
            boolean notNull = true;
            for (Integer i : this.call.getArgList()) {
                if (row.getObject(i) != null) continue;
                notNull = false;
                break;
            }
            if (notNull) {
                ++this.cnt;
            }
        }

        @Override
        public Object end() {
            return this.cnt;
        }
    }
}

