/*
 * Decompiled with CFR 0.152.
 */
package io.quarkiverse.langchain4j.infinispan;

import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStore;
import io.quarkiverse.langchain4j.infinispan.runtime.InfinispanSchema;
import io.quarkiverse.langchain4j.infinispan.runtime.LangchainInfinispanItem;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import org.infinispan.client.hotrod.RemoteCache;
import org.infinispan.client.hotrod.RemoteCacheManager;
import org.infinispan.commons.api.query.Query;
import org.infinispan.commons.configuration.BasicConfiguration;
import org.infinispan.commons.configuration.StringConfiguration;

public class InfinispanEmbeddingStore
implements EmbeddingStore<TextSegment> {
    private final RemoteCache<String, LangchainInfinispanItem> remoteCache;
    private final InfinispanSchema schema;
    private static final String DEFAULT_CACHE_CONFIG = "<distributed-cache name=\"CACHE_NAME\">\n<indexing storage=\"local-heap\">\n<indexed-entities>\n<indexed-entity>LANGCHAINITEM</indexed-entity>\n</indexed-entities>\n</indexing>\n</distributed-cache>";

    public static Builder builder() {
        return new Builder();
    }

    public InfinispanEmbeddingStore(RemoteCacheManager cacheManager, InfinispanSchema schema) {
        String langchainCache = DEFAULT_CACHE_CONFIG.replace("CACHE_NAME", schema.getCacheName()).replace("LANGCHAINITEM", "LangchainItem" + schema.getDimension());
        this.remoteCache = cacheManager.administration().getOrCreateCache(schema.getCacheName(), (BasicConfiguration)new StringConfiguration(langchainCache));
        this.schema = schema;
    }

    public String add(Embedding embedding) {
        String id = Utils.randomUUID();
        this.add(id, embedding);
        return id;
    }

    public void add(String id, Embedding embedding) {
        this.addInternal(id, embedding, null);
    }

    public String add(Embedding embedding, TextSegment textSegment) {
        String id = Utils.randomUUID();
        this.addInternal(id, embedding, textSegment);
        return id;
    }

    public List<String> addAll(List<Embedding> embeddings) {
        List<String> ids = embeddings.stream().map(ignored -> Utils.randomUUID()).collect(Collectors.toList());
        this.addAllInternal(ids, embeddings, null);
        return ids;
    }

    public List<String> addAll(List<Embedding> embeddings, List<TextSegment> embedded) {
        List<String> ids = embeddings.stream().map(ignored -> Utils.randomUUID()).collect(Collectors.toList());
        this.addAllInternal(ids, embeddings, embedded);
        return ids;
    }

    private void addInternal(String id, Embedding embedding, TextSegment embedded) {
        this.addAllInternal(Collections.singletonList(id), Collections.singletonList(embedding), embedded == null ? null : Collections.singletonList(embedded));
    }

    private void addAllInternal(List<String> ids, List<Embedding> embeddings, List<TextSegment> embedded) {
        if (ids.isEmpty() || ids.size() != embeddings.size() || embedded != null && embedded.size() != embeddings.size()) {
            throw new IllegalArgumentException("ids, embeddings and embedded must be non-empty and of the same size");
        }
        ValidationUtils.ensureTrue((ids.size() == embeddings.size() ? 1 : 0) != 0, (String)"ids size is not equal to embeddings size");
        ValidationUtils.ensureTrue((embedded == null || embeddings.size() == embedded.size() ? 1 : 0) != 0, (String)"embeddings size is not equal to embedded size");
        int size = ids.size();
        HashMap<String, LangchainInfinispanItem> elements = new HashMap<String, LangchainInfinispanItem>(size);
        for (int i = 0; i < size; ++i) {
            TextSegment textSegment;
            String id = ids.get(i);
            Embedding embedding = embeddings.get(i);
            TextSegment textSegment2 = textSegment = embedded == null ? null : embedded.get(i);
            if (textSegment != null) {
                Map map = textSegment.metadata().asMap();
                ArrayList<String> metadataKeys = new ArrayList<String>(map.size());
                ArrayList<String> metadataValues = new ArrayList<String>(map.size());
                map.entrySet().forEach(e -> {
                    metadataKeys.add((String)e.getKey());
                    metadataValues.add((String)e.getValue());
                });
                elements.put(id, new LangchainInfinispanItem(id, embedding.vector(), textSegment.text(), metadataKeys, metadataValues));
                continue;
            }
            elements.put(id, new LangchainInfinispanItem(id, embedding.vector(), null, null, null));
        }
        this.remoteCache.putAll(elements);
    }

    public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding referenceEmbedding, int maxResults, double minScore) {
        Query query = this.remoteCache.query("select i, score(i) from LangchainItem" + this.schema.getDimension().toString() + " i where i.floatVector <-> " + Arrays.toString(referenceEmbedding.vector()) + "~" + this.schema.getDistance());
        List hits = query.maxResults(maxResults).list();
        return hits.stream().map(obj -> {
            LangchainInfinispanItem item = (LangchainInfinispanItem)obj[0];
            Float score = (Float)obj[1];
            if (score.doubleValue() < minScore) {
                return null;
            }
            TextSegment embedded = null;
            if (item.getText() != null) {
                HashMap<String, String> map = new HashMap<String, String>();
                List<String> metadataKeys = item.getMetadataKeys();
                List<String> metadataValues = item.getMetadataValues();
                for (int i = 0; i < metadataKeys.size(); ++i) {
                    map.put(metadataKeys.get(i), metadataValues.get(i));
                }
                embedded = new TextSegment(item.getText(), new Metadata(map));
            }
            Embedding embedding = new Embedding(item.getFloatVector());
            return new EmbeddingMatch(Double.valueOf(score.doubleValue()), item.getId(), embedding, embedded);
        }).filter(Objects::nonNull).collect(Collectors.toList());
    }

    public void deleteAll() {
        this.remoteCache.clearAsync();
    }

    public static class Builder {
        private RemoteCacheManager cacheManager;
        private InfinispanSchema schema;

        public Builder cacheManager(RemoteCacheManager client) {
            this.cacheManager = client;
            return this;
        }

        public Builder schema(InfinispanSchema schema) {
            this.schema = schema;
            return this;
        }

        public InfinispanEmbeddingStore build() {
            return new InfinispanEmbeddingStore(this.cacheManager, this.schema);
        }
    }
}

