/*
 * Decompiled with CFR 0.152.
 */
package org.apache.calcite.rel.rules;

import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Objects;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.sql.SqlSplittableAggFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.calcite.util.ImmutableBitSet;

public class AggregateMergeRule
extends RelOptRule {
    public static final AggregateMergeRule INSTANCE = new AggregateMergeRule();

    private AggregateMergeRule() {
        this(AggregateMergeRule.operand(Aggregate.class, AggregateMergeRule.operandJ(Aggregate.class, null, agg -> Aggregate.isSimple(agg), AggregateMergeRule.any()), new RelOptRuleOperand[0]), RelFactories.LOGICAL_BUILDER);
    }

    public AggregateMergeRule(RelOptRuleOperand operand, RelBuilderFactory relBuilderFactory) {
        super(operand, relBuilderFactory, null);
    }

    private boolean isAggregateSupported(AggregateCall aggCall) {
        if (aggCall.isDistinct() || aggCall.hasFilter() || aggCall.isApproximate() || aggCall.getArgList().size() > 1) {
            return false;
        }
        SqlSplittableAggFunction splitter = aggCall.getAggregation().unwrap(SqlSplittableAggFunction.class);
        return splitter != null;
    }

    @Override
    public void onMatch(RelOptRuleCall call) {
        Aggregate topAgg = (Aggregate)call.rel(0);
        Aggregate bottomAgg = (Aggregate)call.rel(1);
        if (topAgg.getGroupCount() > bottomAgg.getGroupCount()) {
            return;
        }
        ImmutableBitSet bottomGroupSet = bottomAgg.getGroupSet();
        HashMap<Integer, Integer> map = new HashMap<Integer, Integer>();
        bottomGroupSet.forEach(v -> map.put(map.size(), (Integer)v));
        for (int k : topAgg.getGroupSet()) {
            if (map.containsKey(k)) continue;
            return;
        }
        ImmutableBitSet topGroupSet = topAgg.getGroupSet().permute(map);
        if (!bottomGroupSet.contains(topGroupSet)) {
            return;
        }
        boolean hasEmptyGroup = topAgg.getGroupSets().stream().anyMatch(n -> n.isEmpty());
        ArrayList<AggregateCall> finalCalls = new ArrayList<AggregateCall>();
        for (AggregateCall topCall : topAgg.getAggCallList()) {
            if (!this.isAggregateSupported(topCall) || topCall.getArgList().size() == 0) {
                return;
            }
            int bottomIndex = topCall.getArgList().get(0) - bottomGroupSet.cardinality();
            if (bottomIndex >= bottomAgg.getAggCallList().size() || bottomIndex < 0) {
                return;
            }
            AggregateCall bottomCall = bottomAgg.getAggCallList().get(bottomIndex);
            if (!this.isAggregateSupported(bottomCall) || bottomCall.getAggregation() == SqlStdOperatorTable.COUNT && hasEmptyGroup) {
                return;
            }
            SqlSplittableAggFunction splitter = Objects.requireNonNull(bottomCall.getAggregation().unwrap(SqlSplittableAggFunction.class));
            AggregateCall finalCall = splitter.merge(topCall, bottomCall);
            if (finalCall == null) {
                return;
            }
            finalCalls.add(finalCall);
        }
        ImmutableList newGroupingSets = null;
        if (topAgg.getGroupType() != Aggregate.Group.SIMPLE) {
            newGroupingSets = ImmutableBitSet.ORDERING.immutableSortedCopy(ImmutableBitSet.permute(topAgg.getGroupSets(), map));
        }
        Aggregate finalAgg = topAgg.copy(topAgg.getTraitSet(), bottomAgg.getInput(), topGroupSet, (List<ImmutableBitSet>)newGroupingSets, finalCalls);
        call.transformTo(finalAgg);
    }
}

