package io.trino.plugin.opa;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import io.trino.connector.MockConnectorFactory;
import io.trino.connector.MockConnectorPlugin;
import io.trino.spi.connector.ColumnMetadata;
import io.trino.spi.type.IntegerType;
import io.trino.spi.type.VarcharType;
import java.util.Set;
import org.assertj.core.api.Assertions;
import org.intellij.lang.annotations.Language;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;

@Testcontainers
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
/* loaded from: input_file:io/trino/plugin/opa/TestOpaAccessControlDataFilteringSystem.class */
public class TestOpaAccessControlDataFilteringSystem {
    private static final String OPA_ALLOW_POLICY_NAME = "allow";
    private static final String OPA_ROW_LEVEL_FILTERING_POLICY_NAME = "rowFilters";
    private static final String OPA_COLUMN_MASKING_POLICY_NAME = "columnMask";
    private static final String SAMPLE_ROW_LEVEL_FILTERING_POLICY = "package trino\nimport future.keywords.in\nimport future.keywords.if\nimport future.keywords.contains\n\ndefault allow := true\n\ntable_resource := input.action.resource.table\nis_admin {\n  input.context.identity.user == \"admin\"\n}\n\nrowFilters contains {\"expression\": \"user_type <> 'customer'\"} if {\n    not is_admin\n    table_resource.catalogName == \"sample_catalog\"\n    table_resource.schemaName == \"sample_schema\"\n    table_resource.tableName == \"restricted_table\"\n}";
    private static final String SAMPLE_COLUMN_MASKING_POLICY = "package trino\nimport future.keywords.in\nimport future.keywords.if\nimport future.keywords.contains\n\ndefault allow := true\n\ncolumn_resource := input.action.resource.column\nis_admin {\n  input.context.identity.user == \"admin\"\n}\n\ncolumnMask := {\"expression\": \"NULL\"} if {\n    not is_admin\n    column_resource.catalogName == \"sample_catalog\"\n    column_resource.schemaName == \"sample_schema\"\n    column_resource.tableName == \"restricted_table\"\n    column_resource.columnName == \"user_phone\"\n}\n\ncolumnMask := {\"expression\": \"'****' || substring(user_name, -3)\"} if {\n    not is_admin\n    column_resource.catalogName == \"sample_catalog\"\n    column_resource.schemaName == \"sample_schema\"\n    column_resource.tableName == \"restricted_table\"\n    column_resource.columnName == \"user_name\"\n}\n";
    private QueryRunnerHelper runner;

    @Container
    private static final OpaContainer OPA_CONTAINER = new OpaContainer();
    private static final Set<String> DUMMY_CUSTOMERS_IN_TABLE = ImmutableSet.of("customer_one", "customer_two");
    private static final Set<String> DUMMY_INTERNAL_USERS_IN_TABLE = ImmutableSet.of("some_internal_user");
    private static final Set<String> ALL_DUMMY_USERS_IN_TABLE = ImmutableSet.builder().addAll(DUMMY_INTERNAL_USERS_IN_TABLE).addAll(DUMMY_CUSTOMERS_IN_TABLE).build();

    @AfterEach
    public void teardown() {
        this.runner.teardown();
    }

    @Test
    public void testRowFilteringEnabled() throws Exception {
        setupTrinoWithOpa(new OpaConfig().setOpaUri(OPA_CONTAINER.getOpaUriForPolicyPath(OPA_ALLOW_POLICY_NAME)).setOpaRowFiltersUri(OPA_CONTAINER.getOpaUriForPolicyPath(OPA_ROW_LEVEL_FILTERING_POLICY_NAME)));
        OPA_CONTAINER.submitPolicy(SAMPLE_ROW_LEVEL_FILTERING_POLICY);
        assertResultsForUser("admin", "SELECT user_name FROM sample_catalog.sample_schema.restricted_table", ALL_DUMMY_USERS_IN_TABLE);
        assertResultsForUser("admin", "SELECT user_name FROM sample_catalog.sample_schema.unrestricted_table", ALL_DUMMY_USERS_IN_TABLE);
        assertResultsForUser("bob", "SELECT user_name FROM sample_catalog.sample_schema.unrestricted_table", ALL_DUMMY_USERS_IN_TABLE);
        assertResultsForUser("bob", "SELECT user_name FROM sample_catalog.sample_schema.restricted_table", DUMMY_INTERNAL_USERS_IN_TABLE);
    }

