package org.jpmml.evaluator;

import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.math3.util.Precision;
import org.dmg.pmml.BayesInput;
import org.dmg.pmml.BayesInputs;
import org.dmg.pmml.BayesOutput;
import org.dmg.pmml.ContinuousDistribution;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Discretize;
import org.dmg.pmml.Expression;
import org.dmg.pmml.Extension;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.GaussianDistribution;
import org.dmg.pmml.MiningFunctionType;
import org.dmg.pmml.NaiveBayesModel;
import org.dmg.pmml.PMML;
import org.dmg.pmml.PairCounts;
import org.dmg.pmml.PoissonDistribution;
import org.dmg.pmml.TargetValueCount;
import org.dmg.pmml.TargetValueCounts;
import org.dmg.pmml.TargetValueStat;
import org.dmg.pmml.TargetValueStats;

/* loaded from: input_file:WEB-INF/lib/pmml-evaluator-1.2.11.jar:org/jpmml/evaluator/NaiveBayesModelEvaluator.class */
public class NaiveBayesModelEvaluator extends ModelEvaluator<NaiveBayesModel> {
    private transient List<BayesInput> bayesInputs;
    private transient Map<FieldName, Map<String, Double>> fieldCountSums;
    private static final LoadingCache<NaiveBayesModel, List<BayesInput>> bayesInputCache = CacheUtil.buildLoadingCache(new CacheLoader<NaiveBayesModel, List<BayesInput>>() { // from class: org.jpmml.evaluator.NaiveBayesModelEvaluator.1
        @Override // com.google.common.cache.CacheLoader
        public List<BayesInput> load(NaiveBayesModel naiveBayesModel) {
            return ImmutableList.copyOf((Collection) NaiveBayesModelEvaluator.parseBayesInputs(naiveBayesModel));
        }
    });
    private static final LoadingCache<NaiveBayesModel, Map<FieldName, Map<String, Double>>> fieldCountSumCache = CacheUtil.buildLoadingCache(new CacheLoader<NaiveBayesModel, Map<FieldName, Map<String, Double>>>() { // from class: org.jpmml.evaluator.NaiveBayesModelEvaluator.2
        @Override // com.google.common.cache.CacheLoader
        public Map<FieldName, Map<String, Double>> load(NaiveBayesModel naiveBayesModel) {
            return ImmutableMap.copyOf(NaiveBayesModelEvaluator.calculateFieldCountSums(naiveBayesModel));
        }
    });

    public NaiveBayesModelEvaluator(PMML pmml) {
        super(pmml, NaiveBayesModel.class);
        this.bayesInputs = null;
        this.fieldCountSums = null;
    }

    public NaiveBayesModelEvaluator(PMML pmml, NaiveBayesModel naiveBayesModel) {
        super(pmml, naiveBayesModel);
        this.bayesInputs = null;
        this.fieldCountSums = null;
    }

