package edu.berkeley.nlp.lm.io;

import edu.berkeley.nlp.lm.ArrayEncodedNgramLanguageModel;
import edu.berkeley.nlp.lm.ArrayEncodedProbBackoffLm;
import edu.berkeley.nlp.lm.ConfigOptions;
import edu.berkeley.nlp.lm.ContextEncodedNgramLanguageModel;
import edu.berkeley.nlp.lm.ContextEncodedProbBackoffLm;
import edu.berkeley.nlp.lm.StringWordIndexer;
import edu.berkeley.nlp.lm.cache.ArrayEncodedCachingLmWrapper;
import edu.berkeley.nlp.lm.cache.ContextEncodedCachingLmWrapper;
import edu.berkeley.nlp.lm.collections.Iterators;
import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.Iterator;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:edu/berkeley/nlp/lm/io/PerplexityTest.class */
public class PerplexityTest {
    public static final String TEST_PERPLEX_TINY_TXT = "test_perplex_tiny.txt";
    public static final String TEST_PERPLEX_TXT = "test_perplex.txt";
    public static final String BIG_TEST_ARPA = "big_test.arpa";
    public static final float TEST_PERPLEX_GOLD_PROB = -2675.41f;
    public static final float TEST_PERPLEX_TINY_GOLD_PROB = -38.9312f;

    @Test
    public void testTiny() {
        testArrayEncodedLogProb(getLm(false), FileUtils.getFile(TEST_PERPLEX_TINY_TXT), -38.9312f);
    }

    @Test
    public void testTinyUnranked() {
        testArrayEncodedLogProb(getLm(true), FileUtils.getFile(TEST_PERPLEX_TINY_TXT), -38.9312f);
    }

    @Test
    public void testTinyContextEncoded() {
        testContextEncodedLogProb(getContextEncodedLm(false), FileUtils.getFile(TEST_PERPLEX_TINY_TXT), -38.9312f);
    }

    @Test
    public void testTinyContextEncodedUnranked() {
        testContextEncodedLogProb(getContextEncodedLm(true), FileUtils.getFile(TEST_PERPLEX_TINY_TXT), -38.9312f);
    }

    @Test
    public void test() {
        testArrayEncodedLogProb(getLm(false), FileUtils.getFile(TEST_PERPLEX_TXT), -2675.41f);
    }

    @Test
    public void testUnranked() {
        testArrayEncodedLogProb(getLm(true), FileUtils.getFile(TEST_PERPLEX_TXT), -2675.41f);
    }

    @Test
    public void testCompressed() {
        File file = FileUtils.getFile(TEST_PERPLEX_TXT);
        File file2 = FileUtils.getFile(BIG_TEST_ARPA);
        ConfigOptions configOptions = new ConfigOptions();
        configOptions.unknownWordLogProb = 0.0d;
        testArrayEncodedLogProb(LmReaders.readArrayEncodedLmFromArpa(file2.getPath(), true, new StringWordIndexer(), configOptions, Integer.MAX_VALUE), file, -2675.41f);
    }

    @Test
    public void testCompressedCached() {
        File file = FileUtils.getFile(TEST_PERPLEX_TXT);
        File file2 = FileUtils.getFile(BIG_TEST_ARPA);
        ConfigOptions configOptions = new ConfigOptions();
        configOptions.unknownWordLogProb = 0.0d;
        ArrayEncodedProbBackoffLm readArrayEncodedLmFromArpa = LmReaders.readArrayEncodedLmFromArpa(file2.getPath(), true, new StringWordIndexer(), configOptions, Integer.MAX_VALUE);
        testArrayEncodedLogProb(ArrayEncodedCachingLmWrapper.wrapWithCacheNotThreadSafe(readArrayEncodedLmFromArpa, 16), file, -2675.41f);
        testArrayEncodedLogProb(ArrayEncodedCachingLmWrapper.wrapWithCacheThreadSafe(readArrayEncodedLmFromArpa, 16), file, -2675.41f);
    }

    @Test
    public void testContextEncoded() {
        testContextEncodedLogProb(getContextEncodedLm(false), FileUtils.getFile(TEST_PERPLEX_TXT), -2675.41f);
    }

    @Test
    public void testContextEncodedUnranked() {
        testContextEncodedLogProb(getContextEncodedLm(true), FileUtils.getFile(TEST_PERPLEX_TXT), -2675.41f);
    }

    @Test
    public void testCachedTiny() {
        File file = FileUtils.getFile(TEST_PERPLEX_TINY_TXT);
        ArrayEncodedProbBackoffLm<String> lm = getLm(false);
        testArrayEncodedLogProb(ArrayEncodedCachingLmWrapper.wrapWithCacheNotThreadSafe(lm, 16), file, -38.9312f);
        testArrayEncodedLogProb(ArrayEncodedCachingLmWrapper.wrapWithCacheThreadSafe(lm, 16), file, -38.9312f);
    }

