package io.trino.orc;

import com.google.common.collect.Lists;
import com.google.common.collect.Ordering;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.trino.orc.metadata.ColumnMetadata;
import io.trino.orc.metadata.OrcColumnId;
import io.trino.orc.metadata.statistics.BloomFilter;
import io.trino.orc.metadata.statistics.ColumnStatistics;
import io.trino.spi.type.ArrayType;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.BooleanType;
import io.trino.spi.type.CharType;
import io.trino.spi.type.DateTimeEncoding;
import io.trino.spi.type.DateType;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.DoubleType;
import io.trino.spi.type.IntegerType;
import io.trino.spi.type.LongTimestamp;
import io.trino.spi.type.LongTimestampWithTimeZone;
import io.trino.spi.type.MapType;
import io.trino.spi.type.RealType;
import io.trino.spi.type.RowType;
import io.trino.spi.type.SmallintType;
import io.trino.spi.type.SqlDate;
import io.trino.spi.type.SqlDecimal;
import io.trino.spi.type.SqlTimestamp;
import io.trino.spi.type.SqlTimestampWithTimeZone;
import io.trino.spi.type.TimeZoneKey;
import io.trino.spi.type.TimestampType;
import io.trino.spi.type.TimestampWithTimeZoneType;
import io.trino.spi.type.TinyintType;
import io.trino.spi.type.Type;
import io.trino.spi.type.UuidType;
import io.trino.spi.type.VarbinaryType;
import io.trino.spi.type.VarcharType;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.ThreadLocalRandom;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.testng.Assert;

/* loaded from: input_file:io/trino/orc/TestingOrcPredicate.class */
public final class TestingOrcPredicate {
    public static final int ORC_STRIPE_SIZE = 30000;
    public static final int ORC_ROW_GROUP_SIZE = 10000;

    /* loaded from: input_file:io/trino/orc/TestingOrcPredicate$BasicOrcPredicate.class */
    public static class BasicOrcPredicate<T> implements OrcPredicate {
        private final List<T> expectedValues;

        public BasicOrcPredicate(Iterable<?> iterable, Class<T> cls) {
            ArrayList arrayList = new ArrayList();
            Iterator<?> it = iterable.iterator();
            while (it.hasNext()) {
                arrayList.add(cls.cast(it.next()));
            }
            this.expectedValues = Collections.unmodifiableList(arrayList);
        }

        public boolean matches(long j, ColumnMetadata<ColumnStatistics> columnMetadata) {
            ColumnStatistics columnStatistics = (ColumnStatistics) columnMetadata.get(new OrcColumnId(1));
            Assert.assertTrue(columnStatistics.hasNumberOfValues());
            if (j == this.expectedValues.size()) {
                assertChunkStats(this.expectedValues, columnStatistics);
                return true;
            }
            if (j == 10000) {
                matchMiddleSection(columnStatistics, TestingOrcPredicate.ORC_ROW_GROUP_SIZE);
                return true;
            }
            if (j == 30000) {
                matchMiddleSection(columnStatistics, TestingOrcPredicate.ORC_STRIPE_SIZE);
                return true;
            }
            if (j == this.expectedValues.size() % TestingOrcPredicate.ORC_ROW_GROUP_SIZE || j == this.expectedValues.size() % TestingOrcPredicate.ORC_STRIPE_SIZE) {
                assertChunkStats(this.expectedValues.subList((int) (this.expectedValues.size() - j), this.expectedValues.size()), columnStatistics);
                return true;
            }
            Assert.fail("Unexpected number of rows: " + j);
            return true;
        }

        private void matchMiddleSection(ColumnStatistics columnStatistics, int i) {
            int i2 = 0;
            while (true) {
                int i3 = i2;
                if (i3 >= this.expectedValues.size()) {
                    Assert.fail("match not found for middle section");
                    return;
                }
                int min = Math.min(i, this.expectedValues.size() - i3);
                if (chunkMatchesStats(this.expectedValues.subList(i3, i3 + min), columnStatistics)) {
                    return;
                } else {
                    i2 = i3 + min;
                }
            }
        }

