package io.trino.orc;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.primitives.Longs;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.trino.hive.orc.util.Murmur3;
import io.trino.orc.OrcWriterOptions;
import io.trino.orc.TupleDomainOrcPredicate;
import io.trino.orc.metadata.ColumnMetadata;
import io.trino.orc.metadata.CompressedMetadataWriter;
import io.trino.orc.metadata.CompressionKind;
import io.trino.orc.metadata.OrcColumnId;
import io.trino.orc.metadata.OrcMetadataReader;
import io.trino.orc.metadata.OrcMetadataWriter;
import io.trino.orc.metadata.statistics.BinaryStatistics;
import io.trino.orc.metadata.statistics.BloomFilter;
import io.trino.orc.metadata.statistics.BooleanStatistics;
import io.trino.orc.metadata.statistics.ColumnStatistics;
import io.trino.orc.metadata.statistics.DateStatistics;
import io.trino.orc.metadata.statistics.DecimalStatistics;
import io.trino.orc.metadata.statistics.DoubleStatistics;
import io.trino.orc.metadata.statistics.IntegerStatistics;
import io.trino.orc.metadata.statistics.StringStatistics;
import io.trino.orc.metadata.statistics.TimestampStatistics;
import io.trino.orc.metadata.statistics.Utf8BloomFilterBuilder;
import io.trino.orc.proto.OrcProto;
import io.trino.orc.protobuf.CodedInputStream;
import io.trino.spi.predicate.Domain;
import io.trino.spi.predicate.Range;
import io.trino.spi.predicate.ValueSet;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.DateType;
import io.trino.spi.type.DoubleType;
import io.trino.spi.type.IntegerType;
import io.trino.spi.type.RealType;
import io.trino.spi.type.SmallintType;
import io.trino.spi.type.TimestampType;
import io.trino.spi.type.TinyintType;
import io.trino.spi.type.Type;
import io.trino.spi.type.VarbinaryType;
import io.trino.spi.type.VarcharType;
import java.math.BigDecimal;
import java.nio.charset.StandardCharsets;
import java.sql.Timestamp;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ThreadLocalRandom;
import org.assertj.core.api.AbstractBooleanAssert;
import org.assertj.core.api.Assertions;
import org.assertj.core.api.Fail;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:io/trino/orc/TestOrcBloomFilters.class */
public class TestOrcBloomFilters {
    private static final int TEST_INTEGER = 12345;
    private static final byte[] TEST_STRING = "ORC_STRING".getBytes(StandardCharsets.UTF_8);
    private static final byte[] TEST_STRING_NOT_WRITTEN = "ORC_STRING_not".getBytes(StandardCharsets.UTF_8);
    private static final Map<Object, Type> TEST_VALUES = ImmutableMap.builder().put(Slices.wrappedBuffer(TEST_STRING), VarcharType.VARCHAR).put(Slices.wrappedBuffer(new byte[]{12, 34, 56}), VarbinaryType.VARBINARY).put(4312L, BigintType.BIGINT).put(123, IntegerType.INTEGER).put(789, SmallintType.SMALLINT).put(77, TinyintType.TINYINT).put(901, DateType.DATE).put(987654L, TimestampType.TIMESTAMP_MILLIS).put(Double.valueOf(234.567d), DoubleType.DOUBLE).put(Long.valueOf(Float.floatToIntBits(987.654f)), RealType.REAL).buildOrThrow();

