/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.clustering.lda;

import java.util.Iterator;
import org.apache.commons.math.special.Gamma;
import org.apache.mahout.clustering.lda.LDAState;
import org.apache.mahout.clustering.lda.LDAUtil;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.function.DoubleDoubleFunction;

public class LDAInference {
    private static final double E_STEP_CONVERGENCE = 1.0E-6;
    private static final int MAX_ITER = 20;
    private DenseMatrix phi;
    private final LDAState state;

    public LDAInference(LDAState state) {
        this.state = state;
    }

    public InferredDocument infer(Vector wordCounts) {
        double docTotal = wordCounts.zSum();
        int docLength = wordCounts.size();
        DenseVector gamma = new DenseVector(this.state.getNumTopics());
        gamma.assign(this.state.getTopicSmoothing() + docTotal / (double)this.state.getNumTopics());
        DenseVector nextGamma = new DenseVector(this.state.getNumTopics());
        this.createPhiMatrix(docLength);
        Vector digammaGamma = this.digammaGamma((Vector)gamma);
        int[] map = new int[docLength];
        boolean converged = false;
        double oldLL = 1.0;
        for (int iteration = 0; !converged && iteration < 20; ++iteration) {
            nextGamma.assign(this.state.getTopicSmoothing());
            int mapping = 0;
            Iterator iter = wordCounts.iterateNonZero();
            while (iter.hasNext()) {
                Vector.Element e = (Vector.Element)iter.next();
                int word = e.index();
                Vector phiW = this.eStepForWord(word, digammaGamma);
                this.phi.assignColumn(mapping, phiW);
                if (iteration == 0) {
                    map[word] = mapping;
                }
                for (int k = 0; k < nextGamma.size(); ++k) {
                    double g = nextGamma.getQuick(k);
                    nextGamma.setQuick(k, g + e.get() * Math.exp(phiW.getQuick(k)));
                }
                ++mapping;
            }
            DenseVector tempG = gamma;
            gamma = nextGamma;
            nextGamma = tempG;
            digammaGamma = this.digammaGamma((Vector)gamma);
            double ll = this.computeLikelihood(wordCounts, map, (Matrix)this.phi, (Vector)gamma, digammaGamma);
            converged = oldLL < 0.0 && (oldLL - ll) / oldLL < 1.0E-6;
            oldLL = ll;
        }
        return new InferredDocument(wordCounts, (Vector)gamma, map, (Matrix)this.phi, oldLL);
    }

    private Vector digammaGamma(Vector gamma) {
        Vector digammaGamma = LDAInference.digamma(gamma);
        double digammaSumGamma = LDAInference.digamma(gamma.zSum());
        for (int i = 0; i < this.state.getNumTopics(); ++i) {
            digammaGamma.setQuick(i, digammaGamma.getQuick(i) - digammaSumGamma);
        }
        return digammaGamma;
    }

    private void createPhiMatrix(int docLength) {
        if (this.phi == null || this.phi.getRow(0).size() != docLength) {
            this.phi = new DenseMatrix(this.state.getNumTopics(), docLength);
        } else {
            this.phi.assign(0.0);
        }
    }

    private double computeLikelihood(Vector wordCounts, int[] map, Matrix phi, Vector gamma, Vector digammaGamma) {
        double ll = 0.0;
        ll += Gamma.logGamma((double)(this.state.getTopicSmoothing() * (double)this.state.getNumTopics()));
        ll -= (double)this.state.getNumTopics() * Gamma.logGamma((double)this.state.getTopicSmoothing());
        for (int k = 0; k < this.state.getNumTopics(); ++k) {
            double gammaK = gamma.get(k);
            ll += (this.state.getTopicSmoothing() - gammaK) * digammaGamma.getQuick(k);
            ll += Gamma.logGamma((double)gammaK);
        }
        ll -= Gamma.logGamma((double)gamma.zSum());
        Iterator iter = wordCounts.iterateNonZero();
        while (iter.hasNext()) {
            Vector.Element e = (Vector.Element)iter.next();
            int w = e.index();
            double n = e.get();
            int mapping = map[w];
            for (int k = 0; k < this.state.getNumTopics(); ++k) {
                double llPart = 0.0;
                double phiKMapping = phi.getQuick(k, mapping);
                ll += (llPart += Math.exp(phiKMapping) * (digammaGamma.getQuick(k) - phiKMapping + this.state.logProbWordGivenTopic(w, k))) * n;
            }
        }
        return ll;
    }

    private Vector eStepForWord(int word, Vector digammaGamma) {
        DenseVector phi = new DenseVector(this.state.getNumTopics());
        double phiTotal = Double.NEGATIVE_INFINITY;
        for (int k = 0; k < this.state.getNumTopics(); ++k) {
            phi.setQuick(k, this.state.logProbWordGivenTopic(word, k) + digammaGamma.getQuick(k));
            phiTotal = LDAUtil.logSum(phiTotal, phi.getQuick(k));
        }
        for (int i = 0; i < this.state.getNumTopics(); ++i) {
            phi.setQuick(i, phi.getQuick(i) - phiTotal);
        }
        return phi;
    }

    private static Vector digamma(Vector v) {
        DenseVector digammaGamma = new DenseVector(v.size());
        digammaGamma.assign(v, new DoubleDoubleFunction(){

            public double apply(double unused, double g) {
                return LDAInference.digamma(g);
            }
        });
        return digammaGamma;
    }

    private static double digamma(double x) {
        double r = 0.0;
        while (x <= 5.0) {
            r -= 1.0 / x;
            x += 1.0;
        }
        double f = 1.0 / (x * x);
        double t = f * (-0.08333333333333333 + f * (0.008333333333333333 + f * (-0.003968253968253968 + f * (0.004166666666666667 + f * (-0.007575757575757576 + f * (0.021092796092796094 + f * (-0.08333333333333333 + f * 3617.0 / 8160.0)))))));
        return r + Math.log(x) - 0.5 / x + t;
    }

    public static class InferredDocument {
        private final Vector wordCounts;
        private final Vector gamma;
        private final Matrix mphi;
        private final int[] columnMap;
        private final double logLikelihood;

        InferredDocument(Vector wordCounts, Vector gamma, int[] columnMap, Matrix phi, double ll) {
            this.wordCounts = wordCounts;
            this.gamma = gamma;
            this.mphi = phi;
            this.columnMap = columnMap;
            this.logLikelihood = ll;
        }

        public double phi(int k, int w) {
            return this.mphi.getQuick(k, this.columnMap[w]);
        }

        public Vector getWordCounts() {
            return this.wordCounts;
        }

        public Vector getGamma() {
            return this.gamma;
        }

        public double getLogLikelihood() {
            return this.logLikelihood;
        }
    }
}

