diff --git a/pipeline/src/main/java/org/neo4j/gds/ml/pipeline/linkPipeline/LinkPredictionTrainingPipeline.java b/pipeline/src/main/java/org/neo4j/gds/ml/pipeline/linkPipeline/LinkPredictionTrainingPipeline.java index 1ed5fe8501f..5f21bb854be 100644 --- a/pipeline/src/main/java/org/neo4j/gds/ml/pipeline/linkPipeline/LinkPredictionTrainingPipeline.java +++ b/pipeline/src/main/java/org/neo4j/gds/ml/pipeline/linkPipeline/LinkPredictionTrainingPipeline.java @@ -23,7 +23,7 @@ import org.neo4j.gds.config.RelationshipWeightConfig; import org.neo4j.gds.config.ToMapConvertible; import org.neo4j.gds.core.model.Model; -import org.neo4j.gds.executor.ExecutionContext; +import org.neo4j.gds.core.model.ModelCatalog; import org.neo4j.gds.ml.pipeline.ExecutableNodePropertyStep; import org.neo4j.gds.ml.pipeline.TrainingPipeline; import org.neo4j.gds.settings.Neo4jSettings; @@ -86,12 +86,12 @@ public void specificValidateBeforeExecution(GraphStore graphStore) { } } - public Map> tasksByRelationshipProperty(ExecutionContext executionContext) { + public Map> tasksByRelationshipProperty(ModelCatalog modelCatalog, String username) { Map> tasksByRelationshipProperty = new HashMap<>(); for (ExecutableNodePropertyStep existingStep : nodePropertySteps()) { Map config = existingStep.config(); - Optional maybeProperty = extractRelationshipProperty(executionContext, config); + Optional maybeProperty = extractRelationshipProperty(config, modelCatalog, username); maybeProperty.ifPresent(property -> { var tasks = tasksByRelationshipProperty.computeIfAbsent(property, key -> new ArrayList<>()); @@ -102,16 +102,17 @@ public Map> tasksByRelationshipProperty(ExecutionContext ex return tasksByRelationshipProperty; } - private static Optional extractRelationshipProperty( - ExecutionContext executionContext, - Map config + private Optional extractRelationshipProperty( + Map config, + ModelCatalog modelCatalog, + String username ) { if (config.containsKey(RELATIONSHIP_WEIGHT_PROPERTY)) { var existingProperty = (String) config.get(RELATIONSHIP_WEIGHT_PROPERTY); return Optional.of(existingProperty); } else if (config.containsKey(MODEL_NAME_KEY)) { - return Optional.ofNullable(executionContext.modelCatalog().getUntyped( - executionContext.username(), + return Optional.ofNullable(modelCatalog.getUntyped( + username, ((String) config.get(MODEL_NAME_KEY)) )) .map(Model::trainConfig) @@ -122,8 +123,10 @@ private static Optional extractRelationshipProperty( return Optional.empty(); } - public Optional relationshipWeightProperty(ExecutionContext executionContext) { - var relationshipWeightPropertySet = tasksByRelationshipProperty(executionContext).entrySet(); + public Optional relationshipWeightProperty(ModelCatalog modelCatalog, String username) { + var relationshipWeightPropertySet = tasksByRelationshipProperty( + modelCatalog, username + ).entrySet(); return relationshipWeightPropertySet.isEmpty() ? Optional.empty() : Optional.of(relationshipWeightPropertySet.iterator().next().getKey()); diff --git a/pipeline/src/main/java/org/neo4j/gds/ml/pipeline/linkPipeline/train/LinkPredictionTrainPipelineExecutor.java b/pipeline/src/main/java/org/neo4j/gds/ml/pipeline/linkPipeline/train/LinkPredictionTrainPipelineExecutor.java index 23c780fa373..53a4c7d1572 100644 --- a/pipeline/src/main/java/org/neo4j/gds/ml/pipeline/linkPipeline/train/LinkPredictionTrainPipelineExecutor.java +++ b/pipeline/src/main/java/org/neo4j/gds/ml/pipeline/linkPipeline/train/LinkPredictionTrainPipelineExecutor.java @@ -112,7 +112,7 @@ public static MemoryEstimation estimate( var splitEstimations = splitEstimation( pipeline.splitConfig(), configuration.targetRelationshipType(), - pipeline.relationshipWeightProperty(executionContext) + pipeline.relationshipWeightProperty(executionContext.modelCatalog(), executionContext.username()) ); MemoryEstimation maxOverNodePropertySteps = NodePropertyStepExecutor.estimateNodePropertySteps( @@ -156,7 +156,7 @@ public Map generateDatasetSplitGraphFilters( @Override public void splitDatasets() { - this.linkPredictionRelationshipSampler.splitAndSampleRelationships(pipeline.relationshipWeightProperty(executionContext)); + this.linkPredictionRelationshipSampler.splitAndSampleRelationships(pipeline.relationshipWeightProperty(executionContext.modelCatalog(), executionContext.username())); } @Override diff --git a/pipeline/src/test/java/org/neo4j/gds/ml/pipeline/linkPipeline/LinkPredictionTrainingPipelineTest.java b/pipeline/src/test/java/org/neo4j/gds/ml/pipeline/linkPipeline/LinkPredictionTrainingPipelineTest.java index 40eb44a5f47..e9181fdd0d6 100644 --- a/pipeline/src/test/java/org/neo4j/gds/ml/pipeline/linkPipeline/LinkPredictionTrainingPipelineTest.java +++ b/pipeline/src/test/java/org/neo4j/gds/ml/pipeline/linkPipeline/LinkPredictionTrainingPipelineTest.java @@ -212,13 +212,13 @@ public boolean containsDependency(Class type) { var pipeline = new LinkPredictionTrainingPipeline(); - assertThat(pipeline.relationshipWeightProperty(executionContext)).isEmpty(); + assertThat(pipeline.relationshipWeightProperty(executionContext.modelCatalog(), executionContext.username())).isEmpty(); var step = new TestNodePropertyStep(Map.of("relationshipWeightProperty", "myWeight")); pipeline.addNodePropertyStep(step); - assertThat(pipeline.relationshipWeightProperty(executionContext)).isPresent().get().isEqualTo("myWeight"); + assertThat(pipeline.relationshipWeightProperty(executionContext.modelCatalog(), executionContext.username())).isPresent().get().isEqualTo("myWeight"); } @Test @@ -266,13 +266,13 @@ public boolean containsDependency(Class type) { var pipeline = new LinkPredictionTrainingPipeline(); - assertThat(pipeline.relationshipWeightProperty(executionContext)).isEmpty(); + assertThat(pipeline.relationshipWeightProperty(executionContext.modelCatalog(), executionContext.username())).isEmpty(); var step = new TestNodePropertyStep(Map.of("modelName", modelName)); pipeline.addNodePropertyStep(step); - assertThat(pipeline.relationshipWeightProperty(executionContext)).isPresent().get().isEqualTo("derivedWeight"); + assertThat(pipeline.relationshipWeightProperty(executionContext.modelCatalog(), executionContext.username())).isPresent().get().isEqualTo("derivedWeight"); } @Test @@ -320,13 +320,13 @@ public boolean containsDependency(Class type) { var pipeline = new LinkPredictionTrainingPipeline(); - assertThat(pipeline.relationshipWeightProperty(executionContext)).isEmpty(); + assertThat(pipeline.relationshipWeightProperty(executionContext.modelCatalog(), executionContext.username())).isEmpty(); var step = new TestNodePropertyStep(Map.of("modelName", modelName)); pipeline.addNodePropertyStep(step); - assertThat(pipeline.relationshipWeightProperty(executionContext)).isEmpty(); + assertThat(pipeline.relationshipWeightProperty(executionContext.modelCatalog(), executionContext.username())).isEmpty(); } private static class TestNodePropertyStep implements ExecutableNodePropertyStep { diff --git a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/LinkPredictionPipelineAddStepProcs.java b/proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/LinkPredictionPipelineAddStepProcs.java index 5be2d2260d2..ec48835ccb4 100644 --- a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/LinkPredictionPipelineAddStepProcs.java +++ b/proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/LinkPredictionPipelineAddStepProcs.java @@ -21,24 +21,25 @@ import org.neo4j.gds.BaseProc; import org.neo4j.gds.core.CypherMapWrapper; -import org.neo4j.gds.ml.pipeline.NodePropertyStepFactory; import org.neo4j.gds.ml.pipeline.PipelineCatalog; import org.neo4j.gds.ml.pipeline.linkPipeline.LinkFeatureStepFactory; import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionTrainingPipeline; import org.neo4j.gds.ml.pipeline.linkPipeline.linkfunctions.LinkFeatureStepConfigurationImpl; +import org.neo4j.gds.procedures.GraphDataScienceProcedures; +import org.neo4j.gds.procedures.pipelines.PipelineInfoResult; +import org.neo4j.procedure.Context; import org.neo4j.procedure.Description; import org.neo4j.procedure.Name; import org.neo4j.procedure.Procedure; import java.util.Map; -import java.util.stream.Collectors; import java.util.stream.Stream; -import static org.neo4j.gds.config.RelationshipWeightConfig.RELATIONSHIP_WEIGHT_PROPERTY; -import static org.neo4j.gds.utils.StringFormatting.formatWithLocale; import static org.neo4j.procedure.Mode.READ; public class LinkPredictionPipelineAddStepProcs extends BaseProc { + @Context + public GraphDataScienceProcedures facade; @Procedure(name = "gds.beta.pipeline.linkPrediction.addNodeProperty", mode = READ) @Description("Add a node property step to an existing link prediction pipeline.") @@ -47,12 +48,7 @@ public Stream addNodeProperty( @Name("procedureName") String taskName, @Name("procedureConfiguration") Map procedureConfig ) { - var pipeline = PipelineCatalog.getTyped(username(), pipelineName, LinkPredictionTrainingPipeline.class); - validateRelationshipProperty(pipeline, procedureConfig); - - pipeline.addNodePropertyStep(NodePropertyStepFactory.createNodePropertyStep(taskName, procedureConfig)); - - return Stream.of(new PipelineInfoResult(pipelineName, pipeline)); + return facade.pipelines().linkPrediction().addNodeProperty(pipelineName, taskName, procedureConfig); } @Procedure(name = "gds.beta.pipeline.linkPrediction.addFeature", mode = READ) @@ -68,33 +64,6 @@ public Stream addFeature( pipeline.addFeatureStep(LinkFeatureStepFactory.create(featureType, parsedConfig)); - return Stream.of(new PipelineInfoResult(pipelineName, pipeline)); - } - - // check if adding would result in more than one relationshipWeightProperty - private void validateRelationshipProperty( - LinkPredictionTrainingPipeline pipeline, - Map procedureConfig - ) { - if (!procedureConfig.containsKey(RELATIONSHIP_WEIGHT_PROPERTY)) return; - var maybeRelationshipProperty = pipeline.relationshipWeightProperty(executionContext()); - if (maybeRelationshipProperty.isEmpty()) return; - var relationshipProperty = maybeRelationshipProperty.get(); - var property = (String) procedureConfig.get(RELATIONSHIP_WEIGHT_PROPERTY); - if (relationshipProperty.equals(property)) return; - - String tasks = pipeline.tasksByRelationshipProperty(executionContext()) - .get(relationshipProperty) - .stream() - .map(s -> "`" + s + "`") - .collect(Collectors.joining(", ")); - throw new IllegalArgumentException(formatWithLocale( - "Node property steps added to a pipeline may not have different non-null values for `%s`. " + - "Pipeline already contains tasks %s which use the value `%s`.", - RELATIONSHIP_WEIGHT_PROPERTY, - tasks, - relationshipProperty - )); + return Stream.of(PipelineInfoResult.create(pipelineName, pipeline)); } - } diff --git a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/LinkPredictionPipelineAddTrainerMethodProcs.java b/proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/LinkPredictionPipelineAddTrainerMethodProcs.java index 32a3bbc5f4f..7a8770c4547 100644 --- a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/LinkPredictionPipelineAddTrainerMethodProcs.java +++ b/proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/LinkPredictionPipelineAddTrainerMethodProcs.java @@ -28,6 +28,7 @@ import org.neo4j.gds.ml.models.randomforest.RandomForestClassifierTrainerConfig; import org.neo4j.gds.ml.pipeline.PipelineCatalog; import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionTrainingPipeline; +import org.neo4j.gds.procedures.pipelines.PipelineInfoResult; import org.neo4j.procedure.Description; import org.neo4j.procedure.Internal; import org.neo4j.procedure.Name; @@ -56,7 +57,7 @@ public Stream addLogisticRegression( tunableTrainerConfig ); - return Stream.of(new PipelineInfoResult(pipelineName, pipeline)); + return Stream.of(PipelineInfoResult.create(pipelineName, pipeline)); } @Procedure(name = "gds.beta.pipeline.linkPrediction.addRandomForest", mode = READ) @@ -75,7 +76,7 @@ public Stream addRandomForest( tunableTrainerConfig ); - return Stream.of(new PipelineInfoResult(pipelineName, pipeline)); + return Stream.of(PipelineInfoResult.create(pipelineName, pipeline)); } @Procedure(name = "gds.alpha.pipeline.linkPrediction.addRandomForest", mode = READ, deprecatedBy = "gds.beta.pipeline.linkPrediction.addRandomForest") @@ -109,6 +110,6 @@ public Stream addMLP( pipeline.addTrainerConfig(TunableTrainerConfig.of(mlpClassifierConfig, TrainingMethod.MLPClassification)); - return Stream.of(new PipelineInfoResult(pipelineName, pipeline)); + return Stream.of(PipelineInfoResult.create(pipelineName, pipeline)); } } diff --git a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/LinkPredictionPipelineConfigureAutoTuningProc.java b/proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/LinkPredictionPipelineConfigureAutoTuningProc.java index eff6a95d69c..2db3069a5d8 100644 --- a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/LinkPredictionPipelineConfigureAutoTuningProc.java +++ b/proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/LinkPredictionPipelineConfigureAutoTuningProc.java @@ -23,6 +23,7 @@ import org.neo4j.gds.ml.pipeline.PipelineCompanion; import org.neo4j.gds.ml.pipeline.PipelineCatalog; import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionTrainingPipeline; +import org.neo4j.gds.procedures.pipelines.PipelineInfoResult; import org.neo4j.procedure.Description; import org.neo4j.procedure.Name; import org.neo4j.procedure.Procedure; @@ -42,7 +43,7 @@ public Stream configureAutoTuning(@Name("pipelineName") Stri username(), pipelineName, configMap, - pipeline -> new PipelineInfoResult(pipelineName, (LinkPredictionTrainingPipeline) pipeline) + pipeline -> PipelineInfoResult.create(pipelineName, (LinkPredictionTrainingPipeline) pipeline) ); } } diff --git a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/LinkPredictionPipelineConfigureSplitProc.java b/proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/LinkPredictionPipelineConfigureSplitProc.java index 39d3fc3b523..c46d1c04ba8 100644 --- a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/LinkPredictionPipelineConfigureSplitProc.java +++ b/proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/LinkPredictionPipelineConfigureSplitProc.java @@ -24,6 +24,7 @@ import org.neo4j.gds.ml.pipeline.PipelineCatalog; import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionSplitConfig; import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionTrainingPipeline; +import org.neo4j.gds.procedures.pipelines.PipelineInfoResult; import org.neo4j.procedure.Description; import org.neo4j.procedure.Name; import org.neo4j.procedure.Procedure; @@ -47,6 +48,6 @@ public Stream configureSplit(@Name("pipelineName") String pi pipeline.setSplitConfig(config); - return Stream.of(new PipelineInfoResult(pipelineName, pipeline)); + return Stream.of(PipelineInfoResult.create(pipelineName, pipeline)); } } diff --git a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/LinkPredictionPipelineCreateProc.java b/proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/LinkPredictionPipelineCreateProc.java index d179772a07e..9dd8f39b6ed 100644 --- a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/LinkPredictionPipelineCreateProc.java +++ b/proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/LinkPredictionPipelineCreateProc.java @@ -23,6 +23,7 @@ import org.neo4j.gds.core.StringIdentifierValidations; import org.neo4j.gds.ml.pipeline.PipelineCatalog; import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionTrainingPipeline; +import org.neo4j.gds.procedures.pipelines.PipelineInfoResult; import org.neo4j.procedure.Description; import org.neo4j.procedure.Name; import org.neo4j.procedure.Procedure; @@ -42,7 +43,7 @@ public Stream create(@Name("pipelineName") String input) { LinkPredictionTrainingPipeline pipeline = new LinkPredictionTrainingPipeline(); PipelineCatalog.set(username(), pipelineName, pipeline); - return Stream.of(new PipelineInfoResult(pipelineName, pipeline)); + return Stream.of(PipelineInfoResult.create(pipelineName, pipeline)); } } diff --git a/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/LinkPredictionFacade.java b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/LinkPredictionFacade.java new file mode 100644 index 00000000000..1fc5446684d --- /dev/null +++ b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/LinkPredictionFacade.java @@ -0,0 +1,49 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.gds.procedures.pipelines; + +import java.util.Map; +import java.util.stream.Stream; + +public class LinkPredictionFacade { + private final PipelineApplications pipelineApplications; + + LinkPredictionFacade(PipelineApplications pipelineApplications) { + this.pipelineApplications = pipelineApplications; + } + + public Stream addNodeProperty( + String pipelineNameAsString, + String taskName, + Map procedureConfig + ) { + var pipelineName = PipelineName.parse(pipelineNameAsString); + + var pipeline = pipelineApplications.addNodePropertyToLinkPredictionPipeline( + pipelineName, + taskName, + procedureConfig + ); + + var result = PipelineInfoResult.create(pipelineName.value, pipeline); + + return Stream.of(result); + } +} diff --git a/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelineApplications.java b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelineApplications.java index 16293b0e808..67a28725501 100644 --- a/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelineApplications.java +++ b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelineApplications.java @@ -52,6 +52,7 @@ import org.neo4j.gds.ml.pipeline.NodePropertyStepFactory; import org.neo4j.gds.ml.pipeline.PipelineCatalog; import org.neo4j.gds.ml.pipeline.TrainingPipeline; +import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionTrainingPipeline; import org.neo4j.gds.ml.pipeline.nodePipeline.NodeFeatureStep; import org.neo4j.gds.ml.pipeline.nodePipeline.NodePropertyPredictionSplitConfig; import org.neo4j.gds.ml.pipeline.nodePipeline.classification.NodeClassificationTrainingPipeline; @@ -67,8 +68,12 @@ import java.util.Map; import java.util.Optional; import java.util.function.Consumer; +import java.util.stream.Collectors; import java.util.stream.Stream; +import static org.neo4j.gds.config.RelationshipWeightConfig.RELATIONSHIP_WEIGHT_PROPERTY; +import static org.neo4j.gds.utils.StringFormatting.formatWithLocale; + class PipelineApplications { private final Log log; private final GraphStoreService graphStoreService; @@ -224,7 +229,22 @@ static PipelineApplications create( ); } - NodeClassificationTrainingPipeline addNodeProperty( + LinkPredictionTrainingPipeline addNodePropertyToLinkPredictionPipeline( + PipelineName pipelineName, + String taskName, + Map procedureConfig + ) { + var pipeline = pipelineRepository.getLinkPredictionTrainingPipeline(user, pipelineName); + validateRelationshipProperty(pipeline, procedureConfig); + + var nodePropertyStep = NodePropertyStepFactory.createNodePropertyStep(taskName, procedureConfig); + + pipeline.addNodePropertyStep(nodePropertyStep); + + return pipeline; + } + + NodeClassificationTrainingPipeline addNodePropertyToNodeClassificationPipeline( PipelineName pipelineName, String taskName, Map procedureConfig @@ -534,4 +554,33 @@ private MemoryEstimation nodeClassificationTrainEstimation(NodeClassificationPip .add(estimate) .build(); } + + /** + * check if adding would result in more than one relationshipWeightProperty + */ + private void validateRelationshipProperty( + LinkPredictionTrainingPipeline pipeline, + Map procedureConfig + ) { + if (!procedureConfig.containsKey(RELATIONSHIP_WEIGHT_PROPERTY)) return; + var maybeRelationshipProperty = pipeline.relationshipWeightProperty(modelCatalog, user.getUsername()); + if (maybeRelationshipProperty.isEmpty()) return; + var relationshipProperty = maybeRelationshipProperty.get(); + var property = (String) procedureConfig.get(RELATIONSHIP_WEIGHT_PROPERTY); + if (relationshipProperty.equals(property)) return; + + String tasks = pipeline.tasksByRelationshipProperty(modelCatalog, user.getUsername()) + .get(relationshipProperty) + .stream() + .map(s -> "`" + s + "`") + .collect(Collectors.joining(", ")); + + throw new IllegalArgumentException(formatWithLocale( + "Node property steps added to a pipeline may not have different non-null values for `%s`. " + + "Pipeline already contains tasks %s which use the value `%s`.", + RELATIONSHIP_WEIGHT_PROPERTY, + tasks, + relationshipProperty + )); + } } diff --git a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/PipelineInfoResult.java b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelineInfoResult.java similarity index 54% rename from proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/PipelineInfoResult.java rename to procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelineInfoResult.java index 4415fee4256..5c9d3f66b82 100644 --- a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/PipelineInfoResult.java +++ b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelineInfoResult.java @@ -17,7 +17,7 @@ * You should have received a copy of the GNU General Public License * along with this program. If not, see . */ -package org.neo4j.gds.ml.linkmodels.pipeline; +package org.neo4j.gds.procedures.pipelines; import org.neo4j.gds.ml.pipeline.ExecutableNodePropertyStep; import org.neo4j.gds.ml.pipeline.TrainingPipeline; @@ -26,9 +26,8 @@ import java.util.List; import java.util.Map; -import java.util.stream.Collectors; -public class PipelineInfoResult { +public final class PipelineInfoResult { public final String name; public final List> nodePropertySteps; public final List> featureSteps; @@ -36,16 +35,38 @@ public class PipelineInfoResult { public final Map autoTuningConfig; public final Object parameterSpace; - PipelineInfoResult(String pipelineName, LinkPredictionTrainingPipeline pipeline) { - this.name = pipelineName; - this.nodePropertySteps = pipeline + private PipelineInfoResult( + String name, + List> nodePropertySteps, + List> featureSteps, + Map splitConfig, + Map autoTuningConfig, + Object parameterSpace + ) { + this.name = name; + this.nodePropertySteps = nodePropertySteps; + this.featureSteps = featureSteps; + this.splitConfig = splitConfig; + this.autoTuningConfig = autoTuningConfig; + this.parameterSpace = parameterSpace; + } + + public static PipelineInfoResult create(String pipelineName, LinkPredictionTrainingPipeline pipeline) { + var nodePropertySteps = pipeline .nodePropertySteps() .stream() .map(ExecutableNodePropertyStep::toMap) - .collect(Collectors.toList()); - this.featureSteps = pipeline.featureSteps().stream().map(LinkFeatureStep::toMap).collect(Collectors.toList()); - this.splitConfig = pipeline.splitConfig().toMap(); - this.autoTuningConfig = pipeline.autoTuningConfig().toMap(); - this.parameterSpace = TrainingPipeline.toMapParameterSpace(pipeline.trainingParameterSpace()); + .toList(); + + var featureSteps = pipeline.featureSteps().stream().map(LinkFeatureStep::toMap).toList(); + + return new PipelineInfoResult( + pipelineName, + nodePropertySteps, + featureSteps, + pipeline.splitConfig().toMap(), + pipeline.autoTuningConfig().toMap(), + TrainingPipeline.toMapParameterSpace(pipeline.trainingParameterSpace()) + ); } } diff --git a/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelineRepository.java b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelineRepository.java index 37bf346baeb..cc741a86312 100644 --- a/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelineRepository.java +++ b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelineRepository.java @@ -22,6 +22,7 @@ import org.neo4j.gds.api.User; import org.neo4j.gds.ml.pipeline.PipelineCatalog; import org.neo4j.gds.ml.pipeline.TrainingPipeline; +import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionTrainingPipeline; import org.neo4j.gds.ml.pipeline.nodePipeline.classification.NodeClassificationTrainingPipeline; import java.util.Optional; @@ -55,6 +56,10 @@ Stream getAll(User user) { return PipelineCatalog.getAllPipelines(user.getUsername()); } + LinkPredictionTrainingPipeline getLinkPredictionTrainingPipeline(User user, PipelineName pipelineName) { + return PipelineCatalog.getTyped(user.getUsername(), pipelineName.value, LinkPredictionTrainingPipeline.class); + } + NodeClassificationTrainingPipeline getNodeClassificationTrainingPipeline(User user, PipelineName pipelineName) { return PipelineCatalog.getTyped( user.getUsername(), diff --git a/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelinesProcedureFacade.java b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelinesProcedureFacade.java index 03df661f8dc..6f54c93322a 100644 --- a/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelinesProcedureFacade.java +++ b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelinesProcedureFacade.java @@ -62,14 +62,18 @@ public final class PipelinesProcedureFacade { private final PipelineApplications pipelineApplications; + private final LinkPredictionFacade linkPredictionFacade; + PipelinesProcedureFacade( NodeClassificationPredictConfigPreProcessor nodeClassificationPredictConfigPreProcessor, PipelineConfigurationParser pipelineConfigurationParser, - PipelineApplications pipelineApplications + PipelineApplications pipelineApplications, + LinkPredictionFacade linkPredictionFacade ) { this.nodeClassificationPredictConfigPreProcessor = nodeClassificationPredictConfigPreProcessor; this.pipelineConfigurationParser = pipelineConfigurationParser; this.pipelineApplications = pipelineApplications; + this.linkPredictionFacade = linkPredictionFacade; } public static PipelinesProcedureFacade create( @@ -127,10 +131,13 @@ public static PipelinesProcedureFacade create( algorithmProcessingTemplate ); + var linkPredictionFacade = new LinkPredictionFacade(pipelineApplications); + return new PipelinesProcedureFacade( nodeClassificationPredictConfigPreProcessor, pipelineConfigurationParser, - pipelineApplications + pipelineApplications, + linkPredictionFacade ); } @@ -160,7 +167,11 @@ public Stream addNodeProperty( ) { var pipelineName = PipelineName.parse(pipelineNameAsString); - var pipeline = pipelineApplications.addNodeProperty(pipelineName, taskName, procedureConfig); + var pipeline = pipelineApplications.addNodePropertyToNodeClassificationPipeline( + pipelineName, + taskName, + procedureConfig + ); var result = NodePipelineInfoResult.create(pipelineName, pipeline); @@ -416,4 +427,8 @@ private List parseNodeProperties(Object nodeProperties) { throw new IllegalArgumentException("The value of `nodeProperties` is required to be a list of strings."); } + + public LinkPredictionFacade linkPrediction() { + return linkPredictionFacade; + } } diff --git a/procedures/pipelines-facade/src/test/java/org/neo4j/gds/procedures/pipelines/PipelinesProcedureFacadeTest.java b/procedures/pipelines-facade/src/test/java/org/neo4j/gds/procedures/pipelines/PipelinesProcedureFacadeTest.java index 3ecd521de0c..7f45242f75a 100644 --- a/procedures/pipelines-facade/src/test/java/org/neo4j/gds/procedures/pipelines/PipelinesProcedureFacadeTest.java +++ b/procedures/pipelines-facade/src/test/java/org/neo4j/gds/procedures/pipelines/PipelinesProcedureFacadeTest.java @@ -59,7 +59,7 @@ void createPipeline() { null, null ); - var facade = new PipelinesProcedureFacade(null, null, applications); + var facade = new PipelinesProcedureFacade(null, null, applications, null); var result = facade.createPipeline("myPipeline").findAny().orElseThrow(); @@ -102,7 +102,7 @@ void shouldNotCreatePipelineWhenOneExists() { null, null ); - var facade = new PipelinesProcedureFacade(null, null, applications); + var facade = new PipelinesProcedureFacade(null, null, applications, null); assertThatIllegalStateException() .isThrownBy(() -> facade.createPipeline("myPipeline")) @@ -112,6 +112,7 @@ void shouldNotCreatePipelineWhenOneExists() { @Test void shouldNotCreatePipelineWithInvalidName() { var facade = new PipelinesProcedureFacade( + null, null, null, null