    @Test
    public void testCachedTinyUnranked() {
        File file = FileUtils.getFile(TEST_PERPLEX_TINY_TXT);
        ArrayEncodedProbBackoffLm<String> lm = getLm(true);
        testArrayEncodedLogProb(ArrayEncodedCachingLmWrapper.wrapWithCacheNotThreadSafe(lm, 16), file, -38.9312f);
        testArrayEncodedLogProb(ArrayEncodedCachingLmWrapper.wrapWithCacheThreadSafe(lm, 16), file, -38.9312f);
    }

    @Test
    public void testCachedTinyContextEncoded() {
        File file = FileUtils.getFile(TEST_PERPLEX_TINY_TXT);
        ContextEncodedProbBackoffLm<String> contextEncodedLm = getContextEncodedLm(false);
        testContextEncodedLogProb(ContextEncodedCachingLmWrapper.wrapWithCacheNotThreadSafe(contextEncodedLm, 16), file, -38.9312f);
        testContextEncodedLogProb(ContextEncodedCachingLmWrapper.wrapWithCacheThreadSafe(contextEncodedLm, 16), file, -38.9312f);
    }

    @Test
    public void testCachedTinyContextEncodedUnranked() {
        File file = FileUtils.getFile(TEST_PERPLEX_TINY_TXT);
        ContextEncodedProbBackoffLm<String> contextEncodedLm = getContextEncodedLm(true);
        testContextEncodedLogProb(ContextEncodedCachingLmWrapper.wrapWithCacheNotThreadSafe(contextEncodedLm, 16), file, -38.9312f);
        testContextEncodedLogProb(ContextEncodedCachingLmWrapper.wrapWithCacheThreadSafe(contextEncodedLm, 16), file, -38.9312f);
    }

    @Test
    public void testCached() {
        File file = FileUtils.getFile(TEST_PERPLEX_TXT);
        ArrayEncodedProbBackoffLm<String> lm = getLm(false);
        testArrayEncodedLogProb(ArrayEncodedCachingLmWrapper.wrapWithCacheNotThreadSafe(lm, 16), file, -2675.41f);
        testArrayEncodedLogProb(ArrayEncodedCachingLmWrapper.wrapWithCacheThreadSafe(lm, 16), file, -2675.41f);
    }

    @Test
    public void testCachedUnranked() {
        File file = FileUtils.getFile(TEST_PERPLEX_TXT);
        ArrayEncodedProbBackoffLm<String> lm = getLm(true);
        testArrayEncodedLogProb(ArrayEncodedCachingLmWrapper.wrapWithCacheNotThreadSafe(lm, 16), file, -2675.41f);
        testArrayEncodedLogProb(ArrayEncodedCachingLmWrapper.wrapWithCacheThreadSafe(lm, 16), file, -2675.41f);
    }

    @Test
    public void testCachedContextEncoded() {
        File file = FileUtils.getFile(TEST_PERPLEX_TXT);
        ContextEncodedProbBackoffLm<String> contextEncodedLm = getContextEncodedLm(false);
        testContextEncodedLogProb(ContextEncodedCachingLmWrapper.wrapWithCacheNotThreadSafe(contextEncodedLm, 16), file, -2675.41f);
        testContextEncodedLogProb(ContextEncodedCachingLmWrapper.wrapWithCacheThreadSafe(contextEncodedLm, 16), file, -2675.41f);
    }

    @Test
    public void testCachedContextEncodedUnranked() {
        File file = FileUtils.getFile(TEST_PERPLEX_TXT);
        ContextEncodedProbBackoffLm<String> contextEncodedLm = getContextEncodedLm(true);
        testContextEncodedLogProb(ContextEncodedCachingLmWrapper.wrapWithCacheNotThreadSafe(contextEncodedLm, 16), file, -2675.41f);
        testContextEncodedLogProb(ContextEncodedCachingLmWrapper.wrapWithCacheThreadSafe(contextEncodedLm, 16), file, -2675.41f);
    }

    private ContextEncodedProbBackoffLm<String> getContextEncodedLm(boolean z) {
        File file = FileUtils.getFile(BIG_TEST_ARPA);
        ConfigOptions configOptions = new ConfigOptions();
        configOptions.storeRankedProbBackoffs = !z;
        configOptions.unknownWordLogProb = 0.0d;
        return LmReaders.readContextEncodedLmFromArpa(file.getPath(), new StringWordIndexer(), configOptions, Integer.MAX_VALUE);
    }

