package io.trino.plugin.kafka;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import io.airlift.slice.Slice;
import io.trino.plugin.kafka.encoder.RowEncoder;
import io.trino.spi.Page;
import io.trino.spi.TrinoException;
import io.trino.spi.connector.ConnectorPageSink;
import io.trino.spi.connector.ConnectorSession;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import org.apache.kafka.clients.producer.Callback;
import org.apache.kafka.clients.producer.KafkaProducer;
import org.apache.kafka.clients.producer.ProducerRecord;
import org.apache.kafka.clients.producer.RecordMetadata;

/* loaded from: input_file:io/trino/plugin/kafka/KafkaPageSink.class */
public class KafkaPageSink implements ConnectorPageSink {
    private final String topicName;
    private final List<KafkaColumnHandle> columns;
    private final RowEncoder keyEncoder;
    private final RowEncoder messageEncoder;
    private final KafkaProducer<byte[], byte[]> producer;
    private final ProducerCallback producerCallback;
    private long expectedWrittenBytes;

    /* loaded from: input_file:io/trino/plugin/kafka/KafkaPageSink$ProducerCallback.class */
    private static class ProducerCallback implements Callback {
        private long errorCount = 0;
        private long writtenBytes = 0;

        public void onCompletion(RecordMetadata recordMetadata, Exception exc) {
            if (exc != null) {
                this.errorCount++;
            } else {
                this.writtenBytes += recordMetadata.serializedValueSize() + recordMetadata.serializedKeySize();
            }
        }

        public long getErrorCount() {
            return this.errorCount;
        }

        public long getWrittenBytes() {
            return this.writtenBytes;
        }
    }

    public KafkaPageSink(String str, List<KafkaColumnHandle> list, RowEncoder rowEncoder, RowEncoder rowEncoder2, KafkaProducerFactory kafkaProducerFactory, ConnectorSession connectorSession) {
        this.topicName = (String) Objects.requireNonNull(str, "topicName is null");
        this.columns = (List) Objects.requireNonNull(ImmutableList.copyOf(list), "columns is null");
        this.keyEncoder = (RowEncoder) Objects.requireNonNull(rowEncoder, "keyEncoder is null");
        this.messageEncoder = (RowEncoder) Objects.requireNonNull(rowEncoder2, "messageEncoder is null");
        Objects.requireNonNull(kafkaProducerFactory, "producerFactory is null");
        Objects.requireNonNull(connectorSession, "session is null");
        this.producer = kafkaProducerFactory.create(connectorSession);
        this.producerCallback = new ProducerCallback();
        this.expectedWrittenBytes = 0L;
    }

    public long getCompletedBytes() {
        return this.producerCallback.getWrittenBytes();
    }

    public CompletableFuture<?> appendPage(Page page) {
        for (int i = 0; i < page.getPositionCount(); i++) {
            for (int i2 = 0; i2 < page.getChannelCount(); i2++) {
                if (this.columns.get(i2).isKeyCodec()) {
                    this.keyEncoder.appendColumnValue(page.getBlock(i2), i);
                } else {
                    this.messageEncoder.appendColumnValue(page.getBlock(i2), i);
                }
            }
            byte[] byteArray = this.keyEncoder.toByteArray();
            byte[] byteArray2 = this.messageEncoder.toByteArray();
            this.expectedWrittenBytes += byteArray.length + byteArray2.length;
            this.producer.send(new ProducerRecord(this.topicName, byteArray, byteArray2), this.producerCallback);
        }
        return NOT_BLOCKED;
    }

    public CompletableFuture<Collection<Slice>> finish() {
        this.producer.flush();
        this.producer.close();
        try {
            this.keyEncoder.close();
            this.messageEncoder.close();
            Preconditions.checkArgument(this.producerCallback.getWrittenBytes() == this.expectedWrittenBytes, String.format("Actual written bytes: '%s' not equal to expected written bytes: '%s'", Long.valueOf(this.producerCallback.getWrittenBytes()), Long.valueOf(this.expectedWrittenBytes)));
            if (this.producerCallback.getErrorCount() > 0) {
                throw new TrinoException(KafkaErrorCode.KAFKA_PRODUCER_ERROR, String.format("%d producer record(s) failed to send", Long.valueOf(this.producerCallback.getErrorCount())));
            }
            return CompletableFuture.completedFuture(ImmutableList.of());
        } catch (IOException e) {
            throw new UncheckedIOException("Failed to close row encoders", e);
        }
    }

    public void abort() {
        this.producer.close();
    }
}