        private void assertChunkStats(List<T> list, ColumnStatistics columnStatistics) {
            Assert.assertTrue(chunkMatchesStats(list, columnStatistics));
        }

        protected boolean chunkMatchesStats(List<T> list, ColumnStatistics columnStatistics) {
            return columnStatistics.getNumberOfValues() == list.stream().filter(Objects::nonNull).count();
        }
    }

    /* loaded from: input_file:io/trino/orc/TestingOrcPredicate$BooleanOrcPredicate.class */
    public static class BooleanOrcPredicate extends BasicOrcPredicate<Boolean> {
        public BooleanOrcPredicate(Iterable<?> iterable) {
            super(iterable, Boolean.class);
        }

        @Override // io.trino.orc.TestingOrcPredicate.BasicOrcPredicate
        protected boolean chunkMatchesStats(List<Boolean> list, ColumnStatistics columnStatistics) {
            Assert.assertNull(columnStatistics.getIntegerStatistics());
            Assert.assertNull(columnStatistics.getDoubleStatistics());
            Assert.assertNull(columnStatistics.getStringStatistics());
            Assert.assertNull(columnStatistics.getDateStatistics());
            if (!super.chunkMatchesStats(list, columnStatistics)) {
                return false;
            }
            if (columnStatistics.getBooleanStatistics() == null) {
                return true;
            }
            long trueValueCount = columnStatistics.getBooleanStatistics().getTrueValueCount();
            Stream<Boolean> stream = list.stream();
            Boolean bool = Boolean.TRUE;
            Objects.requireNonNull(bool);
            return trueValueCount == stream.filter((v1) -> {
                return r2.equals(v1);
            }).count();
        }
    }

    /* loaded from: input_file:io/trino/orc/TestingOrcPredicate$CharOrcPredicate.class */
    public static class CharOrcPredicate extends BasicOrcPredicate<String> {
        public CharOrcPredicate(Iterable<?> iterable) {
            super(iterable, String.class);
        }

        @Override // io.trino.orc.TestingOrcPredicate.BasicOrcPredicate
        protected boolean chunkMatchesStats(List<String> list, ColumnStatistics columnStatistics) {
            Assert.assertNull(columnStatistics.getBooleanStatistics());
            Assert.assertNull(columnStatistics.getIntegerStatistics());
            Assert.assertNull(columnStatistics.getDoubleStatistics());
            Assert.assertNull(columnStatistics.getDateStatistics());
            Assert.assertNull(columnStatistics.getBloomFilter());
            if (!super.chunkMatchesStats(list, columnStatistics)) {
                return false;
            }
            List list2 = (List) list.stream().filter((v0) -> {
                return Objects.nonNull(v0);
            }).map((v0) -> {
                return v0.trim();
            }).collect(Collectors.toList());
            if (columnStatistics.getStringStatistics() == null) {
                return true;
            }
            if (list2.isEmpty()) {
                return columnStatistics.getStringStatistics().getMin() == null && columnStatistics.getStringStatistics().getMax() == null;
            }
            if (columnStatistics.getStringStatistics().getMin().toStringUtf8().trim().compareTo((String) Ordering.natural().nullsLast().min(list2)) > 0) {
                return false;
            }
            return columnStatistics.getStringStatistics().getMax().toStringUtf8().trim().compareTo((String) Ordering.natural().nullsFirst().max(list2)) >= 0;
        }
    }

    /* loaded from: input_file:io/trino/orc/TestingOrcPredicate$DateOrcPredicate.class */
    public static class DateOrcPredicate extends BasicOrcPredicate<Long> {
        public DateOrcPredicate(Iterable<?> iterable) {
            super(iterable, Long.class);
        }

