/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.classifier.naivebayes.trainer;

import java.io.IOException;
import java.net.URI;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.filecache.DistributedCache;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.map.OpenObjectIntHashMap;

public class NaiveBayesThetaMapper
extends Mapper<IntWritable, VectorWritable, Text, VectorWritable> {
    private final OpenObjectIntHashMap<String> labelMap = new OpenObjectIntHashMap();
    private Vector featureSum;
    private Vector labelSum;
    private Vector perLabelThetaNormalizer;
    private double alphaI = 1.0;
    private double vocabCount;

    protected void map(IntWritable key, VectorWritable value, Mapper.Context context) throws IOException, InterruptedException {
        Vector vector = value.get();
        int label = key.get();
        double weight = Math.log((vector.zSum() + this.alphaI) / (this.labelSum.get(label) + this.vocabCount));
        this.perLabelThetaNormalizer.set(label, this.perLabelThetaNormalizer.get(label) + weight);
    }

    protected void setup(Mapper.Context context) throws IOException, InterruptedException {
        super.setup(context);
        Configuration conf = context.getConfiguration();
        URI[] localFiles = DistributedCache.getCacheFiles((Configuration)conf);
        if (localFiles == null || localFiles.length < 2) {
            throw new IllegalArgumentException("missing paths from the DistributedCache");
        }
        this.alphaI = conf.getFloat("alphaI", 1.0f);
        Path weightFile = new Path(localFiles[0].getPath());
        for (Pair record : new SequenceFileIterable(weightFile, true, conf)) {
            Text key = (Text)record.getFirst();
            VectorWritable value = (VectorWritable)((Object)record.getSecond());
            if (key.toString().equals("__SJ")) {
                this.featureSum = value.get();
                continue;
            }
            if (!key.toString().equals("__SK")) continue;
            this.labelSum = value.get();
        }
        this.perLabelThetaNormalizer = this.labelSum.like();
        this.vocabCount = this.featureSum.getNumNondefaultElements();
        Path labelMapFile = new Path(localFiles[1].getPath());
        for (Pair record : new SequenceFileIterable(labelMapFile, true, conf)) {
            this.labelMap.put((Object)((Writable)record.getFirst()).toString(), ((IntWritable)record.getSecond()).get());
        }
    }

    protected void cleanup(Mapper.Context context) throws IOException, InterruptedException {
        context.write((Object)new Text("_LTN"), (Object)new VectorWritable(this.perLabelThetaNormalizer));
        super.cleanup(context);
    }
}

