Package org.neo4j.gds.ml.training
Class CrossValidation<MODEL_TYPE>
java.lang.Object
org.neo4j.gds.ml.training.CrossValidation<MODEL_TYPE>
-
Nested Class Summary
Nested ClassesModifier and TypeClassDescriptionstatic interfacestatic interface -
Constructor Summary
ConstructorsConstructorDescriptionCrossValidation(org.neo4j.gds.core.utils.progress.tasks.ProgressTracker progressTracker, org.neo4j.gds.termination.TerminationFlag terminationFlag, List<? extends Metric> metrics, int validationFolds, Optional<Long> randomSeed, CrossValidation.ModelTrainer<MODEL_TYPE> modelTrainer, CrossValidation.ModelEvaluator<MODEL_TYPE> modelEvaluator) -
Method Summary
Modifier and TypeMethodDescriptionstatic List<org.neo4j.gds.core.utils.progress.tasks.Task>progressTasks(int validationFolds, int numberOfModelSelectionTrials, long trainSetSize) voidselectModel(org.neo4j.gds.core.utils.paged.ReadOnlyHugeLongArray outerTrainSet, org.eclipse.collections.api.block.function.primitive.LongToLongFunction targets, SortedSet<Long> distinctInternalTargets, TrainingStatistics trainingStatistics, Iterator<TrainerConfig> modelCandidates)
-
Constructor Details
-
CrossValidation
public CrossValidation(org.neo4j.gds.core.utils.progress.tasks.ProgressTracker progressTracker, org.neo4j.gds.termination.TerminationFlag terminationFlag, List<? extends Metric> metrics, int validationFolds, Optional<Long> randomSeed, CrossValidation.ModelTrainer<MODEL_TYPE> modelTrainer, CrossValidation.ModelEvaluator<MODEL_TYPE> modelEvaluator)
-
-
Method Details
-
progressTasks
public static List<org.neo4j.gds.core.utils.progress.tasks.Task> progressTasks(int validationFolds, int numberOfModelSelectionTrials, long trainSetSize) -
selectModel
public void selectModel(org.neo4j.gds.core.utils.paged.ReadOnlyHugeLongArray outerTrainSet, org.eclipse.collections.api.block.function.primitive.LongToLongFunction targets, SortedSet<Long> distinctInternalTargets, TrainingStatistics trainingStatistics, Iterator<TrainerConfig> modelCandidates)
-