package edu.berkeley.nlp.lm.io;

import edu.berkeley.nlp.lm.ConfigOptions;
import edu.berkeley.nlp.lm.StringWordIndexer;
import edu.berkeley.nlp.lm.collections.Iterators;
import edu.berkeley.nlp.lm.util.Pair;
import java.io.File;
import java.io.IOException;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:edu/berkeley/nlp/lm/io/KneserNeyFromTextReaderTest.class */
public class KneserNeyFromTextReaderTest {
    @Test
    public void testBigram() {
        doTest("tiny_test_bigram", new double[]{0.75d, 0.3333300054073334d});
    }

    @Test
    public void testTrigram() {
        doTest("tiny_test_trigram", new double[]{0.75d, 0.6000000238418579d, 0.6000000238418579d});
    }

    @Test
    public void testFivegram() {
        doTest("tiny_test_fivegram", new double[]{0.4000000059604645d, 0.5d, 0.5d, 0.5384619832038879d, 0.4545449912548065d});
    }

    @Test
    public void testBig() {
        doTest("big_test", new double[]{0.7556390166282654d, 0.8919339776039124d, 0.944267988204956d, 0.9559410214424133d, 0.3594360053539276d});
    }

    private void doTest(String str, double[] dArr) {
        StringWordIndexer stringWordIndexer = new StringWordIndexer();
        int length = dArr.length;
        stringWordIndexer.setStartSymbol("<s>");
        stringWordIndexer.setEndSymbol("</s>");
        stringWordIndexer.setUnkSymbol("<unk>");
        String path = FileUtils.getFile(str + ".txt").getPath();
        File file = FileUtils.getFile(str + ".arpa");
        StringWriter stringWriter = new StringWriter();
        TextReader textReader = new TextReader(Arrays.asList(path), stringWordIndexer);
        ConfigOptions configOptions = new ConfigOptions();
        configOptions.kneserNeyDiscounts = dArr;
        configOptions.kneserNeyMinCounts = new double[]{0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d};
        KneserNeyLmReaderCallback kneserNeyLmReaderCallback = new KneserNeyLmReaderCallback(stringWordIndexer, length, configOptions);
        textReader.parse(kneserNeyLmReaderCallback);
        kneserNeyLmReaderCallback.parse(new KneserNeyFileWritingLmReaderCallback(new PrintWriter(stringWriter), stringWordIndexer));
        ArrayList arrayList = new ArrayList(Arrays.asList(stringWriter.toString().split("\n")));
        sortAndRemoveBlankLines(arrayList);
        List<String> lines = getLines(file);
        sortAndRemoveBlankLines(lines);
        compareLines(arrayList, lines);
    }

    private void compareLines(List<String> list, List<String> list2) {
        Assert.assertEquals(list.size(), list2.size());
        for (Pair pair : Iterators.able(Iterators.zip(list.iterator(), list2.iterator()))) {
            String trim = ((String) pair.getFirst()).trim();
            String trim2 = ((String) pair.getSecond()).trim();
            if (trim2.startsWith("-")) {
                Assert.assertTrue(pair.toString(), trim.startsWith("-"));
                String[] split = trim.split("\t");
                String[] split2 = trim2.split("\t");
                Assert.assertEquals(pair.toString(), split.length, split2.length);
                Assert.assertTrue(pair.toString(), split.length == 2 || split.length == 3);
                Assert.assertEquals(pair.toString(), split[1], split2[1]);
                Assert.assertEquals(pair.toString(), Double.parseDouble(split[0]), Double.parseDouble(split2[0]), 0.001d);
                if (split.length == 3) {
                    Assert.assertEquals(pair.toString(), Double.parseDouble(split[2]), Double.parseDouble(split2[2]), 0.001d);
                }
            } else {
                Assert.assertEquals(trim, trim2);
            }
        }
    }

    private List<String> getLines(File file) {
        ArrayList arrayList = new ArrayList();
        try {
            Iterator it = Iterators.able(IOUtils.lineIterator(file.getAbsolutePath())).iterator();
            while (it.hasNext()) {
                arrayList.add((String) it.next());
            }
            return arrayList;
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    private void sortAndRemoveBlankLines(List<String> list) {
        Collections.sort(list, new Comparator<String>() { // from class: edu.berkeley.nlp.lm.io.KneserNeyFromTextReaderTest.1
            @Override // java.util.Comparator
            public int compare(String str, String str2) {
                String[] split = str.split("\t");
                String[] split2 = str2.split("\t");
                int compare = Double.compare(split.length, split2.length);
                return compare != 0 ? compare : split.length > 1 ? split[1].compareTo(split2[1]) : split[0].compareTo(split2[0]);
            }
        });
        for (int size = list.size() - 1; size >= 0; size--) {
            if (list.get(size).trim().isEmpty()) {
                list.remove(size);
            }
        }
    }
}