    @Test
    public void testHiveBloomFilterSerde() {
        BloomFilter bloomFilter = new BloomFilter(1000000L, 0.05d);
        bloomFilter.add(TEST_STRING);
        Assertions.assertThat(bloomFilter.test(TEST_STRING)).isTrue();
        Assertions.assertThat(bloomFilter.testSlice(Slices.wrappedBuffer(TEST_STRING))).isTrue();
        Assertions.assertThat(bloomFilter.test(TEST_STRING_NOT_WRITTEN)).isFalse();
        Assertions.assertThat(bloomFilter.testSlice(Slices.wrappedBuffer(TEST_STRING_NOT_WRITTEN))).isFalse();
        bloomFilter.addLong(12345L);
        Assertions.assertThat(bloomFilter.testLong(12345L)).isTrue();
        Assertions.assertThat(bloomFilter.testLong(12346L)).isFalse();
        BloomFilter bloomFilter2 = new BloomFilter(bloomFilter.getBitSet(), bloomFilter.getNumHashFunctions());
        Assertions.assertThat(bloomFilter2.test(TEST_STRING)).isTrue();
        Assertions.assertThat(bloomFilter2.testSlice(Slices.wrappedBuffer(TEST_STRING))).isTrue();
        Assertions.assertThat(bloomFilter2.test(TEST_STRING_NOT_WRITTEN)).isFalse();
        Assertions.assertThat(bloomFilter2.testSlice(Slices.wrappedBuffer(TEST_STRING_NOT_WRITTEN))).isFalse();
        Assertions.assertThat(bloomFilter2.testLong(12345L)).isTrue();
        Assertions.assertThat(bloomFilter2.testLong(12346L)).isFalse();
    }

    @Test
    public void testOrcHiveBloomFilterSerde() throws Exception {
        BloomFilter bloomFilter = new BloomFilter(1000L, 0.05d);
        bloomFilter.add(TEST_STRING);
        Assertions.assertThat(bloomFilter.test(TEST_STRING)).isTrue();
        Assertions.assertThat(bloomFilter.testSlice(Slices.wrappedBuffer(TEST_STRING))).isTrue();
        Slice writeBloomFilters = new CompressedMetadataWriter(new OrcMetadataWriter(OrcWriterOptions.WriterIdentification.TRINO), CompressionKind.NONE, 1024).writeBloomFilters(ImmutableList.of(bloomFilter));
        List readBloomFilterIndexes = new OrcMetadataReader(new OrcReaderOptions()).readBloomFilterIndexes(writeBloomFilters.getInput());
        Assertions.assertThat(readBloomFilterIndexes.size()).isEqualTo(1);
        Assertions.assertThat(((BloomFilter) readBloomFilterIndexes.get(0)).test(TEST_STRING)).isTrue();
        Assertions.assertThat(((BloomFilter) readBloomFilterIndexes.get(0)).testSlice(Slices.wrappedBuffer(TEST_STRING))).isTrue();
        Assertions.assertThat(((BloomFilter) readBloomFilterIndexes.get(0)).test(TEST_STRING_NOT_WRITTEN)).isFalse();
        Assertions.assertThat(((BloomFilter) readBloomFilterIndexes.get(0)).testSlice(Slices.wrappedBuffer(TEST_STRING_NOT_WRITTEN))).isFalse();
        Assertions.assertThat(bloomFilter.getNumBits()).isEqualTo(((BloomFilter) readBloomFilterIndexes.get(0)).getNumBits());
        Assertions.assertThat(bloomFilter.getNumHashFunctions()).isEqualTo(((BloomFilter) readBloomFilterIndexes.get(0)).getNumHashFunctions());
        Assertions.assertThat(Arrays.equals(((BloomFilter) readBloomFilterIndexes.get(0)).getBitSet(), bloomFilter.getBitSet())).isTrue();
        List bloomFilterList = OrcProto.BloomFilterIndex.parseFrom(CodedInputStream.newInstance(writeBloomFilters.getBytes())).getBloomFilterList();
        Assertions.assertThat(bloomFilterList.size()).isEqualTo(1);
        OrcProto.BloomFilter bloomFilter2 = (OrcProto.BloomFilter) bloomFilterList.get(0);
        Assertions.assertThat(Arrays.equals(Longs.toArray(bloomFilter2.getBitsetList()), bloomFilter.getBitSet())).isTrue();
        Assertions.assertThat(bloomFilter.getNumHashFunctions()).isEqualTo(bloomFilter2.getNumHashFunctions());
        Assertions.assertThat(bloomFilter.getBitSet().length).isEqualTo(bloomFilter2.getBitsetCount());
    }

