package io.trino.tests;

import com.google.common.collect.ImmutableMap;
import io.trino.plugin.tpch.ColumnNaming;
import io.trino.testing.DistributedQueryRunner;
import io.trino.testing.statistics.MetricComparisonStrategies;
import io.trino.testing.statistics.Metrics;
import io.trino.testing.statistics.StatisticsAssertion;
import io.trino.tests.tpch.TpchQueryRunnerBuilder;
import io.trino.tpch.TpchTable;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

/* loaded from: input_file:io/trino/tests/TestTpchDistributedStats.class */
public class TestTpchDistributedStats {
    private StatisticsAssertion statisticsAssertion;

    @BeforeClass
    public void setup() throws Exception {
        DistributedQueryRunner buildWithoutCatalogs = TpchQueryRunnerBuilder.builder().amendSession(sessionBuilder -> {
            return sessionBuilder.setSystemProperty("prefer_partial_aggregation", "false").setSystemProperty("collect_plan_statistics_for_all_queries", "true");
        }).buildWithoutCatalogs();
        buildWithoutCatalogs.createCatalog("tpch", "tpch", ImmutableMap.of("tpch.column-naming", ColumnNaming.STANDARD.name()));
        this.statisticsAssertion = new StatisticsAssertion(buildWithoutCatalogs);
    }

    @AfterClass(alwaysRun = true)
    public void tearDown() {
        this.statisticsAssertion.close();
        this.statisticsAssertion = null;
    }

    @Test
    public void testTableScanStats() {
        TpchTable.getTables().forEach(tpchTable -> {
            this.statisticsAssertion.check("SELECT * FROM " + tpchTable.getTableName(), checks -> {
                checks.estimate(Metrics.OUTPUT_ROW_COUNT, MetricComparisonStrategies.noError());
            });
        });
    }

    @Test
    public void testFilter() {
        this.statisticsAssertion.check("SELECT * FROM lineitem WHERE l_shipdate <= DATE '1998-12-01' - INTERVAL '90' DAY", checks -> {
            checks.estimate(Metrics.OUTPUT_ROW_COUNT, MetricComparisonStrategies.defaultTolerance());
        });
    }

    @Test
    public void testJoin() {
        this.statisticsAssertion.check("SELECT * FROM  part, partsupp WHERE p_partkey = ps_partkey", checks -> {
            checks.estimate(Metrics.OUTPUT_ROW_COUNT, MetricComparisonStrategies.defaultTolerance());
        });
    }

    @Test
    public void testUnion() {
        this.statisticsAssertion.check("SELECT * FROM nation UNION SELECT * FROM nation", checks -> {
            checks.estimate(Metrics.OUTPUT_ROW_COUNT, MetricComparisonStrategies.relativeError(1.0d, 1.0d));
        });
        this.statisticsAssertion.check("SELECT * FROM nation UNION ALL SELECT * FROM nation", checks2 -> {
            checks2.estimate(Metrics.OUTPUT_ROW_COUNT, MetricComparisonStrategies.noError());
        });
        this.statisticsAssertion.check("SELECT * FROM orders WHERE o_custkey < 755 OR o_orderstatus = '0' UNION SELECT * FROM orders WHERE o_custkey > 755 OR o_orderstatus = 'F'", checks3 -> {
            checks3.estimate(Metrics.OUTPUT_ROW_COUNT, MetricComparisonStrategies.relativeError(0.3d, 0.35d));
        });
        this.statisticsAssertion.check("SELECT * FROM orders WHERE o_custkey < 755 OR o_orderstatus = '0' UNION ALL SELECT * FROM orders WHERE o_custkey > 755 OR o_orderstatus = 'F'", checks4 -> {
            checks4.estimate(Metrics.OUTPUT_ROW_COUNT, MetricComparisonStrategies.defaultTolerance());
        });
        this.statisticsAssertion.check("SELECT * FROM orders WHERE o_custkey < 900 UNION SELECT * FROM orders WHERE o_custkey > 600", checks5 -> {
            checks5.estimate(Metrics.OUTPUT_ROW_COUNT, MetricComparisonStrategies.relativeError(0.15d, 0.25d));
        });
        this.statisticsAssertion.check("SELECT * FROM orders WHERE o_custkey < 900 UNION ALL SELECT * FROM orders WHERE o_custkey > 600", checks6 -> {
            checks6.estimate(Metrics.OUTPUT_ROW_COUNT, MetricComparisonStrategies.defaultTolerance());
        });
    }

