/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.metrics.classification;

import java.util.Comparator;
import java.util.Objects;
import java.util.stream.DoubleStream;
import org.neo4j.gds.collections.LongMultiSet;
import org.neo4j.gds.collections.ha.HugeIntArray;
import org.neo4j.gds.ml.core.subgraph.LocalIdMap;
import org.neo4j.gds.ml.metrics.classification.ClassificationMetric;
import org.neo4j.gds.ml.metrics.classification.F1Score;

public class F1Weighted
implements ClassificationMetric {
    public static final String NAME = "F1_WEIGHTED";
    private final LocalIdMap classIdMap;
    private final LongMultiSet globalClassCounts;

    public F1Weighted(LocalIdMap classIdMap, LongMultiSet globalClassCounts) {
        this.classIdMap = classIdMap;
        this.globalClassCounts = globalClassCounts;
    }

    @Override
    public String name() {
        return NAME;
    }

    @Override
    public Comparator<Double> comparator() {
        return Comparator.naturalOrder();
    }

    @Override
    public double compute(HugeIntArray targets, HugeIntArray predictions) {
        if (this.globalClassCounts.size() == 0L) {
            return 0.0;
        }
        DoubleStream weightedScores = this.classIdMap.getMappings().mapToDouble(idMap -> {
            long weight = this.globalClassCounts.count(idMap.key);
            return (double)weight * new F1Score(idMap.key, idMap.value).compute(targets, predictions);
        });
        return weightedScores.sum() / (double)this.globalClassCounts.sum();
    }

    public boolean equals(Object o) {
        return o != null && this.getClass().equals(o.getClass());
    }

    public int hashCode() {
        return Objects.hash(this.classIdMap);
    }

    public String toString() {
        return NAME;
    }
}