        @Override // io.trino.orc.TestingOrcPredicate.BasicOrcPredicate
        protected boolean chunkMatchesStats(List<Long> list, ColumnStatistics columnStatistics) {
            Assert.assertNull(columnStatistics.getBooleanStatistics());
            Assert.assertNull(columnStatistics.getIntegerStatistics());
            Assert.assertNull(columnStatistics.getDoubleStatistics());
            Assert.assertNull(columnStatistics.getStringStatistics());
            if (!super.chunkMatchesStats(list, columnStatistics)) {
                return false;
            }
            if (columnStatistics.getDateStatistics() == null) {
                return true;
            }
            if (list.stream().allMatch((v0) -> {
                return Objects.isNull(v0);
            })) {
                return columnStatistics.getDateStatistics().getMin() == null && columnStatistics.getDateStatistics().getMax() == null;
            }
            if (!Long.valueOf(columnStatistics.getDateStatistics().getMin().longValue()).equals(Ordering.natural().nullsLast().min(list)) || !Long.valueOf(columnStatistics.getDateStatistics().getMax().longValue()).equals((Long) Ordering.natural().nullsFirst().max(list))) {
                return false;
            }
            BloomFilter bloomFilter = columnStatistics.getBloomFilter();
            if (bloomFilter == null) {
                return true;
            }
            for (Long l : list) {
                if (l != null && !bloomFilter.testLong(l.longValue())) {
                    return false;
                }
            }
            return true;
        }
    }

    /* loaded from: input_file:io/trino/orc/TestingOrcPredicate$DecimalOrcPredicate.class */
    private static class DecimalOrcPredicate extends BasicOrcPredicate<SqlDecimal> {
        public DecimalOrcPredicate(Iterable<?> iterable) {
            super(iterable, SqlDecimal.class);
        }
    }

    /* loaded from: input_file:io/trino/orc/TestingOrcPredicate$DoubleOrcPredicate.class */
    public static class DoubleOrcPredicate extends BasicOrcPredicate<Double> {
        public DoubleOrcPredicate(Iterable<?> iterable) {
            super(iterable, Double.class);
        }

        @Override // io.trino.orc.TestingOrcPredicate.BasicOrcPredicate
        protected boolean chunkMatchesStats(List<Double> list, ColumnStatistics columnStatistics) {
            Assert.assertNull(columnStatistics.getBooleanStatistics());
            Assert.assertNull(columnStatistics.getIntegerStatistics());
            Assert.assertNull(columnStatistics.getStringStatistics());
            Assert.assertNull(columnStatistics.getDateStatistics());
            if (!super.chunkMatchesStats(list, columnStatistics)) {
                return false;
            }
            BloomFilter bloomFilter = columnStatistics.getBloomFilter();
            if (bloomFilter != null) {
                for (Double d : list) {
                    if (d != null && !bloomFilter.testDouble(d.doubleValue())) {
                        return false;
                    }
                }
            }
            if (columnStatistics.getDoubleStatistics() != null) {
                return list.stream().allMatch((v0) -> {
                    return Objects.isNull(v0);
                }) ? columnStatistics.getDoubleStatistics().getMin() == null && columnStatistics.getDoubleStatistics().getMax() == null : Math.abs(columnStatistics.getDoubleStatistics().getMin().doubleValue() - ((Double) Ordering.natural().nullsLast().min(list)).doubleValue()) <= 0.001d && Math.abs(columnStatistics.getDoubleStatistics().getMax().doubleValue() - ((Double) Ordering.natural().nullsFirst().max(list)).doubleValue()) <= 0.001d;
            }
            return true;
        }
    }

    /* loaded from: input_file:io/trino/orc/TestingOrcPredicate$LongOrcPredicate.class */
    public static class LongOrcPredicate extends BasicOrcPredicate<Long> {
        private final boolean testBloomFilter;

        public LongOrcPredicate(boolean z, Iterable<?> iterable) {
            super(iterable, Long.class);
            this.testBloomFilter = z;
        }

