/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.store.embedding.infinispan;

import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.infinispan.LangChainInfinispanItem;
import dev.langchain4j.store.embedding.infinispan.LangChainItemMarshaller;
import dev.langchain4j.store.embedding.infinispan.LangChainMetadata;
import dev.langchain4j.store.embedding.infinispan.LangChainMetadataMarshaller;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import org.infinispan.client.hotrod.RemoteCache;
import org.infinispan.client.hotrod.RemoteCacheManager;
import org.infinispan.client.hotrod.configuration.ConfigurationBuilder;
import org.infinispan.commons.api.query.Query;
import org.infinispan.commons.marshall.Marshaller;
import org.infinispan.commons.marshall.ProtoStreamMarshaller;
import org.infinispan.protostream.BaseMarshaller;
import org.infinispan.protostream.FileDescriptorSource;
import org.infinispan.protostream.SerializationContext;
import org.infinispan.protostream.schema.Schema;
import org.infinispan.protostream.schema.Type;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class InfinispanEmbeddingStore
implements EmbeddingStore<TextSegment> {
    private static final Logger log = LoggerFactory.getLogger(InfinispanEmbeddingStore.class);
    private final RemoteCache<String, LangChainInfinispanItem> remoteCache;
    private final LangChainItemMarshaller itemMarshaller;
    private final LangChainMetadataMarshaller metadataMarshaller;
    private static final String DEFAULT_CACHE_CONFIG = "<distributed-cache name=\"CACHE_NAME\">\n<indexing storage=\"local-heap\">\n<indexed-entities>\n<indexed-entity>LANGCHAINITEM</indexed-entity>\n<indexed-entity>LANGCHAIN_METADATA</indexed-entity>\n</indexed-entities>\n</indexing>\n</distributed-cache>";
    public static final String ITEM_PACKAGE = "dev.langchain4j";
    public static final String LANGCHAIN_ITEM = "LangChainItem";
    public static final String METADATA_ITEM = "LangChainMetadata";

    public InfinispanEmbeddingStore(ConfigurationBuilder builder, String name, Integer dimension) {
        ValidationUtils.ensureNotNull((Object)builder, (String)"builder");
        ValidationUtils.ensureNotBlank((String)name, (String)"name");
        ValidationUtils.ensureNotNull((Object)dimension, (String)"dimension");
        String langchainType = LANGCHAIN_ITEM + dimension;
        String metadataType = METADATA_ITEM + dimension;
        this.itemMarshaller = new LangChainItemMarshaller(InfinispanEmbeddingStore.computeTypeWithPackage(langchainType));
        this.metadataMarshaller = new LangChainMetadataMarshaller(InfinispanEmbeddingStore.computeTypeWithPackage(metadataType));
        builder.remoteCache(name).configuration(DEFAULT_CACHE_CONFIG.replace("CACHE_NAME", name).replace("LANGCHAINITEM", this.itemMarshaller.getTypeName()).replace("LANGCHAIN_METADATA", this.metadataMarshaller.getTypeName()));
        ProtoStreamMarshaller marshaller = new ProtoStreamMarshaller();
        SerializationContext serializationContext = marshaller.getSerializationContext();
        String fileName = "dev.langchain4j.dimension." + dimension + ".proto";
        Schema schema = new Schema.Builder("magazine.proto").packageName(ITEM_PACKAGE).addMessage(metadataType).addComment("@Indexed").addField((Type)Type.Scalar.STRING, "name", 1).addComment("@Text").addField((Type)Type.Scalar.STRING, "value", 2).addComment("@Text").addMessage(langchainType).addComment("@Indexed").addField((Type)Type.Scalar.STRING, "id", 1).addComment("@Text").addField((Type)Type.Scalar.STRING, "text", 2).addComment("@Keyword").addRepeatedField((Type)Type.Scalar.FLOAT, "embedding", 3).addComment("@Vector(dimension=" + dimension + ", similarity=COSINE)").addRepeatedField(Type.create((String)metadataType), "metadata", 4).build();
        String schemaContent = schema.toString();
        FileDescriptorSource fileDescriptorSource = FileDescriptorSource.fromString((String)fileName, (String)schemaContent);
        serializationContext.registerProtoFiles(fileDescriptorSource);
        serializationContext.registerMarshaller((BaseMarshaller)this.metadataMarshaller);
        serializationContext.registerMarshaller((BaseMarshaller)this.itemMarshaller);
        builder.marshaller((Marshaller)marshaller);
        RemoteCacheManager rmc = new RemoteCacheManager(builder.build());
        RemoteCache metadataCache = rmc.getCache("___protobuf_metadata");
        metadataCache.put((Object)fileName, (Object)schemaContent);
        this.remoteCache = rmc.getCache(name);
    }

    private static String computeTypeWithPackage(String langchainType) {
        return "dev.langchain4j." + langchainType;
    }

    public String add(Embedding embedding) {
        String id = Utils.randomUUID();
        this.add(id, embedding);
        return id;
    }

    public void add(String id, Embedding embedding) {
        this.addInternal(id, embedding, null);
    }

    public String add(Embedding embedding, TextSegment textSegment) {
        String id = Utils.randomUUID();
        this.addInternal(id, embedding, textSegment);
        return id;
    }

    public List<String> addAll(List<Embedding> embeddings) {
        List<String> ids = embeddings.stream().map(ignored -> Utils.randomUUID()).collect(Collectors.toList());
        this.addAllInternal(ids, embeddings, null);
        return ids;
    }

    public List<String> addAll(List<Embedding> embeddings, List<TextSegment> embedded) {
        List<String> ids = embeddings.stream().map(ignored -> Utils.randomUUID()).collect(Collectors.toList());
        this.addAllInternal(ids, embeddings, embedded);
        return ids;
    }

    public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding referenceEmbedding, int maxResults, double minScore) {
        Query query = this.remoteCache.query("select i, score(i) from " + this.itemMarshaller.getTypeName() + " i where i.embedding <-> " + Arrays.toString(referenceEmbedding.vector()) + "~3");
        List hits = query.maxResults(maxResults).list();
        return hits.stream().map(obj -> {
            LangChainInfinispanItem item = (LangChainInfinispanItem)obj[0];
            Float score = (Float)obj[1];
            if (score.doubleValue() < minScore) {
                return null;
            }
            TextSegment embedded = null;
            if (item.text() != null) {
                HashMap<String, String> map = new HashMap<String, String>();
                for (LangChainMetadata metadata : item.metadata()) {
                    map.put(metadata.name(), metadata.value());
                }
                embedded = new TextSegment(item.text(), new Metadata(map));
            }
            Embedding embedding = new Embedding(item.embedding());
            return new EmbeddingMatch(Double.valueOf(score.doubleValue()), item.id(), embedding, embedded);
        }).filter(Objects::nonNull).collect(Collectors.toList());
    }

    private void addInternal(String id, Embedding embedding, TextSegment embedded) {
        this.addAllInternal(Collections.singletonList(id), Collections.singletonList(embedding), embedded == null ? null : Collections.singletonList(embedded));
    }

    private void addAllInternal(List<String> ids, List<Embedding> embeddings, List<TextSegment> embedded) {
        if (Utils.isNullOrEmpty(ids) || Utils.isNullOrEmpty(embeddings)) {
            log.info("do not add empty embeddings to infinispan");
            return;
        }
        ValidationUtils.ensureTrue((ids.size() == embeddings.size() ? 1 : 0) != 0, (String)"ids size is not equal to embeddings size");
        ValidationUtils.ensureTrue((embedded == null || embeddings.size() == embedded.size() ? 1 : 0) != 0, (String)"embeddings size is not equal to embedded size");
        int size = ids.size();
        HashMap<String, LangChainInfinispanItem> elements = new HashMap<String, LangChainInfinispanItem>(size);
        for (int i = 0; i < size; ++i) {
            TextSegment textSegment;
            String id = ids.get(i);
            Embedding embedding = embeddings.get(i);
            TextSegment textSegment2 = textSegment = embedded == null ? null : embedded.get(i);
            if (textSegment != null) {
                Set<LangChainMetadata> metadata = textSegment.metadata().asMap().entrySet().stream().map(e -> new LangChainMetadata((String)e.getKey(), (String)e.getValue())).collect(Collectors.toSet());
                elements.put(id, new LangChainInfinispanItem(id, embedding.vector(), textSegment.text(), metadata));
                continue;
            }
            elements.put(id, new LangChainInfinispanItem(id, embedding.vector(), null, null));
        }
        this.remoteCache.putAll(elements);
    }

    public static Builder builder() {
        return new Builder();
    }

    public RemoteCache<String, LangChainInfinispanItem> remoteCache() {
        return this.remoteCache;
    }

    public void clearCache() {
        this.remoteCache.clear();
    }

    public static class Builder {
        private ConfigurationBuilder builder;
        private String name;
        private Integer dimension;

        public Builder cacheName(String name) {
            this.name = name;
            return this;
        }

        public Builder dimension(Integer dimension) {
            this.dimension = dimension;
            return this;
        }

        public Builder infinispanConfigBuilder(ConfigurationBuilder builder) {
            this.builder = builder;
            return this;
        }

        public InfinispanEmbeddingStore build() {
            return new InfinispanEmbeddingStore(this.builder, this.name, this.dimension);
        }
    }
}

