package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.Session;
import io.trino.connector.MockConnectorFactory;
import io.trino.connector.MockConnectorTableHandle;
import io.trino.cost.PlanNodeStatsEstimate;
import io.trino.cost.ScalarStatsCalculator;
import io.trino.cost.SymbolStatsEstimate;
import io.trino.plugin.tpch.TpchColumnHandle;
import io.trino.spi.connector.Assignment;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.ColumnMetadata;
import io.trino.spi.connector.ConnectorPartitioningHandle;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.connector.ConnectorTableHandle;
import io.trino.spi.connector.ConnectorTablePartitioning;
import io.trino.spi.connector.ConnectorTableProperties;
import io.trino.spi.connector.ProjectionApplicationResult;
import io.trino.spi.connector.SchemaTableName;
import io.trino.spi.expression.Call;
import io.trino.spi.expression.ConnectorExpression;
import io.trino.spi.expression.Constant;
import io.trino.spi.expression.FieldDereference;
import io.trino.spi.expression.Variable;
import io.trino.spi.predicate.TupleDomain;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.RowType;
import io.trino.spi.type.Type;
import io.trino.spi.type.VarcharType;
import io.trino.sql.PlannerContext;
import io.trino.sql.analyzer.TypeSignatureProvider;
import io.trino.sql.planner.ConnectorExpressionTranslator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.TypeAnalyzer;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.iterative.rule.test.RuleAssert;
import io.trino.sql.planner.iterative.rule.test.RuleTester;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.FunctionCall;
import io.trino.sql.tree.LongLiteral;
import io.trino.sql.tree.StringLiteral;
import io.trino.sql.tree.SubscriptExpression;
import io.trino.sql.tree.SymbolReference;
import io.trino.testing.TestingSession;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/TestPushProjectionIntoTableScan.class */
public class TestPushProjectionIntoTableScan {
    private static final String TEST_SCHEMA = "test_schema";
    private static final String TEST_TABLE = "test_table";
    private static final Type ROW_TYPE = RowType.from(Arrays.asList(RowType.field("a", BigintType.BIGINT), RowType.field("b", BigintType.BIGINT)));
    private static final ConnectorPartitioningHandle PARTITIONING_HANDLE = new ConnectorPartitioningHandle() { // from class: io.trino.sql.planner.iterative.rule.TestPushProjectionIntoTableScan.1
    };
    private static final Session MOCK_SESSION = TestingSession.testSessionBuilder().setCatalog("test-catalog").setSchema("test_schema").build();

