/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.clustering.dirichlet;

import java.io.IOException;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.io.WritableComparable;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.mahout.clustering.dirichlet.DirichletCluster;
import org.apache.mahout.clustering.dirichlet.DirichletClusterer;
import org.apache.mahout.clustering.dirichlet.DirichletDriver;
import org.apache.mahout.clustering.dirichlet.DirichletState;
import org.apache.mahout.clustering.dirichlet.UncommonDistributions;
import org.apache.mahout.clustering.dirichlet.models.DistributionDescription;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.iterator.sequencefile.PathFilters;
import org.apache.mahout.common.iterator.sequencefile.PathType;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;

public class DirichletMapper
extends Mapper<WritableComparable<?>, VectorWritable, Text, VectorWritable> {
    private DirichletClusterer clusterer;

    protected void map(WritableComparable<?> key, VectorWritable v, Mapper.Context context) throws IOException, InterruptedException {
        int k = this.clusterer.assignToModel(v);
        context.write((Object)new Text(String.valueOf(k)), (Object)v);
    }

    protected void setup(Mapper.Context context) throws IOException, InterruptedException {
        super.setup(context);
        DirichletState dirichletState = DirichletMapper.getDirichletState(context.getConfiguration());
        for (DirichletCluster cluster : dirichletState.getClusters()) {
            cluster.getModel().configure(context.getConfiguration());
        }
        this.clusterer = new DirichletClusterer(dirichletState);
        for (int i = 0; i < dirichletState.getNumClusters(); ++i) {
            context.write((Object)new Text(Integer.toString(i)), (Object)new VectorWritable((Vector)new DenseVector(0)));
        }
    }

    public void setup(DirichletState state) {
        this.clusterer = new DirichletClusterer(state);
    }

    public static DirichletState getDirichletState(Configuration conf) {
        String statePath = conf.get("org.apache.mahout.clustering.dirichlet.stateIn");
        String descriptionString = conf.get("org.apache.mahout.clustering.dirichlet.modelFactory");
        String numClusters = conf.get("org.apache.mahout.clustering.dirichlet.numClusters");
        String alpha0 = conf.get("org.apache.mahout.clustering.dirichlet.alpha_0");
        DistributionDescription description = DistributionDescription.fromString(descriptionString);
        return DirichletMapper.loadState(conf, statePath, description, Double.parseDouble(alpha0), Integer.parseInt(numClusters));
    }

    protected static DirichletState loadState(Configuration conf, String statePath, DistributionDescription description, double alpha, int k) {
        DirichletState state = DirichletDriver.createState(description, k, alpha);
        Path path = new Path(statePath);
        for (Pair record : new SequenceFileDirIterable(path, PathType.LIST, PathFilters.logsCRCFilter(), conf)) {
            int index = Integer.parseInt(((Writable)record.getFirst()).toString());
            state.getClusters().set(index, (DirichletCluster)record.getSecond());
        }
        state.setMixture(UncommonDistributions.rDirichlet(state.totalCounts(), alpha));
        return state;
    }
}