    @Override // org.jpmml.evaluator.ModelManager, org.jpmml.evaluator.Consumer
    public String getSummary() {
        return "Naive Bayes model";
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.jpmml.evaluator.ModelEvaluator
    public Map<FieldName, ?> evaluate(ModelEvaluationContext modelEvaluationContext) {
        NaiveBayesModel naiveBayesModel = (NaiveBayesModel) getModel();
        if (!naiveBayesModel.isScorable()) {
            throw new InvalidResultException(naiveBayesModel);
        }
        MiningFunctionType functionName = naiveBayesModel.getFunctionName();
        switch (functionName) {
            case CLASSIFICATION:
                return OutputUtil.evaluate(evaluateClassification(modelEvaluationContext), modelEvaluationContext);
            default:
                throw new UnsupportedFeatureException(naiveBayesModel, functionName);
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private Map<FieldName, ? extends Classification> evaluateClassification(ModelEvaluationContext modelEvaluationContext) {
        NaiveBayesModel naiveBayesModel = (NaiveBayesModel) getModel();
        double threshold = naiveBayesModel.getThreshold();
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        Map<FieldName, Map<String, Double>> fieldCountSums = getFieldCountSums();
        for (BayesInput bayesInput : getBayesInputs()) {
            FieldName fieldName = bayesInput.getFieldName();
            FieldValue evaluate = modelEvaluationContext.evaluate(fieldName);
            if (evaluate != null) {
                TargetValueStats targetValueStats = getTargetValueStats(bayesInput);
                if (targetValueStats != null) {
                    calculateContinuousProbabilities(evaluate, targetValueStats, threshold, linkedHashMap);
                } else {
                    DerivedField derivedField = bayesInput.getDerivedField();
                    if (derivedField != null) {
                        Expression expression = derivedField.getExpression();
                        if (expression == null) {
                            throw new InvalidFeatureException(derivedField);
                        }
                        if (!(expression instanceof Discretize)) {
                            throw new UnsupportedFeatureException(expression);
                        }
                        FieldValue discretize = DiscretizationUtil.discretize((Discretize) expression, evaluate);
                        if (discretize == null) {
                            throw new EvaluationException();
                        }
                        evaluate = FieldValueUtil.refine(derivedField, discretize);
                    }
                    Map<String, Double> map = fieldCountSums.get(fieldName);
                    TargetValueCounts targetValueCounts = getTargetValueCounts(bayesInput, evaluate);
                    if (targetValueCounts != null) {
                        calculateDiscreteProbabilities(map, targetValueCounts, threshold, linkedHashMap);
                    }
                }
            }
        }
        BayesOutput bayesOutput = naiveBayesModel.getBayesOutput();
        calculatePriorProbabilities(bayesOutput.getTargetValueCounts(), linkedHashMap);
        ProbabilityDistribution probabilityDistribution = new ProbabilityDistribution();
        Double d = (Double) Collections.max(linkedHashMap.values());
        for (Map.Entry<String, Double> entry : linkedHashMap.entrySet()) {
            probabilityDistribution.put(entry.getKey(), Double.valueOf(Math.exp(entry.getValue().doubleValue() - d.doubleValue())));
        }
        probabilityDistribution.normalizeValues();
        FieldName fieldName2 = bayesOutput.getFieldName();
        if (fieldName2 == null) {
            throw new InvalidFeatureException(bayesOutput);
        }
        return TargetUtil.evaluateClassification(fieldName2, probabilityDistribution, modelEvaluationContext);
    }

    private void calculateContinuousProbabilities(FieldValue fieldValue, TargetValueStats targetValueStats, double d, Map<String, Double> map) {
        Number asNumber = fieldValue.asNumber();
        Iterator<TargetValueStat> it = targetValueStats.iterator();
        while (it.hasNext()) {
            TargetValueStat next = it.next();
            String value = next.getValue();
            ContinuousDistribution continuousDistribution = next.getContinuousDistribution();
            if (!(continuousDistribution instanceof GaussianDistribution) && !(continuousDistribution instanceof PoissonDistribution)) {
                throw new InvalidFeatureException(next);
            }
            updateSum(value, Double.valueOf(Math.log(Math.max(DistributionUtil.probability(continuousDistribution, asNumber), d))), map);
        }
    }

    private void calculateDiscreteProbabilities(Map<String, Double> map, TargetValueCounts targetValueCounts, double d, Map<String, Double> map2) {
        Iterator<TargetValueCount> it = targetValueCounts.iterator();
        while (it.hasNext()) {
            TargetValueCount next = it.next();
            String value = next.getValue();
            double count = next.getCount() / map.get(value).doubleValue();
            if (VerificationUtil.isZero(Double.valueOf(next.getCount()), Precision.EPSILON)) {
                count = d;
            }
            updateSum(value, Double.valueOf(Math.log(count)), map2);
        }
    }

    private void calculatePriorProbabilities(TargetValueCounts targetValueCounts, Map<String, Double> map) {
        Iterator<TargetValueCount> it = targetValueCounts.iterator();
        while (it.hasNext()) {
            TargetValueCount next = it.next();
            updateSum(next.getValue(), Double.valueOf(Math.log(next.getCount())), map);
        }
    }

    protected List<BayesInput> getBayesInputs() {
        if (this.bayesInputs == null) {
            this.bayesInputs = (List) getValue(bayesInputCache);
        }
        return this.bayesInputs;
    }

    protected Map<FieldName, Map<String, Double>> getFieldCountSums() {
        if (this.fieldCountSums == null) {
            this.fieldCountSums = (Map) getValue(fieldCountSumCache);
        }
        return this.fieldCountSums;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Map<FieldName, Map<String, Double>> calculateFieldCountSums(NaiveBayesModel naiveBayesModel) {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (BayesInput bayesInput : (List) CacheUtil.getValue(naiveBayesModel, bayesInputCache)) {
            FieldName fieldName = bayesInput.getFieldName();
            LinkedHashMap linkedHashMap2 = new LinkedHashMap();
            Iterator<PairCounts> it = bayesInput.getPairCounts().iterator();
            while (it.hasNext()) {
                Iterator<TargetValueCount> it2 = it.next().getTargetValueCounts().iterator();
                while (it2.hasNext()) {
                    TargetValueCount next = it2.next();
                    updateSum(next.getValue(), Double.valueOf(next.getCount()), linkedHashMap2);
                }
            }
            linkedHashMap.put(fieldName, linkedHashMap2);
        }
        return linkedHashMap;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static List<BayesInput> parseBayesInputs(NaiveBayesModel naiveBayesModel) {
        BayesInputs bayesInputs = naiveBayesModel.getBayesInputs();
        if (!bayesInputs.hasExtensions()) {
            return bayesInputs.getBayesInputs();
        }
        ArrayList arrayList = new ArrayList(bayesInputs.getBayesInputs());
        Iterator<Extension> it = bayesInputs.getExtensions().iterator();
        while (it.hasNext()) {
            for (Object obj : it.next().getContent()) {
                if (obj instanceof BayesInput) {
                    arrayList.add((BayesInput) obj);
                }
            }
        }
        return arrayList;
    }

    private static void updateSum(String str, Double d, Map<String, Double> map) {
        Double d2 = map.get(str);
        if (d2 == null) {
            d2 = Double.valueOf(0.0d);
        }
        map.put(str, Double.valueOf(d2.doubleValue() + d.doubleValue()));
    }

    private static TargetValueStats getTargetValueStats(BayesInput bayesInput) {
        return bayesInput.getTargetValueStats();
    }

    /* JADX WARN: Multi-variable type inference failed */
    private static TargetValueCounts getTargetValueCounts(BayesInput bayesInput, FieldValue fieldValue) {
        if (bayesInput instanceof HasParsedValueMapping) {
            return (TargetValueCounts) fieldValue.getMapping((HasParsedValueMapping) bayesInput);
        }
        for (PairCounts pairCounts : bayesInput.getPairCounts()) {
            if (fieldValue.equalsString(pairCounts.getValue())) {
                TargetValueCounts targetValueCounts = pairCounts.getTargetValueCounts();
                if (targetValueCounts == null) {
                    throw new InvalidFeatureException(pairCounts);
                }
                return targetValueCounts;
            }
        }
        return null;
    }
}
