package org.apache.lucene.classification;

import java.io.IOException;
import java.io.StringReader;
import java.util.LinkedList;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.index.AtomicReader;
import org.apache.lucene.index.MultiFields;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.TotalHitCountCollector;
import org.apache.lucene.search.WildcardQuery;
import org.apache.lucene.util.BytesRef;

/* loaded from: input_file:org/apache/lucene/classification/SimpleNaiveBayesClassifier.class */
public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
    private AtomicReader atomicReader;
    private String textFieldName;
    private String classFieldName;
    private int docsWithClassSize;
    private Analyzer analyzer;
    private IndexSearcher indexSearcher;

    @Override // org.apache.lucene.classification.Classifier
    public void train(AtomicReader atomicReader, String str, String str2, Analyzer analyzer) throws IOException {
        this.atomicReader = atomicReader;
        this.indexSearcher = new IndexSearcher(this.atomicReader);
        this.textFieldName = str;
        this.classFieldName = str2;
        this.analyzer = analyzer;
        this.docsWithClassSize = countDocsWithClass();
    }

    private int countDocsWithClass() throws IOException {
        int docCount = MultiFields.getTerms(this.atomicReader, this.classFieldName).getDocCount();
        if (docCount == -1) {
            TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
            this.indexSearcher.search(new WildcardQuery(new Term(this.classFieldName, String.valueOf('*'))), totalHitCountCollector);
            docCount = totalHitCountCollector.getTotalHits();
        }
        return docCount;
    }

    private String[] tokenizeDoc(String str) throws IOException {
        LinkedList linkedList = new LinkedList();
        TokenStream tokenStream = this.analyzer.tokenStream(this.textFieldName, new StringReader(str));
        CharTermAttribute addAttribute = tokenStream.addAttribute(CharTermAttribute.class);
        tokenStream.reset();
        while (tokenStream.incrementToken()) {
            linkedList.add(addAttribute.toString());
        }
        tokenStream.end();
        tokenStream.close();
        return (String[]) linkedList.toArray(new String[linkedList.size()]);
    }

    @Override // org.apache.lucene.classification.Classifier
    public ClassificationResult<BytesRef> assignClass(String str) throws IOException {
        if (this.atomicReader == null) {
            throw new IOException("You must first call Classifier#train first");
        }
        double d = 0.0d;
        BytesRef bytesRef = new BytesRef();
        TermsEnum it = MultiFields.getTerms(this.atomicReader, this.classFieldName).iterator((TermsEnum) null);
        String[] strArr = tokenizeDoc(str);
        while (true) {
            BytesRef next = it.next();
            if (next == null) {
                return new ClassificationResult<>(bytesRef, d);
            }
            double calculatePrior = calculatePrior(next) * calculateLikelihood(strArr, next);
            if (calculatePrior > d) {
                d = calculatePrior;
                bytesRef = next.clone();
            }
        }
    }

    private double calculateLikelihood(String[] strArr, BytesRef bytesRef) throws IOException {
        double d = 1.0d;
        for (String str : strArr) {
            d *= (getWordFreqForClass(str, bytesRef) + 1) / (getTextTermFreqForClass(bytesRef) + this.docsWithClassSize);
        }
        return d;
    }

    private double getTextTermFreqForClass(BytesRef bytesRef) throws IOException {
        Terms terms = MultiFields.getTerms(this.atomicReader, this.textFieldName);
        return (terms.getSumDocFreq() / terms.getDocCount()) * this.atomicReader.docFreq(new Term(this.classFieldName, bytesRef));
    }

    private int getWordFreqForClass(String str, BytesRef bytesRef) throws IOException {
        BooleanQuery booleanQuery = new BooleanQuery();
        booleanQuery.add(new BooleanClause(new TermQuery(new Term(this.textFieldName, str)), BooleanClause.Occur.MUST));
        booleanQuery.add(new BooleanClause(new TermQuery(new Term(this.classFieldName, bytesRef)), BooleanClause.Occur.MUST));
        TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
        this.indexSearcher.search(booleanQuery, totalHitCountCollector);
        return totalHitCountCollector.getTotalHits();
    }

    private double calculatePrior(BytesRef bytesRef) throws IOException {
        return docCount(bytesRef) / this.docsWithClassSize;
    }

    private int docCount(BytesRef bytesRef) throws IOException {
        return this.atomicReader.docFreq(new Term(this.classFieldName, bytesRef));
    }
}
