/*
 * Decompiled with CFR 0.152.
 */
package net.myrrix.common.math;

import com.google.common.base.Preconditions;
import java.lang.reflect.Field;
import java.util.Arrays;
import net.myrrix.common.collection.FastByIDFloatMap;
import net.myrrix.common.collection.FastByIDMap;
import net.myrrix.common.collection.FastIDSet;
import net.myrrix.common.math.SimpleVectorMath;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.DecompositionSolver;
import org.apache.commons.math3.linear.QRDecomposition;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.SingularMatrixException;
import org.apache.commons.math3.util.FastMath;
import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class MatrixUtils {
    private static final Logger log = LoggerFactory.getLogger(MatrixUtils.class);
    private static final int PRINT_COLUMN_WIDTH = 12;
    private static final double SINGULARITY_THRESHOLD = Double.parseDouble(System.getProperty("common.matrix.singularityThreshold", "0.001"));
    private static final Field MATRIX_DATA_FIELD = MatrixUtils.loadField(Array2DRowRealMatrix.class, "data");
    private static final Field RDIAG_FIELD = MatrixUtils.loadField(QRDecomposition.class, "rDiag");

    private static Field loadField(Class<?> clazz, String fieldName) {
        Field field;
        try {
            field = clazz.getDeclaredField(fieldName);
        }
        catch (NoSuchFieldException nsfe) {
            log.error("Can't access {}.{}", clazz, (Object)fieldName);
            throw new IllegalStateException(nsfe);
        }
        field.setAccessible(true);
        return field;
    }

    private MatrixUtils() {
    }

    public static void addTo(long row, long column, float value, FastByIDMap<FastByIDFloatMap> RbyRow, FastByIDMap<FastByIDFloatMap> RbyColumn) {
        FastByIDFloatMap theRow = RbyRow.get(row);
        if (theRow == null) {
            theRow = new FastByIDFloatMap();
            RbyRow.put(row, theRow);
        }
        theRow.increment(column, value);
        FastByIDFloatMap theColumn = RbyColumn.get(column);
        if (theColumn == null) {
            theColumn = new FastByIDFloatMap();
            RbyColumn.put(column, theColumn);
        }
        theColumn.increment(row, value);
    }

    public static void remove(long row, long column, FastByIDMap<FastByIDFloatMap> RbyRow, FastByIDMap<FastByIDFloatMap> RbyColumn) {
        FastByIDFloatMap theColumn;
        FastByIDFloatMap theRow = RbyRow.get(row);
        if (theRow != null) {
            theRow.remove(column);
            if (theRow.isEmpty()) {
                RbyRow.remove(row);
            }
        }
        if ((theColumn = RbyColumn.get(column)) != null) {
            theColumn.remove(row);
            if (theColumn.isEmpty()) {
                RbyColumn.remove(column);
            }
        }
    }

    public static FastByIDMap<float[]> getPseudoInverse(FastByIDMap<float[]> M) {
        if (M == null || M.isEmpty()) {
            return M;
        }
        return MatrixUtils.multiply(MatrixUtils.getTransposeTimesSelfInverse(M), M);
    }

    public static RealMatrix getTransposeTimesSelfInverse(FastByIDMap<float[]> M) {
        if (M == null || M.isEmpty()) {
            return null;
        }
        RealMatrix MTM = MatrixUtils.transposeTimesSelf(M);
        return MatrixUtils.invert(MTM);
    }

    public static RealMatrix invert(RealMatrix M) {
        RealMatrix inverse;
        QRDecomposition decomposition = new QRDecomposition(M, SINGULARITY_THRESHOLD);
        DecompositionSolver solver = decomposition.getSolver();
        try {
            inverse = solver.getInverse();
        }
        catch (SingularMatrixException sme) {
            double[] rDiag;
            log.warn("{} x {} matrix is near-singular (threshold {}); add more data or decrease the value of model.features ({})", new Object[]{M.getRowDimension(), M.getColumnDimension(), SINGULARITY_THRESHOLD, sme.toString()});
            try {
                rDiag = (double[])RDIAG_FIELD.get(decomposition);
            }
            catch (IllegalAccessException iae) {
                log.warn("Can't read QR decomposition fields to suggest dimensionality");
                throw sme;
            }
            log.info("QR decomposition diagonal: {}", (Object)Arrays.toString(rDiag));
            for (int i = 0; i < rDiag.length; ++i) {
                if (!(FastMath.abs((double)rDiag[i]) <= SINGULARITY_THRESHOLD)) continue;
                log.info("Suggested value of -Dmodel.features is about {} or less", (Object)i);
                break;
            }
            throw sme;
        }
        return new Array2DRowRealMatrix(inverse.getData());
    }

    public static FastByIDMap<float[]> multiply(RealMatrix M, FastByIDMap<float[]> S) {
        FastByIDMap<float[]> result = new FastByIDMap<float[]>(S.size(), 1.25f);
        double[][] matrixData = MatrixUtils.accessMatrixDataDirectly(M);
        for (FastByIDMap.MapEntry<float[]> entry : S.entrySet()) {
            result.put(entry.getKey(), MatrixUtils.matrixMultiply(matrixData, entry.getValue()));
        }
        return result;
    }

    public static RealMatrix multiplyXYT(FastByIDMap<float[]> X, FastByIDMap<float[]> Y) {
        int Ysize = Y.size();
        int Xsize = X.size();
        Array2DRowRealMatrix result = new Array2DRowRealMatrix(Xsize, Ysize);
        for (int row = 0; row < Xsize; ++row) {
            for (int col = 0; col < Ysize; ++col) {
                result.setEntry(row, col, SimpleVectorMath.dot(X.get(row), Y.get(col)));
            }
        }
        return result;
    }

    private static double[][] accessMatrixDataDirectly(RealMatrix matrix) {
        try {
            return (double[][])MATRIX_DATA_FIELD.get(matrix);
        }
        catch (IllegalAccessException iae) {
            throw new IllegalStateException(iae);
        }
    }

    public static double[] multiply(RealMatrix matrix, float[] V) {
        double[][] M = MatrixUtils.accessMatrixDataDirectly(matrix);
        int rows = M.length;
        int cols = V.length;
        double[] out = new double[rows];
        for (int i = 0; i < rows; ++i) {
            double total = 0.0;
            double[] matrixRow = M[i];
            for (int j = 0; j < cols; ++j) {
                total += (double)V[j] * matrixRow[j];
            }
            out[i] = total;
        }
        return out;
    }

    private static float[] matrixMultiply(double[][] M, float[] V) {
        int rows = M.length;
        int cols = V.length;
        float[] out = new float[rows];
        for (int i = 0; i < rows; ++i) {
            double total = 0.0;
            double[] matrixRow = M[i];
            for (int j = 0; j < cols; ++j) {
                total += (double)V[j] * matrixRow[j];
            }
            out[i] = (float)total;
        }
        return out;
    }

    public static RealMatrix transposeTimesSelf(FastByIDMap<float[]> M) {
        Array2DRowRealMatrix result = null;
        for (FastByIDMap.MapEntry<float[]> entry : M.entrySet()) {
            float[] vector = entry.getValue();
            int dimension = vector.length;
            if (result == null) {
                result = new Array2DRowRealMatrix(dimension, dimension);
            }
            for (int row = 0; row < dimension; ++row) {
                float rowValue = vector[row];
                for (int col = 0; col < dimension; ++col) {
                    result.addToEntry(row, col, (double)(rowValue * vector[col]));
                }
            }
        }
        Preconditions.checkNotNull(result);
        return result;
    }

    public static String matrixToString(FastByIDMap<FastByIDFloatMap> M) {
        long[] rowKeys;
        StringBuilder result = new StringBuilder();
        long[] colKeys = MatrixUtils.unionColumnKeysInOrder(M);
        MatrixUtils.appendWithPadOrTruncate("", result);
        for (long colKey : colKeys) {
            result.append('\t');
            MatrixUtils.appendWithPadOrTruncate(colKey, result);
        }
        result.append("\n\n");
        for (long rowKey : rowKeys = MatrixUtils.keysInOrder(M)) {
            MatrixUtils.appendWithPadOrTruncate(rowKey, result);
            FastByIDFloatMap row = M.get(rowKey);
            for (long colKey : colKeys) {
                result.append('\t');
                float value = row.get(colKey);
                if (Float.isNaN(value)) {
                    MatrixUtils.appendWithPadOrTruncate("", result);
                    continue;
                }
                MatrixUtils.appendWithPadOrTruncate(value, result);
            }
            result.append('\n');
        }
        result.append('\n');
        return result.toString();
    }

    private static long[] keysInOrder(FastByIDMap<?> map) {
        FastIDSet keys = new FastIDSet(map.size(), 1.25f);
        LongPrimitiveIterator it = map.keySetIterator();
        while (it.hasNext()) {
            keys.add(it.nextLong());
        }
        long[] keysArray = keys.toArray();
        Arrays.sort(keysArray);
        return keysArray;
    }

    private static long[] unionColumnKeysInOrder(FastByIDMap<FastByIDFloatMap> M) {
        FastIDSet keys = new FastIDSet(1000, 1.25f);
        for (FastByIDMap.MapEntry<FastByIDFloatMap> entry : M.entrySet()) {
            LongPrimitiveIterator it = entry.getValue().keySetIterator();
            while (it.hasNext()) {
                keys.add(it.nextLong());
            }
        }
        long[] keysArray = keys.toArray();
        Arrays.sort(keysArray);
        return keysArray;
    }

    private static void appendWithPadOrTruncate(long value, StringBuilder to) {
        MatrixUtils.appendWithPadOrTruncate(Long.toString(value), to);
    }

    private static void appendWithPadOrTruncate(float value, StringBuilder to) {
        String stringValue = Float.toString(value);
        if (value >= 0.0f) {
            stringValue = ' ' + stringValue;
        }
        MatrixUtils.appendWithPadOrTruncate(stringValue, to);
    }

    private static void appendWithPadOrTruncate(CharSequence value, StringBuilder to) {
        int length = value.length();
        if (length >= 12) {
            to.append(value, 0, 12);
        } else {
            for (int i = length; i < 12; ++i) {
                to.append(' ');
            }
            to.append(value);
        }
    }
}

