/*
 * Decompiled with CFR 0.152.
 */
package io.prestosql.cost;

import com.google.common.base.Preconditions;
import com.google.common.base.VerifyException;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.prestosql.Session;
import io.prestosql.cost.ComparisonStatsCalculator;
import io.prestosql.cost.PlanNodeStatsEstimate;
import io.prestosql.cost.PlanNodeStatsEstimateMath;
import io.prestosql.cost.ScalarStatsCalculator;
import io.prestosql.cost.StatsNormalizer;
import io.prestosql.cost.StatsUtil;
import io.prestosql.cost.SymbolStatsEstimate;
import io.prestosql.execution.warnings.WarningCollector;
import io.prestosql.metadata.Metadata;
import io.prestosql.security.AccessControl;
import io.prestosql.security.AllowAllAccessControl;
import io.prestosql.spi.type.BooleanType;
import io.prestosql.spi.type.Type;
import io.prestosql.sql.DynamicFilters;
import io.prestosql.sql.ExpressionUtils;
import io.prestosql.sql.analyzer.ExpressionAnalyzer;
import io.prestosql.sql.analyzer.Scope;
import io.prestosql.sql.planner.ExpressionInterpreter;
import io.prestosql.sql.planner.LiteralEncoder;
import io.prestosql.sql.planner.LiteralInterpreter;
import io.prestosql.sql.planner.NoOpSymbolResolver;
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.planner.TypeProvider;
import io.prestosql.sql.tree.AstVisitor;
import io.prestosql.sql.tree.BetweenPredicate;
import io.prestosql.sql.tree.BooleanLiteral;
import io.prestosql.sql.tree.ComparisonExpression;
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.FunctionCall;
import io.prestosql.sql.tree.InListExpression;
import io.prestosql.sql.tree.InPredicate;
import io.prestosql.sql.tree.IsNotNullPredicate;
import io.prestosql.sql.tree.IsNullPredicate;
import io.prestosql.sql.tree.Literal;
import io.prestosql.sql.tree.LogicalBinaryExpression;
import io.prestosql.sql.tree.Node;
import io.prestosql.sql.tree.NodeRef;
import io.prestosql.sql.tree.NotExpression;
import io.prestosql.sql.tree.Parameter;
import io.prestosql.sql.tree.SymbolReference;
import java.util.Collections;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.OptionalDouble;
import javax.annotation.Nullable;

public class FilterStatsCalculator {
    static final double UNKNOWN_FILTER_COEFFICIENT = 0.9;
    private final Metadata metadata;
    private final ScalarStatsCalculator scalarStatsCalculator;
    private final StatsNormalizer normalizer;
    private final LiteralEncoder literalEncoder;

    public FilterStatsCalculator(Metadata metadata, ScalarStatsCalculator scalarStatsCalculator, StatsNormalizer normalizer) {
        this.metadata = Objects.requireNonNull(metadata, "metadata is null");
        this.scalarStatsCalculator = Objects.requireNonNull(scalarStatsCalculator, "scalarStatsCalculator is null");
        this.normalizer = Objects.requireNonNull(normalizer, "normalizer is null");
        this.literalEncoder = new LiteralEncoder(metadata);
    }

    public PlanNodeStatsEstimate filterStats(PlanNodeStatsEstimate statsEstimate, Expression predicate, Session session, TypeProvider types) {
        Expression simplifiedExpression = this.simplifyExpression(session, predicate, types);
        return (PlanNodeStatsEstimate)new FilterExpressionStatsCalculatingVisitor(statsEstimate, session, types).process((Node)simplifiedExpression);
    }

    private Expression simplifyExpression(Session session, Expression predicate, TypeProvider types) {
        Map<NodeRef<Expression>, Type> expressionTypes = this.getExpressionTypes(session, predicate, types);
        ExpressionInterpreter interpreter = ExpressionInterpreter.expressionOptimizer(predicate, this.metadata, session, expressionTypes);
        Object value = interpreter.optimize(NoOpSymbolResolver.INSTANCE);
        if (value == null) {
            value = false;
        }
        return this.literalEncoder.toExpression(value, (Type)BooleanType.BOOLEAN);
    }

    private Map<NodeRef<Expression>, Type> getExpressionTypes(Session session, Expression expression, TypeProvider types) {
        ExpressionAnalyzer expressionAnalyzer = ExpressionAnalyzer.createWithoutSubqueries(this.metadata, (AccessControl)new AllowAllAccessControl(), session, types, Collections.emptyMap(), node -> new IllegalStateException("Unexpected node: %s" + node), WarningCollector.NOOP, false);
        expressionAnalyzer.analyze(expression, Scope.create());
        return expressionAnalyzer.getExpressionTypes();
    }

