/*
 * Decompiled with CFR 0.152.
 */
package io.quarkiverse.langchain4j.watsonx;

import com.fasterxml.jackson.databind.ObjectMapper;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.chat.TokenCountEstimator;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import io.quarkiverse.langchain4j.QuarkusJsonCodecFactory;
import io.quarkiverse.langchain4j.watsonx.WatsonxModel;
import io.quarkiverse.langchain4j.watsonx.bean.Parameters;
import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationRequest;
import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationResponse;
import io.quarkiverse.langchain4j.watsonx.bean.TokenizationRequest;
import io.smallrye.mutiny.Context;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.Callable;
import java.util.concurrent.Flow;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;

public class WatsonxStreamingChatModel
extends WatsonxModel
implements StreamingChatLanguageModel,
TokenCountEstimator {
    private final ObjectMapper mapper = QuarkusJsonCodecFactory.SnakeCaseObjectMapperHolder.MAPPER;

    public WatsonxStreamingChatModel(WatsonxModel.Builder config) {
        super(config);
    }

    public void generate(List<ChatMessage> messages, final StreamingResponseHandler<AiMessage> handler) {
        Parameters.LengthPenalty lengthPenalty = null;
        if (Objects.nonNull(this.decayFactor) || Objects.nonNull(this.startIndex)) {
            lengthPenalty = new Parameters.LengthPenalty(this.decayFactor, this.startIndex);
        }
        Parameters parameters = Parameters.builder().decodingMethod(this.decodingMethod).lengthPenalty(lengthPenalty).minNewTokens(this.minNewTokens).maxNewTokens(this.maxNewTokens).randomSeed(this.randomSeed).stopSequences(this.stopSequences).temperature(this.temperature).topP(this.topP).topK(this.topK).repetitionPenalty(this.repetitionPenalty).truncateInputTokens(this.truncateInputTokens).includeStopSequence(this.includeStopSequence).build();
        final TextGenerationRequest request = new TextGenerationRequest(this.modelId, this.projectId, this.toInput(messages), parameters);
        final Context context = Context.of((Object[])new Object[]{"response", new ArrayList()});
        this.generateBearerToken().onItem().transformToMulti((Function)new Function<String, Flow.Publisher<? extends String>>(){

            @Override
            public Flow.Publisher<? extends String> apply(String token) {
                return WatsonxStreamingChatModel.this.client.chatStreaming(request, token, WatsonxStreamingChatModel.this.version);
            }
        }).subscribe().with(context, (Consumer)new Consumer<String>(){

            @Override
            public void accept(String response) {
                try {
                    if (response == null || response.isBlank()) {
                        return;
                    }
                    TextGenerationResponse obj = (TextGenerationResponse)WatsonxStreamingChatModel.this.mapper.readValue(response, TextGenerationResponse.class);
                    ((List)context.get("response")).add(obj);
                    handler.onNext(obj.results().get(0).generatedText());
                }
                catch (Exception e) {
                    handler.onError((Throwable)e);
                }
            }
        }, (Consumer)new Consumer<Throwable>(){

            @Override
            public void accept(Throwable error) {
                handler.onError(error);
            }
        }, new Runnable(){

            @Override
            public void run() {
                List list = (List)context.get("response");
                int inputTokenCount = 0;
                int outputTokenCount = 0;
                String stopReason = null;
                StringBuilder builder = new StringBuilder();
                for (int i = 0; i < list.size(); ++i) {
                    TextGenerationResponse.Result response = ((TextGenerationResponse)list.get(i)).results().get(0);
                    if (i == 0) {
                        inputTokenCount = response.inputTokenCount();
                    }
                    if (i == list.size() - 1) {
                        outputTokenCount = response.generatedTokenCount();
                        stopReason = response.stopReason();
                    }
                    builder.append(response.generatedText());
                }
                AiMessage message = new AiMessage(builder.toString());
                TokenUsage tokenUsage = new TokenUsage(Integer.valueOf(inputTokenCount), Integer.valueOf(outputTokenCount));
                FinishReason finishReason = WatsonxStreamingChatModel.this.toFinishReason(stopReason);
                handler.onComplete(Response.from((Object)message, (TokenUsage)tokenUsage, (FinishReason)finishReason));
            }
        });
    }

    public int estimateTokenCount(List<ChatMessage> messages) {
        String input = messages.stream().map(ChatMessage::text).collect(Collectors.joining(" "));
        final TokenizationRequest request = new TokenizationRequest(this.modelId, input, this.projectId);
        return WatsonxStreamingChatModel.retryOn(new Callable<Integer>(){

            @Override
            public Integer call() throws Exception {
                String token = (String)WatsonxStreamingChatModel.this.generateBearerToken().await().atMost(Duration.ofSeconds(10L));
                return WatsonxStreamingChatModel.this.client.tokenization(request, token, WatsonxStreamingChatModel.this.version).result().tokenCount();
            }
        });
    }
}

