/*
 * Decompiled with CFR 0.152.
 */
package io.prestosql.sql.planner;

import com.google.common.collect.ImmutableBiMap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import io.prestosql.Session;
import io.prestosql.metadata.Metadata;
import io.prestosql.spi.connector.ColumnHandle;
import io.prestosql.spi.predicate.Domain;
import io.prestosql.spi.predicate.TupleDomain;
import io.prestosql.spi.type.Type;
import io.prestosql.sql.ExpressionUtils;
import io.prestosql.sql.planner.DeterminismEvaluator;
import io.prestosql.sql.planner.DomainTranslator;
import io.prestosql.sql.planner.EqualityInference;
import io.prestosql.sql.planner.ExpressionInterpreter;
import io.prestosql.sql.planner.NoOpSymbolResolver;
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.planner.SymbolsExtractor;
import io.prestosql.sql.planner.TypeAnalyzer;
import io.prestosql.sql.planner.TypeProvider;
import io.prestosql.sql.planner.plan.AggregationNode;
import io.prestosql.sql.planner.plan.AssignUniqueId;
import io.prestosql.sql.planner.plan.DistinctLimitNode;
import io.prestosql.sql.planner.plan.ExchangeNode;
import io.prestosql.sql.planner.plan.FilterNode;
import io.prestosql.sql.planner.plan.JoinNode;
import io.prestosql.sql.planner.plan.LimitNode;
import io.prestosql.sql.planner.plan.PlanNode;
import io.prestosql.sql.planner.plan.PlanVisitor;
import io.prestosql.sql.planner.plan.ProjectNode;
import io.prestosql.sql.planner.plan.SemiJoinNode;
import io.prestosql.sql.planner.plan.SortNode;
import io.prestosql.sql.planner.plan.SpatialJoinNode;
import io.prestosql.sql.planner.plan.TableScanNode;
import io.prestosql.sql.planner.plan.TopNNode;
import io.prestosql.sql.planner.plan.UnionNode;
import io.prestosql.sql.planner.plan.UnnestNode;
import io.prestosql.sql.planner.plan.ValuesNode;
import io.prestosql.sql.planner.plan.WindowNode;
import io.prestosql.sql.tree.BooleanLiteral;
import io.prestosql.sql.tree.ComparisonExpression;
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.NodeRef;
import io.prestosql.sql.tree.SymbolReference;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;

public class EffectivePredicateExtractor {
    private static final Predicate<Map.Entry<Symbol, ? extends Expression>> SYMBOL_MATCHES_EXPRESSION = entry -> ((Expression)entry.getValue()).equals((Object)((Symbol)entry.getKey()).toSymbolReference());
    private static final Function<Map.Entry<Symbol, ? extends Expression>, Expression> ENTRY_TO_EQUALITY = entry -> {
        SymbolReference reference = ((Symbol)entry.getKey()).toSymbolReference();
        Expression expression = (Expression)entry.getValue();
        return new ComparisonExpression(ComparisonExpression.Operator.EQUAL, (Expression)reference, expression);
    };
    private final DomainTranslator domainTranslator;
    private final Metadata metadata;
    private final boolean useTableProperties;

    public EffectivePredicateExtractor(DomainTranslator domainTranslator, Metadata metadata, boolean useTableProperties) {
        this.domainTranslator = Objects.requireNonNull(domainTranslator, "domainTranslator is null");
        this.metadata = Objects.requireNonNull(metadata, "metadata is null");
        this.useTableProperties = useTableProperties;
    }

    public Expression extract(Session session, PlanNode node, TypeProvider types, TypeAnalyzer typeAnalyzer) {
        return node.accept(new Visitor(this.domainTranslator, this.metadata, session, types, typeAnalyzer, this.useTableProperties), null);
    }