    @Test
    public void testBloomFilterPredicateValuesExisting() {
        BloomFilter bloomFilter = new BloomFilter(TEST_VALUES.size() * 10, 0.01d);
        for (Map.Entry<Object, Type> entry : TEST_VALUES.entrySet()) {
            Object key = entry.getKey();
            if (key instanceof Long) {
                if (entry.getValue() instanceof RealType) {
                    bloomFilter.addDouble(Float.intBitsToFloat(((Number) key).intValue()));
                } else {
                    bloomFilter.addLong(((Long) key).longValue());
                }
            } else if (key instanceof Integer) {
                bloomFilter.addLong(((Integer) key).intValue());
            } else if (key instanceof String) {
                bloomFilter.add(((String) key).getBytes(StandardCharsets.UTF_8));
            } else if (key instanceof BigDecimal) {
                bloomFilter.add(key.toString().getBytes(StandardCharsets.UTF_8));
            } else if (key instanceof Slice) {
                bloomFilter.add(((Slice) key).getBytes());
            } else if (key instanceof Timestamp) {
                bloomFilter.addLong(((Timestamp) key).getTime());
            } else if (key instanceof Double) {
                bloomFilter.addDouble(((Double) key).doubleValue());
            } else {
                Fail.fail("Unsupported type " + key.getClass());
            }
        }
        for (Map.Entry<Object, Type> entry2 : TEST_VALUES.entrySet()) {
            ((AbstractBooleanAssert) Assertions.assertThat(TupleDomainOrcPredicate.checkInBloomFilter(bloomFilter, entry2.getKey(), entry2.getValue())).describedAs("type " + entry2.getClass(), new Object[0])).isTrue();
        }
    }

    @Test
    public void testBloomFilterPredicateValuesNonExisting() {
        BloomFilter bloomFilter = new BloomFilter(TEST_VALUES.size() * 10, 0.01d);
        for (Map.Entry<Object, Type> entry : TEST_VALUES.entrySet()) {
            ((AbstractBooleanAssert) Assertions.assertThat(TupleDomainOrcPredicate.checkInBloomFilter(bloomFilter, entry.getKey(), entry.getValue())).describedAs("type " + entry.getKey().getClass(), new Object[0])).isFalse();
        }
    }

    @Test
    public void testMatches() {
        TupleDomainOrcPredicate build = TupleDomainOrcPredicate.builder().setBloomFiltersEnabled(true).addColumn(OrcColumnId.ROOT_COLUMN, Domain.singleValue(BigintType.BIGINT, 1234L)).build();
        TupleDomainOrcPredicate build2 = TupleDomainOrcPredicate.builder().build();
        ColumnMetadata columnMetadata = new ColumnMetadata(ImmutableList.of(new ColumnStatistics((Long) null, 0L, (BooleanStatistics) null, new IntegerStatistics(10L, 2000L, (Long) null), (DoubleStatistics) null, (Long) null, (StringStatistics) null, (DateStatistics) null, (TimestampStatistics) null, (DecimalStatistics) null, (BinaryStatistics) null, new Utf8BloomFilterBuilder(1000, 0.01d).addLong(1234L).buildBloomFilter())));
        ColumnMetadata columnMetadata2 = new ColumnMetadata(ImmutableList.of(new ColumnStatistics((Long) null, 0L, (BooleanStatistics) null, new IntegerStatistics(10L, 2000L, (Long) null), (DoubleStatistics) null, (Long) null, (StringStatistics) null, (DateStatistics) null, (TimestampStatistics) null, (DecimalStatistics) null, (BinaryStatistics) null, new Utf8BloomFilterBuilder(1000, 0.01d).buildBloomFilter())));
        ColumnMetadata columnMetadata3 = new ColumnMetadata(ImmutableList.of(new ColumnStatistics((Long) null, 0L, (BooleanStatistics) null, new IntegerStatistics(10L, 2000L, (Long) null), (DoubleStatistics) null, (Long) null, (StringStatistics) null, (DateStatistics) null, (TimestampStatistics) null, (DecimalStatistics) null, (BinaryStatistics) null, (BloomFilter) null)));
        Assertions.assertThat(build.matches(1L, columnMetadata)).isTrue();
        Assertions.assertThat(build.matches(1L, columnMetadata3)).isTrue();
        Assertions.assertThat(build.matches(1L, columnMetadata2)).isFalse();
        Assertions.assertThat(build2.matches(1L, columnMetadata)).isTrue();
    }

