package org.jpmml.evaluator;

import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.BiMap;
import com.google.common.collect.ImmutableBiMap;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import org.dmg.pmml.ActivationFunctionType;
import org.dmg.pmml.Connection;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Entity;
import org.dmg.pmml.Expression;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.MiningFunctionType;
import org.dmg.pmml.NeuralInput;
import org.dmg.pmml.NeuralInputs;
import org.dmg.pmml.NeuralLayer;
import org.dmg.pmml.NeuralNetwork;
import org.dmg.pmml.NeuralOutput;
import org.dmg.pmml.NeuralOutputs;
import org.dmg.pmml.Neuron;
import org.dmg.pmml.NnNormalizationMethodType;
import org.dmg.pmml.NormContinuous;
import org.dmg.pmml.NormDiscrete;
import org.dmg.pmml.PMML;

/* loaded from: input_file:WEB-INF/lib/pmml-evaluator-1.2.11.jar:org/jpmml/evaluator/NeuralNetworkEvaluator.class */
public class NeuralNetworkEvaluator extends ModelEvaluator<NeuralNetwork> implements HasEntityRegistry<Entity> {
    private transient BiMap<String, Entity> entityRegistry;
    private static final LoadingCache<NeuralNetwork, BiMap<String, Entity>> entityCache = CacheUtil.buildLoadingCache(new CacheLoader<NeuralNetwork, BiMap<String, Entity>>() { // from class: org.jpmml.evaluator.NeuralNetworkEvaluator.1
        @Override // com.google.common.cache.CacheLoader
        public BiMap<String, Entity> load(NeuralNetwork neuralNetwork) {
            ImmutableBiMap.Builder builder = new ImmutableBiMap.Builder();
            AtomicInteger atomicInteger = new AtomicInteger(1);
            NeuralInputs neuralInputs = neuralNetwork.getNeuralInputs();
            if (neuralInputs == null) {
                throw new InvalidFeatureException(neuralNetwork);
            }
            Iterator<NeuralInput> it = neuralInputs.iterator();
            while (it.hasNext()) {
                builder = EntityUtil.put(it.next(), atomicInteger, builder);
            }
            Iterator<NeuralLayer> it2 = neuralNetwork.getNeuralLayers().iterator();
            while (it2.hasNext()) {
                Iterator<Neuron> it3 = it2.next().getNeurons().iterator();
                while (it3.hasNext()) {
                    builder = EntityUtil.put(it3.next(), atomicInteger, builder);
                }
            }
            return builder.build();
        }
    });

    public NeuralNetworkEvaluator(PMML pmml) {
        super(pmml, NeuralNetwork.class);
        this.entityRegistry = null;
    }

    public NeuralNetworkEvaluator(PMML pmml, NeuralNetwork neuralNetwork) {
        super(pmml, neuralNetwork);
        this.entityRegistry = null;
    }

    @Override // org.jpmml.evaluator.ModelManager, org.jpmml.evaluator.Consumer
    public String getSummary() {
        return "Neural network";
    }