    @Test
    public void testRowFilteringDisabledDoesNothing() throws Exception {
        setupTrinoWithOpa(new OpaConfig().setOpaUri(OPA_CONTAINER.getOpaUriForPolicyPath(OPA_ALLOW_POLICY_NAME)));
        OPA_CONTAINER.submitPolicy(SAMPLE_ROW_LEVEL_FILTERING_POLICY);
        assertResultsForUser("admin", "SELECT user_name FROM sample_catalog.sample_schema.restricted_table", ALL_DUMMY_USERS_IN_TABLE);
        assertResultsForUser("admin", "SELECT user_name FROM sample_catalog.sample_schema.unrestricted_table", ALL_DUMMY_USERS_IN_TABLE);
        assertResultsForUser("bob", "SELECT user_name FROM sample_catalog.sample_schema.unrestricted_table", ALL_DUMMY_USERS_IN_TABLE);
        assertResultsForUser("bob", "SELECT user_name FROM sample_catalog.sample_schema.restricted_table", ALL_DUMMY_USERS_IN_TABLE);
    }

    @Test
    public void testColumnMasking() throws Exception {
        setupTrinoWithOpa(new OpaConfig().setOpaUri(OPA_CONTAINER.getOpaUriForPolicyPath(OPA_ALLOW_POLICY_NAME)).setOpaColumnMaskingUri(OPA_CONTAINER.getOpaUriForPolicyPath(OPA_COLUMN_MASKING_POLICY_NAME)));
        OPA_CONTAINER.submitPolicy(SAMPLE_COLUMN_MASKING_POLICY);
        assertResultsForUser("admin", "SELECT user_name FROM sample_catalog.sample_schema.unrestricted_table", ALL_DUMMY_USERS_IN_TABLE);
        assertResultsForUser("bob", "SELECT user_name FROM sample_catalog.sample_schema.unrestricted_table", ALL_DUMMY_USERS_IN_TABLE);
        assertResultsForUser("admin", "SELECT user_name FROM sample_catalog.sample_schema.restricted_table", ALL_DUMMY_USERS_IN_TABLE);
        assertResultsForUser("bob", "SELECT user_name FROM sample_catalog.sample_schema.restricted_table", (Set) ALL_DUMMY_USERS_IN_TABLE.stream().map(str -> {
            return "****" + str.substring(str.length() - 3);
        }).collect(ImmutableSet.toImmutableSet()));
        Set<String> set = (Set) ALL_DUMMY_USERS_IN_TABLE.stream().map(str2 -> {
            return String.valueOf(str2.hashCode());
        }).collect(ImmutableSet.toImmutableSet());
        assertResultsForUser("admin", "SELECT user_phone FROM sample_catalog.sample_schema.unrestricted_table", set);
        assertResultsForUser("bob", "SELECT user_phone FROM sample_catalog.sample_schema.unrestricted_table", set);
        assertResultsForUser("admin", "SELECT user_phone FROM sample_catalog.sample_schema.restricted_table", set);
        assertResultsForUser("bob", "SELECT user_phone FROM sample_catalog.sample_schema.restricted_table", ImmutableSet.of("<NULL>"));
    }

    @Test
    public void testColumnMaskingDisabledDoesNothing() throws Exception {
        setupTrinoWithOpa(new OpaConfig().setOpaUri(OPA_CONTAINER.getOpaUriForPolicyPath(OPA_ALLOW_POLICY_NAME)));
        OPA_CONTAINER.submitPolicy(SAMPLE_COLUMN_MASKING_POLICY);
        assertResultsForUser("admin", "SELECT user_name FROM sample_catalog.sample_schema.restricted_table", ALL_DUMMY_USERS_IN_TABLE);
        assertResultsForUser("admin", "SELECT user_name FROM sample_catalog.sample_schema.unrestricted_table", ALL_DUMMY_USERS_IN_TABLE);
        assertResultsForUser("bob", "SELECT user_name FROM sample_catalog.sample_schema.unrestricted_table", ALL_DUMMY_USERS_IN_TABLE);
        assertResultsForUser("bob", "SELECT user_name FROM sample_catalog.sample_schema.restricted_table", ALL_DUMMY_USERS_IN_TABLE);
    }