    @Test
    public void testMatchesExpandedRange() {
        TupleDomainOrcPredicate build = TupleDomainOrcPredicate.builder().setBloomFiltersEnabled(true).addColumn(OrcColumnId.ROOT_COLUMN, Domain.create(ValueSet.ofRanges(Range.range(BigintType.BIGINT, 1233L, true, 1235L, true), new Range[0]), false)).setDomainCompactionThreshold(100).build();
        ColumnMetadata columnMetadata = new ColumnMetadata(ImmutableList.of(new ColumnStatistics((Long) null, 0L, (BooleanStatistics) null, new IntegerStatistics(10L, 2000L, (Long) null), (DoubleStatistics) null, (Long) null, (StringStatistics) null, (DateStatistics) null, (TimestampStatistics) null, (DecimalStatistics) null, (BinaryStatistics) null, new Utf8BloomFilterBuilder(1000, 0.01d).addLong(1234L).buildBloomFilter())));
        ColumnMetadata columnMetadata2 = new ColumnMetadata(ImmutableList.of(new ColumnStatistics((Long) null, 0L, (BooleanStatistics) null, new IntegerStatistics(10L, 2000L, (Long) null), (DoubleStatistics) null, (Long) null, (StringStatistics) null, (DateStatistics) null, (TimestampStatistics) null, (DecimalStatistics) null, (BinaryStatistics) null, new Utf8BloomFilterBuilder(1000, 0.01d).addLong(9876L).buildBloomFilter())));
        Assertions.assertThat(build.matches(1L, columnMetadata)).isTrue();
        Assertions.assertThat(build.matches(1L, columnMetadata2)).isFalse();
    }

