/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.models;

import java.util.List;
import org.immutables.value.Value;
import org.neo4j.gds.annotation.Configuration;
import org.neo4j.gds.ml.models.TrainerConfig;
import org.neo4j.gds.utils.StringFormatting;

public interface ClassAwareTrainerConfig
extends TrainerConfig {
    @Configuration.DoubleRange(min=0.0)
    default public double focusWeight() {
        return 0.0;
    }

    @Value.Default
    default public List<Double> classWeights() {
        return List.of();
    }

    @Configuration.Ignore
    default public double[] initializeClassWeights(int numberOfClasses) {
        double[] initializedClassWeights;
        if (this.classWeights().isEmpty()) {
            initializedClassWeights = new double[numberOfClasses];
            for (int i2 = 0; i2 < numberOfClasses; ++i2) {
                initializedClassWeights[i2] = 1.0;
            }
        } else {
            if (this.classWeights().size() != numberOfClasses) {
                throw new IllegalArgumentException(StringFormatting.formatWithLocale((String)"The classWeights list %s has %s entries, but it should have %s entries instead, which is the number of classes.", (Object[])new Object[]{this.classWeights(), this.classWeights().size(), numberOfClasses}));
            }
            initializedClassWeights = this.classWeights().stream().mapToDouble(i -> i).toArray();
        }
        return initializedClassWeights;
    }
}