    private class FilterExpressionStatsCalculatingVisitor
    extends AstVisitor<PlanNodeStatsEstimate, Void> {
        private final PlanNodeStatsEstimate input;
        private final Session session;
        private final TypeProvider types;

        FilterExpressionStatsCalculatingVisitor(PlanNodeStatsEstimate input, Session session, TypeProvider types) {
            this.input = input;
            this.session = session;
            this.types = types;
        }

        public PlanNodeStatsEstimate process(Node node, @Nullable Void context) {
            return FilterStatsCalculator.this.normalizer.normalize((PlanNodeStatsEstimate)super.process(node, (Object)context), this.types);
        }

        protected PlanNodeStatsEstimate visitExpression(Expression node, Void context) {
            return PlanNodeStatsEstimate.unknown();
        }

        protected PlanNodeStatsEstimate visitNotExpression(NotExpression node, Void context) {
            if (node.getValue() instanceof IsNullPredicate) {
                return (PlanNodeStatsEstimate)this.process((Node)new IsNotNullPredicate(((IsNullPredicate)node.getValue()).getValue()));
            }
            return PlanNodeStatsEstimateMath.subtractSubsetStats(this.input, (PlanNodeStatsEstimate)this.process((Node)node.getValue()));
        }

        protected PlanNodeStatsEstimate visitLogicalBinaryExpression(LogicalBinaryExpression node, Void context) {
            switch (node.getOperator()) {
                case AND: {
                    return this.estimateLogicalAnd(node.getLeft(), node.getRight());
                }
                case OR: {
                    return this.estimateLogicalOr(node.getLeft(), node.getRight());
                }
            }
            throw new IllegalArgumentException("Unexpected binary operator: " + node.getOperator());
        }

        private PlanNodeStatsEstimate estimateLogicalAnd(Expression left, Expression right) {
            PlanNodeStatsEstimate smallestKnownEstimate;
            PlanNodeStatsEstimate logicalAndEstimate;
            PlanNodeStatsEstimate leftEstimate = (PlanNodeStatsEstimate)this.process((Node)left);
            if (!leftEstimate.isOutputRowCountUnknown() && !(logicalAndEstimate = (PlanNodeStatsEstimate)new FilterExpressionStatsCalculatingVisitor(leftEstimate, this.session, this.types).process((Node)right)).isOutputRowCountUnknown()) {
                return logicalAndEstimate;
            }
            PlanNodeStatsEstimate rightEstimate = (PlanNodeStatsEstimate)this.process((Node)right);
            if (leftEstimate.isOutputRowCountUnknown()) {
                smallestKnownEstimate = rightEstimate;
            } else if (rightEstimate.isOutputRowCountUnknown()) {
                smallestKnownEstimate = leftEstimate;
            } else {
                PlanNodeStatsEstimate planNodeStatsEstimate = smallestKnownEstimate = leftEstimate.getOutputRowCount() <= rightEstimate.getOutputRowCount() ? leftEstimate : rightEstimate;
            }
            if (smallestKnownEstimate.isOutputRowCountUnknown()) {
                return PlanNodeStatsEstimate.unknown();
            }
            return smallestKnownEstimate.mapOutputRowCount(rowCount -> rowCount * 0.9);
        }

        private PlanNodeStatsEstimate estimateLogicalOr(Expression left, Expression right) {
            PlanNodeStatsEstimate leftEstimate = (PlanNodeStatsEstimate)this.process((Node)left);
            if (leftEstimate.isOutputRowCountUnknown()) {
                return PlanNodeStatsEstimate.unknown();
            }
            PlanNodeStatsEstimate rightEstimate = (PlanNodeStatsEstimate)this.process((Node)right);
            if (rightEstimate.isOutputRowCountUnknown()) {
                return PlanNodeStatsEstimate.unknown();
            }
            PlanNodeStatsEstimate andEstimate = (PlanNodeStatsEstimate)new FilterExpressionStatsCalculatingVisitor(leftEstimate, this.session, this.types).process((Node)right);
            if (andEstimate.isOutputRowCountUnknown()) {
                return PlanNodeStatsEstimate.unknown();
            }
            return PlanNodeStatsEstimateMath.capStats(PlanNodeStatsEstimateMath.subtractSubsetStats(PlanNodeStatsEstimateMath.addStatsAndSumDistinctValues(leftEstimate, rightEstimate), andEstimate), this.input);
        }

        protected PlanNodeStatsEstimate visitBooleanLiteral(BooleanLiteral node, Void context) {
            if (node.getValue()) {
                return this.input;
            }
            PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.builder();
            result.setOutputRowCount(0.0);
            this.input.getSymbolsWithKnownStatistics().forEach(symbol -> result.addSymbolStatistics((Symbol)symbol, SymbolStatsEstimate.zero()));
            return result.build();
        }

        protected PlanNodeStatsEstimate visitIsNotNullPredicate(IsNotNullPredicate node, Void context) {
            if (node.getValue() instanceof SymbolReference) {
                Symbol symbol = Symbol.from(node.getValue());
                SymbolStatsEstimate symbolStats = this.input.getSymbolStatistics(symbol);
                PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.buildFrom(this.input);
                result.setOutputRowCount(this.input.getOutputRowCount() * (1.0 - symbolStats.getNullsFraction()));
                result.addSymbolStatistics(symbol, symbolStats.mapNullsFraction(x -> 0.0));
                return result.build();
            }
            return PlanNodeStatsEstimate.unknown();
        }

        protected PlanNodeStatsEstimate visitIsNullPredicate(IsNullPredicate node, Void context) {
            if (node.getValue() instanceof SymbolReference) {
                Symbol symbol = Symbol.from(node.getValue());
                SymbolStatsEstimate symbolStats = this.input.getSymbolStatistics(symbol);
                PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.buildFrom(this.input);
                result.setOutputRowCount(this.input.getOutputRowCount() * symbolStats.getNullsFraction());
                result.addSymbolStatistics(symbol, SymbolStatsEstimate.builder().setNullsFraction(1.0).setLowValue(Double.NaN).setHighValue(Double.NaN).setDistinctValuesCount(0.0).build());
                return result.build();
            }
            return PlanNodeStatsEstimate.unknown();
        }

        protected PlanNodeStatsEstimate visitBetweenPredicate(BetweenPredicate node, Void context) {
            if (!(node.getValue() instanceof SymbolReference)) {
                return PlanNodeStatsEstimate.unknown();
            }
            if (!this.getExpressionStats(node.getMin()).isSingleValue()) {
                return PlanNodeStatsEstimate.unknown();
            }
            if (!this.getExpressionStats(node.getMax()).isSingleValue()) {
                return PlanNodeStatsEstimate.unknown();
            }
            SymbolStatsEstimate valueStats = this.input.getSymbolStatistics(Symbol.from(node.getValue()));
            ComparisonExpression lowerBound = new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL, node.getValue(), node.getMin());
            ComparisonExpression upperBound = new ComparisonExpression(ComparisonExpression.Operator.LESS_THAN_OR_EQUAL, node.getValue(), node.getMax());
            Expression transformed = Double.isInfinite(valueStats.getLowValue()) ? ExpressionUtils.and(new Expression[]{lowerBound, upperBound}) : ExpressionUtils.and(new Expression[]{upperBound, lowerBound});
            return (PlanNodeStatsEstimate)this.process((Node)transformed);
        }