    @Test
    public void testColumnMaskingAndRowFiltering() throws Exception {
        setupTrinoWithOpa(new OpaConfig().setOpaUri(OPA_CONTAINER.getOpaUriForPolicyPath(OPA_ALLOW_POLICY_NAME)).setOpaColumnMaskingUri(OPA_CONTAINER.getOpaUriForPolicyPath(OPA_COLUMN_MASKING_POLICY_NAME)).setOpaRowFiltersUri(OPA_CONTAINER.getOpaUriForPolicyPath(OPA_ROW_LEVEL_FILTERING_POLICY_NAME)));
        OPA_CONTAINER.submitPolicy("package trino\nimport future.keywords.in\nimport future.keywords.if\nimport future.keywords.contains\n\ndefault allow := true\n\nis_admin {\n  input.context.identity.user == \"admin\"\n}\n\ntable_resource := input.action.resource.table\ncolumn_resource := input.action.resource.column\n\nrowFilters contains {\"expression\": \"user_type <> 'customer'\"} if {\n    not is_admin\n}\ncolumnMask := {\"expression\": \"NULL\"} if {\n    not is_admin\n    column_resource.columnName == \"user_name\"\n}");
        ImmutableSet of = ImmutableSet.of("internal_user", "customer");
        assertResultsForUser("admin", "SELECT user_name FROM sample_catalog.sample_schema.restricted_table", ALL_DUMMY_USERS_IN_TABLE);
        assertResultsForUser("admin", "SELECT user_type FROM sample_catalog.sample_schema.restricted_table", of);
        assertResultsForUser("bob", "SELECT user_name FROM sample_catalog.sample_schema.restricted_table", ImmutableSet.of("<NULL>"));
        assertResultsForUser("bob", "SELECT user_type FROM sample_catalog.sample_schema.restricted_table", ImmutableSet.of("internal_user"));
    }

    private void assertResultsForUser(String str, @Language("SQL") String str2, Set<String> set) {
        Assertions.assertThat(this.runner.querySetOfStrings(str, str2)).containsExactlyInAnyOrderElementsOf(set);
    }

    private void setupTrinoWithOpa(OpaConfig opaConfig) {
        this.runner = QueryRunnerHelper.withOpaConfig(opaConfig);
        this.runner.getBaseQueryRunner().installPlugin(new MockConnectorPlugin(MockConnectorFactory.builder().withListSchemaNames(connectorSession -> {
            return ImmutableList.of("sample_schema");
        }).withListTables((connectorSession2, str) -> {
            return ImmutableList.builder().add("restricted_table").add("unrestricted_table").build();
        }).withGetColumns(schemaTableName -> {
            return ImmutableList.builder().add(ColumnMetadata.builder().setName("user_type").setType(VarcharType.VARCHAR).build()).add(ColumnMetadata.builder().setName("user_name").setType(VarcharType.VARCHAR).build()).add(ColumnMetadata.builder().setName("user_phone").setType(IntegerType.INTEGER).build()).build();
        }).withData(schemaTableName2 -> {
            return ImmutableList.builder().addAll((Iterable) DUMMY_CUSTOMERS_IN_TABLE.stream().map(str2 -> {
                return ImmutableList.of("customer", str2, Integer.valueOf(str2.hashCode()));
            }).collect(ImmutableSet.toImmutableSet())).addAll((Iterable) DUMMY_INTERNAL_USERS_IN_TABLE.stream().map(str3 -> {
                return ImmutableList.of("internal_user", str3, Integer.valueOf(str3.hashCode()));
            }).collect(ImmutableSet.toImmutableSet())).build();
        }).build()));
        this.runner.getBaseQueryRunner().createCatalog("sample_catalog", "mock");
    }
}
