/*
 * Decompiled with CFR 0.152.
 */
package io.quarkiverse.langchain4j.cohere.runtime;

import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.RetryUtils;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import dev.langchain4j.model.scoring.ScoringModel;
import io.quarkiverse.langchain4j.cohere.runtime.api.CohereApi;
import io.quarkiverse.langchain4j.cohere.runtime.api.RerankRequest;
import io.quarkiverse.langchain4j.cohere.runtime.api.RerankResponse;
import io.quarkiverse.langchain4j.cohere.runtime.api.RerankResult;
import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder;
import jakarta.ws.rs.core.MultivaluedHashMap;
import jakarta.ws.rs.core.MultivaluedMap;
import java.net.URI;
import java.net.URISyntaxException;
import java.time.Duration;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import org.eclipse.microprofile.rest.client.ext.ClientHeadersFactory;

public class QuarkusCohereScoringModel
implements ScoringModel {
    private final CohereApi cohereApi;
    private final String model;
    private final Integer maxRetries;

    public QuarkusCohereScoringModel(String baseUrl, final String apiKey, String model, Duration timeout, Integer maxRetries) {
        this.model = model;
        this.maxRetries = maxRetries;
        if (this.maxRetries < 1) {
            throw new IllegalArgumentException("max-retries must be at least 1");
        }
        ClientHeadersFactory factory = new ClientHeadersFactory(){

            public MultivaluedMap<String, String> update(MultivaluedMap<String, String> incomingHeaders, MultivaluedMap<String, String> clientOutgoingHeaders) {
                MultivaluedHashMap headers = new MultivaluedHashMap();
                headers.put((Object)"Authorization", Collections.singletonList("Bearer " + apiKey));
                return headers;
            }
        };
        try {
            this.cohereApi = (CohereApi)QuarkusRestClientBuilder.newBuilder().baseUri(new URI(baseUrl)).clientHeadersFactory(factory).connectTimeout(timeout.toSeconds(), TimeUnit.SECONDS).readTimeout(timeout.toSeconds(), TimeUnit.SECONDS).build(CohereApi.class);
        }
        catch (URISyntaxException e) {
            throw new RuntimeException(e);
        }
    }

    public Response<List<Double>> scoreAll(List<TextSegment> segments, String query) {
        List<String> documents = segments.stream().map(TextSegment::text).collect(Collectors.toList());
        RerankRequest request = new RerankRequest(this.model, query, documents);
        RerankResponse response = (RerankResponse)RetryUtils.withRetry(() -> this.cohereApi.rerank(request), (int)this.maxRetries);
        List scores = response.getResults().stream().sorted(Comparator.comparingInt(RerankResult::getIndex)).map(RerankResult::getRelevanceScore).collect(Collectors.toList());
        return Response.from(scores, (TokenUsage)new TokenUsage(response.getMeta().getBilledUnits().getSearchUnits()));
    }
}

