diff --git a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/predict/NodeClassificationPipelineTrainProc.java b/proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/predict/NodeClassificationPipelineTrainProc.java index 6761e862c4..263bb9a2f7 100644 --- a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/predict/NodeClassificationPipelineTrainProc.java +++ b/proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/predict/NodeClassificationPipelineTrainProc.java @@ -22,9 +22,9 @@ import org.neo4j.gds.BaseProc; import org.neo4j.gds.core.model.ModelCatalog; import org.neo4j.gds.executor.ExecutionContext; -import org.neo4j.gds.executor.MemoryEstimationExecutor; import org.neo4j.gds.executor.ProcedureExecutor; import org.neo4j.gds.applications.algorithms.machinery.MemoryEstimateResult; +import org.neo4j.gds.procedures.GraphDataScienceProcedures; import org.neo4j.procedure.Context; import org.neo4j.procedure.Description; import org.neo4j.procedure.Mode; @@ -37,6 +37,9 @@ import static org.neo4j.gds.ml.pipeline.PipelineCompanion.preparePipelineConfig; public class NodeClassificationPipelineTrainProc extends BaseProc { + @Context + public GraphDataScienceProcedures facade; + @Context public ModelCatalog modelCatalog; @@ -59,12 +62,7 @@ public Stream estimate( @Name(value = "graphNameOrConfiguration") Object graphNameOrConfiguration, @Name(value = "algoConfiguration") Map algoConfiguration ) { - preparePipelineConfig(graphNameOrConfiguration, algoConfiguration); - return new MemoryEstimationExecutor<>( - new NodeClassificationPipelineTrainSpec(), - executionContext(), - transactionContext() - ).computeEstimate(graphNameOrConfiguration, algoConfiguration); + return facade.pipelines().nodeClassificationTrainEstimate(graphNameOrConfiguration, algoConfiguration); } @Override diff --git a/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelineApplications.java b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelineApplications.java index 523e3bb41b..55f4fde42e 100644 --- a/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelineApplications.java +++ b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelineApplications.java @@ -54,6 +54,7 @@ import org.neo4j.gds.ml.pipeline.nodePipeline.classification.NodeClassificationTrainingPipeline; import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationPipelineModelInfo; import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationPipelineTrainConfig; +import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationTrain; import org.neo4j.gds.procedures.algorithms.AlgorithmsProcedureFacade; import org.neo4j.gds.termination.TerminationMonitor; @@ -264,11 +265,11 @@ Optional> getSingle(PipelineName pipelineName) { return pipelineRepository.getSingle(user, pipelineName); } - MemoryEstimateResult nodeClassificationEstimate( + MemoryEstimateResult nodeClassificationPredictEstimate( Object graphNameOrConfiguration, NodeClassificationPredictPipelineBaseConfig configuration ) { - var estimate = nodeClassificationMemoryEstimation(configuration); + var estimate = nodeClassificationPredictMemoryEstimation(configuration); var memoryEstimation = MemoryEstimations.builder("Node Classification Predict Pipeline Executor") .add("Pipeline executor", estimate) @@ -291,7 +292,7 @@ PredictMutateResult nodeClassificationMutate(GraphName graphName, Map nodeClassificationMemoryEstimation(configuration), + () -> nodeClassificationPredictMemoryEstimation(configuration), computation, mutateStep, resultBuilder @@ -314,12 +315,38 @@ Stream nodeClassificationStream( configuration, Optional.empty(), label, - () -> nodeClassificationMemoryEstimation(configuration), + () -> nodeClassificationPredictMemoryEstimation(configuration), computation, resultBuilder ); } + MemoryEstimateResult nodeClassificationTrainEstimate( + Object graphNameOrConfiguration, + NodeClassificationPipelineTrainConfig configuration + ) { + var specifiedUser = new User(configuration.username(), false); + var pipelineName = PipelineName.parse(configuration.pipeline()); + + var pipeline = pipelineRepository.getNodeClassificationTrainingPipeline( + specifiedUser, + pipelineName + ); + + var estimate = NodeClassificationTrain.estimate( + pipeline, + configuration, + modelCatalog, + algorithmsProcedureFacade + ); + + var memoryEstimation = MemoryEstimations.builder("Node Classification Train") + .add(estimate) + .build(); + + return algorithmEstimationTemplate.estimate(configuration, graphNameOrConfiguration, memoryEstimation); + } + NodeClassificationTrainingPipeline selectFeatures( PipelineName pipelineName, Iterable nodeFeatureSteps @@ -386,7 +413,7 @@ private Model configura } NodeClassificationPredictPipelineMutateConfig parseNodeClassificationPredictPipelineMutateConfig(Map configuration) { - return parseNodeClassificationPredictPipelineConfig( - NodeClassificationPredictPipelineMutateConfig::of, - configuration - ); + return parseNodeClassificationPipelineConfig(NodeClassificationPredictPipelineMutateConfig::of, configuration); } NodeClassificationPredictPipelineStreamConfig parseNodeClassificationPredictPipelineStreamConfig(Map configuration) { - return parseNodeClassificationPredictPipelineConfig( - NodeClassificationPredictPipelineStreamConfig::of, - configuration - ); + return parseNodeClassificationPipelineConfig(NodeClassificationPredictPipelineStreamConfig::of, configuration); + } + + NodeClassificationPipelineTrainConfig parseNodeClassificationPipelineTrainConfig(Map configuration) { + return parseNodeClassificationPipelineConfig(NodeClassificationPipelineTrainConfig::of, configuration); } NodePropertyPredictionSplitConfig parseNodePropertyPredictionSplitConfig(Map rawConfiguration) { @@ -115,7 +114,7 @@ private CONFIGURATION parseConfiguration( /** * Dumb scaffolding */ - private CONFIGURATION parseNodeClassificationPredictPipelineConfig( + private CONFIGURATION parseNodeClassificationPipelineConfig( BiFunction parser, Map configuration ) { diff --git a/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelinesProcedureFacade.java b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelinesProcedureFacade.java index fac79e8f37..c89e85ca74 100644 --- a/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelinesProcedureFacade.java +++ b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelinesProcedureFacade.java @@ -271,7 +271,7 @@ public Stream nodeClassificationMutateEstimate( var configuration = pipelineConfigurationParser.parseNodeClassificationPredictPipelineMutateConfig( rawConfiguration); - var result = pipelineApplications.nodeClassificationEstimate( + var result = pipelineApplications.nodeClassificationPredictEstimate( graphNameOrConfiguration, configuration ); @@ -301,7 +301,22 @@ public Stream nodeClassificationStreamEstimate( var configuration = pipelineConfigurationParser.parseNodeClassificationPredictPipelineStreamConfig( rawConfiguration); - var result = pipelineApplications.nodeClassificationEstimate( + var result = pipelineApplications.nodeClassificationPredictEstimate( + graphNameOrConfiguration, + configuration + ); + + return Stream.of(result); + } + + public Stream nodeClassificationTrainEstimate( + Object graphNameOrConfiguration, + Map rawConfiguration + ) { + PipelineCompanion.preparePipelineConfig(graphNameOrConfiguration, rawConfiguration); + var configuration = pipelineConfigurationParser.parseNodeClassificationPipelineTrainConfig(rawConfiguration); + + var result = pipelineApplications.nodeClassificationTrainEstimate( graphNameOrConfiguration, configuration );