Class LogisticRegressionObjective

java.lang.Object
org.neo4j.gds.ml.models.logisticregression.LogisticRegressionObjective
All Implemented Interfaces:
Objective<LogisticRegressionData>

public class LogisticRegressionObjective extends Object implements Objective<LogisticRegressionData>
  • Constructor Details

    • LogisticRegressionObjective

      public LogisticRegressionObjective(LogisticRegressionClassifier classifier, double penalty, Features features, org.neo4j.gds.collections.ha.HugeIntArray labels, double focusWeight, double[] classWeights)
  • Method Details

    • sizeOfBatchInBytes

      public static long sizeOfBatchInBytes(boolean isReduced, int batchSize, int numberOfFeatures, int numberOfClasses)
    • weights

      public List<org.neo4j.gds.ml.core.functions.Weights<? extends org.neo4j.gds.ml.core.tensor.Tensor<?>>> weights()
      Specified by:
      weights in interface Objective<LogisticRegressionData>
    • loss

      public org.neo4j.gds.ml.core.Variable<org.neo4j.gds.ml.core.tensor.Scalar> loss(org.neo4j.gds.ml.core.batch.Batch batch, long trainSize)
      Specified by:
      loss in interface Objective<LogisticRegressionData>
    • modelData

      public LogisticRegressionData modelData()
      Description copied from interface: Objective
      Returns the data, such as weights, needed to store or load the model
      Specified by:
      modelData in interface Objective<LogisticRegressionData>
      Returns:
      the data