/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.classifier.sequencelearning.hmm;

import java.util.Collection;
import java.util.Iterator;
import org.apache.mahout.classifier.sequencelearning.hmm.HmmAlgorithms;
import org.apache.mahout.classifier.sequencelearning.hmm.HmmEvaluator;
import org.apache.mahout.classifier.sequencelearning.hmm.HmmModel;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;

public final class HmmTrainer {
    private HmmTrainer() {
    }

    public static HmmModel trainSupervised(int nrOfHiddenStates, int nrOfOutputStates, int[] observedSequence, int[] hiddenSequence, double pseudoCount) {
        pseudoCount = pseudoCount == 0.0 ? Double.MIN_VALUE : pseudoCount;
        DenseMatrix transitionMatrix = new DenseMatrix(nrOfHiddenStates, nrOfHiddenStates);
        DenseMatrix emissionMatrix = new DenseMatrix(nrOfHiddenStates, nrOfOutputStates);
        transitionMatrix.assign(pseudoCount);
        emissionMatrix.assign(pseudoCount);
        DenseVector initialProbabilities = new DenseVector(nrOfHiddenStates);
        initialProbabilities.assign(1.0 / (double)nrOfHiddenStates);
        HmmTrainer.countTransitions((Matrix)transitionMatrix, (Matrix)emissionMatrix, observedSequence, hiddenSequence);
        for (int i = 0; i < nrOfHiddenStates; ++i) {
            int j;
            double sum = 0.0;
            for (j = 0; j < nrOfHiddenStates; ++j) {
                sum += transitionMatrix.getQuick(i, j);
            }
            for (j = 0; j < nrOfHiddenStates; ++j) {
                transitionMatrix.setQuick(i, j, transitionMatrix.getQuick(i, j) / sum);
            }
            sum = 0.0;
            for (j = 0; j < nrOfOutputStates; ++j) {
                sum += emissionMatrix.getQuick(i, j);
            }
            for (j = 0; j < nrOfOutputStates; ++j) {
                emissionMatrix.setQuick(i, j, emissionMatrix.getQuick(i, j) / sum);
            }
        }
        return new HmmModel((Matrix)transitionMatrix, (Matrix)emissionMatrix, (Vector)initialProbabilities);
    }

    private static void countTransitions(Matrix transitionMatrix, Matrix emissionMatrix, int[] observedSequence, int[] hiddenSequence) {
        emissionMatrix.setQuick(hiddenSequence[0], observedSequence[0], emissionMatrix.getQuick(hiddenSequence[0], observedSequence[0]) + 1.0);
        for (int i = 1; i < observedSequence.length; ++i) {
            transitionMatrix.setQuick(hiddenSequence[i - 1], hiddenSequence[i], transitionMatrix.getQuick(hiddenSequence[i - 1], hiddenSequence[i]) + 1.0);
            emissionMatrix.setQuick(hiddenSequence[i], observedSequence[i], emissionMatrix.getQuick(hiddenSequence[i], observedSequence[i]) + 1.0);
        }
    }