        protected PlanNodeStatsEstimate visitInPredicate(InPredicate node, Void context) {
            if (!(node.getValueList() instanceof InListExpression)) {
                return PlanNodeStatsEstimate.unknown();
            }
            InListExpression inList = (InListExpression)node.getValueList();
            ImmutableList equalityEstimates = (ImmutableList)inList.getValues().stream().map(inValue -> (PlanNodeStatsEstimate)this.process((Node)new ComparisonExpression(ComparisonExpression.Operator.EQUAL, node.getValue(), inValue))).collect(ImmutableList.toImmutableList());
            if (equalityEstimates.stream().anyMatch(PlanNodeStatsEstimate::isOutputRowCountUnknown)) {
                return PlanNodeStatsEstimate.unknown();
            }
            PlanNodeStatsEstimate inEstimate = equalityEstimates.stream().reduce(PlanNodeStatsEstimateMath::addStatsAndSumDistinctValues).orElse(PlanNodeStatsEstimate.unknown());
            if (inEstimate.isOutputRowCountUnknown()) {
                return PlanNodeStatsEstimate.unknown();
            }
            SymbolStatsEstimate valueStats = this.getExpressionStats(node.getValue());
            if (valueStats.isUnknown()) {
                return PlanNodeStatsEstimate.unknown();
            }
            double notNullValuesBeforeIn = this.input.getOutputRowCount() * (1.0 - valueStats.getNullsFraction());
            PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.buildFrom(this.input);
            result.setOutputRowCount(Double.min(inEstimate.getOutputRowCount(), notNullValuesBeforeIn));
            if (node.getValue() instanceof SymbolReference) {
                Symbol valueSymbol = Symbol.from(node.getValue());
                SymbolStatsEstimate newSymbolStats = inEstimate.getSymbolStatistics(valueSymbol).mapDistinctValuesCount(newDistinctValuesCount -> Double.min(newDistinctValuesCount, valueStats.getDistinctValuesCount()));
                result.addSymbolStatistics(valueSymbol, newSymbolStats);
            }
            return result.build();
        }

