/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.clustering.dirichlet;

import java.util.ArrayList;
import java.util.List;
import org.apache.mahout.clustering.Cluster;
import org.apache.mahout.clustering.Model;
import org.apache.mahout.clustering.ModelDistribution;
import org.apache.mahout.clustering.dirichlet.DirichletCluster;
import org.apache.mahout.clustering.dirichlet.UncommonDistributions;
import org.apache.mahout.clustering.dirichlet.models.DistributionDescription;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;

public class DirichletState {
    private int numClusters;
    private ModelDistribution<VectorWritable> modelFactory;
    private List<DirichletCluster> clusters;
    private Vector mixture;
    private final double alpha0;

    public DirichletState(ModelDistribution<VectorWritable> modelFactory, int numClusters, double alpha0) {
        this.numClusters = numClusters;
        this.modelFactory = modelFactory;
        this.alpha0 = alpha0;
        this.clusters = new ArrayList<DirichletCluster>();
        for (Model<VectorWritable> m : modelFactory.sampleFromPrior(numClusters)) {
            this.clusters.add(new DirichletCluster((Cluster)m));
        }
        this.mixture = UncommonDistributions.rDirichlet(this.computeTotalCounts(), alpha0);
    }

    public DirichletState(DistributionDescription description, int numClusters, double alpha0) {
        this(description.createModelDistribution(), numClusters, alpha0);
    }

    public int getNumClusters() {
        return this.numClusters;
    }

    public void setNumClusters(int numClusters) {
        this.numClusters = numClusters;
    }

    public ModelDistribution<VectorWritable> getModelFactory() {
        return this.modelFactory;
    }

    public void setModelFactory(ModelDistribution<VectorWritable> modelFactory) {
        this.modelFactory = modelFactory;
    }

    public List<DirichletCluster> getClusters() {
        return this.clusters;
    }

    public void setClusters(List<DirichletCluster> clusters) {
        this.clusters = clusters;
    }

    public Vector getMixture() {
        return this.mixture;
    }

    public void setMixture(Vector mixture) {
        this.mixture = mixture;
    }

    public Vector totalCounts() {
        return this.computeTotalCounts();
    }

    private Vector computeTotalCounts() {
        DenseVector result = new DenseVector(this.numClusters);
        for (int i = 0; i < this.numClusters; ++i) {
            result.set(i, this.clusters.get(i).getTotalCount());
        }
        return result;
    }

    public void update(Cluster[] newModels) {
        for (int i = 0; i < newModels.length; ++i) {
            newModels[i].computeParameters();
            this.clusters.get(i).setModel(newModels[i]);
        }
        this.mixture = UncommonDistributions.rDirichlet(this.totalCounts(), this.alpha0);
    }

    public double adjustedProbability(VectorWritable x, int k) {
        double pdf = this.clusters.get(k).getModel().pdf(x);
        double mix = this.mixture.get(k);
        return mix * pdf;
    }

    public Model<VectorWritable>[] getModels() {
        Model[] result = new Model[this.numClusters];
        for (int i = 0; i < this.numClusters; ++i) {
            result[i] = this.clusters.get(i).getModel();
        }
        return result;
    }
}