    @Test
    public void testDoesNotFire() {
        String str = "input_column";
        Type type = ROW_TYPE;
        ColumnHandle column = column("input_column", type);
        RuleTester build = RuleTester.builder().withDefaultCatalogConnectorFactory(createMockFactory(ImmutableMap.of("input_column", column), Optional.empty())).build();
        try {
            build.assertThat(createRule(build)).on(planBuilder -> {
                Symbol symbol = planBuilder.symbol(str, type);
                return planBuilder.project(Assignments.of(planBuilder.symbol("symbol_dereference", BigintType.BIGINT), new SubscriptExpression(symbol.toSymbolReference(), new LongLiteral("1"))), planBuilder.tableScan(build.getCurrentCatalogTableHandle("test_schema", TEST_TABLE), ImmutableList.of(symbol), ImmutableMap.of(symbol, column)));
            }).withSession(MOCK_SESSION).doesNotFire();
            if (build != null) {
                build.close();
            }
        } catch (Throwable th) {
            if (build != null) {
                try {
                    build.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Test
    public void testPushProjection() {
        Type type = ROW_TYPE;
        Symbol symbol = new Symbol("col0");
        TpchColumnHandle tpchColumnHandle = new TpchColumnHandle("col0", type);
        RuleTester build = RuleTester.builder().withDefaultCatalogConnectorFactory(createMockFactory(ImmutableMap.of("col0", tpchColumnHandle), Optional.of(this::mockApplyProjection))).build();
        try {
            TypeAnalyzer createTestingTypeAnalyzer = TypeAnalyzer.createTestingTypeAnalyzer(build.getPlannerContext());
            Symbol symbol2 = new Symbol("symbol_identity");
            Symbol symbol3 = new Symbol("symbol_dereference");
            Symbol symbol4 = new Symbol("symbol_constant");
            Symbol symbol5 = new Symbol("symbol_call");
            ImmutableMap of = ImmutableMap.of(symbol, ROW_TYPE, symbol2, ROW_TYPE, symbol3, BigintType.BIGINT, symbol4, BigintType.BIGINT, symbol5, VarcharType.VARCHAR);
            ImmutableMap buildOrThrow = ImmutableMap.builder().put(symbol2, symbol.toSymbolReference()).put(symbol3, new SubscriptExpression(symbol.toSymbolReference(), new LongLiteral("1"))).put(symbol4, new LongLiteral("5")).put(symbol5, new FunctionCall(build.getMetadata().resolveBuiltinFunction("starts_with", TypeSignatureProvider.fromTypes(new Type[]{VarcharType.VARCHAR, VarcharType.VARCHAR})).toQualifiedName(), ImmutableList.of(new StringLiteral("abc"), new StringLiteral("ab")))).buildOrThrow();
            Session beginTransactionId = MOCK_SESSION.beginTransactionId(build.getQueryRunner().getTransactionManager().beginTransaction(false), build.getQueryRunner().getTransactionManager(), build.getQueryRunner().getAccessControl());
            ImmutableMap immutableMap = (ImmutableMap) buildOrThrow.entrySet().stream().collect(ImmutableMap.toImmutableMap((v0) -> {
                return v0.getKey();
            }, entry -> {
                return ((ConnectorExpression) ConnectorExpressionTranslator.translate(beginTransactionId, (Expression) entry.getValue(), TypeProvider.viewOf(of), build.getPlannerContext(), createTestingTypeAnalyzer).get()).toString();
            }));
            ImmutableMap of2 = ImmutableMap.of(symbol2, "projected_variable_" + ((String) immutableMap.get(symbol2)), symbol3, "projected_dereference_" + ((String) immutableMap.get(symbol3)), symbol4, "projected_constant_" + ((String) immutableMap.get(symbol4)), symbol5, "projected_call_" + ((String) immutableMap.get(symbol5)));
            Map map = (Map) of2.entrySet().stream().collect(ImmutableMap.toImmutableMap((v0) -> {
                return v0.getValue();
            }, entry2 -> {
                return column((String) entry2.getValue(), (Type) of.get(entry2.getKey()));
            }));
            RuleAssert withSession = build.assertThat(createRule(build)).on(planBuilder -> {
                of.forEach((symbol6, type2) -> {
                    planBuilder.symbol(symbol6.getName(), type2);
                });
                return planBuilder.project(new Assignments(buildOrThrow), planBuilder.tableScan(tableScanBuilder -> {
                    tableScanBuilder.setTableHandle(build.getCurrentCatalogTableHandle("test_schema", TEST_TABLE)).setSymbols(ImmutableList.copyOf(of.keySet())).setAssignments((Map) of.keySet().stream().collect(Collectors.toMap(Function.identity(), symbol7 -> {
                        return tpchColumnHandle;
                    }))).setStatistics(Optional.of(PlanNodeStatsEstimate.builder().setOutputRowCount(42.0d).addSymbolStatistics(symbol, SymbolStatsEstimate.builder().setNullsFraction(0.0d).setDistinctValuesCount(33.0d).build()).build()));
                }));
            }).withSession(MOCK_SESSION);
            Map map2 = (Map) of2.entrySet().stream().collect(ImmutableMap.toImmutableMap(entry3 -> {
                return ((Symbol) entry3.getKey()).getName();
            }, entry4 -> {
                return PlanMatchPattern.expression((Expression) symbolReference((String) entry4.getValue()));
            }));
            MockConnectorTableHandle mockConnectorTableHandle = new MockConnectorTableHandle(new SchemaTableName("test_schema", "projected_test_table"), TupleDomain.all(), Optional.of(ImmutableList.copyOf(map.values())));
            Predicate predicate = (v1) -> {
                return r2.equals(v1);
            };
            TupleDomain all = TupleDomain.all();
            Map map3 = (Map) map.entrySet().stream().collect(ImmutableMap.toImmutableMap((v0) -> {
                return v0.getKey();
            }, entry5 -> {
                ColumnHandle columnHandle = (ColumnHandle) entry5.getValue();
                Objects.requireNonNull(columnHandle);
                return (v1) -> {
                    return r0.equals(v1);
                };
            }));
            Optional of3 = Optional.of(PlanNodeStatsEstimate.builder().setOutputRowCount(42.0d).addSymbolStatistics(new Symbol((String) of2.get(symbol4)), SymbolStatsEstimate.builder().setDistinctValuesCount(1.0d).setNullsFraction(0.0d).setLowValue(5.0d).setHighValue(5.0d).build()).addSymbolStatistics(new Symbol(((String) of2.get(symbol5)).toLowerCase(Locale.ENGLISH)), SymbolStatsEstimate.builder().setDistinctValuesCount(1.0d).setNullsFraction(0.0d).build()).addSymbolStatistics(new Symbol((String) of2.get(symbol2)), SymbolStatsEstimate.builder().setDistinctValuesCount(33.0d).setNullsFraction(0.0d).build()).addSymbolStatistics(new Symbol((String) of2.get(symbol3)), SymbolStatsEstimate.unknown()).build());
            Objects.requireNonNull(of3);
            withSession.matches(PlanMatchPattern.project(map2, PlanMatchPattern.tableScan(predicate, all, map3, (v1) -> {
                return r5.equals(v1);
            })));
            if (build != null) {
                build.close();
            }
        } catch (Throwable th) {
            if (build != null) {
                try {
                    build.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Test
    public void testPartitioningChanged() {
        TpchColumnHandle tpchColumnHandle = new TpchColumnHandle("col0", VarcharType.VARCHAR);
        RuleTester build = RuleTester.builder().withDefaultCatalogConnectorFactory(createMockFactory(ImmutableMap.of("col0", tpchColumnHandle), Optional.of(this::mockApplyProjection))).build();
        try {
            Assertions.assertThatThrownBy(() -> {
                build.assertThat(createRule(build)).on(planBuilder -> {
                    return planBuilder.project(Assignments.of(), planBuilder.tableScan(build.getCurrentCatalogTableHandle("test_schema", TEST_TABLE), (List<Symbol>) ImmutableList.of(planBuilder.symbol("col", VarcharType.VARCHAR)), (Map<Symbol, ColumnHandle>) ImmutableMap.of(planBuilder.symbol("col", VarcharType.VARCHAR), tpchColumnHandle), Optional.of(true)));
                }).withSession(MOCK_SESSION).matches(PlanMatchPattern.anyTree(new PlanMatchPattern[0]));
            }).hasMessage("Partitioning must not change after projection is pushed down");
            if (build != null) {
                build.close();
            }
        } catch (Throwable th) {
            if (build != null) {
                try {
                    build.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    private MockConnectorFactory createMockFactory(Map<String, ColumnHandle> map, Optional<MockConnectorFactory.ApplyProjection> optional) {
        List list = (List) map.entrySet().stream().map(entry -> {
            return new ColumnMetadata((String) entry.getKey(), ((TpchColumnHandle) entry.getValue()).getType());
        }).collect(ImmutableList.toImmutableList());
        MockConnectorFactory.Builder withGetTableProperties = MockConnectorFactory.builder().withListSchemaNames(connectorSession -> {
            return ImmutableList.of("test_schema");
        }).withListTables((connectorSession2, str) -> {
            return "test_schema".equals(str) ? ImmutableList.of(TEST_TABLE) : ImmutableList.of();
        }).withGetColumns(schemaTableName -> {
            return list;
        }).withGetTableProperties((connectorSession3, connectorTableHandle) -> {
            return ((MockConnectorTableHandle) connectorTableHandle).getTableName().getTableName().equals(TEST_TABLE) ? new ConnectorTableProperties(TupleDomain.all(), Optional.of(new ConnectorTablePartitioning(PARTITIONING_HANDLE, ImmutableList.of(column("col", VarcharType.VARCHAR)))), Optional.empty(), ImmutableList.of()) : new ConnectorTableProperties();
        });
        if (optional.isPresent()) {
            withGetTableProperties = withGetTableProperties.withApplyProjection(optional.get());
        }
        return withGetTableProperties.build();
    }

    private Optional<ProjectionApplicationResult<ConnectorTableHandle>> mockApplyProjection(ConnectorSession connectorSession, ConnectorTableHandle connectorTableHandle, List<ConnectorExpression> list, Map<String, ColumnHandle> map) {
        String str;
        SchemaTableName tableName = ((MockConnectorTableHandle) connectorTableHandle).getTableName();
        SchemaTableName schemaTableName = new SchemaTableName(tableName.getSchemaName(), "projected_" + tableName.getTableName());
        ImmutableList.Builder builder = ImmutableList.builder();
        ImmutableList.Builder builder2 = ImmutableList.builder();
        ImmutableList.Builder builder3 = ImmutableList.builder();
        for (ConnectorExpression connectorExpression : list) {
            if (connectorExpression instanceof Variable) {
                str = "projected_variable_";
            } else if (connectorExpression instanceof FieldDereference) {
                str = "projected_dereference_";
            } else if (connectorExpression instanceof Constant) {
                str = "projected_constant_";
            } else {
                if (!(connectorExpression instanceof Call)) {
                    throw new UnsupportedOperationException();
                }
                str = "projected_call_";
            }
            String str2 = str + connectorExpression.toString();
            Variable variable = new Variable(str2, connectorExpression.getType());
            TpchColumnHandle tpchColumnHandle = new TpchColumnHandle(str2, connectorExpression.getType());
            builder.add(variable);
            builder2.add(new Assignment(str2, tpchColumnHandle, connectorExpression.getType()));
            builder3.add(tpchColumnHandle);
        }
        return Optional.of(new ProjectionApplicationResult(new MockConnectorTableHandle(schemaTableName, TupleDomain.all(), Optional.of(builder3.build())), builder.build(), builder2.build(), false));
    }

    private static PushProjectionIntoTableScan createRule(RuleTester ruleTester) {
        PlannerContext plannerContext = ruleTester.getPlannerContext();
        TypeAnalyzer typeAnalyzer = ruleTester.getTypeAnalyzer();
        return new PushProjectionIntoTableScan(plannerContext, typeAnalyzer, new ScalarStatsCalculator(plannerContext, typeAnalyzer));
    }

    private static SymbolReference symbolReference(String str) {
        return new SymbolReference(str);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static ColumnHandle column(String str, Type type) {
        return new TpchColumnHandle(str, type);
    }
}