        protected PlanNodeStatsEstimate visitComparisonExpression(ComparisonExpression node, Void context) {
            Optional<Symbol> leftSymbol;
            ComparisonExpression.Operator operator = node.getOperator();
            Expression left = node.getLeft();
            Expression right = node.getRight();
            Preconditions.checkArgument((!(left instanceof Literal) || !(right instanceof Literal) ? 1 : 0) != 0, (Object)"Literal-to-literal not supported here, should be eliminated earlier");
            if (!(left instanceof SymbolReference) && right instanceof SymbolReference) {
                return (PlanNodeStatsEstimate)this.process((Node)new ComparisonExpression(operator.flip(), right, left));
            }
            if (left instanceof Literal && !(right instanceof Literal)) {
                return (PlanNodeStatsEstimate)this.process((Node)new ComparisonExpression(operator.flip(), right, left));
            }
            if (left instanceof SymbolReference && left.equals((Object)right)) {
                return (PlanNodeStatsEstimate)this.process((Node)new IsNotNullPredicate(left));
            }
            SymbolStatsEstimate leftStats = this.getExpressionStats(left);
            Optional<Symbol> optional = leftSymbol = left instanceof SymbolReference ? Optional.of(Symbol.from(left)) : Optional.empty();
            if (right instanceof Literal) {
                OptionalDouble literal = this.doubleValueFromLiteral(this.getType(left), (Literal)right);
                return ComparisonStatsCalculator.estimateExpressionToLiteralComparison(this.input, leftStats, leftSymbol, literal, operator);
            }
            SymbolStatsEstimate rightStats = this.getExpressionStats(right);
            if (rightStats.isSingleValue()) {
                OptionalDouble value = Double.isNaN(rightStats.getLowValue()) ? OptionalDouble.empty() : OptionalDouble.of(rightStats.getLowValue());
                return ComparisonStatsCalculator.estimateExpressionToLiteralComparison(this.input, leftStats, leftSymbol, value, operator);
            }
            Optional<Symbol> rightSymbol = right instanceof SymbolReference ? Optional.of(Symbol.from(right)) : Optional.empty();
            return ComparisonStatsCalculator.estimateExpressionToExpressionComparison(this.input, leftStats, leftSymbol, rightStats, rightSymbol, operator);
        }

        protected PlanNodeStatsEstimate visitFunctionCall(FunctionCall node, Void context) {
            if (DynamicFilters.isDynamicFilter((Expression)node)) {
                return this.process((Node)BooleanLiteral.TRUE_LITERAL, context);
            }
            return PlanNodeStatsEstimate.unknown();
        }

        private Type getType(Expression expression) {
            if (expression instanceof SymbolReference) {
                Symbol symbol = Symbol.from(expression);
                return Objects.requireNonNull(this.types.get(symbol), () -> String.format("No type for symbol %s", symbol));
            }
            ExpressionAnalyzer expressionAnalyzer = ExpressionAnalyzer.createWithoutSubqueries(FilterStatsCalculator.this.metadata, (AccessControl)new AllowAllAccessControl(), this.session, this.types, (Map<NodeRef<Parameter>, Expression>)ImmutableMap.of(), node -> new VerifyException("Unexpected subquery"), WarningCollector.NOOP, false);
            return expressionAnalyzer.analyze(expression, Scope.create());
        }

        private SymbolStatsEstimate getExpressionStats(Expression expression) {
            if (expression instanceof SymbolReference) {
                Symbol symbol = Symbol.from(expression);
                return Objects.requireNonNull(this.input.getSymbolStatistics(symbol), () -> String.format("No statistics for symbol %s", symbol));
            }
            return FilterStatsCalculator.this.scalarStatsCalculator.calculate(expression, this.input, this.session, this.types);
        }

        private OptionalDouble doubleValueFromLiteral(Type type, Literal literal) {
            Object literalValue = LiteralInterpreter.evaluate(FilterStatsCalculator.this.metadata, this.session.toConnectorSession(), FilterStatsCalculator.this.getExpressionTypes(this.session, (Expression)literal, this.types), (Expression)literal);
            return StatsUtil.toStatsRepresentation(FilterStatsCalculator.this.metadata, this.session, type, literalValue);
        }
    }
}