    public static HmmModel trainSupervisedSequence(int nrOfHiddenStates, int nrOfOutputStates, Collection<int[]> hiddenSequences, Collection<int[]> observedSequences, double pseudoCount) {
        int i;
        pseudoCount = pseudoCount == 0.0 ? Double.MIN_VALUE : pseudoCount;
        DenseMatrix transitionMatrix = new DenseMatrix(nrOfHiddenStates, nrOfHiddenStates);
        DenseMatrix emissionMatrix = new DenseMatrix(nrOfHiddenStates, nrOfOutputStates);
        DenseVector initialProbabilities = new DenseVector(nrOfHiddenStates);
        transitionMatrix.assign(pseudoCount);
        emissionMatrix.assign(pseudoCount);
        initialProbabilities.assign(pseudoCount);
        Iterator<int[]> hiddenSequenceIt = hiddenSequences.iterator();
        Iterator<int[]> observedSequenceIt = observedSequences.iterator();
        while (hiddenSequenceIt.hasNext() && observedSequenceIt.hasNext()) {
            int[] hiddenSequence = hiddenSequenceIt.next();
            int[] observedSequence = observedSequenceIt.next();
            initialProbabilities.setQuick(hiddenSequence[0], initialProbabilities.getQuick(hiddenSequence[0]) + 1.0);
            HmmTrainer.countTransitions((Matrix)transitionMatrix, (Matrix)emissionMatrix, observedSequence, hiddenSequence);
        }
        double isum = 0.0;
        for (i = 0; i < nrOfHiddenStates; ++i) {
            int j;
            isum += initialProbabilities.getQuick(i);
            double sum = 0.0;
            for (j = 0; j < nrOfHiddenStates; ++j) {
                sum += transitionMatrix.getQuick(i, j);
            }
            for (j = 0; j < nrOfHiddenStates; ++j) {
                transitionMatrix.setQuick(i, j, transitionMatrix.getQuick(i, j) / sum);
            }
            sum = 0.0;
            for (j = 0; j < nrOfOutputStates; ++j) {
                sum += emissionMatrix.getQuick(i, j);
            }
            for (j = 0; j < nrOfOutputStates; ++j) {
                emissionMatrix.setQuick(i, j, emissionMatrix.getQuick(i, j) / sum);
            }
        }
        for (i = 0; i < nrOfHiddenStates; ++i) {
            initialProbabilities.setQuick(i, initialProbabilities.getQuick(i) / isum);
        }
        return new HmmModel((Matrix)transitionMatrix, (Matrix)emissionMatrix, (Vector)initialProbabilities);
    }

    public static HmmModel trainViterbi(HmmModel initialModel, int[] observedSequence, double pseudoCount, double epsilon, int maxIterations, boolean scaled) {
        HmmModel iteration;
        HmmModel lastIteration;
        pseudoCount = pseudoCount == 0.0 ? Double.MIN_VALUE : pseudoCount;
        try {
            lastIteration = initialModel.clone();
            iteration = initialModel.clone();
        }
        catch (CloneNotSupportedException e) {
            throw new UnknownError("Cloning HmmModels broke. Check for programming errors, changed APIs.");
        }
        int[] viterbiPath = new int[observedSequence.length];
        int[][] phi = new int[observedSequence.length - 1][initialModel.getNrOfHiddenStates()];
        double[][] delta = new double[observedSequence.length][initialModel.getNrOfHiddenStates()];
        for (int i = 0; i < maxIterations; ++i) {
            HmmAlgorithms.viterbiAlgorithm(viterbiPath, delta, phi, lastIteration, observedSequence, scaled);
            Matrix emissionMatrix = iteration.getEmissionMatrix();
            Matrix transitionMatrix = iteration.getTransitionMatrix();
            emissionMatrix.assign(pseudoCount);
            transitionMatrix.assign(pseudoCount);
            HmmTrainer.countTransitions(transitionMatrix, emissionMatrix, observedSequence, viterbiPath);
            for (int j = 0; j < iteration.getNrOfHiddenStates(); ++j) {
                int k;
                double sum = 0.0;
                for (k = 0; k < iteration.getNrOfHiddenStates(); ++k) {
                    sum += transitionMatrix.getQuick(j, k);
                }
                for (k = 0; k < iteration.getNrOfHiddenStates(); ++k) {
                    transitionMatrix.setQuick(j, k, transitionMatrix.getQuick(j, k) / sum);
                }
                sum = 0.0;
                for (k = 0; k < iteration.getNrOfOutputStates(); ++k) {
                    sum += emissionMatrix.getQuick(j, k);
                }
                for (k = 0; k < iteration.getNrOfOutputStates(); ++k) {
                    emissionMatrix.setQuick(j, k, emissionMatrix.getQuick(j, k) / sum);
                }
            }
            if (HmmTrainer.checkConvergence(lastIteration, iteration, epsilon)) break;
            lastIteration.assign(iteration);
        }
        return iteration;
    }

