package org.nasdanika.rag.openai;

import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.models.EmbeddingsOptions;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import org.nasdanika.common.ProgressMonitor;
import org.nasdanika.rag.core.KeyExtractor;
import org.nasdanika.rag.core.StringFloatVectorKeyExtractor;
import org.nasdanika.rag.core.StringMapFloatVectorKeyExtractor;

/* loaded from: input_file:org/nasdanika/rag/openai/OpenAIEmbeddingsKeyExtractor.class */
public class OpenAIEmbeddingsKeyExtractor implements KeyExtractor<List<String>, List<List<Float>>> {
    private OpenAIClient client;
    private String model;
    private String user;
    private String deploymentOrModelId;

    public OpenAIEmbeddingsKeyExtractor(OpenAIClient openAIClient, String str, String str2, String str3) {
        this.client = openAIClient;
        this.deploymentOrModelId = str;
        this.model = str2;
        this.user = str3;
    }

    public List<List<Float>> extract(List<String> list, ProgressMonitor progressMonitor) {
        EmbeddingsOptions embeddingsOptions = new EmbeddingsOptions(list);
        if (this.model != null) {
            embeddingsOptions.setModel(this.model);
        }
        if (this.user != null) {
            embeddingsOptions.setUser(this.user);
        }
        return this.client.getEmbeddings(this.deploymentOrModelId, embeddingsOptions).getData().stream().map((v0) -> {
            return v0.getEmbedding();
        }).toList();
    }

    public StringFloatVectorKeyExtractor asStringFloatVectorKeyExtractor() {
        return (str, progressMonitor) -> {
            return extract(Collections.singletonList(str), progressMonitor).get(0);
        };
    }

    public StringMapFloatVectorKeyExtractor asStringMapFloatVectorKeyExtractor() {
        return (map, progressMonitor) -> {
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            map.entrySet().forEach(entry -> {
                arrayList.add((String) entry.getKey());
                arrayList2.add((String) entry.getValue());
            });
            LinkedHashMap linkedHashMap = new LinkedHashMap();
            List<List<Float>> extract = extract((List<String>) arrayList2, progressMonitor);
            for (int i = 0; i < extract.size(); i++) {
                linkedHashMap.put((String) arrayList.get(i), extract.get(i));
            }
            return linkedHashMap;
        };
    }

    public <T extends KeyExtractor<?, ?>> T adapt(Class<T> cls) {
        return cls.isInstance(this) ? this : cls.isAssignableFrom(StringFloatVectorKeyExtractor.class) ? asStringFloatVectorKeyExtractor() : cls.isAssignableFrom(StringMapFloatVectorKeyExtractor.class) ? asStringMapFloatVectorKeyExtractor() : (T) super.adapt(cls);
    }
}