    @Test
    public void testIntersect() {
        this.statisticsAssertion.check("SELECT * FROM nation INTERSECT SELECT * FROM nation", checks -> {
            checks.noEstimate(Metrics.OUTPUT_ROW_COUNT);
        });
        this.statisticsAssertion.check("SELECT * FROM orders WHERE o_custkey < 900 INTERSECT SELECT * FROM orders WHERE o_custkey > 600", checks2 -> {
            checks2.noEstimate(Metrics.OUTPUT_ROW_COUNT);
        });
    }

    @Test
    public void testExcept() {
        this.statisticsAssertion.check("SELECT * FROM nation EXCEPT SELECT * FROM nation", checks -> {
            checks.noEstimate(Metrics.OUTPUT_ROW_COUNT);
        });
        this.statisticsAssertion.check("SELECT * FROM orders WHERE o_custkey < 900 EXCEPT SELECT * FROM orders WHERE o_custkey > 600", checks2 -> {
            checks2.noEstimate(Metrics.OUTPUT_ROW_COUNT);
        });
    }

    @Test
    public void testEnforceSingleRow() {
        this.statisticsAssertion.check("SELECT (SELECT n_regionkey FROM nation WHERE n_name = 'nosuchvalue') AS sub", checks -> {
            checks.estimate(Metrics.OUTPUT_ROW_COUNT, MetricComparisonStrategies.noError());
        });
        this.statisticsAssertion.check("SELECT (SELECT n_regionkey FROM nation WHERE n_name = 'GERMANY') AS sub", checks2 -> {
            checks2.estimate(Metrics.distinctValuesCount("sub"), MetricComparisonStrategies.noError()).estimate(Metrics.OUTPUT_ROW_COUNT, MetricComparisonStrategies.noError());
        });
    }

    @Test
    public void testValues() {
        this.statisticsAssertion.check("VALUES 1", checks -> {
            checks.estimate(Metrics.OUTPUT_ROW_COUNT, MetricComparisonStrategies.noError());
        });
    }

    @Test
    public void testSemiJoin() {
        this.statisticsAssertion.check("SELECT * FROM nation WHERE n_regionkey IN (SELECT r_regionkey FROM region)", checks -> {
            checks.estimate(Metrics.OUTPUT_ROW_COUNT, MetricComparisonStrategies.noError());
        });
        this.statisticsAssertion.check("SELECT * FROM nation WHERE n_regionkey IN (SELECT r_regionkey FROM region WHERE r_regionkey % 3 = 0)", checks2 -> {
            checks2.estimate(Metrics.OUTPUT_ROW_COUNT, MetricComparisonStrategies.absoluteError(15.0d));
        });
    }

    @Test
    public void testLimit() {
        this.statisticsAssertion.check("SELECT * FROM nation LIMIT 10", checks -> {
            checks.estimate(Metrics.OUTPUT_ROW_COUNT, MetricComparisonStrategies.noError());
        });
    }

    @Test
    public void testGroupBy() {
        this.statisticsAssertion.check("SELECT l_returnflag, l_linestatus FROM lineitem GROUP BY l_returnflag, l_linestatus", checks -> {
            checks.estimate(Metrics.OUTPUT_ROW_COUNT, MetricComparisonStrategies.absoluteError(2.0d));
        });
    }

    @Test
    public void testSort() {
        this.statisticsAssertion.check("SELECT * FROM nation ORDER BY n_nationkey", checks -> {
            checks.estimate(Metrics.OUTPUT_ROW_COUNT, MetricComparisonStrategies.noError());
        });
    }
}
