Skip to content

Commit

Permalink
Merge pull request #9687 from lassewesth/ncsomething3
Browse files Browse the repository at this point in the history
migrate node classification train estimate
  • Loading branch information
lassewesth authored Oct 3, 2024
2 parents 1307a11 + 8698f34 commit 0bb9087
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
import org.neo4j.gds.BaseProc;
import org.neo4j.gds.core.model.ModelCatalog;
import org.neo4j.gds.executor.ExecutionContext;
import org.neo4j.gds.executor.MemoryEstimationExecutor;
import org.neo4j.gds.executor.ProcedureExecutor;
import org.neo4j.gds.applications.algorithms.machinery.MemoryEstimateResult;
import org.neo4j.gds.procedures.GraphDataScienceProcedures;
import org.neo4j.procedure.Context;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Mode;
Expand All @@ -37,6 +37,9 @@
import static org.neo4j.gds.ml.pipeline.PipelineCompanion.preparePipelineConfig;

public class NodeClassificationPipelineTrainProc extends BaseProc {
@Context
public GraphDataScienceProcedures facade;

@Context
public ModelCatalog modelCatalog;

Expand All @@ -59,12 +62,7 @@ public Stream<MemoryEstimateResult> estimate(
@Name(value = "graphNameOrConfiguration") Object graphNameOrConfiguration,
@Name(value = "algoConfiguration") Map<String, Object> algoConfiguration
) {
preparePipelineConfig(graphNameOrConfiguration, algoConfiguration);
return new MemoryEstimationExecutor<>(
new NodeClassificationPipelineTrainSpec(),
executionContext(),
transactionContext()
).computeEstimate(graphNameOrConfiguration, algoConfiguration);
return facade.pipelines().nodeClassificationTrainEstimate(graphNameOrConfiguration, algoConfiguration);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.NodeClassificationTrainingPipeline;
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationPipelineModelInfo;
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationPipelineTrainConfig;
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationTrain;
import org.neo4j.gds.procedures.algorithms.AlgorithmsProcedureFacade;
import org.neo4j.gds.termination.TerminationMonitor;

Expand Down Expand Up @@ -264,11 +265,11 @@ Optional<TrainingPipeline<?>> getSingle(PipelineName pipelineName) {
return pipelineRepository.getSingle(user, pipelineName);
}

MemoryEstimateResult nodeClassificationEstimate(
MemoryEstimateResult nodeClassificationPredictEstimate(
Object graphNameOrConfiguration,
NodeClassificationPredictPipelineBaseConfig configuration
) {
var estimate = nodeClassificationMemoryEstimation(configuration);
var estimate = nodeClassificationPredictMemoryEstimation(configuration);

var memoryEstimation = MemoryEstimations.builder("Node Classification Predict Pipeline Executor")
.add("Pipeline executor", estimate)
Expand All @@ -291,7 +292,7 @@ PredictMutateResult nodeClassificationMutate(GraphName graphName, Map<String, Ob
configuration,
Optional.empty(),
label,
() -> nodeClassificationMemoryEstimation(configuration),
() -> nodeClassificationPredictMemoryEstimation(configuration),
computation,
mutateStep,
resultBuilder
Expand All @@ -314,12 +315,38 @@ Stream<NodeClassificationStreamResult> nodeClassificationStream(
configuration,
Optional.empty(),
label,
() -> nodeClassificationMemoryEstimation(configuration),
() -> nodeClassificationPredictMemoryEstimation(configuration),
computation,
resultBuilder
);
}

MemoryEstimateResult nodeClassificationTrainEstimate(
Object graphNameOrConfiguration,
NodeClassificationPipelineTrainConfig configuration
) {
var specifiedUser = new User(configuration.username(), false);
var pipelineName = PipelineName.parse(configuration.pipeline());

var pipeline = pipelineRepository.getNodeClassificationTrainingPipeline(
specifiedUser,
pipelineName
);

var estimate = NodeClassificationTrain.estimate(
pipeline,
configuration,
modelCatalog,
algorithmsProcedureFacade
);

var memoryEstimation = MemoryEstimations.builder("Node Classification Train")
.add(estimate)
.build();

return algorithmEstimationTemplate.estimate(configuration, graphNameOrConfiguration, memoryEstimation);
}

NodeClassificationTrainingPipeline selectFeatures(
PipelineName pipelineName,
Iterable<NodeFeatureStep> nodeFeatureSteps
Expand Down Expand Up @@ -386,7 +413,7 @@ private Model<Classifier.ClassifierData, NodeClassificationPipelineTrainConfig,
);
}

private MemoryEstimation nodeClassificationMemoryEstimation(NodeClassificationPredictPipelineBaseConfig configuration) {
private MemoryEstimation nodeClassificationPredictMemoryEstimation(NodeClassificationPredictPipelineBaseConfig configuration) {
var model = getTrainedNCPipelineModel(configuration.modelName(), configuration.username());

return nodeClassificationPredictPipelineEstimator.estimate(model, configuration);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.neo4j.gds.ml.models.randomforest.RandomForestClassifierTrainerConfig;
import org.neo4j.gds.ml.pipeline.AutoTuningConfig;
import org.neo4j.gds.ml.pipeline.nodePipeline.NodePropertyPredictionSplitConfig;
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationPipelineTrainConfig;

import java.util.Collection;
import java.util.Map;
Expand Down Expand Up @@ -67,17 +68,15 @@ TunableTrainerConfig parseMLPClassifierTrainConfig(Map<String, Object> configura
}

NodeClassificationPredictPipelineMutateConfig parseNodeClassificationPredictPipelineMutateConfig(Map<String, Object> configuration) {
return parseNodeClassificationPredictPipelineConfig(
NodeClassificationPredictPipelineMutateConfig::of,
configuration
);
return parseNodeClassificationPipelineConfig(NodeClassificationPredictPipelineMutateConfig::of, configuration);
}

NodeClassificationPredictPipelineStreamConfig parseNodeClassificationPredictPipelineStreamConfig(Map<String, Object> configuration) {
return parseNodeClassificationPredictPipelineConfig(
NodeClassificationPredictPipelineStreamConfig::of,
configuration
);
return parseNodeClassificationPipelineConfig(NodeClassificationPredictPipelineStreamConfig::of, configuration);
}

NodeClassificationPipelineTrainConfig parseNodeClassificationPipelineTrainConfig(Map<String, Object> configuration) {
return parseNodeClassificationPipelineConfig(NodeClassificationPipelineTrainConfig::of, configuration);
}

NodePropertyPredictionSplitConfig parseNodePropertyPredictionSplitConfig(Map<String, Object> rawConfiguration) {
Expand Down Expand Up @@ -115,7 +114,7 @@ private <CONFIGURATION> CONFIGURATION parseConfiguration(
/**
* Dumb scaffolding
*/
private <CONFIGURATION> CONFIGURATION parseNodeClassificationPredictPipelineConfig(
private <CONFIGURATION> CONFIGURATION parseNodeClassificationPipelineConfig(
BiFunction<String, CypherMapWrapper, CONFIGURATION> parser,
Map<String, Object> configuration
) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ public Stream<MemoryEstimateResult> nodeClassificationMutateEstimate(
var configuration = pipelineConfigurationParser.parseNodeClassificationPredictPipelineMutateConfig(
rawConfiguration);

var result = pipelineApplications.nodeClassificationEstimate(
var result = pipelineApplications.nodeClassificationPredictEstimate(
graphNameOrConfiguration,
configuration
);
Expand Down Expand Up @@ -301,7 +301,22 @@ public Stream<MemoryEstimateResult> nodeClassificationStreamEstimate(
var configuration = pipelineConfigurationParser.parseNodeClassificationPredictPipelineStreamConfig(
rawConfiguration);

var result = pipelineApplications.nodeClassificationEstimate(
var result = pipelineApplications.nodeClassificationPredictEstimate(
graphNameOrConfiguration,
configuration
);

return Stream.of(result);
}

public Stream<MemoryEstimateResult> nodeClassificationTrainEstimate(
Object graphNameOrConfiguration,
Map<String, Object> rawConfiguration
) {
PipelineCompanion.preparePipelineConfig(graphNameOrConfiguration, rawConfiguration);
var configuration = pipelineConfigurationParser.parseNodeClassificationPipelineTrainConfig(rawConfiguration);

var result = pipelineApplications.nodeClassificationTrainEstimate(
graphNameOrConfiguration,
configuration
);
Expand Down

0 comments on commit 0bb9087

Please sign in to comment.