package org.carrot2.clustering.lingo;

import org.apache.mahout.math.matrix.DoubleFactory2D;
import org.carrot2.core.attribute.Processing;
import org.carrot2.matrix.MatrixUtils;
import org.carrot2.matrix.factorization.IMatrixFactorizationFactory;
import org.carrot2.matrix.factorization.IterationNumberGuesser;
import org.carrot2.matrix.factorization.IterativeMatrixFactorizationFactory;
import org.carrot2.matrix.factorization.KMeansMatrixFactorizationFactory;
import org.carrot2.matrix.factorization.LocalNonnegativeMatrixFactorizationFactory;
import org.carrot2.matrix.factorization.NonnegativeMatrixFactorizationEDFactory;
import org.carrot2.matrix.factorization.NonnegativeMatrixFactorizationKLFactory;
import org.carrot2.matrix.factorization.PartialSingularValueDecompositionFactory;
import org.carrot2.text.analysis.ITokenizer;
import org.carrot2.text.vsm.VectorSpaceModelContext;
import org.carrot2.util.attribute.Attribute;
import org.carrot2.util.attribute.Bindable;
import org.carrot2.util.attribute.Input;
import org.carrot2.util.attribute.Required;
import org.carrot2.util.attribute.constraint.ImplementingClasses;
import org.carrot2.util.attribute.constraint.IntRange;

@Bindable(prefix = "LingoClusteringAlgorithm")
/* loaded from: input_file:org/carrot2/clustering/lingo/TermDocumentMatrixReducer.class */
public class TermDocumentMatrixReducer {

    @ImplementingClasses(classes = {PartialSingularValueDecompositionFactory.class, NonnegativeMatrixFactorizationEDFactory.class, NonnegativeMatrixFactorizationKLFactory.class, LocalNonnegativeMatrixFactorizationFactory.class, KMeansMatrixFactorizationFactory.class}, strict = false)
    @Processing
    @Required
    @Input
    @Attribute
    public IMatrixFactorizationFactory factorizationFactory = new NonnegativeMatrixFactorizationEDFactory();

    @Processing
    @Required
    @Input
    @Attribute
    public IterationNumberGuesser.FactorizationQuality factorizationQuality = IterationNumberGuesser.FactorizationQuality.HIGH;

    @Processing
    @Input
    @Attribute
    @IntRange(min = ITokenizer.TT_NUMERIC, max = 100)
    public int desiredClusterCountBase = 30;

    /* JADX INFO: Access modifiers changed from: package-private */
    public void reduce(LingoProcessingContext lingoProcessingContext) {
        VectorSpaceModelContext vectorSpaceModelContext = lingoProcessingContext.vsmContext;
        if (vectorSpaceModelContext.termDocumentMatrix.columns() == 0 || vectorSpaceModelContext.termDocumentMatrix.rows() == 0) {
            lingoProcessingContext.baseMatrix = DoubleFactory2D.dense.make(vectorSpaceModelContext.termDocumentMatrix.rows(), vectorSpaceModelContext.termDocumentMatrix.columns());
            return;
        }
        if (this.factorizationFactory instanceof IterativeMatrixFactorizationFactory) {
            ((IterativeMatrixFactorizationFactory) this.factorizationFactory).setK(getDesiredClusterCount(lingoProcessingContext));
            IterationNumberGuesser.setEstimatedIterationsNumber((IterativeMatrixFactorizationFactory) this.factorizationFactory, vectorSpaceModelContext.termDocumentMatrix, this.factorizationQuality);
        }
        MatrixUtils.normalizeColumnL2(vectorSpaceModelContext.termDocumentMatrix, null);
        lingoProcessingContext.baseMatrix = this.factorizationFactory.factorize(vectorSpaceModelContext.termDocumentMatrix).getU();
        if ((this.factorizationFactory instanceof IterativeMatrixFactorizationFactory) || lingoProcessingContext.baseMatrix.columns() <= this.desiredClusterCountBase) {
            return;
        }
        lingoProcessingContext.baseMatrix = lingoProcessingContext.baseMatrix.viewPart(0, 0, lingoProcessingContext.baseMatrix.rows(), this.desiredClusterCountBase);
    }

    private int getDesiredClusterCount(LingoProcessingContext lingoProcessingContext) {
        int size = lingoProcessingContext.preprocessingContext.documents.size();
        return Math.min((int) ((this.desiredClusterCountBase / 10.0d) * Math.sqrt(size)), size);
    }
}
