/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.classifier.bayes.datastore;

import java.util.Collection;
import org.apache.hadoop.conf.Configuration;
import org.apache.mahout.classifier.bayes.common.BayesParameters;
import org.apache.mahout.classifier.bayes.exceptions.InvalidDatastoreException;
import org.apache.mahout.classifier.bayes.interfaces.Datastore;
import org.apache.mahout.classifier.bayes.io.SequenceFileModelReader;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.SparseMatrix;
import org.apache.mahout.math.map.OpenIntDoubleHashMap;
import org.apache.mahout.math.map.OpenObjectIntHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class InMemoryBayesDatastore
implements Datastore {
    private static final Logger log = LoggerFactory.getLogger(InMemoryBayesDatastore.class);
    private final OpenObjectIntHashMap<String> featureDictionary = new OpenObjectIntHashMap();
    private final OpenObjectIntHashMap<String> labelDictionary = new OpenObjectIntHashMap();
    private final OpenIntDoubleHashMap sigmaJ = new OpenIntDoubleHashMap();
    private final OpenIntDoubleHashMap sigmaK = new OpenIntDoubleHashMap();
    private final OpenIntDoubleHashMap thetaNormalizerPerLabel = new OpenIntDoubleHashMap();
    private final Matrix weightMatrix = new SparseMatrix(new int[]{1, 0});
    private final BayesParameters params;
    private double thetaNormalizer = 1.0;
    private double alphaI = 1.0;
    private double sigmaJsigmaK = 1.0;

    public InMemoryBayesDatastore(BayesParameters params) {
        String basePath = params.getBasePath();
        this.params = params;
        params.set("sigma_j", basePath + "/trainer-weights/Sigma_j/part-*");
        params.set("sigma_k", basePath + "/trainer-weights/Sigma_k/part-*");
        params.set("sigma_kSigma_j", basePath + "/trainer-weights/Sigma_kSigma_j/part-*");
        params.set("thetaNormalizer", basePath + "/trainer-thetaNormalizer/part-*");
        params.set("weight", basePath + "/trainer-tfIdf/trainer-tfIdf/part-*");
        this.alphaI = Double.valueOf(params.get("alpha_i", "1.0"));
    }

    @Override
    public void initialize() throws InvalidDatastoreException {
        Configuration conf = new Configuration();
        SequenceFileModelReader.loadModel(this, this.params, conf);
        for (String label : this.getKeys("")) {
            log.info("{} {} {} {}", new Object[]{label, this.thetaNormalizerPerLabel.get(this.getLabelID(label)), this.thetaNormalizer, this.thetaNormalizerPerLabel.get(this.getLabelID(label)) / this.thetaNormalizer});
        }
    }

    @Override
    public Collection<String> getKeys(String name) throws InvalidDatastoreException {
        return this.labelDictionary.keys();
    }

    @Override
    public double getWeight(String matrixName, String row, String column) throws InvalidDatastoreException {
        if ("weight".equals(matrixName)) {
            if ("sigma_j".equals(column)) {
                return this.sigmaJ.get(this.getFeatureID(row));
            }
            return this.weightMatrix.getQuick(this.getFeatureID(row), this.getLabelID(column));
        }
        throw new InvalidDatastoreException("Matrix not found: " + matrixName);
    }

    @Override
    public double getWeight(String vectorName, String index) throws InvalidDatastoreException {
        if ("sumWeight".equals(vectorName)) {
            if ("sigma_jSigma_k".equals(index)) {
                return this.sigmaJsigmaK;
            }
            if ("vocabCount".equals(index)) {
                return this.featureDictionary.size();
            }
            throw new InvalidDatastoreException();
        }
        if ("thetaNormalizer".equals(vectorName)) {
            return this.thetaNormalizerPerLabel.get(this.getLabelID(index)) / this.thetaNormalizer;
        }
        if ("params".equals(vectorName)) {
            if ("alpha_i".equals(index)) {
                return this.alphaI;
            }
            throw new InvalidDatastoreException();
        }
        if ("labelWeight".equals(vectorName)) {
            return this.sigmaK.get(this.getLabelID(index));
        }
        throw new InvalidDatastoreException();
    }

    private int getFeatureID(String feature) {
        if (this.featureDictionary.containsKey((Object)feature)) {
            return this.featureDictionary.get((Object)feature);
        }
        int id = this.featureDictionary.size();
        this.featureDictionary.put((Object)feature, id);
        return id;
    }

    private int getLabelID(String label) {
        if (this.labelDictionary.containsKey((Object)label)) {
            return this.labelDictionary.get((Object)label);
        }
        int id = this.labelDictionary.size();
        this.labelDictionary.put((Object)label, id);
        return id;
    }

    public void loadFeatureWeight(String feature, String label, double weight) {
        int fid = this.getFeatureID(feature);
        int lid = this.getLabelID(label);
        this.weightMatrix.setQuick(fid, lid, weight);
    }

    public void setSumFeatureWeight(String feature, double weight) {
        int fid = this.getFeatureID(feature);
        this.sigmaJ.put(fid, weight);
    }

    public void setSumLabelWeight(String label, double weight) {
        int lid = this.getLabelID(label);
        this.sigmaK.put(lid, weight);
    }

    public void setThetaNormalizer(String label, double weight) {
        int lid = this.getLabelID(label);
        this.thetaNormalizerPerLabel.put(lid, weight);
        this.thetaNormalizer = Math.max(this.thetaNormalizer, Math.abs(weight));
    }

    public void setSigmaJSigmaK(double weight) {
        this.sigmaJsigmaK = weight;
    }
}