    public static HmmModel trainBaumWelch(HmmModel initialModel, int[] observedSequence, double epsilon, int maxIterations, boolean scaled) {
        HmmModel iteration;
        HmmModel lastIteration;
        try {
            lastIteration = initialModel.clone();
            iteration = initialModel.clone();
        }
        catch (CloneNotSupportedException e) {
            throw new UnknownError("Cloning HmmModels broke. Check for programming errors, changed APIs etc.");
        }
        int hiddenCount = initialModel.getNrOfHiddenStates();
        int visibleCount = observedSequence.length;
        DenseMatrix alpha = new DenseMatrix(visibleCount, hiddenCount);
        DenseMatrix beta = new DenseMatrix(visibleCount, hiddenCount);
        for (int it = 0; it < maxIterations; ++it) {
            Vector initialProbabilities = iteration.getInitialProbabilities();
            Matrix emissionMatrix = iteration.getEmissionMatrix();
            Matrix transitionMatrix = iteration.getTransitionMatrix();
            HmmAlgorithms.forwardAlgorithm((Matrix)alpha, iteration, observedSequence, scaled);
            HmmAlgorithms.backwardAlgorithm((Matrix)beta, iteration, observedSequence, scaled);
            if (scaled) {
                HmmTrainer.logScaledBaumWelch(observedSequence, iteration, (Matrix)alpha, (Matrix)beta);
            } else {
                HmmTrainer.unscaledBaumWelch(observedSequence, iteration, (Matrix)alpha, (Matrix)beta);
            }
            double isum = 0.0;
            for (int j = 0; j < iteration.getNrOfHiddenStates(); ++j) {
                int k;
                double sum = 0.0;
                for (k = 0; k < iteration.getNrOfHiddenStates(); ++k) {
                    sum += transitionMatrix.getQuick(j, k);
                }
                for (k = 0; k < iteration.getNrOfHiddenStates(); ++k) {
                    transitionMatrix.setQuick(j, k, transitionMatrix.getQuick(j, k) / sum);
                }
                sum = 0.0;
                for (k = 0; k < iteration.getNrOfOutputStates(); ++k) {
                    sum += emissionMatrix.getQuick(j, k);
                }
                for (k = 0; k < iteration.getNrOfOutputStates(); ++k) {
                    emissionMatrix.setQuick(j, k, emissionMatrix.getQuick(j, k) / sum);
                }
                isum += initialProbabilities.getQuick(j);
            }
            for (int i = 0; i < iteration.getNrOfHiddenStates(); ++i) {
                initialProbabilities.setQuick(i, initialProbabilities.getQuick(i) / isum);
            }
            if (HmmTrainer.checkConvergence(lastIteration, iteration, epsilon)) break;
            lastIteration.assign(iteration);
        }
        return iteration;
    }

    private static void unscaledBaumWelch(int[] observedSequence, HmmModel iteration, Matrix alpha, Matrix beta) {
        int t;
        double temp;
        int j;
        int i;
        Vector initialProbabilities = iteration.getInitialProbabilities();
        Matrix emissionMatrix = iteration.getEmissionMatrix();
        Matrix transitionMatrix = iteration.getTransitionMatrix();
        double modelLikelihood = HmmEvaluator.modelLikelihood(alpha, false);
        for (i = 0; i < iteration.getNrOfHiddenStates(); ++i) {
            initialProbabilities.setQuick(i, alpha.getQuick(0, i) * beta.getQuick(0, i));
        }
        for (i = 0; i < iteration.getNrOfHiddenStates(); ++i) {
            for (j = 0; j < iteration.getNrOfHiddenStates(); ++j) {
                temp = 0.0;
                for (t = 0; t < observedSequence.length - 1; ++t) {
                    temp += alpha.getQuick(t, i) * emissionMatrix.getQuick(j, observedSequence[t + 1]) * beta.getQuick(t + 1, j);
                }
                transitionMatrix.setQuick(i, j, transitionMatrix.getQuick(i, j) * temp / modelLikelihood);
            }
        }
        for (i = 0; i < iteration.getNrOfHiddenStates(); ++i) {
            for (j = 0; j < iteration.getNrOfOutputStates(); ++j) {
                temp = 0.0;
                for (t = 0; t < observedSequence.length; ++t) {
                    if (observedSequence[t] != j) continue;
                    temp += alpha.getQuick(t, i) * beta.getQuick(t, i);
                }
                emissionMatrix.setQuick(i, j, temp / modelLikelihood);
            }
        }
    }

