package io.trino.execution.buffer;

import com.google.common.base.Preconditions;
import com.google.common.base.VerifyException;
import io.airlift.compress.Compressor;
import io.airlift.compress.lz4.Lz4Compressor;
import io.airlift.slice.SizeOf;
import io.airlift.slice.Slice;
import io.airlift.slice.SliceOutput;
import io.airlift.slice.Slices;
import io.trino.execution.buffer.PageCodecMarker;
import io.trino.spi.Page;
import io.trino.spi.StandardErrorCode;
import io.trino.spi.TrinoException;
import io.trino.spi.block.BlockEncodingSerde;
import io.trino.util.Ciphers;
import io.trino.util.LongBigArrayFIFOQueue;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.Charset;
import java.security.GeneralSecurityException;
import java.util.Objects;
import java.util.Optional;
import javax.crypto.Cipher;
import javax.crypto.SecretKey;
import javax.crypto.spec.SecretKeySpec;
import org.openjdk.jol.info.ClassLayout;

/* loaded from: input_file:io/trino/execution/buffer/PageSerializer.class */
public class PageSerializer {
    private static final int INSTANCE_SIZE = Math.toIntExact(ClassLayout.parseClass(PageSerializer.class).instanceSize());
    private final BlockEncodingSerde blockEncodingSerde;
    private final SerializedPageOutput output;

    /* loaded from: input_file:io/trino/execution/buffer/PageSerializer$SerializedPageOutput.class */
    private static class SerializedPageOutput extends SliceOutput {
        private static final int INSTANCE_SIZE = Math.toIntExact(ClassLayout.parseClass(SerializedPageOutput.class).instanceSize());
        private static final int COMPRESSOR_RETAINED_SIZE = Math.toIntExact(ClassLayout.parseClass(Lz4Compressor.class).instanceSize() + SizeOf.sizeOfIntArray(4096));
        private static final int ENCRYPTION_KEY_RETAINED_SIZE = Math.toIntExact(ClassLayout.parseClass(SecretKeySpec.class).instanceSize() + SizeOf.sizeOfByteArray(32));
        private static final double MINIMUM_COMPRESSION_RATIO = 0.8d;
        private final Optional<Lz4Compressor> compressor;
        private final Optional<SecretKey> encryptionKey;
        private final int markers;
        private final Optional<Cipher> cipher;
        private final WriteBuffer[] buffers;
        private int uncompressedSize;

        private SerializedPageOutput(Optional<Lz4Compressor> optional, Optional<SecretKey> optional2, int i) {
            this.compressor = (Optional) Objects.requireNonNull(optional, "compressor is null");
            this.encryptionKey = (Optional) Objects.requireNonNull(optional2, "encryptionKey is null");
            this.buffers = new WriteBuffer[(optional.isPresent() ? 1 : 0) + (optional2.isPresent() ? 1 : 0) + 1];
            PageCodecMarker.MarkerSet empty = PageCodecMarker.MarkerSet.empty();
            if (optional.isPresent()) {
                this.buffers[0] = new WriteBuffer(i);
                empty.add(PageCodecMarker.COMPRESSED);
            }
            if (optional2.isPresent()) {
                this.buffers[this.buffers.length - 2] = new WriteBuffer(optional.isPresent() ? optional.get().maxCompressedLength(i) + 4 : i);
                empty.add(PageCodecMarker.ENCRYPTED);
                try {
                    this.cipher = Optional.of(Cipher.getInstance("AES/CBC/PKCS5Padding"));
                } catch (GeneralSecurityException e) {
                    throw new TrinoException(StandardErrorCode.GENERIC_INTERNAL_ERROR, "Failed to create cipher: " + e.getMessage(), e);
                }
            } else {
                this.cipher = Optional.empty();
            }
            this.markers = empty.byteValue();
        }

        public void startPage(int i, int i2) {
            WriteBuffer writeBuffer = new WriteBuffer(Math.round(i2 * 1.2f) + 13);
            writeBuffer.writeInt(i);
            writeBuffer.writeByte(this.markers);
            writeBuffer.skip(8);
            this.buffers[this.buffers.length - 1] = writeBuffer;
            this.uncompressedSize = 0;
        }

