package eus.ixa.ixa.pipe.ml;

import eus.ixa.ixa.pipe.ml.document.DocSample;
import eus.ixa.ixa.pipe.ml.document.DocSampleStream;
import eus.ixa.ixa.pipe.ml.document.DocumentClassifierEvaluationMonitor;
import eus.ixa.ixa.pipe.ml.document.DocumentClassifierEvaluator;
import eus.ixa.ixa.pipe.ml.document.DocumentClassifierFactory;
import eus.ixa.ixa.pipe.ml.document.DocumentClassifierME;
import eus.ixa.ixa.pipe.ml.document.DocumentClassifierModel;
import eus.ixa.ixa.pipe.ml.document.features.DocumentFeatureDescriptor;
import eus.ixa.ixa.pipe.ml.document.features.DocumentModelResources;
import eus.ixa.ixa.pipe.ml.utils.Flags;
import eus.ixa.ixa.pipe.ml.utils.IOUtils;
import java.io.IOException;
import java.nio.charset.Charset;
import opennlp.tools.util.ObjectStream;
import opennlp.tools.util.TrainingParameters;

/* loaded from: input_file:eus/ixa/ixa/pipe/ml/DocumentClassifierTrainer.class */
public class DocumentClassifierTrainer {
    private final String lang;
    private final String trainData;
    private final String testData;
    private ObjectStream<DocSample> trainSamples;
    private ObjectStream<DocSample> testSamples;
    private final String clearTrainingFeatures;
    private final String clearEvaluationFeatures;
    private DocumentClassifierFactory docClassFactory;

    public DocumentClassifierTrainer(TrainingParameters trainingParameters) throws IOException {
        this.lang = Flags.getLanguage(trainingParameters);
        this.clearTrainingFeatures = Flags.getClearTrainingFeatures(trainingParameters);
        this.clearEvaluationFeatures = Flags.getClearEvaluationFeatures(trainingParameters);
        this.trainData = (String) trainingParameters.getSettings().get("TrainSet");
        this.testData = (String) trainingParameters.getSettings().get("TestSet");
        this.trainSamples = getDocumentStream(this.trainData, this.clearTrainingFeatures);
        this.testSamples = getDocumentStream(this.testData, this.clearEvaluationFeatures);
        createDocumentClassificationFactory(trainingParameters);
    }

    public void createDocumentClassificationFactory(TrainingParameters trainingParameters) throws IOException {
        String createDocumentFeatureDescriptor = DocumentFeatureDescriptor.createDocumentFeatureDescriptor(trainingParameters);
        System.err.println(createDocumentFeatureDescriptor);
        setDocumentClassifierFactory(DocumentClassifierFactory.create(DocumentClassifierFactory.class.getName(), createDocumentFeatureDescriptor.getBytes(Charset.forName("UTF-8")), DocumentModelResources.loadDocumentResources(trainingParameters)));
    }

    public final DocumentClassifierModel train(TrainingParameters trainingParameters) {
        if (getDocumentClassificationFactory() == null) {
            throw new IllegalStateException("The DocumentClassificationFactory must be instantiated!!");
        }
        DocumentClassifierModel documentClassifierModel = null;
        DocumentClassifierEvaluator documentClassifierEvaluator = null;
        try {
            documentClassifierModel = DocumentClassifierME.train(this.lang, this.trainSamples, trainingParameters, this.docClassFactory);
            documentClassifierEvaluator = new DocumentClassifierEvaluator(new DocumentClassifierME(documentClassifierModel), new DocumentClassifierEvaluationMonitor[0]);
            documentClassifierEvaluator.evaluate(this.testSamples);
        } catch (IOException e) {
            System.err.println("IO error while loading traing and test sets!");
            e.printStackTrace();
            System.exit(1);
        }
        System.out.println("Final Result: \n" + documentClassifierEvaluator.getAccuracy());
        return documentClassifierModel;
    }

    public static ObjectStream<DocSample> getDocumentStream(String str, String str2) throws IOException {
        return new DocSampleStream(str2, IOUtils.readFileIntoMarkableStreamFactory(str));
    }

    public final DocumentClassifierFactory getDocumentClassificationFactory() {
        return this.docClassFactory;
    }

    public final DocumentClassifierFactory setDocumentClassifierFactory(DocumentClassifierFactory documentClassifierFactory) {
        this.docClassFactory = documentClassifierFactory;
        return this.docClassFactory;
    }
}