    private ArrayEncodedProbBackoffLm<String> getLm(boolean z) {
        File file = FileUtils.getFile(BIG_TEST_ARPA);
        ConfigOptions configOptions = new ConfigOptions();
        configOptions.storeRankedProbBackoffs = !z;
        configOptions.unknownWordLogProb = 0.0d;
        return LmReaders.readArrayEncodedLmFromArpa(file.getPath(), false, new StringWordIndexer(), configOptions, Integer.MAX_VALUE);
    }

    public static void testContextEncodedLogProb(ContextEncodedNgramLanguageModel<String> contextEncodedNgramLanguageModel, File file, float f) {
        float f2 = 0.0f;
        try {
            Iterator it = Iterators.able(IOUtils.lineIterator(file.getPath())).iterator();
            while (it.hasNext()) {
                String[] split = ((String) it.next()).trim().split(" ");
                int[] iArr = new int[split.length + 2];
                iArr[0] = contextEncodedNgramLanguageModel.getWordIndexer().getOrAddIndexFromString((String) contextEncodedNgramLanguageModel.getWordIndexer().getStartSymbol());
                iArr[iArr.length - 1] = contextEncodedNgramLanguageModel.getWordIndexer().getOrAddIndexFromString((String) contextEncodedNgramLanguageModel.getWordIndexer().getEndSymbol());
                int i = 1;
                for (String str : split) {
                    int i2 = i;
                    i++;
                    iArr[i2] = contextEncodedNgramLanguageModel.getWordIndexer().getIndexPossiblyUnk(str);
                }
                ContextEncodedNgramLanguageModel.LmContextInfo lmContextInfo = new ContextEncodedNgramLanguageModel.LmContextInfo();
                contextEncodedNgramLanguageModel.getLogProb(lmContextInfo.offset, lmContextInfo.order, iArr[0], lmContextInfo);
                float f3 = 0.0f;
                for (int i3 = 1; i3 < iArr.length; i3++) {
                    float logProb = contextEncodedNgramLanguageModel.getLogProb(lmContextInfo.offset, lmContextInfo.order, iArr[i3], (ContextEncodedNgramLanguageModel.LmContextInfo) null);
                    float logProb2 = contextEncodedNgramLanguageModel.getLogProb(lmContextInfo.offset, lmContextInfo.order, iArr[i3], lmContextInfo);
                    Assert.assertEquals(logProb2, logProb, 1.401298464324817E-45d);
                    f3 += logProb2;
                }
                Assert.assertEquals(f3, contextEncodedNgramLanguageModel.scoreSentence(Arrays.asList(split)), 1.0E-5d);
                f2 += f3;
            }
            Assert.assertEquals(f2, f, 0.1d);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public static void testArrayEncodedLogProb(ArrayEncodedNgramLanguageModel<String> arrayEncodedNgramLanguageModel, File file, float f) {
        float f2 = 0.0f;
        try {
            Iterator it = Iterators.able(IOUtils.lineIterator(file.getPath())).iterator();
            while (it.hasNext()) {
                String[] split = ((String) it.next()).trim().split(" ");
                int[] iArr = new int[split.length + 2];
                iArr[0] = arrayEncodedNgramLanguageModel.getWordIndexer().getOrAddIndexFromString((String) arrayEncodedNgramLanguageModel.getWordIndexer().getStartSymbol());
                iArr[iArr.length - 1] = arrayEncodedNgramLanguageModel.getWordIndexer().getOrAddIndexFromString((String) arrayEncodedNgramLanguageModel.getWordIndexer().getEndSymbol());
                int i = 1;
                for (String str : split) {
                    int i2 = i;
                    i++;
                    iArr[i2] = arrayEncodedNgramLanguageModel.getWordIndexer().getIndexPossiblyUnk(str);
                }
                float f3 = 0.0f;
                for (int i3 = 2; i3 <= Math.min(arrayEncodedNgramLanguageModel.getLmOrder(), iArr.length); i3++) {
                    f3 += arrayEncodedNgramLanguageModel.getLogProb(iArr, 0, i3);
                }
                for (int i4 = 1; i4 <= iArr.length - arrayEncodedNgramLanguageModel.getLmOrder(); i4++) {
                    f3 += arrayEncodedNgramLanguageModel.getLogProb(iArr, i4, i4 + arrayEncodedNgramLanguageModel.getLmOrder());
                }
                Assert.assertEquals(f3, arrayEncodedNgramLanguageModel.scoreSentence(Arrays.asList(split)), 1.0E-5d);
                f2 += f3;
            }
            Assert.assertEquals(f2, f, 0.1d);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }
}