        public void writeByte(int i) {
            ensureCapacityFor(1);
            this.buffers[0].writeByte(i);
            this.uncompressedSize++;
        }

        public void writeShort(int i) {
            ensureCapacityFor(2);
            this.buffers[0].writeShort(i);
            this.uncompressedSize += 2;
        }

        public void writeInt(int i) {
            ensureCapacityFor(4);
            this.buffers[0].writeInt(i);
            this.uncompressedSize += 4;
        }

        public void writeLong(long j) {
            ensureCapacityFor(8);
            this.buffers[0].writeLong(j);
            this.uncompressedSize += 8;
        }

        public void writeFloat(float f) {
            ensureCapacityFor(4);
            this.buffers[0].writeFloat(f);
            this.uncompressedSize += 4;
        }

        public void writeDouble(double d) {
            ensureCapacityFor(8);
            this.buffers[0].writeDouble(d);
            this.uncompressedSize += 8;
        }

        public void writeBytes(Slice slice, int i, int i2) {
            WriteBuffer writeBuffer = this.buffers[0];
            int i3 = i;
            int i4 = i2;
            while (true) {
                int i5 = i4;
                if (i5 <= 0) {
                    this.uncompressedSize += i2;
                    return;
                }
                ensureCapacityFor(Math.min(8, i5));
                int min = Math.min(i5, writeBuffer.remainingCapacity());
                writeBuffer.writeBytes(slice, i3, min);
                i3 += min;
                i4 = i5 - min;
            }
        }

        public void writeBytes(byte[] bArr, int i, int i2) {
            WriteBuffer writeBuffer = this.buffers[0];
            int i3 = i;
            int i4 = i2;
            while (true) {
                int i5 = i4;
                if (i5 <= 0) {
                    this.uncompressedSize += i2;
                    return;
                }
                ensureCapacityFor(Math.min(8, i5));
                int min = Math.min(i5, writeBuffer.remainingCapacity());
                writeBuffer.writeBytes(bArr, i3, min);
                i3 += min;
                i4 = i5 - min;
            }
        }

        public Slice closePage() {
            compress();
            encrypt();
            WriteBuffer writeBuffer = this.buffers[this.buffers.length - 1];
            int position = writeBuffer.getPosition();
            int i = position - 13;
            Slice slice = writeBuffer.getSlice();
            slice.setInt(5, this.uncompressedSize);
            slice.setInt(9, i);
            Slice copyOf = position < slice.length() / 2 ? Slices.copyOf(slice, 0, position) : slice.slice(0, position);
            for (WriteBuffer writeBuffer2 : this.buffers) {
                writeBuffer2.reset();
            }
            this.buffers[this.buffers.length - 1] = null;
            this.uncompressedSize = 0;
            return copyOf;
        }

        private void ensureCapacityFor(int i) {
            if (this.buffers[0].remainingCapacity() >= i) {
                return;
            }
            this.buffers[this.buffers.length - 1].ensureCapacityFor(i);
            compress();
            encrypt();
        }

        private void compress() {
            int i;
            if (this.compressor.isEmpty()) {
                return;
            }
            Compressor compressor = this.compressor.get();
            WriteBuffer writeBuffer = this.buffers[0];
            WriteBuffer writeBuffer2 = this.buffers[1];
            int maxCompressedLength = compressor.maxCompressedLength(writeBuffer.getPosition());
            writeBuffer2.ensureCapacityFor(maxCompressedLength + 4);
            int position = writeBuffer.getPosition();
            int compress = compressor.compress(writeBuffer.getSlice().byteArray(), writeBuffer.getSlice().byteArrayOffset(), position, writeBuffer2.getSlice().byteArray(), writeBuffer2.getSlice().byteArrayOffset() + writeBuffer2.getPosition() + 4, maxCompressedLength);
            boolean z = ((double) position) * MINIMUM_COMPRESSION_RATIO > ((double) compress);
            if (z) {
                i = compress;
            } else {
                System.arraycopy(writeBuffer.getSlice().byteArray(), writeBuffer.getSlice().byteArrayOffset(), writeBuffer2.getSlice().byteArray(), writeBuffer2.getSlice().byteArrayOffset() + writeBuffer2.getPosition() + 4, position);
                i = position;
            }
            writeBuffer2.writeInt(createBlockMarker(z, i));
            writeBuffer2.skip(i);
            writeBuffer.reset();
        }