    private static class Visitor
    extends PlanVisitor<Expression, Void> {
        private final DomainTranslator domainTranslator;
        private final Metadata metadata;
        private final Session session;
        private final TypeProvider types;
        private final TypeAnalyzer typeAnalyzer;
        private final boolean useTableProperties;

        public Visitor(DomainTranslator domainTranslator, Metadata metadata, Session session, TypeProvider types, TypeAnalyzer typeAnalyzer, boolean useTableProperties) {
            this.domainTranslator = Objects.requireNonNull(domainTranslator, "domainTranslator is null");
            this.metadata = Objects.requireNonNull(metadata, "metadata is null");
            this.session = Objects.requireNonNull(session, "session is null");
            this.types = Objects.requireNonNull(types, "types is null");
            this.typeAnalyzer = Objects.requireNonNull(typeAnalyzer, "typeAnalyzer is null");
            this.useTableProperties = useTableProperties;
        }

        @Override
        protected Expression visitPlan(PlanNode node, Void context) {
            return BooleanLiteral.TRUE_LITERAL;
        }

        @Override
        public Expression visitAggregation(AggregationNode node, Void context) {
            if (node.getGroupingKeys().isEmpty()) {
                return BooleanLiteral.TRUE_LITERAL;
            }
            Expression underlyingPredicate = node.getSource().accept(this, context);
            return this.pullExpressionThroughSymbols(underlyingPredicate, node.getGroupingKeys());
        }

        @Override
        public Expression visitFilter(FilterNode node, Void context) {
            Expression underlyingPredicate = node.getSource().accept(this, context);
            Expression predicate = node.getPredicate();
            predicate = ExpressionUtils.filterDeterministicConjuncts(this.metadata, predicate);
            return ExpressionUtils.combineConjuncts(this.metadata, predicate, underlyingPredicate);
        }

        @Override
        public Expression visitExchange(ExchangeNode node, Void context) {
            return this.deriveCommonPredicates(node, source -> {
                HashMap<Symbol, SymbolReference> mappings = new HashMap<Symbol, SymbolReference>();
                for (int i = 0; i < node.getInputs().get((int)source).size(); ++i) {
                    mappings.put(node.getOutputSymbols().get(i), node.getInputs().get((int)source).get(i).toSymbolReference());
                }
                return mappings.entrySet();
            });
        }

        @Override
        public Expression visitProject(ProjectNode node, Void context) {
            Expression underlyingPredicate = node.getSource().accept(this, context);
            List projectionEqualities = (List)node.getAssignments().entrySet().stream().filter(SYMBOL_MATCHES_EXPRESSION.negate()).map(ENTRY_TO_EQUALITY).collect(ImmutableList.toImmutableList());
            return this.pullExpressionThroughSymbols(ExpressionUtils.combineConjuncts(this.metadata, (Collection<Expression>)ImmutableList.builder().addAll((Iterable)projectionEqualities).add((Object)underlyingPredicate).build()), node.getOutputSymbols());
        }

        @Override
        public Expression visitTopN(TopNNode node, Void context) {
            return node.getSource().accept(this, context);
        }

        @Override
        public Expression visitLimit(LimitNode node, Void context) {
            return node.getSource().accept(this, context);
        }

        @Override
        public Expression visitAssignUniqueId(AssignUniqueId node, Void context) {
            return node.getSource().accept(this, context);
        }

        @Override
        public Expression visitDistinctLimit(DistinctLimitNode node, Void context) {
            return node.getSource().accept(this, context);
        }

        @Override
        public Expression visitTableScan(TableScanNode node, Void context) {
            ImmutableBiMap assignments = ImmutableBiMap.copyOf(node.getAssignments()).inverse();
            TupleDomain<ColumnHandle> predicate = node.getEnforcedConstraint();
            if (this.useTableProperties && !node.getEnforcedConstraint().isAll()) {
                predicate = this.metadata.getTableProperties(this.session, node.getTable()).getPredicate();
            }
            return this.domainTranslator.toPredicate((TupleDomain<Symbol>)predicate.simplify().transform(((Map)assignments)::get));
        }

        @Override
        public Expression visitSort(SortNode node, Void context) {
            return node.getSource().accept(this, context);
        }

        @Override
        public Expression visitWindow(WindowNode node, Void context) {
            return node.getSource().accept(this, context);
        }

        @Override
        public Expression visitUnion(UnionNode node, Void context) {
            return this.deriveCommonPredicates(node, source -> node.outputSymbolMap((int)source).entries());
        }

        @Override
        public Expression visitUnnest(UnnestNode node, Void context) {
            Expression sourcePredicate = node.getSource().accept(this, context);
            switch (node.getJoinType()) {
                case INNER: 
                case LEFT: {
                    return this.pullExpressionThroughSymbols(ExpressionUtils.combineConjuncts(this.metadata, node.getFilter().orElse((Expression)BooleanLiteral.TRUE_LITERAL), sourcePredicate), node.getOutputSymbols());
                }
                case RIGHT: 
                case FULL: {
                    return BooleanLiteral.TRUE_LITERAL;
                }
            }
            throw new UnsupportedOperationException("Unknown UNNEST join type: " + node.getJoinType());
        }

        @Override
        public Expression visitJoin(JoinNode node, Void context) {
            Expression leftPredicate = node.getLeft().accept(this, context);
            Expression rightPredicate = node.getRight().accept(this, context);
            List joinConjuncts = (List)node.getCriteria().stream().map(JoinNode.EquiJoinClause::toExpression).collect(ImmutableList.toImmutableList());
            switch (node.getType()) {
                case INNER: {
                    return this.pullExpressionThroughSymbols(ExpressionUtils.combineConjuncts(this.metadata, (Collection<Expression>)ImmutableList.builder().add((Object)leftPredicate).add((Object)rightPredicate).add((Object)ExpressionUtils.combineConjuncts(this.metadata, joinConjuncts)).add((Object)node.getFilter().orElse((Expression)BooleanLiteral.TRUE_LITERAL)).build()), node.getOutputSymbols());
                }
                case LEFT: {
                    Predicate[] predicateArray = new Predicate[1];
                    predicateArray[0] = node.getRight().getOutputSymbols()::contains;
                    Predicate[] predicateArray2 = new Predicate[1];
                    predicateArray2[0] = node.getRight().getOutputSymbols()::contains;
                    return ExpressionUtils.combineConjuncts(this.metadata, (Collection<Expression>)ImmutableList.builder().add((Object)this.pullExpressionThroughSymbols(leftPredicate, node.getOutputSymbols())).addAll(this.pullNullableConjunctsThroughOuterJoin(ExpressionUtils.extractConjuncts(rightPredicate), node.getOutputSymbols(), predicateArray)).addAll(this.pullNullableConjunctsThroughOuterJoin(joinConjuncts, node.getOutputSymbols(), predicateArray2)).build());
                }
                case RIGHT: {
                    Predicate[] predicateArray = new Predicate[1];
                    predicateArray[0] = node.getLeft().getOutputSymbols()::contains;
                    Predicate[] predicateArray3 = new Predicate[1];
                    predicateArray3[0] = node.getLeft().getOutputSymbols()::contains;
                    return ExpressionUtils.combineConjuncts(this.metadata, (Collection<Expression>)ImmutableList.builder().add((Object)this.pullExpressionThroughSymbols(rightPredicate, node.getOutputSymbols())).addAll(this.pullNullableConjunctsThroughOuterJoin(ExpressionUtils.extractConjuncts(leftPredicate), node.getOutputSymbols(), predicateArray)).addAll(this.pullNullableConjunctsThroughOuterJoin(joinConjuncts, node.getOutputSymbols(), predicateArray3)).build());
                }
                case FULL: {
                    Predicate[] predicateArray = new Predicate[1];
                    predicateArray[0] = node.getLeft().getOutputSymbols()::contains;
                    Predicate[] predicateArray4 = new Predicate[1];
                    predicateArray4[0] = node.getRight().getOutputSymbols()::contains;
                    Predicate[] predicateArray5 = new Predicate[2];
                    predicateArray5[0] = node.getLeft().getOutputSymbols()::contains;
                    predicateArray5[1] = node.getRight().getOutputSymbols()::contains;
                    return ExpressionUtils.combineConjuncts(this.metadata, (Collection<Expression>)ImmutableList.builder().addAll(this.pullNullableConjunctsThroughOuterJoin(ExpressionUtils.extractConjuncts(leftPredicate), node.getOutputSymbols(), predicateArray)).addAll(this.pullNullableConjunctsThroughOuterJoin(ExpressionUtils.extractConjuncts(rightPredicate), node.getOutputSymbols(), predicateArray4)).addAll(this.pullNullableConjunctsThroughOuterJoin(joinConjuncts, node.getOutputSymbols(), predicateArray5)).build());
                }
            }
            throw new UnsupportedOperationException("Unknown join type: " + node.getType());
        }

        @Override
        public Expression visitValues(ValuesNode node, Void context) {
            if (node.getOutputSymbols().isEmpty()) {
                return BooleanLiteral.TRUE_LITERAL;
            }
            List<Expression> allExpressions = node.getRows().stream().flatMap(Collection::stream).collect(Collectors.toList());
            Map<NodeRef<Expression>, Type> expressionTypes = this.typeAnalyzer.getTypes(this.session, this.types, allExpressions);
            ImmutableMap.Builder domains = ImmutableMap.builder();
            for (int column = 0; column < node.getOutputSymbols().size(); ++column) {
                Symbol symbol = node.getOutputSymbols().get(column);
                Type type = this.types.get(symbol);
                ImmutableList.Builder builder = ImmutableList.builder();
                boolean hasNull = false;
                boolean nonDeterministic = false;
                for (int row = 0; row < node.getRows().size(); ++row) {
                    Expression value = node.getRows().get(row).get(column);
                    if (!DeterminismEvaluator.isDeterministic(value, this.metadata)) {
                        nonDeterministic = true;
                        break;
                    }
                    ExpressionInterpreter interpreter = ExpressionInterpreter.expressionOptimizer(value, this.metadata, this.session, expressionTypes);
                    Object evaluated = interpreter.optimize(NoOpSymbolResolver.INSTANCE);
                    if (evaluated instanceof Expression) {
                        return BooleanLiteral.TRUE_LITERAL;
                    }
                    if (evaluated == null) {
                        hasNull = true;
                        continue;
                    }
                    builder.add(evaluated);
                }
                if (nonDeterministic) continue;
                ImmutableList values = builder.build();
                Domain domain = Domain.none((Type)type);
                if (!values.isEmpty()) {
                    domain = domain.union(Domain.multipleValues((Type)type, (List)values));
                }
                if (hasNull) {
                    domain = domain.union(Domain.onlyNull((Type)type));
                }
                domains.put((Object)symbol, (Object)domain);
            }
            return this.domainTranslator.toPredicate((TupleDomain<Symbol>)TupleDomain.withColumnDomains((Map)domains.build()).simplify());
        }

        @SafeVarargs
        private final Iterable<Expression> pullNullableConjunctsThroughOuterJoin(List<Expression> conjuncts, Collection<Symbol> outputSymbols, Predicate<Symbol> ... nullSymbolScopes) {
            return (Iterable)conjuncts.stream().map(expression -> this.pullExpressionThroughSymbols((Expression)expression, outputSymbols)).map(expression -> SymbolsExtractor.extractAll(expression).isEmpty() ? BooleanLiteral.TRUE_LITERAL : expression).map(ExpressionUtils.expressionOrNullSymbols(nullSymbolScopes)).collect(ImmutableList.toImmutableList());
        }

        @Override
        public Expression visitSemiJoin(SemiJoinNode node, Void context) {
            return node.getSource().accept(this, context);
        }

        @Override
        public Expression visitSpatialJoin(SpatialJoinNode node, Void context) {
            Expression leftPredicate = node.getLeft().accept(this, context);
            Expression rightPredicate = node.getRight().accept(this, context);
            switch (node.getType()) {
                case INNER: {
                    return ExpressionUtils.combineConjuncts(this.metadata, (Collection<Expression>)ImmutableList.builder().add((Object)this.pullExpressionThroughSymbols(leftPredicate, node.getOutputSymbols())).add((Object)this.pullExpressionThroughSymbols(rightPredicate, node.getOutputSymbols())).build());
                }
                case LEFT: {
                    Predicate[] predicateArray = new Predicate[1];
                    predicateArray[0] = node.getRight().getOutputSymbols()::contains;
                    return ExpressionUtils.combineConjuncts(this.metadata, (Collection<Expression>)ImmutableList.builder().add((Object)this.pullExpressionThroughSymbols(leftPredicate, node.getOutputSymbols())).addAll(this.pullNullableConjunctsThroughOuterJoin(ExpressionUtils.extractConjuncts(rightPredicate), node.getOutputSymbols(), predicateArray)).build());
                }
            }
            throw new IllegalArgumentException("Unsupported spatial join type: " + node.getType());
        }

        private Expression deriveCommonPredicates(PlanNode node, Function<Integer, Collection<Map.Entry<Symbol, SymbolReference>>> mapping) {
            ArrayList<ImmutableSet> sourceOutputConjuncts = new ArrayList<ImmutableSet>();
            for (int i = 0; i < node.getSources().size(); ++i) {
                Expression underlyingPredicate = node.getSources().get(i).accept(this, null);
                List equalities = (List)mapping.apply(i).stream().filter(SYMBOL_MATCHES_EXPRESSION.negate()).map(ENTRY_TO_EQUALITY).collect(ImmutableList.toImmutableList());
                sourceOutputConjuncts.add(ImmutableSet.copyOf(ExpressionUtils.extractConjuncts(this.pullExpressionThroughSymbols(ExpressionUtils.combineConjuncts(this.metadata, (Collection<Expression>)ImmutableList.builder().addAll((Iterable)equalities).add((Object)underlyingPredicate).build()), node.getOutputSymbols()))));
            }
            Iterator iterator = sourceOutputConjuncts.iterator();
            Set potentialOutputConjuncts = (Set)iterator.next();
            while (iterator.hasNext()) {
                potentialOutputConjuncts = Sets.intersection((Set)potentialOutputConjuncts, (Set)((Set)iterator.next()));
            }
            return ExpressionUtils.combineConjuncts(this.metadata, potentialOutputConjuncts);
        }

        private Expression pullExpressionThroughSymbols(Expression expression, Collection<Symbol> symbols) {
            EqualityInference equalityInference = EqualityInference.newInstance(this.metadata, expression);
            ImmutableList.Builder effectiveConjuncts = ImmutableList.builder();
            ImmutableSet scope = ImmutableSet.copyOf(symbols);
            for (Expression conjunct : EqualityInference.nonInferrableConjuncts(this.metadata, expression)) {
                Expression rewritten;
                if (!DeterminismEvaluator.isDeterministic(conjunct, this.metadata) || (rewritten = equalityInference.rewrite(conjunct, (Set<Symbol>)scope)) == null) continue;
                effectiveConjuncts.add((Object)rewritten);
            }
            effectiveConjuncts.addAll(equalityInference.generateEqualitiesPartitionedBy((Set<Symbol>)scope).getScopeEqualities());
            return ExpressionUtils.combineConjuncts(this.metadata, (Collection<Expression>)effectiveConjuncts.build());
        }
    }
}

