/*
 * Decompiled with CFR 0.152.
 */
package io.quarkiverse.langchain4j.redis;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.Json;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStore;
import io.quarkiverse.langchain4j.QuarkusJsonCodecFactory;
import io.quarkiverse.langchain4j.redis.runtime.RedisSchema;
import io.quarkus.redis.datasource.ReactiveRedisDataSource;
import io.quarkus.redis.datasource.json.ReactiveJsonCommands;
import io.quarkus.redis.datasource.keys.KeyScanArgs;
import io.quarkus.redis.datasource.search.CreateArgs;
import io.quarkus.redis.datasource.search.Document;
import io.quarkus.redis.datasource.search.QueryArgs;
import io.quarkus.redis.datasource.search.SearchQueryResponse;
import io.smallrye.mutiny.Uni;
import io.vertx.mutiny.redis.client.Command;
import io.vertx.mutiny.redis.client.Request;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.jboss.logging.Logger;

public class RedisEmbeddingStore
implements EmbeddingStore<TextSegment> {
    public static final String EXTRA_ATTRIBUTES = "extra_attributes";
    public static final String ID = "id";
    private final ReactiveRedisDataSource ds;
    private final RedisSchema schema;
    private final Logger LOG = Logger.getLogger(RedisEmbeddingStore.class);
    private static final String SCORE_FIELD_NAME = "vector_score";

    public static Builder builder() {
        return new Builder();
    }

    public RedisEmbeddingStore(ReactiveRedisDataSource ds, RedisSchema schema) {
        this.ds = ds;
        this.schema = schema;
        this.createIndexIfDoesNotExist();
    }

    private void createIndexIfDoesNotExist() {
        List indexes = (List)this.ds.search().ft_list().onFailure().invoke(t -> {
            if (t.getMessage().contains("unknown command")) {
                this.LOG.error((Object)"The Redis server does not seem to support RediSearch. Please install the RediSearch module. If using containers, we suggest to use the redis/redis-stack images.");
            }
        }).await().indefinitely();
        if (!indexes.contains(this.schema.getIndexName())) {
            CreateArgs indexCreateArgs = new CreateArgs().onJson().prefixes(new String[]{this.schema.getPrefix()});
            this.schema.defineFields(indexCreateArgs);
            this.LOG.debug((Object)("Creating Redis index " + this.schema.getIndexName()));
            this.ds.search().ftCreate(this.schema.getIndexName(), indexCreateArgs).await().indefinitely();
        } else {
            this.LOG.debug((Object)("Index in Redis already exists: " + this.schema.getIndexName()));
        }
    }

    public String add(Embedding embedding) {
        String id = Utils.randomUUID();
        this.add(id, embedding);
        return id;
    }

    public void add(String id, Embedding embedding) {
        this.addInternal(id, embedding, null);
    }

    public String add(Embedding embedding, TextSegment textSegment) {
        String id = Utils.randomUUID();
        this.addInternal(id, embedding, textSegment);
        return id;
    }

    public List<String> addAll(List<Embedding> embeddings) {
        List<String> ids = embeddings.stream().map(ignored -> Utils.randomUUID()).collect(Collectors.toList());
        this.addAllInternal(ids, embeddings, null);
        return ids;
    }

    public List<String> addAll(List<Embedding> embeddings, List<TextSegment> embedded) {
        List<String> ids = embeddings.stream().map(ignored -> Utils.randomUUID()).collect(Collectors.toList());
        this.addAllInternal(ids, embeddings, embedded);
        return ids;
    }

    private void addInternal(String id, Embedding embedding, TextSegment embedded) {
        this.addAllInternal(Collections.singletonList(id), Collections.singletonList(embedding), embedded == null ? null : Collections.singletonList(embedded));
    }

    private void addAllInternal(List<String> ids, List<Embedding> embeddings, List<TextSegment> embedded) {
        if (ids.isEmpty() || ids.size() != embeddings.size() || embedded != null && embedded.size() != embeddings.size()) {
            throw new IllegalArgumentException("ids, embeddings and embedded must be non-empty and of the same size");
        }
        ReactiveJsonCommands json = this.ds.json();
        int size = ids.size();
        Uni[] unis = new Uni[size];
        for (int i = 0; i < size; ++i) {
            String id = ids.get(i);
            Embedding embedding = embeddings.get(i);
            TextSegment textSegment = embedded == null ? null : embedded.get(i);
            HashMap<String, Object> fields = new HashMap<String, Object>();
            fields.put(this.schema.getVectorFieldName(), embedding.vector());
            if (textSegment != null) {
                fields.put(this.schema.getScalarFieldName(), textSegment.text());
                fields.putAll(textSegment.metadata().asMap());
            }
            String key = this.schema.getPrefix() + id;
            unis[i] = json.jsonSet((Object)key, "$", fields);
        }
        Uni.join().all(unis).andFailFast().await().indefinitely();
    }

    public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding referenceEmbedding, int maxResults, double minScore) {
        String queryTemplate = "*=>[ KNN %d @%s $BLOB AS %s ]";
        String query = String.format(queryTemplate, maxResults, this.schema.getVectorFieldName(), SCORE_FIELD_NAME);
        QueryArgs args = new QueryArgs().sortByAscending(SCORE_FIELD_NAME).param("DIALECT", "2").param("BLOB", referenceEmbedding.vector());
        Uni search = this.ds.search().ftSearch(this.schema.getIndexName(), query, args);
        SearchQueryResponse response = (SearchQueryResponse)search.await().indefinitely();
        return response.documents().stream().map(this::extractEmbeddingMatch).filter(embeddingMatch -> embeddingMatch.score() >= minScore).collect(Collectors.toList());
    }

    private EmbeddingMatch<TextSegment> extractEmbeddingMatch(Document document) {
        try {
            JsonNode jsonNode = QuarkusJsonCodecFactory.ObjectMapperHolder.MAPPER.readTree(document.property("$").asString());
            JsonNode embedded = jsonNode.get(this.schema.getScalarFieldName());
            Embedding embedding = new Embedding((float[])Json.fromJson((String)jsonNode.get(this.schema.getVectorFieldName()).toString(), float[].class));
            double score = (2.0 - document.property(SCORE_FIELD_NAME).asDouble()) / 2.0;
            String id = document.key().substring(this.schema.getPrefix().length());
            Map<String, String> metadata = this.schema.getMetadataFields().stream().filter(arg_0 -> ((JsonNode)jsonNode).has(arg_0)).collect(Collectors.toMap(metadataFieldName -> metadataFieldName, name -> jsonNode.get(name).asText()));
            TextSegment textSegment = embedded != null ? new TextSegment(embedded.asText(), Metadata.from(metadata)) : null;
            return new EmbeddingMatch(Double.valueOf(score), id, embedding, (Object)textSegment);
        }
        catch (JsonProcessingException e) {
            throw new RuntimeException(e);
        }
    }

    public void deleteAll() {
        KeyScanArgs args = new KeyScanArgs().match(this.schema.getPrefix() + "*");
        Set keysToDelete = (Set)this.ds.key().scan(args).toMulti().collect().asSet().await().indefinitely();
        if (!keysToDelete.isEmpty()) {
            Request command = Request.cmd((Command)Command.DEL);
            keysToDelete.forEach(arg_0 -> ((Request)command).arg(arg_0));
            this.ds.getRedis().send(command).await().indefinitely();
            this.LOG.debug((Object)("Deleted " + keysToDelete.size() + " keys"));
        }
    }

    public static class Builder {
        private ReactiveRedisDataSource redisClient;
        private RedisSchema schema;

        public Builder dataSource(ReactiveRedisDataSource client) {
            this.redisClient = client;
            return this;
        }

        public Builder schema(RedisSchema schema) {
            this.schema = schema;
            return this;
        }

        public RedisEmbeddingStore build() {
            return new RedisEmbeddingStore(this.redisClient, this.schema);
        }
    }
}

