package org.dkpro.tc.ml.libsvm;

import java.io.BufferedWriter;
import java.io.File;
import java.io.FileOutputStream;
import java.io.OutputStreamWriter;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.compress.utils.IOUtils;
import org.apache.commons.io.FileUtils;
import org.dkpro.lab.engine.TaskContext;
import org.dkpro.lab.storage.StorageService;
import org.dkpro.tc.core.Constants;
import org.dkpro.tc.io.libsvm.LibsvmDataFormatTestTask;
import org.dkpro.tc.ml.libsvm.core.LibsvmPredictor;
import org.dkpro.tc.ml.libsvm.core.LibsvmTrainer;

/* loaded from: input_file:org/dkpro/tc/ml/libsvm/LibsvmTestTask.class */
public class LibsvmTestTask extends LibsvmDataFormatTestTask implements Constants {
    private List<String> buildParameters() {
        ArrayList arrayList = new ArrayList();
        if (this.classificationArguments != null) {
            for (int i = 1; i < this.classificationArguments.size(); i++) {
                arrayList.add((String) this.classificationArguments.get(i));
            }
        }
        return arrayList;
    }

    private List<String> pickGold(List<String> list) {
        ArrayList arrayList = new ArrayList();
        for (String str : list) {
            if (!str.isEmpty()) {
                arrayList.add(str.substring(0, str.indexOf("\t")));
            }
        }
        return arrayList;
    }

    protected Object trainModel(TaskContext taskContext) throws Exception {
        File trainFile = getTrainFile(taskContext);
        File file = new File(taskContext.getFolder("", StorageService.AccessMode.READWRITE), "classifier.ser");
        new LibsvmTrainer().train(trainFile, file, buildParameters());
        return file;
    }

    protected void runPrediction(TaskContext taskContext, Object obj) throws Exception {
        File testFile = getTestFile(taskContext);
        mergePredictionWithGold(taskContext, new LibsvmPredictor().predict(testFile, (File) obj));
    }

    private void mergePredictionWithGold(TaskContext taskContext, List<String> list) throws Exception {
        File testFile = getTestFile(taskContext);
        BufferedWriter bufferedWriter = null;
        try {
            bufferedWriter = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(getPredictionFile(taskContext)), "utf-8"));
            List<String> pickGold = pickGold(FileUtils.readLines(testFile, "utf-8"));
            bufferedWriter.write("#PREDICTION;GOLD\n");
            for (int i = 0; i < pickGold.size(); i++) {
                bufferedWriter.write(list.get(i) + ";" + pickGold.get(i));
                bufferedWriter.write("\n");
            }
            IOUtils.closeQuietly(bufferedWriter);
        } catch (Throwable th) {
            IOUtils.closeQuietly(bufferedWriter);
            throw th;
        }
    }
}
