package dev.lydtech.component.framework.kafka;

import java.time.Duration;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

import lombok.extern.slf4j.Slf4j;
import org.apache.kafka.clients.consumer.Consumer;
import org.apache.kafka.clients.consumer.ConsumerRecord;
import org.apache.kafka.clients.consumer.ConsumerRecords;
import org.apache.kafka.clients.producer.ProducerRecord;
import org.apache.kafka.clients.producer.RecordMetadata;
import org.awaitility.Awaitility;

@Slf4j
public class KafkaBrokerUtils {

    private static final String KEY_BROKER_URL = "brokerUrl";
    private static final String KEY_GROUPID = "groupId";
    private static final String DEFAULT_BROKER_HOST = "localhost";

    private final KafkaConsumerFactory kafkaConsumerFactory;
    private final KafkaProducerFactory kafkaProducerFactory;

    public KafkaBrokerUtils(Map<String, Object> config) {
        this((String) config.get("brokerUrl"), (String) config.get("groupId"));
    }

    public KafkaBrokerUtils(String brokerUrl, String groupId) {
        kafkaConsumerFactory = new KafkaConsumerFactory(brokerUrl, groupId);
        kafkaProducerFactory = new KafkaProducerFactory(brokerUrl);
        log.debug("Created Kafka message broker utils with brokerUrl={} and groupID={}", brokerUrl, groupId);
    }

    public static KafkaBrokerUtils createBroker(String kafkaGroupId) {
        final Map<String, Object> config = getKafkaConfig(kafkaGroupId);
        return createBroker(config);
    }

    public static KafkaBrokerUtils createBroker(Map<String, Object> config) {
        return new KafkaBrokerUtils(config);
    }

    private static Map<String, Object> getKafkaConfig(String kafkaGroupId) {
        Map<String, Object> config = new HashMap<>();
        Integer kafkaPort = Optional.ofNullable(System.getProperty("kafka.port"))
                .map(Integer::parseInt)
                .orElseThrow(() -> new RuntimeException("Kafka port not set"));

        config.put(KEY_BROKER_URL, "tcp://" + DEFAULT_BROKER_HOST + ":" + kafkaPort);
        config.put(KEY_GROUPID, kafkaGroupId);
        return config;
    }

    public void sendMessage(String topic, String key, String payload) throws Exception {
        final ProducerRecord<Long, String> record = new ProducerRecord(topic, null, key, payload);
        final RecordMetadata metadata = kafkaProducerFactory.createProducer().send(record).get();
        log.debug(String.format("Sent record(key=%s value=%s) meta(topic=%s, partition=%d, offset=%d)",
                record.key(), record.value(), metadata.topic(), metadata.partition(), metadata.offset()));
    }

    public Consumer createConsumer(String topic) {
        return kafkaConsumerFactory.createAndSubscribe(topic);
    }

    /**
     * 1. Poll for messages on the application’s outbound topic.
     * 2. Assert the expected number are received.
     * 3. Performs the specified number of extra polls after the expected number received to ensure no further events.
     * 4. Returns the consumed events.
     */
    public List<ConsumerRecord<String, String>> consumeAndAssert(String testName, Consumer consumer, int expectedEventCount, int furtherPolls) throws Exception {
        AtomicInteger totalReceivedEvents = new AtomicInteger();
        AtomicInteger totalExtraPolls = new AtomicInteger(-1);
        AtomicInteger pollCount = new AtomicInteger();
        List<ConsumerRecord<String, String>> events = new ArrayList<>();

        Awaitility.await()
                .atMost(30, TimeUnit.SECONDS)
                .pollInterval(1, TimeUnit.SECONDS)
                .until(() -> {
                    final ConsumerRecords<String, String> consumerRecords = consumer.poll(Duration.ofMillis(100));
                    consumerRecords.forEach(record -> {
                        log.info(testName + " - received: " + record.value());
                        totalReceivedEvents.incrementAndGet();
                        events.add(record);
                    });
                    if(totalReceivedEvents.get() == expectedEventCount) {
                        // Track the extra polls, allowing for time to consume duplicates.
                        totalExtraPolls.incrementAndGet();
                    }
                    pollCount.getAndIncrement();
                    log.info(testName + " - poll count: " + pollCount.get() + " - received count: " + totalReceivedEvents.get());
                    return totalReceivedEvents.get() == expectedEventCount && totalExtraPolls.get() == furtherPolls;
                });
        return events;
    }
}
