/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.explainability.integrationtests.opennlp;

import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException;
import java.util.function.Function;
import opennlp.tools.langdetect.Language;
import opennlp.tools.langdetect.LanguageDetector;
import opennlp.tools.langdetect.LanguageDetectorME;
import opennlp.tools.langdetect.LanguageDetectorModel;
import org.assertj.core.api.Assertions;
import org.assertj.core.api.AssertionsForClassTypes;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.kie.kogito.explainability.Config;
import org.kie.kogito.explainability.local.LocalExplainer;
import org.kie.kogito.explainability.local.lime.LimeConfig;
import org.kie.kogito.explainability.local.lime.LimeExplainer;
import org.kie.kogito.explainability.local.lime.optim.LimeConfigOptimizer;
import org.kie.kogito.explainability.model.DataDistribution;
import org.kie.kogito.explainability.model.Feature;
import org.kie.kogito.explainability.model.FeatureFactory;
import org.kie.kogito.explainability.model.Output;
import org.kie.kogito.explainability.model.PerturbationContext;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionInputsDataDistribution;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.model.Saliency;
import org.kie.kogito.explainability.model.SimplePrediction;
import org.kie.kogito.explainability.model.Type;
import org.kie.kogito.explainability.model.Value;
import org.kie.kogito.explainability.utils.DataUtils;
import org.kie.kogito.explainability.utils.ExplainabilityMetrics;
import org.kie.kogito.explainability.utils.ValidationUtils;

class OpenNLPLimeExplainerTest {
    OpenNLPLimeExplainerTest() {
    }

