package org.dkpro.tc.io.libsvm.serialization;

import java.io.BufferedWriter;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.commons.io.FileUtils;
import org.apache.uima.UimaContext;
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
import org.apache.uima.fit.descriptor.ConfigurationParameter;
import org.apache.uima.fit.descriptor.ExternalResource;
import org.apache.uima.fit.util.JCasUtil;
import org.apache.uima.jcas.JCas;
import org.apache.uima.pear.util.FileUtil;
import org.apache.uima.resource.ResourceInitializationException;
import org.dkpro.tc.api.features.Feature;
import org.dkpro.tc.api.features.FeatureExtractorResource_ImplBase;
import org.dkpro.tc.api.features.Instance;
import org.dkpro.tc.api.type.TextClassificationOutcome;
import org.dkpro.tc.core.ml.ModelSerialization_ImplBase;
import org.dkpro.tc.core.task.uima.InstanceExtractor;
import org.dkpro.tc.io.libsvm.AdapterFormat;

/* loaded from: input_file:org/dkpro/tc/io/libsvm/serialization/LibsvmDataFormatLoadModelConnector.class */
public abstract class LibsvmDataFormatLoadModelConnector extends ModelSerialization_ImplBase {
    protected String OUTCOME_PLACEHOLDER = "-1";

    @ConfigurationParameter(name = "tcModel", mandatory = true)
    protected File tcModelLocation;

    @ExternalResource(key = "featureExtractors", mandatory = true)
    protected FeatureExtractorResource_ImplBase[] featureExtractors;

    @ConfigurationParameter(name = "featureMode", mandatory = true)
    protected String featureMode;

    @ConfigurationParameter(name = "learningMode", mandatory = true)
    protected String learningMode;
    protected Map<String, String> integer2OutcomeMapping;
    protected Map<String, Integer> featureMapping;

    public void initialize(UimaContext uimaContext) throws ResourceInitializationException {
        super.initialize(uimaContext);
        try {
            this.integer2OutcomeMapping = loadInteger2OutcomeMapping(this.tcModelLocation);
            this.featureMapping = loadFeature2IntegerMapping(this.tcModelLocation);
            verifyTcVersion(this.tcModelLocation, getClass());
        } catch (Exception e) {
            throw new ResourceInitializationException(e);
        }
    }

    private Map<String, Integer> loadFeature2IntegerMapping(File file) throws IOException {
        HashMap hashMap = new HashMap();
        Iterator it = FileUtils.readLines(new File(file, AdapterFormat.getFeatureNameMappingFilename()), "utf-8").iterator();
        while (it.hasNext()) {
            String[] split = ((String) it.next()).split("\t");
            hashMap.put(split[0], Integer.valueOf(split[1]));
        }
        return hashMap;
    }

    private Map<String, String> loadInteger2OutcomeMapping(File file) throws IOException {
        if (isRegression()) {
            return new HashMap();
        }
        HashMap hashMap = new HashMap();
        Iterator it = FileUtils.readLines(new File(file, AdapterFormat.getOutcomeMappingFilename()), "utf-8").iterator();
        while (it.hasNext()) {
            String[] split = ((String) it.next()).split("\t");
            hashMap.put(split[1], split[0]);
        }
        return hashMap;
    }

    private boolean isRegression() {
        return this.learningMode.equals("regression");
    }

    public void process(JCas jCas) throws AnalysisEngineProcessException {
        try {
            File runPrediction = runPrediction(createInputFile(jCas));
            List<TextClassificationOutcome> outcomeAnnotations = getOutcomeAnnotations(jCas);
            List<String> readLines = FileUtils.readLines(runPrediction, "utf-8");
            checkErrorConditionNumberOfOutcomesEqualsNumberOfPredictions(outcomeAnnotations, readLines);
            for (int i = 0; i < outcomeAnnotations.size(); i++) {
                if (isRegression()) {
                    outcomeAnnotations.get(i).setOutcome(readLines.get(i));
                } else {
                    outcomeAnnotations.get(i).setOutcome(this.integer2OutcomeMapping.get(readLines.get(i).replaceAll("\\.0", "")));
                }
            }
        } catch (Exception e) {
            throw new AnalysisEngineProcessException(e);
        }
    }

    private List<TextClassificationOutcome> getOutcomeAnnotations(JCas jCas) {
        return new ArrayList(JCasUtil.select(jCas, TextClassificationOutcome.class));
    }

    private void checkErrorConditionNumberOfOutcomesEqualsNumberOfPredictions(List<TextClassificationOutcome> list, List<String> list2) {
        if (list.size() != list2.size()) {
            throw new IllegalStateException("Expected [" + list.size() + "] predictions but were [" + list2.size() + "]");
        }
    }

    protected abstract File runPrediction(File file) throws Exception;

    private File createInputFile(JCas jCas) throws Exception {
        File createTempFile = FileUtil.createTempFile("libsvm", ".txt");
        createTempFile.deleteOnExit();
        BufferedWriter bufferedWriter = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(createTempFile), "utf-8"));
        for (Instance instance : new InstanceExtractor(this.featureMode, this.featureExtractors, true).getInstances(jCas, true)) {
            bufferedWriter.write(this.OUTCOME_PLACEHOLDER);
            bufferedWriter.write(injectSequenceId(instance));
            for (Feature feature : instance.getFeatures()) {
                if (sanityCheckValue(feature)) {
                    bufferedWriter.write("\t");
                    bufferedWriter.write(this.featureMapping.get(feature.getName()) + ":" + feature.getValue());
                }
            }
            bufferedWriter.write("\n");
        }
        bufferedWriter.close();
        return createTempFile;
    }

    protected String injectSequenceId(Instance instance) {
        return "";
    }

    private boolean sanityCheckValue(Feature feature) {
        if (feature.getValue() instanceof Number) {
            return true;
        }
        if (feature.getName().equals("DKProTCInstanceID")) {
            return false;
        }
        try {
            Double.valueOf((String) feature.getValue());
            return false;
        } catch (Exception e) {
            throw new IllegalArgumentException("Feature [" + feature.getName() + "] has a non-numeric value [" + feature.getValue() + "]", e);
        }
    }
}