    @Test
    public void testMatchesNonExpandedRange() {
        ColumnMetadata columnMetadata = new ColumnMetadata(ImmutableList.of(new ColumnStatistics((Long) null, 0L, (BooleanStatistics) null, new IntegerStatistics(10L, 2000L, (Long) null), (DoubleStatistics) null, (Long) null, (StringStatistics) null, (DateStatistics) null, (TimestampStatistics) null, (DecimalStatistics) null, (BinaryStatistics) null, new Utf8BloomFilterBuilder(1000, 0.01d).addLong(1500L).buildBloomFilter())));
        TupleDomainOrcPredicate.TupleDomainOrcPredicateBuilder addColumn = TupleDomainOrcPredicate.builder().setBloomFiltersEnabled(true).addColumn(OrcColumnId.ROOT_COLUMN, Domain.create(ValueSet.ofRanges(Range.range(BigintType.BIGINT, 1233L, true, 1235L, true), new Range[0]), false));
        Assertions.assertThat(addColumn.setDomainCompactionThreshold(1).build().matches(1L, columnMetadata)).isTrue();
        Assertions.assertThat(addColumn.setDomainCompactionThreshold(100).build().matches(1L, columnMetadata)).isFalse();
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Test
    public void testBloomFilterCompatibility() {
        for (int i = 0; i < 200; i++) {
            double nextDouble = ThreadLocalRandom.current().nextDouble(0.01d, 0.1d);
            int nextInt = ThreadLocalRandom.current().nextInt(100, TestingOrcPredicate.ORC_ROW_GROUP_SIZE);
            int nextInt2 = ThreadLocalRandom.current().nextInt(nextInt / 2, nextInt);
            BloomFilter bloomFilter = new BloomFilter(nextInt, nextDouble);
            io.trino.hive.orc.util.BloomFilter bloomFilter2 = new io.trino.hive.orc.util.BloomFilter(nextInt, nextDouble);
            Assertions.assertThat(bloomFilter.test((byte[]) null)).isFalse();
            Assertions.assertThat(bloomFilter2.test((byte[]) null)).isFalse();
            byte[] bArr = new byte[nextInt2];
            long[] jArr = new long[nextInt2];
            double[] dArr = new double[nextInt2];
            float[] fArr = new float[nextInt2];
            for (int i2 = 0; i2 < nextInt2; i2++) {
                bArr[i2] = randomBytes(ThreadLocalRandom.current().nextInt(100));
                jArr[i2] = ThreadLocalRandom.current().nextLong();
                dArr[i2] = ThreadLocalRandom.current().nextDouble();
                fArr[i2] = ThreadLocalRandom.current().nextFloat();
            }
            for (int i3 = 0; i3 < nextInt2; i3++) {
                Assertions.assertThat(bloomFilter.test(bArr[i3])).isFalse();
                Assertions.assertThat(bloomFilter.testSlice(Slices.wrappedBuffer(bArr[i3]))).isFalse();
                Assertions.assertThat(bloomFilter.testLong(jArr[i3])).isFalse();
                Assertions.assertThat(bloomFilter.testDouble(dArr[i3])).isFalse();
                Assertions.assertThat(bloomFilter.testFloat(fArr[i3])).isFalse();
                Assertions.assertThat(bloomFilter2.test(bArr[i3])).isFalse();
                Assertions.assertThat(bloomFilter2.testLong(jArr[i3])).isFalse();
                Assertions.assertThat(bloomFilter2.testDouble(dArr[i3])).isFalse();
                Assertions.assertThat(bloomFilter2.testDouble(fArr[i3])).isFalse();
            }
            for (int i4 = 0; i4 < nextInt2; i4++) {
                bloomFilter.add(bArr[i4]);
                bloomFilter.addLong(jArr[i4]);
                bloomFilter.addDouble(dArr[i4]);
                bloomFilter.addFloat(fArr[i4]);
                bloomFilter2.add(bArr[i4]);
                bloomFilter2.addLong(jArr[i4]);
                bloomFilter2.addDouble(dArr[i4]);
                bloomFilter2.addDouble(fArr[i4]);
            }
            for (int i5 = 0; i5 < nextInt2; i5++) {
                Assertions.assertThat(bloomFilter.test(bArr[i5])).isTrue();
                Assertions.assertThat(bloomFilter.testSlice(Slices.wrappedBuffer(bArr[i5]))).isTrue();
                Assertions.assertThat(bloomFilter.testLong(jArr[i5])).isTrue();
                Assertions.assertThat(bloomFilter.testDouble(dArr[i5])).isTrue();
                Assertions.assertThat(bloomFilter.testFloat(fArr[i5])).isTrue();
                Assertions.assertThat(bloomFilter2.test(bArr[i5])).isTrue();
                Assertions.assertThat(bloomFilter2.testLong(jArr[i5])).isTrue();
                Assertions.assertThat(bloomFilter2.testDouble(dArr[i5])).isTrue();
                Assertions.assertThat(bloomFilter2.testDouble(fArr[i5])).isTrue();
            }
            bloomFilter.add((byte[]) null);
            bloomFilter2.add((byte[]) null);
            Assertions.assertThat(bloomFilter.test((byte[]) null)).isTrue();
            Assertions.assertThat(bloomFilter.testSlice((Slice) null)).isTrue();
            Assertions.assertThat(bloomFilter2.test((byte[]) null)).isTrue();
            Assertions.assertThat(bloomFilter.getBitSet()).isEqualTo(bloomFilter2.getBitSet());
        }
    }

    @Test
    public void testHashCompatibility() {
        for (int i = 0; i < 1000; i++) {
            for (int i2 = 0; i2 < 100; i2++) {
                byte[] randomBytes = randomBytes(i);
                Assertions.assertThat(BloomFilter.OrcMurmur3.hash64(randomBytes)).isEqualTo(Murmur3.hash64(randomBytes));
            }
        }
    }

    private static byte[] randomBytes(int i) {
        byte[] bArr = new byte[i];
        ThreadLocalRandom.current().nextBytes(bArr);
        return bArr;
    }
}
