/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.gradientdescent;

import java.util.List;
import java.util.PrimitiveIterator;
import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.batch.Batch;
import org.neo4j.gds.ml.core.functions.Constant;
import org.neo4j.gds.ml.core.functions.Weights;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.gds.ml.core.tensor.Scalar;
import org.neo4j.gds.ml.core.tensor.Tensor;
import org.neo4j.gds.ml.models.Features;

public interface Objective<DATA> {
    public List<Weights<? extends Tensor<?>>> weights();

    public Variable<Scalar> loss(Batch var1, long var2);

    public DATA modelData();

    public static Constant<Matrix> batchFeatureMatrix(Batch batch, Features features) {
        Matrix batchFeatures = new Matrix(batch.size(), features.featureDimension());
        int batchFeaturesOffset = 0;
        PrimitiveIterator.OfLong batchIterator = batch.elementIds();
        while (batchIterator.hasNext()) {
            long elementId = batchIterator.nextLong();
            batchFeatures.setRow(batchFeaturesOffset++, features.get(elementId));
        }
        return new Constant((Tensor)batchFeatures);
    }
}