        @Override // io.trino.orc.TestingOrcPredicate.BasicOrcPredicate
        protected boolean chunkMatchesStats(List<Long> list, ColumnStatistics columnStatistics) {
            Assert.assertNull(columnStatistics.getBooleanStatistics());
            Assert.assertNull(columnStatistics.getDoubleStatistics());
            Assert.assertNull(columnStatistics.getStringStatistics());
            Assert.assertNull(columnStatistics.getDateStatistics());
            if (!super.chunkMatchesStats(list, columnStatistics)) {
                return false;
            }
            if (columnStatistics.getIntegerStatistics() == null) {
                return true;
            }
            if (list.stream().allMatch((v0) -> {
                return Objects.isNull(v0);
            })) {
                if (columnStatistics.getIntegerStatistics().getMin() != null || columnStatistics.getIntegerStatistics().getMax() != null) {
                    return false;
                }
            } else if (!columnStatistics.getIntegerStatistics().getMin().equals(Ordering.natural().nullsLast().min(list)) || !columnStatistics.getIntegerStatistics().getMax().equals(Ordering.natural().nullsFirst().max(list))) {
                return false;
            }
            if (columnStatistics.getIntegerStatistics().getSum().longValue() != list.stream().filter((v0) -> {
                return Objects.nonNull(v0);
            }).mapToLong((v0) -> {
                return v0.longValue();
            }).sum()) {
                return false;
            }
            BloomFilter bloomFilter = columnStatistics.getBloomFilter();
            if (!this.testBloomFilter || bloomFilter == null) {
                return true;
            }
            for (Long l : list) {
                if (l != null && !bloomFilter.testLong(l.longValue())) {
                    return false;
                }
            }
            return true;
        }
    }

    /* loaded from: input_file:io/trino/orc/TestingOrcPredicate$StringOrcPredicate.class */
    public static class StringOrcPredicate extends BasicOrcPredicate<String> {
        public StringOrcPredicate(Iterable<?> iterable) {
            super(iterable, String.class);
        }

        @Override // io.trino.orc.TestingOrcPredicate.BasicOrcPredicate
        protected boolean chunkMatchesStats(List<String> list, ColumnStatistics columnStatistics) {
            Assert.assertNull(columnStatistics.getBooleanStatistics());
            Assert.assertNull(columnStatistics.getIntegerStatistics());
            Assert.assertNull(columnStatistics.getDoubleStatistics());
            Assert.assertNull(columnStatistics.getDateStatistics());
            if (!super.chunkMatchesStats(list, columnStatistics)) {
                return false;
            }
            List list2 = (List) list.stream().filter((v0) -> {
                return Objects.nonNull(v0);
            }).map(Slices::utf8Slice).collect(Collectors.toList());
            BloomFilter bloomFilter = columnStatistics.getBloomFilter();
            if (bloomFilter != null) {
                Iterator it = list2.iterator();
                while (it.hasNext()) {
                    if (!bloomFilter.testSlice((Slice) it.next())) {
                        return false;
                    }
                }
                int i = 0;
                byte[] bArr = new byte[32];
                for (int i2 = 0; i2 < 100000; i2++) {
                    ThreadLocalRandom.current().nextBytes(bArr);
                    if (bloomFilter.test(bArr)) {
                        i++;
                    }
                }
                if (i != 0 && (1.0d * i) / 100000.0d > 0.55d) {
                    return false;
                }
            }
            if (columnStatistics.getStringStatistics() == null) {
                return true;
            }
            if (list2.isEmpty()) {
                return columnStatistics.getStringStatistics().getMin() == null && columnStatistics.getStringStatistics().getMax() == null;
            }
            return columnStatistics.getStringStatistics().getMin().equals((Slice) Ordering.natural().nullsLast().min(list2)) && columnStatistics.getStringStatistics().getMax().equals((Slice) Ordering.natural().nullsFirst().max(list2));
        }
    }

    private TestingOrcPredicate() {
    }

