/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.store.embedding.azure.search;

import com.azure.core.credential.AzureKeyCredential;
import com.azure.core.credential.TokenCredential;
import com.azure.core.util.Context;
import com.azure.search.documents.SearchClient;
import com.azure.search.documents.SearchClientBuilder;
import com.azure.search.documents.SearchDocument;
import com.azure.search.documents.indexes.SearchIndexClient;
import com.azure.search.documents.indexes.SearchIndexClientBuilder;
import com.azure.search.documents.indexes.models.HnswAlgorithmConfiguration;
import com.azure.search.documents.indexes.models.HnswParameters;
import com.azure.search.documents.indexes.models.SearchField;
import com.azure.search.documents.indexes.models.SearchFieldDataType;
import com.azure.search.documents.indexes.models.SearchIndex;
import com.azure.search.documents.indexes.models.SemanticConfiguration;
import com.azure.search.documents.indexes.models.SemanticField;
import com.azure.search.documents.indexes.models.SemanticPrioritizedFields;
import com.azure.search.documents.indexes.models.SemanticSearch;
import com.azure.search.documents.indexes.models.VectorSearch;
import com.azure.search.documents.indexes.models.VectorSearchAlgorithmMetric;
import com.azure.search.documents.indexes.models.VectorSearchProfile;
import com.azure.search.documents.models.IndexingResult;
import com.azure.search.documents.models.SearchOptions;
import com.azure.search.documents.models.SearchResult;
import com.azure.search.documents.models.VectorQuery;
import com.azure.search.documents.models.VectorSearchOptions;
import com.azure.search.documents.models.VectorizedQuery;
import com.azure.search.documents.util.SearchPagedIterable;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.RelevanceScore;
import dev.langchain4j.store.embedding.azure.search.AzureAiSearchRuntimeException;
import dev.langchain4j.store.embedding.azure.search.Document;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class AzureAiSearchEmbeddingStore
implements EmbeddingStore<TextSegment> {
    private static final Logger log = LoggerFactory.getLogger(AzureAiSearchEmbeddingStore.class);
    private static final String INDEX_NAME = "vectorsearch";
    private static final String DEFAULT_FIELD_ID = "id";
    private static final String DEFAULT_FIELD_CONTENT = "content";
    private static final String DEFAULT_FIELD_CONTENT_VECTOR = "content_vector";
    private static final String DEFAULT_FIELD_METADATA = "metadata";
    private static final String DEFAULT_FIELD_METADATA_SOURCE = "source";
    private static final String DEFAULT_FIELD_METADATA_ATTRS = "attributes";
    private SearchIndexClient searchIndexClient;
    private SearchClient searchClient;

    public AzureAiSearchEmbeddingStore(String endpoint, AzureKeyCredential keyCredential, int dimensions) {
        this.initialize(endpoint, keyCredential, null, dimensions, null);
    }

    public AzureAiSearchEmbeddingStore(String endpoint, AzureKeyCredential keyCredential, SearchIndex index) {
        this.initialize(endpoint, keyCredential, null, 0, index);
    }

    public AzureAiSearchEmbeddingStore(String endpoint, TokenCredential tokenCredential, int dimensions) {
        this.initialize(endpoint, null, tokenCredential, dimensions, null);
    }

    public AzureAiSearchEmbeddingStore(String endpoint, TokenCredential tokenCredential, SearchIndex index) {
        this.initialize(endpoint, null, tokenCredential, 0, index);
    }

    private void initialize(String endpoint, AzureKeyCredential keyCredential, TokenCredential tokenCredential, int dimensions, SearchIndex index) {
        if (keyCredential != null) {
            this.searchIndexClient = new SearchIndexClientBuilder().endpoint(endpoint).credential(keyCredential).buildClient();
            this.searchClient = new SearchClientBuilder().endpoint(endpoint).credential(keyCredential).indexName(INDEX_NAME).buildClient();
        } else {
            this.searchIndexClient = new SearchIndexClientBuilder().endpoint(endpoint).credential(tokenCredential).buildClient();
            this.searchClient = new SearchClientBuilder().endpoint(endpoint).credential(tokenCredential).indexName(INDEX_NAME).buildClient();
        }
        if (index == null) {
            this.createOrUpdateIndex(dimensions);
        } else {
            this.createOrUpdateIndex(index);
        }
    }

    void createOrUpdateIndex(int dimensions) {
        ArrayList<SearchField> fields = new ArrayList<SearchField>();
        fields.add(new SearchField(DEFAULT_FIELD_ID, SearchFieldDataType.STRING).setKey(Boolean.valueOf(true)).setFilterable(Boolean.valueOf(true)));
        fields.add(new SearchField(DEFAULT_FIELD_CONTENT, SearchFieldDataType.STRING).setSearchable(Boolean.valueOf(true)).setFilterable(Boolean.valueOf(true)));
        fields.add(new SearchField(DEFAULT_FIELD_CONTENT_VECTOR, SearchFieldDataType.collection((SearchFieldDataType)SearchFieldDataType.SINGLE)).setSearchable(Boolean.valueOf(true)).setVectorSearchDimensions(Integer.valueOf(dimensions)).setVectorSearchProfileName("vector-search-profile"));
        fields.add(new SearchField(DEFAULT_FIELD_METADATA, SearchFieldDataType.COMPLEX).setFields(Arrays.asList(new SearchField(DEFAULT_FIELD_METADATA_SOURCE, SearchFieldDataType.STRING).setFilterable(Boolean.valueOf(true)), new SearchField(DEFAULT_FIELD_METADATA_ATTRS, SearchFieldDataType.collection((SearchFieldDataType)SearchFieldDataType.COMPLEX)).setFields(Arrays.asList(new SearchField("key", SearchFieldDataType.STRING).setFilterable(Boolean.valueOf(true)), new SearchField("value", SearchFieldDataType.STRING).setFilterable(Boolean.valueOf(true)))))));
        VectorSearch vectorSearch = new VectorSearch().setAlgorithms(Collections.singletonList(new HnswAlgorithmConfiguration("vector-search-algorithm").setParameters(new HnswParameters().setMetric(VectorSearchAlgorithmMetric.COSINE).setM(Integer.valueOf(4)).setEfSearch(Integer.valueOf(500)).setEfConstruction(Integer.valueOf(400))))).setProfiles(Collections.singletonList(new VectorSearchProfile("vector-search-profile", "vector-search-algorithm")));
        SemanticSearch semanticSearch = new SemanticSearch().setDefaultConfigurationName("semantic-search-config").setConfigurations(Arrays.asList(new SemanticConfiguration("semantic-search-config", new SemanticPrioritizedFields().setContentFields(new SemanticField[]{new SemanticField(DEFAULT_FIELD_CONTENT)}).setKeywordsFields(new SemanticField[]{new SemanticField(DEFAULT_FIELD_CONTENT)}))));
        SearchIndex index = new SearchIndex(INDEX_NAME).setFields(fields).setVectorSearch(vectorSearch).setSemanticSearch(semanticSearch);
        this.searchIndexClient.createOrUpdateIndex(index);
    }

    void createOrUpdateIndex(SearchIndex index) {
        this.searchIndexClient.createOrUpdateIndex(index);
    }

    public void deleteIndex() {
        this.searchIndexClient.deleteIndex(INDEX_NAME);
    }

    public String add(Embedding embedding) {
        String id = Utils.randomUUID();
        this.addInternal(id, embedding, null);
        return id;
    }

    public void add(String id, Embedding embedding) {
        this.addInternal(id, embedding, null);
    }

    public String add(Embedding embedding, TextSegment textSegment) {
        String id = Utils.randomUUID();
        this.addInternal(id, embedding, textSegment);
        return id;
    }

    public List<String> addAll(List<Embedding> embeddings) {
        List<String> ids = embeddings.stream().map(ignored -> Utils.randomUUID()).collect(Collectors.toList());
        this.addAllInternal(ids, embeddings, null);
        return ids;
    }

    public List<String> addAll(List<Embedding> embeddings, List<TextSegment> embedded) {
        List<String> ids = embeddings.stream().map(ignored -> Utils.randomUUID()).collect(Collectors.toList());
        this.addAllInternal(ids, embeddings, embedded);
        return ids;
    }

    public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding referenceEmbedding, int maxResults, double minScore) {
        List vector = referenceEmbedding.vectorAsList();
        VectorizedQuery vectorizedQuery = new VectorizedQuery(vector).setFields(new String[]{DEFAULT_FIELD_CONTENT_VECTOR}).setKNearestNeighborsCount(Integer.valueOf(maxResults));
        SearchPagedIterable searchResults = this.searchClient.search(null, new SearchOptions().setVectorSearchOptions(new VectorSearchOptions().setQueries(new VectorQuery[]{vectorizedQuery})), Context.NONE);
        ArrayList<EmbeddingMatch<TextSegment>> result = new ArrayList<EmbeddingMatch<TextSegment>>();
        for (SearchResult searchResult : searchResults) {
            EmbeddingMatch embeddingMatch;
            Double score = this.fromAzureScoreToRelevanceScore(searchResult.getScore());
            if (score < minScore) continue;
            SearchDocument searchDocument = (SearchDocument)searchResult.getDocument(SearchDocument.class);
            String embeddingId = (String)searchDocument.get((Object)DEFAULT_FIELD_ID);
            List embeddingList = (List)searchDocument.get((Object)DEFAULT_FIELD_CONTENT_VECTOR);
            float[] embeddingArray = this.doublesListToFloatArray(embeddingList);
            Embedding embedding = Embedding.from((float[])embeddingArray);
            String embeddedContent = (String)searchDocument.get((Object)DEFAULT_FIELD_CONTENT);
            if (Utils.isNotNullOrBlank((String)embeddedContent)) {
                LinkedHashMap metadata = (LinkedHashMap)searchDocument.get((Object)DEFAULT_FIELD_METADATA);
                List attributes = (List)metadata.get(DEFAULT_FIELD_METADATA_ATTRS);
                HashMap<String, String> attributesMap = new HashMap<String, String>();
                for (Object attribute : attributes) {
                    LinkedHashMap innerAttribute = (LinkedHashMap)attribute;
                    String key = (String)innerAttribute.get("key");
                    String value = (String)innerAttribute.get("value");
                    attributesMap.put(key, value);
                }
                Metadata langChainMetadata = Metadata.from(attributesMap);
                TextSegment embedded = TextSegment.textSegment((String)embeddedContent, (Metadata)langChainMetadata);
                embeddingMatch = new EmbeddingMatch(score, embeddingId, embedding, (Object)embedded);
            } else {
                embeddingMatch = new EmbeddingMatch(score, embeddingId, embedding, null);
            }
            result.add((EmbeddingMatch<TextSegment>)embeddingMatch);
        }
        return result;
    }

    private void addInternal(String id, Embedding embedding, TextSegment embedded) {
        this.addAllInternal(Collections.singletonList(id), Collections.singletonList(embedding), embedded == null ? null : Collections.singletonList(embedded));
    }

    private void addAllInternal(List<String> ids, List<Embedding> embeddings, List<TextSegment> embedded) {
        if (Utils.isNullOrEmpty(ids) || Utils.isNullOrEmpty(embeddings)) {
            log.info("Empty embeddings - no ops");
            return;
        }
        ValidationUtils.ensureTrue((ids.size() == embeddings.size() ? 1 : 0) != 0, (String)"ids size is not equal to embeddings size");
        ValidationUtils.ensureTrue((embedded == null || embeddings.size() == embedded.size() ? 1 : 0) != 0, (String)"embeddings size is not equal to embedded size");
        ArrayList<Document> searchDocuments = new ArrayList<Document>();
        for (int i = 0; i < ids.size(); ++i) {
            Document document = new Document();
            document.setId(ids.get(i));
            document.setContentVector(embeddings.get(i).vectorAsList());
            if (embedded != null) {
                document.setContent(embedded.get(i).text());
                Document.Metadata metadata = new Document.Metadata();
                ArrayList<Document.Metadata.Attribute> attributes = new ArrayList<Document.Metadata.Attribute>();
                for (Map.Entry entry : embedded.get(i).metadata().asMap().entrySet()) {
                    Document.Metadata.Attribute attribute = new Document.Metadata.Attribute();
                    attribute.setKey((String)entry.getKey());
                    attribute.setValue((String)entry.getValue());
                    attributes.add(attribute);
                }
                metadata.setAttributes(attributes);
                document.setMetadata(metadata);
            }
            searchDocuments.add(document);
        }
        List indexingResults = this.searchClient.uploadDocuments(searchDocuments).getResults();
        for (IndexingResult indexingResult : indexingResults) {
            if (!indexingResult.isSucceeded()) {
                throw new AzureAiSearchRuntimeException("Failed to add embedding: " + indexingResult.getErrorMessage());
            }
            log.debug("Added embedding: {}", (Object)indexingResult.getKey());
        }
    }

    private float[] doublesListToFloatArray(List<Double> doubles) {
        float[] array = new float[doubles.size()];
        for (int i = 0; i < doubles.size(); ++i) {
            array[i] = doubles.get(i).floatValue();
        }
        return array;
    }

    private double fromAzureScoreToRelevanceScore(double score) {
        double cosineDistance = (1.0 - score) / score;
        double cosineSimilarity = -cosineDistance + 1.0;
        return RelevanceScore.fromCosineSimilarity((double)cosineSimilarity);
    }

    public static Builder builder() {
        return new Builder();
    }

    public static class Builder {
        private String endpoint;
        private AzureKeyCredential keyCredential;
        private TokenCredential tokenCredential;
        private int dimensions;
        private SearchIndex index;

        public Builder endpoint(String endpoint) {
            this.endpoint = endpoint;
            return this;
        }

        public Builder apiKey(String apiKey) {
            this.keyCredential = new AzureKeyCredential(apiKey);
            return this;
        }

        public Builder tokenCredential(TokenCredential tokenCredential) {
            this.tokenCredential = tokenCredential;
            return this;
        }

        public Builder dimensions(int dimensions) {
            this.dimensions = dimensions;
            return this;
        }

        public Builder index(SearchIndex index) {
            this.index = index;
            return this;
        }

        public AzureAiSearchEmbeddingStore build() {
            ValidationUtils.ensureNotNull((Object)this.endpoint, (String)"endpoint");
            ValidationUtils.ensureTrue((this.keyCredential != null || this.tokenCredential != null ? 1 : 0) != 0, (String)"either apiKey or tokenCredential must be set");
            ValidationUtils.ensureTrue((this.dimensions > 0 || this.index != null ? 1 : 0) != 0, (String)"either dimensions or index must be set");
            if (this.keyCredential == null) {
                if (this.index == null) {
                    return new AzureAiSearchEmbeddingStore(this.endpoint, this.tokenCredential, this.dimensions);
                }
                return new AzureAiSearchEmbeddingStore(this.endpoint, this.tokenCredential, this.index);
            }
            if (this.index == null) {
                return new AzureAiSearchEmbeddingStore(this.endpoint, this.keyCredential, this.dimensions);
            }
            return new AzureAiSearchEmbeddingStore(this.endpoint, this.keyCredential, this.index);
        }
    }
}