        private static int createBlockMarker(boolean z, int i) {
            return z ? i | Integer.MIN_VALUE : i;
        }

        private void encrypt() {
            if (this.encryptionKey.isEmpty()) {
                return;
            }
            Cipher initCipher = initCipher(this.encryptionKey.get());
            byte[] iv = initCipher.getIV();
            WriteBuffer writeBuffer = this.buffers[this.buffers.length - 2];
            WriteBuffer writeBuffer2 = this.buffers[this.buffers.length - 1];
            writeBuffer2.ensureCapacityFor(initCipher.getOutputSize(writeBuffer.getPosition()) + iv.length + 4 + iv.length);
            writeBuffer2.skip(4);
            writeBuffer2.writeBytes(iv, 0, iv.length);
            try {
                int update = initCipher.update(writeBuffer.getSlice().byteArray(), writeBuffer.getSlice().byteArrayOffset(), writeBuffer.getPosition(), writeBuffer2.getSlice().byteArray(), writeBuffer2.getSlice().byteArrayOffset() + writeBuffer2.getPosition());
                int doFinal = update + initCipher.doFinal(writeBuffer2.getSlice().byteArray(), writeBuffer2.getSlice().byteArrayOffset() + writeBuffer2.getPosition() + update);
                writeBuffer2.getSlice().setInt((writeBuffer2.getPosition() - 4) - iv.length, doFinal);
                writeBuffer2.skip(doFinal);
                writeBuffer.reset();
            } catch (GeneralSecurityException e) {
                throw new TrinoException(StandardErrorCode.GENERIC_INTERNAL_ERROR, "Failed to encrypt data: " + e.getMessage(), e);
            }
        }

        private Cipher initCipher(SecretKey secretKey) {
            Cipher orElseThrow = this.cipher.orElseThrow(() -> {
                return new VerifyException("cipher is expected to be present");
            });
            try {
                orElseThrow.init(1, secretKey);
                return orElseThrow;
            } catch (GeneralSecurityException e) {
                throw new TrinoException(StandardErrorCode.GENERIC_INTERNAL_ERROR, "Failed to init cipher: " + e.getMessage(), e);
            }
        }

        public long getRetainedSize() {
            long sizeOf = INSTANCE_SIZE + SizeOf.sizeOf(this.compressor, lz4Compressor -> {
                return COMPRESSOR_RETAINED_SIZE;
            }) + SizeOf.sizeOf(this.encryptionKey, secretKey -> {
                return ENCRYPTION_KEY_RETAINED_SIZE;
            }) + SizeOf.sizeOf(this.cipher, cipher -> {
                return LongBigArrayFIFOQueue.INITIAL_CAPACITY;
            });
            for (WriteBuffer writeBuffer : this.buffers) {
                if (writeBuffer != null) {
                    sizeOf += writeBuffer.getRetainedSizeInBytes();
                }
            }
            return sizeOf;
        }

        public int writableBytes() {
            return Integer.MAX_VALUE;
        }

        public boolean isWritable() {
            return true;
        }

        public void writeBytes(byte[] bArr) {
            writeBytes(bArr, 0, bArr.length);
        }

        public void writeBytes(Slice slice) {
            writeBytes(slice, 0, slice.length());
        }

        public void writeBytes(InputStream inputStream, int i) throws IOException {
            throw new UnsupportedOperationException();
        }

        public Slice slice() {
            throw new UnsupportedOperationException();
        }

        public Slice getUnderlyingSlice() {
            throw new UnsupportedOperationException();
        }