    public static OrcPredicate createOrcPredicate(Type type, Iterable<?> iterable) {
        ArrayList newArrayList = Lists.newArrayList(iterable);
        if (BooleanType.BOOLEAN.equals(type)) {
            return new BooleanOrcPredicate(newArrayList);
        }
        if (TinyintType.TINYINT.equals(type) || SmallintType.SMALLINT.equals(type) || IntegerType.INTEGER.equals(type) || BigintType.BIGINT.equals(type)) {
            return new LongOrcPredicate(true, transform(newArrayList, obj -> {
                return Long.valueOf(((Number) obj).longValue());
            }));
        }
        if (DateType.DATE.equals(type)) {
            return new DateOrcPredicate(transform(newArrayList, obj2 -> {
                return Long.valueOf(((SqlDate) obj2).getDays());
            }));
        }
        if (RealType.REAL.equals(type) || DoubleType.DOUBLE.equals(type)) {
            return new DoubleOrcPredicate(transform(newArrayList, obj3 -> {
                return Double.valueOf(((Number) obj3).doubleValue());
            }));
        }
        if ((type instanceof VarbinaryType) || type.equals(UuidType.UUID)) {
            return new BasicOrcPredicate(newArrayList, Object.class);
        }
        if (type instanceof VarcharType) {
            return new StringOrcPredicate(newArrayList);
        }
        if (type instanceof CharType) {
            return new CharOrcPredicate(newArrayList);
        }
        if (type instanceof DecimalType) {
            return new DecimalOrcPredicate(newArrayList);
        }
        if (TimestampType.TIMESTAMP_MILLIS.equals(type)) {
            return new LongOrcPredicate(false, transform(newArrayList, obj4 -> {
                return Long.valueOf(((SqlTimestamp) obj4).getMillis());
            }));
        }
        if (TimestampType.TIMESTAMP_MICROS.equals(type)) {
            return new LongOrcPredicate(false, transform(newArrayList, obj5 -> {
                return Long.valueOf(((SqlTimestamp) obj5).getEpochMicros());
            }));
        }
        if (TimestampType.TIMESTAMP_NANOS.equals(type)) {
            return new BasicOrcPredicate(transform(newArrayList, obj6 -> {
                SqlTimestamp sqlTimestamp = (SqlTimestamp) obj6;
                return new LongTimestamp(sqlTimestamp.getEpochMicros(), sqlTimestamp.getPicosOfMicros());
            }), LongTimestamp.class);
        }
        if (TimestampWithTimeZoneType.TIMESTAMP_TZ_MILLIS.equals(type)) {
            return new LongOrcPredicate(false, transform(newArrayList, obj7 -> {
                return Long.valueOf(DateTimeEncoding.packDateTimeWithZone(((SqlTimestampWithTimeZone) obj7).getEpochMillis(), TimeZoneKey.UTC_KEY));
            }));
        }
        if (TimestampWithTimeZoneType.TIMESTAMP_TZ_MICROS.equals(type) || TimestampWithTimeZoneType.TIMESTAMP_TZ_NANOS.equals(type)) {
            return new BasicOrcPredicate(transform(newArrayList, obj8 -> {
                SqlTimestampWithTimeZone sqlTimestampWithTimeZone = (SqlTimestampWithTimeZone) obj8;
                return LongTimestampWithTimeZone.fromEpochMillisAndFraction(sqlTimestampWithTimeZone.getEpochMillis(), sqlTimestampWithTimeZone.getPicosOfMilli(), sqlTimestampWithTimeZone.getTimeZoneKey());
            }), LongTimestampWithTimeZone.class);
        }
        if ((type instanceof ArrayType) || (type instanceof MapType) || (type instanceof RowType)) {
            return new BasicOrcPredicate(newArrayList, Object.class);
        }
        throw new IllegalArgumentException("Unsupported type " + type);
    }

    private static <T> List<T> transform(List<Object> list, Function<Object, T> function) {
        return (List) list.stream().map(obj -> {
            if (obj == null) {
                return null;
            }
            return function.apply(obj);
        }).collect(Collectors.toList());
    }
}