    private static void logScaledBaumWelch(int[] observedSequence, HmmModel iteration, Matrix alpha, Matrix beta) {
        double temp;
        int t;
        double sum;
        int j;
        int i;
        Vector initialProbabilities = iteration.getInitialProbabilities();
        Matrix emissionMatrix = iteration.getEmissionMatrix();
        Matrix transitionMatrix = iteration.getTransitionMatrix();
        double modelLikelihood = HmmEvaluator.modelLikelihood(alpha, true);
        for (i = 0; i < iteration.getNrOfHiddenStates(); ++i) {
            initialProbabilities.setQuick(i, Math.exp(alpha.getQuick(0, i) + beta.getQuick(0, i)));
        }
        for (i = 0; i < iteration.getNrOfHiddenStates(); ++i) {
            for (j = 0; j < iteration.getNrOfHiddenStates(); ++j) {
                sum = Double.NEGATIVE_INFINITY;
                for (t = 0; t < observedSequence.length - 1; ++t) {
                    temp = alpha.getQuick(t, i) + Math.log(emissionMatrix.getQuick(j, observedSequence[t + 1])) + beta.getQuick(t + 1, j);
                    if (!(temp > Double.NEGATIVE_INFINITY)) continue;
                    sum = temp + Math.log(1.0 + Math.exp(sum - temp));
                }
                transitionMatrix.setQuick(i, j, transitionMatrix.getQuick(i, j) * Math.exp(sum - modelLikelihood));
            }
        }
        for (i = 0; i < iteration.getNrOfHiddenStates(); ++i) {
            for (j = 0; j < iteration.getNrOfOutputStates(); ++j) {
                sum = Double.NEGATIVE_INFINITY;
                for (t = 0; t < observedSequence.length; ++t) {
                    if (observedSequence[t] != j || !((temp = alpha.getQuick(t, i) + beta.getQuick(t, i)) > Double.NEGATIVE_INFINITY)) continue;
                    sum = temp + Math.log(1.0 + Math.exp(sum - temp));
                }
                emissionMatrix.setQuick(i, j, Math.exp(sum - modelLikelihood));
            }
        }
    }

    private static boolean checkConvergence(HmmModel oldModel, HmmModel newModel, double epsilon) {
        Matrix oldTransitionMatrix = oldModel.getTransitionMatrix();
        Matrix newTransitionMatrix = newModel.getTransitionMatrix();
        double diff = 0.0;
        for (int i = 0; i < oldModel.getNrOfHiddenStates(); ++i) {
            for (int j = 0; j < oldModel.getNrOfHiddenStates(); ++j) {
                double tmp = oldTransitionMatrix.getQuick(i, j) - newTransitionMatrix.getQuick(i, j);
                diff += tmp * tmp;
            }
        }
        double norm = Math.sqrt(diff);
        diff = 0.0;
        Matrix oldEmissionMatrix = oldModel.getEmissionMatrix();
        Matrix newEmissionMatrix = newModel.getEmissionMatrix();
        for (int i = 0; i < oldModel.getNrOfHiddenStates(); ++i) {
            for (int j = 0; j < oldModel.getNrOfOutputStates(); ++j) {
                double tmp = oldEmissionMatrix.getQuick(i, j) - newEmissionMatrix.getQuick(i, j);
                diff += tmp * tmp;
            }
        }
        return (norm += Math.sqrt(diff)) < epsilon;
    }
}