        public void reset() {
            throw new UnsupportedOperationException();
        }

        public void reset(int i) {
            throw new UnsupportedOperationException();
        }

        public int size() {
            throw new UnsupportedOperationException();
        }

        public String toString(Charset charset) {
            throw new UnsupportedOperationException();
        }

        public SliceOutput appendLong(long j) {
            writeLong(j);
            return this;
        }

        public SliceOutput appendDouble(double d) {
            writeDouble(d);
            return this;
        }

        public SliceOutput appendInt(int i) {
            writeInt(i);
            return this;
        }

        public SliceOutput appendShort(int i) {
            writeShort(i);
            return this;
        }

        public SliceOutput appendByte(int i) {
            writeByte(i);
            return this;
        }

        public SliceOutput appendBytes(byte[] bArr, int i, int i2) {
            writeBytes(bArr, i, i2);
            return this;
        }

        public SliceOutput appendBytes(byte[] bArr) {
            return appendBytes(bArr, 0, bArr.length);
        }

        public SliceOutput appendBytes(Slice slice) {
            writeBytes(slice);
            return this;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/execution/buffer/PageSerializer$WriteBuffer.class */
    public static class WriteBuffer {
        private static final int INSTANCE_SIZE = Math.toIntExact(ClassLayout.parseClass(WriteBuffer.class).instanceSize());
        private Slice slice;
        private int position;

        public WriteBuffer(int i) {
            this.slice = Slices.allocate(i);
        }

        public void writeByte(int i) {
            this.slice.setByte(this.position, i);
            this.position++;
        }

        public void writeShort(int i) {
            this.slice.setShort(this.position, i);
            this.position += 2;
        }

        public void writeInt(int i) {
            this.slice.setInt(this.position, i);
            this.position += 4;
        }

        public void writeLong(long j) {
            this.slice.setLong(this.position, j);
            this.position += 8;
        }

        public void writeFloat(float f) {
            this.slice.setFloat(this.position, f);
            this.position += 4;
        }

        public void writeDouble(double d) {
            this.slice.setDouble(this.position, d);
            this.position += 8;
        }

        public void writeBytes(Slice slice, int i, int i2) {
            this.slice.setBytes(this.position, slice, i, i2);
            this.position += i2;
        }

        public void writeBytes(byte[] bArr, int i, int i2) {
            this.slice.setBytes(this.position, bArr, i, i2);
            this.position += i2;
        }

        public void skip(int i) {
            this.position += i;
        }

        public int remainingCapacity() {
            return this.slice.length() - this.position;
        }

        public int getPosition() {
            return this.position;
        }

        public Slice getSlice() {
            return this.slice;
        }

        public void reset() {
            this.position = 0;
        }

        public long getRetainedSizeInBytes() {
            return INSTANCE_SIZE + this.slice.getRetainedSize();
        }

        public void ensureCapacityFor(int i) {
            this.slice = Slices.ensureSize(this.slice, this.position + i);
        }
    }

    public PageSerializer(BlockEncodingSerde blockEncodingSerde, boolean z, Optional<SecretKey> optional, int i) {
        this.blockEncodingSerde = (BlockEncodingSerde) Objects.requireNonNull(blockEncodingSerde, "blockEncodingSerde is null");
        Objects.requireNonNull(optional, "encryptionKey is null");
        optional.ifPresent(secretKey -> {
            Preconditions.checkArgument(Ciphers.is256BitSecretKeySpec(secretKey), "encryptionKey is expected to be an instance of SecretKeySpec containing a 256bit key");
        });
        this.output = new SerializedPageOutput(z ? Optional.of(new Lz4Compressor()) : Optional.empty(), optional, i);
    }

    public Slice serialize(Page page) {
        this.output.startPage(page.getPositionCount(), Math.toIntExact(page.getSizeInBytes()));
        PagesSerdeUtil.writeRawPage(page, this.output, this.blockEncodingSerde);
        return this.output.closePage();
    }

    public long getRetainedSizeInBytes() {
        return INSTANCE_SIZE + this.output.getRetainedSize();
    }
}