    @Override // org.jpmml.evaluator.HasEntityRegistry
    public BiMap<String, Entity> getEntityRegistry() {
        if (this.entityRegistry == null) {
            this.entityRegistry = (BiMap) getValue(entityCache);
        }
        return this.entityRegistry;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.jpmml.evaluator.ModelEvaluator
    public Map<FieldName, ?> evaluate(ModelEvaluationContext modelEvaluationContext) {
        Map<FieldName, ?> evaluateClassification;
        NeuralNetwork neuralNetwork = (NeuralNetwork) getModel();
        if (!neuralNetwork.isScorable()) {
            throw new InvalidResultException(neuralNetwork);
        }
        MiningFunctionType functionName = neuralNetwork.getFunctionName();
        switch (functionName) {
            case REGRESSION:
                evaluateClassification = evaluateRegression(modelEvaluationContext);
                break;
            case CLASSIFICATION:
                evaluateClassification = evaluateClassification(modelEvaluationContext);
                break;
            default:
                throw new UnsupportedFeatureException(neuralNetwork, functionName);
        }
        return OutputUtil.evaluate(evaluateClassification, modelEvaluationContext);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private Map<FieldName, ?> evaluateRegression(ModelEvaluationContext modelEvaluationContext) {
        NeuralNetwork neuralNetwork = (NeuralNetwork) getModel();
        Map<String, Double> evaluateRaw = evaluateRaw(modelEvaluationContext);
        if (evaluateRaw == null) {
            return TargetUtil.evaluateRegressionDefault(modelEvaluationContext);
        }
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        NeuralOutputs neuralOutputs = neuralNetwork.getNeuralOutputs();
        if (neuralOutputs == null) {
            throw new InvalidFeatureException(neuralNetwork);
        }
        Iterator<NeuralOutput> it = neuralOutputs.iterator();
        while (it.hasNext()) {
            NeuralOutput next = it.next();
            String outputNeuron = next.getOutputNeuron();
            Expression outputExpression = getOutputExpression(next);
            if (outputExpression instanceof FieldRef) {
                linkedHashMap.put(((FieldRef) outputExpression).getField(), evaluateRaw.get(outputNeuron));
            } else {
                if (!(outputExpression instanceof NormContinuous)) {
                    throw new UnsupportedFeatureException(outputExpression);
                }
                NormContinuous normContinuous = (NormContinuous) outputExpression;
                linkedHashMap.put(normContinuous.getField(), Double.valueOf(NormalizationUtil.denormalize(normContinuous, evaluateRaw.get(outputNeuron).doubleValue())));
            }
        }
        for (Map.Entry entry : linkedHashMap.entrySet()) {
            entry.setValue(TargetUtil.evaluateRegressionInternal((FieldName) entry.getKey(), entry.getValue(), modelEvaluationContext));
        }
        return linkedHashMap;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private Map<FieldName, ? extends Classification> evaluateClassification(ModelEvaluationContext modelEvaluationContext) {
        NeuralNetwork neuralNetwork = (NeuralNetwork) getModel();
        BiMap<String, Entity> entityRegistry = getEntityRegistry();
        Map<String, Double> evaluateRaw = evaluateRaw(modelEvaluationContext);
        if (evaluateRaw == null) {
            return TargetUtil.evaluateClassificationDefault(modelEvaluationContext);
        }
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        NeuralOutputs neuralOutputs = neuralNetwork.getNeuralOutputs();
        if (neuralOutputs == null) {
            throw new InvalidFeatureException(neuralNetwork);
        }
        Iterator<NeuralOutput> it = neuralOutputs.iterator();
        while (it.hasNext()) {
            NeuralOutput next = it.next();
            String outputNeuron = next.getOutputNeuron();
            Entity entity = entityRegistry.get(outputNeuron);
            Expression outputExpression = getOutputExpression(next);
            if (!(outputExpression instanceof NormDiscrete)) {
                throw new UnsupportedFeatureException(outputExpression);
            }
            NormDiscrete normDiscrete = (NormDiscrete) outputExpression;
            FieldName field = normDiscrete.getField();
            EntityProbabilityDistribution entityProbabilityDistribution = (EntityProbabilityDistribution) linkedHashMap.get(field);
            if (entityProbabilityDistribution == null) {
                entityProbabilityDistribution = new EntityProbabilityDistribution(entityRegistry);
                linkedHashMap.put(field, entityProbabilityDistribution);
            }
            entityProbabilityDistribution.put(entity, normDiscrete.getValue(), evaluateRaw.get(outputNeuron));
        }
        for (Map.Entry entry : linkedHashMap.entrySet()) {
            entry.setValue(TargetUtil.evaluateClassificationInternal((FieldName) entry.getKey(), (Classification) entry.getValue(), modelEvaluationContext));
        }
        return linkedHashMap;
    }

    private Expression getOutputExpression(NeuralOutput neuralOutput) {
        DerivedField derivedField = neuralOutput.getDerivedField();
        if (derivedField == null) {
            throw new InvalidFeatureException(neuralOutput);
        }
        Expression expression = derivedField.getExpression();
        if (expression == null) {
            throw new InvalidFeatureException(derivedField);
        }
        return expression;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private Map<String, Double> evaluateRaw(EvaluationContext evaluationContext) {
        NeuralNetwork neuralNetwork = (NeuralNetwork) getModel();
        HashMap hashMap = new HashMap();
        NeuralInputs neuralInputs = neuralNetwork.getNeuralInputs();
        if (neuralInputs == null) {
            throw new InvalidFeatureException(neuralNetwork);
        }
        Iterator<NeuralInput> it = neuralInputs.iterator();
        while (it.hasNext()) {
            NeuralInput next = it.next();
            FieldValue evaluate = ExpressionUtil.evaluate(next.getDerivedField(), evaluationContext);
            if (evaluate == null) {
                return null;
            }
            hashMap.put(next.getId(), evaluate.asDouble());
        }
        for (NeuralLayer neuralLayer : neuralNetwork.getNeuralLayers()) {
            HashMap hashMap2 = new HashMap();
            for (Neuron neuron : neuralLayer.getNeurons()) {
                double doubleValue = neuron.getBias().doubleValue();
                for (Connection connection : neuron.getConnections()) {
                    doubleValue += ((Double) hashMap.get(connection.getFrom())).doubleValue() * connection.getWeight();
                }
                hashMap2.put(neuron.getId(), Double.valueOf(activation(doubleValue, neuralLayer)));
            }
            normalizeNeuronOutputs(neuralLayer, hashMap2);
            hashMap.putAll(hashMap2);
        }
        return hashMap;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void normalizeNeuronOutputs(NeuralLayer neuralLayer, Map<String, Double> map) {
        NeuralNetwork neuralNetwork = (NeuralNetwork) getModel();
        NeuralLayer neuralLayer2 = neuralLayer;
        NnNormalizationMethodType normalizationMethod = neuralLayer.getNormalizationMethod();
        if (normalizationMethod == null) {
            neuralLayer2 = neuralNetwork;
            normalizationMethod = neuralNetwork.getNormalizationMethod();
        }
        switch (normalizationMethod) {
            case NONE:
                return;
            case SIMPLEMAX:
                Classification.normalize(map);
                return;
            case SOFTMAX:
                Classification.normalizeSoftMax(map);
                return;
            default:
                throw new UnsupportedFeatureException(neuralLayer2, normalizationMethod);
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private double activation(double d, NeuralLayer neuralLayer) {
        NeuralNetwork neuralNetwork = (NeuralNetwork) getModel();
        NeuralLayer neuralLayer2 = neuralLayer;
        ActivationFunctionType activationFunction = neuralLayer.getActivationFunction();
        if (activationFunction == null) {
            neuralLayer2 = neuralNetwork;
            activationFunction = neuralNetwork.getActivationFunction();
        }
        if (activationFunction == null) {
            throw new InvalidFeatureException(neuralLayer2);
        }
        switch (activationFunction) {
            case THRESHOLD:
                Double threshold = neuralLayer.getThreshold();
                if (threshold == null) {
                    threshold = neuralNetwork.getThreshold();
                }
                return d > threshold.doubleValue() ? 1.0d : 0.0d;
            case LOGISTIC:
                return 1.0d / (1.0d + Math.exp(-d));
            case TANH:
                return Math.tanh(d);
            case IDENTITY:
                return d;
            case EXPONENTIAL:
                return Math.exp(d);
            case RECIPROCAL:
                return 1.0d / d;
            case SQUARE:
                return d * d;
            case GAUSS:
                return Math.exp(-(d * d));
            case SINE:
                return Math.sin(d);
            case COSINE:
                return Math.cos(d);
            case ELLIOTT:
                return d / (1.0d + Math.abs(d));
            case ARCTAN:
                return Math.atan(d);
            default:
                throw new UnsupportedFeatureException(neuralLayer2, activationFunction);
        }
    }
}