    @ParameterizedTest
    @ValueSource(ints={0})
    void testOpenNLPLangDetect(int seed) throws Exception {
        Random random = new Random();
        random.setSeed(seed);
        LimeConfig limeConfig = new LimeConfig().withSamples(10).withPerturbationContext(new PerturbationContext(random, 1));
        LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
        PredictionProvider model = this.getModel();
        Function<String, List<String>> tokenizer = this.getTokenizer();
        PredictionInput testInput = this.getTestInput(tokenizer);
        List predictionOutputs = (List)model.predictAsync(List.of(testInput)).get();
        org.junit.jupiter.api.Assertions.assertNotNull((Object)predictionOutputs);
        org.junit.jupiter.api.Assertions.assertFalse((boolean)predictionOutputs.isEmpty());
        PredictionOutput output = (PredictionOutput)predictionOutputs.get(0);
        org.junit.jupiter.api.Assertions.assertNotNull((Object)output);
        org.junit.jupiter.api.Assertions.assertNotNull((Object)output.getOutputs());
        org.junit.jupiter.api.Assertions.assertEquals((int)1, (int)output.getOutputs().size());
        org.junit.jupiter.api.Assertions.assertEquals((Object)"ita", (Object)((Output)output.getOutputs().get(0)).getValue().asString());
        org.junit.jupiter.api.Assertions.assertEquals((double)0.03, (double)((Output)output.getOutputs().get(0)).getScore(), (double)0.01);
        SimplePrediction prediction = new SimplePrediction(testInput, output);
        Map saliencyMap = (Map)limeExplainer.explainAsync((Prediction)prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        for (Saliency saliency : saliencyMap.values()) {
            org.junit.jupiter.api.Assertions.assertNotNull((Object)saliency);
            double i1 = ExplainabilityMetrics.impactScore((PredictionProvider)model, (Prediction)prediction, (List)saliency.getPositiveFeatures(3));
            org.junit.jupiter.api.Assertions.assertEquals((double)1.0, (double)i1);
        }
        org.junit.jupiter.api.Assertions.assertDoesNotThrow(() -> OpenNLPLimeExplainerTest.lambda$testOpenNLPLangDetect$0(model, (Prediction)prediction, limeExplainer));
        List<PredictionInput> inputs = this.getSamples(tokenizer);
        String decision = "lang";
        PredictionInputsDataDistribution distribution = new PredictionInputsDataDistribution(inputs);
        int k = 2;
        int chunkSize = 2;
        double f1 = ExplainabilityMetrics.getLocalSaliencyF1((String)decision, (PredictionProvider)model, (LocalExplainer)limeExplainer, (DataDistribution)distribution, (int)k, (int)chunkSize);
        AssertionsForClassTypes.assertThat((double)f1).isBetween(Double.valueOf(0.5), Double.valueOf(1.0));
    }

    private Function<String, List<String>> getTokenizer() {
        return s -> Arrays.asList(s.split("\\W"));
    }

    private PredictionProvider getModel() throws IOException {
        InputStream is = this.getClass().getResourceAsStream("/opennlp/langdetect-183.bin");
        LanguageDetectorModel languageDetectorModel = new LanguageDetectorModel(is);
        LanguageDetectorME languageDetector = new LanguageDetectorME(languageDetectorModel);
        return arg_0 -> OpenNLPLimeExplainerTest.lambda$getModel$3((LanguageDetector)languageDetector, arg_0);
    }

    private List<PredictionInput> getSamples(Function<String, List<String>> tokenizer) {
        List<String> texts = List.of("we want your money", "please reply quickly", "you are the lucky winner", "italiani, spaghetti pizza mandolino", "guten tag", "allez les bleus", "daje roma");
        ArrayList<PredictionInput> inputs = new ArrayList<PredictionInput>();
        for (String text : texts) {
            inputs.add(new PredictionInput(List.of(FeatureFactory.newFulltextFeature((String)"text", (String)text, tokenizer))));
        }
        return inputs;
    }

    private PredictionInput getTestInput(Function<String, List<String>> tokenizer) {
        String inputText = "italiani,spaghetti pizza mandolino";
        ArrayList<Feature> features = new ArrayList<Feature>();
        features.add(FeatureFactory.newFulltextFeature((String)"text", (String)inputText, tokenizer));
        return new PredictionInput(features);
    }

    @Test
    void testExplanationStabilityWithOptimization() throws ExecutionException, InterruptedException, TimeoutException, IOException {
        PredictionProvider model = this.getModel();
        List<PredictionInput> samples = this.getSamples(this.getTokenizer());
        List predictionOutputs = (List)model.predictAsync(samples.subList(0, 5)).get();
        List predictions = DataUtils.getPredictions(samples, (List)predictionOutputs);
        LimeConfigOptimizer limeConfigOptimizer = new LimeConfigOptimizer().withSampling(false);
        Random random = new Random();
        random.setSeed(0L);
        LimeConfig limeConfig = new LimeConfig().withSamples(10).withPerturbationContext(new PerturbationContext(random, 1));
        LimeConfig optimizedConfig = limeConfigOptimizer.optimize(limeConfig, predictions, model);
        Assertions.assertThat((Object)optimizedConfig).isNotSameAs((Object)limeConfig);
        LimeExplainer limeExplainer = new LimeExplainer(optimizedConfig);
        PredictionInput testPredictionInput = this.getTestInput(this.getTokenizer());
        List testPredictionOutputs = (List)model.predictAsync(List.of(testPredictionInput)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        SimplePrediction instance = new SimplePrediction(testPredictionInput, (PredictionOutput)testPredictionOutputs.get(0));
        org.junit.jupiter.api.Assertions.assertDoesNotThrow(() -> OpenNLPLimeExplainerTest.lambda$testExplanationStabilityWithOptimization$4(model, (Prediction)instance, limeExplainer));
    }

    private static /* synthetic */ void lambda$testExplanationStabilityWithOptimization$4(PredictionProvider model, Prediction instance, LimeExplainer limeExplainer) throws Throwable {
        ValidationUtils.validateLocalSaliencyStability((PredictionProvider)model, (Prediction)instance, (LocalExplainer)limeExplainer, (int)1, (double)0.9, (double)0.8);
    }

    private static /* synthetic */ CompletableFuture lambda$getModel$3(LanguageDetector languageDetector, List inputs) {
        return CompletableFuture.supplyAsync(() -> {
            LinkedList<PredictionOutput> results = new LinkedList<PredictionOutput>();
            for (PredictionInput predictionInput : inputs) {
                StringBuilder builder = new StringBuilder();
                for (Feature f : predictionInput.getFeatures()) {
                    if (builder.length() > 0) {
                        builder.append(' ');
                    }
                    builder.append(f.getValue().asString());
                }
                Language language = languageDetector.predictLanguage((CharSequence)builder.toString());
                PredictionOutput predictionOutput = new PredictionOutput(List.of(new Output("lang", Type.TEXT, new Value((Object)language.getLang()), language.getConfidence())));
                results.add(predictionOutput);
            }
            return results;
        });
    }

    private static /* synthetic */ void lambda$testOpenNLPLangDetect$0(PredictionProvider model, Prediction prediction, LimeExplainer limeExplainer) throws Throwable {
        ValidationUtils.validateLocalSaliencyStability((PredictionProvider)model, (Prediction)prediction, (LocalExplainer)limeExplainer, (int)2, (double)0.6, (double)0.6);
    }
}

