Skip to content

Commit

Permalink
migrate link prediction add steps
Browse files Browse the repository at this point in the history
  • Loading branch information
lassewesth committed Oct 9, 2024
1 parent 8903c5d commit e7a31c6
Show file tree
Hide file tree
Showing 14 changed files with 195 additions and 79 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import org.neo4j.gds.config.RelationshipWeightConfig;
import org.neo4j.gds.config.ToMapConvertible;
import org.neo4j.gds.core.model.Model;
import org.neo4j.gds.executor.ExecutionContext;
import org.neo4j.gds.core.model.ModelCatalog;
import org.neo4j.gds.ml.pipeline.ExecutableNodePropertyStep;
import org.neo4j.gds.ml.pipeline.TrainingPipeline;
import org.neo4j.gds.settings.Neo4jSettings;
Expand Down Expand Up @@ -86,12 +86,12 @@ public void specificValidateBeforeExecution(GraphStore graphStore) {
}
}

public Map<String, List<String>> tasksByRelationshipProperty(ExecutionContext executionContext) {
public Map<String, List<String>> tasksByRelationshipProperty(ModelCatalog modelCatalog, String username) {
Map<String, List<String>> tasksByRelationshipProperty = new HashMap<>();

for (ExecutableNodePropertyStep existingStep : nodePropertySteps()) {
Map<String, Object> config = existingStep.config();
Optional<String> maybeProperty = extractRelationshipProperty(executionContext, config);
Optional<String> maybeProperty = extractRelationshipProperty(config, modelCatalog, username);

maybeProperty.ifPresent(property -> {
var tasks = tasksByRelationshipProperty.computeIfAbsent(property, key -> new ArrayList<>());
Expand All @@ -102,16 +102,17 @@ public Map<String, List<String>> tasksByRelationshipProperty(ExecutionContext ex
return tasksByRelationshipProperty;
}

private static Optional<String> extractRelationshipProperty(
ExecutionContext executionContext,
Map<String, Object> config
private Optional<String> extractRelationshipProperty(
Map<String, Object> config,
ModelCatalog modelCatalog,
String username
) {
if (config.containsKey(RELATIONSHIP_WEIGHT_PROPERTY)) {
var existingProperty = (String) config.get(RELATIONSHIP_WEIGHT_PROPERTY);
return Optional.of(existingProperty);
} else if (config.containsKey(MODEL_NAME_KEY)) {
return Optional.ofNullable(executionContext.modelCatalog().getUntyped(
executionContext.username(),
return Optional.ofNullable(modelCatalog.getUntyped(
username,
((String) config.get(MODEL_NAME_KEY))
))
.map(Model::trainConfig)
Expand All @@ -122,8 +123,10 @@ private static Optional<String> extractRelationshipProperty(
return Optional.empty();
}

public Optional<String> relationshipWeightProperty(ExecutionContext executionContext) {
var relationshipWeightPropertySet = tasksByRelationshipProperty(executionContext).entrySet();
public Optional<String> relationshipWeightProperty(ModelCatalog modelCatalog, String username) {
var relationshipWeightPropertySet = tasksByRelationshipProperty(
modelCatalog, username
).entrySet();
return relationshipWeightPropertySet.isEmpty()
? Optional.empty()
: Optional.of(relationshipWeightPropertySet.iterator().next().getKey());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ public static MemoryEstimation estimate(
var splitEstimations = splitEstimation(
pipeline.splitConfig(),
configuration.targetRelationshipType(),
pipeline.relationshipWeightProperty(executionContext)
pipeline.relationshipWeightProperty(executionContext.modelCatalog(), executionContext.username())
);

MemoryEstimation maxOverNodePropertySteps = NodePropertyStepExecutor.estimateNodePropertySteps(
Expand Down Expand Up @@ -156,7 +156,7 @@ public Map<DatasetSplits, PipelineGraphFilter> generateDatasetSplitGraphFilters(

@Override
public void splitDatasets() {
this.linkPredictionRelationshipSampler.splitAndSampleRelationships(pipeline.relationshipWeightProperty(executionContext));
this.linkPredictionRelationshipSampler.splitAndSampleRelationships(pipeline.relationshipWeightProperty(executionContext.modelCatalog(), executionContext.username()));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,13 +212,13 @@ public boolean containsDependency(Class<?> type) {

var pipeline = new LinkPredictionTrainingPipeline();

assertThat(pipeline.relationshipWeightProperty(executionContext)).isEmpty();
assertThat(pipeline.relationshipWeightProperty(executionContext.modelCatalog(), executionContext.username())).isEmpty();

var step = new TestNodePropertyStep(Map.of("relationshipWeightProperty", "myWeight"));

pipeline.addNodePropertyStep(step);

assertThat(pipeline.relationshipWeightProperty(executionContext)).isPresent().get().isEqualTo("myWeight");
assertThat(pipeline.relationshipWeightProperty(executionContext.modelCatalog(), executionContext.username())).isPresent().get().isEqualTo("myWeight");
}

@Test
Expand Down Expand Up @@ -266,13 +266,13 @@ public boolean containsDependency(Class<?> type) {

var pipeline = new LinkPredictionTrainingPipeline();

assertThat(pipeline.relationshipWeightProperty(executionContext)).isEmpty();
assertThat(pipeline.relationshipWeightProperty(executionContext.modelCatalog(), executionContext.username())).isEmpty();

var step = new TestNodePropertyStep(Map.of("modelName", modelName));

pipeline.addNodePropertyStep(step);

assertThat(pipeline.relationshipWeightProperty(executionContext)).isPresent().get().isEqualTo("derivedWeight");
assertThat(pipeline.relationshipWeightProperty(executionContext.modelCatalog(), executionContext.username())).isPresent().get().isEqualTo("derivedWeight");
}

@Test
Expand Down Expand Up @@ -320,13 +320,13 @@ public boolean containsDependency(Class<?> type) {

var pipeline = new LinkPredictionTrainingPipeline();

assertThat(pipeline.relationshipWeightProperty(executionContext)).isEmpty();
assertThat(pipeline.relationshipWeightProperty(executionContext.modelCatalog(), executionContext.username())).isEmpty();

var step = new TestNodePropertyStep(Map.of("modelName", modelName));

pipeline.addNodePropertyStep(step);

assertThat(pipeline.relationshipWeightProperty(executionContext)).isEmpty();
assertThat(pipeline.relationshipWeightProperty(executionContext.modelCatalog(), executionContext.username())).isEmpty();
}

private static class TestNodePropertyStep implements ExecutableNodePropertyStep {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,24 +21,25 @@

import org.neo4j.gds.BaseProc;
import org.neo4j.gds.core.CypherMapWrapper;
import org.neo4j.gds.ml.pipeline.NodePropertyStepFactory;
import org.neo4j.gds.ml.pipeline.PipelineCatalog;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkFeatureStepFactory;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionTrainingPipeline;
import org.neo4j.gds.ml.pipeline.linkPipeline.linkfunctions.LinkFeatureStepConfigurationImpl;
import org.neo4j.gds.procedures.GraphDataScienceProcedures;
import org.neo4j.gds.procedures.pipelines.PipelineInfoResult;
import org.neo4j.procedure.Context;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.Procedure;

import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static org.neo4j.gds.config.RelationshipWeightConfig.RELATIONSHIP_WEIGHT_PROPERTY;
import static org.neo4j.gds.utils.StringFormatting.formatWithLocale;
import static org.neo4j.procedure.Mode.READ;

public class LinkPredictionPipelineAddStepProcs extends BaseProc {
@Context
public GraphDataScienceProcedures facade;

@Procedure(name = "gds.beta.pipeline.linkPrediction.addNodeProperty", mode = READ)
@Description("Add a node property step to an existing link prediction pipeline.")
Expand All @@ -47,12 +48,7 @@ public Stream<PipelineInfoResult> addNodeProperty(
@Name("procedureName") String taskName,
@Name("procedureConfiguration") Map<String, Object> procedureConfig
) {
var pipeline = PipelineCatalog.getTyped(username(), pipelineName, LinkPredictionTrainingPipeline.class);
validateRelationshipProperty(pipeline, procedureConfig);

pipeline.addNodePropertyStep(NodePropertyStepFactory.createNodePropertyStep(taskName, procedureConfig));

return Stream.of(new PipelineInfoResult(pipelineName, pipeline));
return facade.pipelines().linkPrediction().addNodeProperty(pipelineName, taskName, procedureConfig);
}

@Procedure(name = "gds.beta.pipeline.linkPrediction.addFeature", mode = READ)
Expand All @@ -68,33 +64,6 @@ public Stream<PipelineInfoResult> addFeature(

pipeline.addFeatureStep(LinkFeatureStepFactory.create(featureType, parsedConfig));

return Stream.of(new PipelineInfoResult(pipelineName, pipeline));
}

// check if adding would result in more than one relationshipWeightProperty
private void validateRelationshipProperty(
LinkPredictionTrainingPipeline pipeline,
Map<String, Object> procedureConfig
) {
if (!procedureConfig.containsKey(RELATIONSHIP_WEIGHT_PROPERTY)) return;
var maybeRelationshipProperty = pipeline.relationshipWeightProperty(executionContext());
if (maybeRelationshipProperty.isEmpty()) return;
var relationshipProperty = maybeRelationshipProperty.get();
var property = (String) procedureConfig.get(RELATIONSHIP_WEIGHT_PROPERTY);
if (relationshipProperty.equals(property)) return;

String tasks = pipeline.tasksByRelationshipProperty(executionContext())
.get(relationshipProperty)
.stream()
.map(s -> "`" + s + "`")
.collect(Collectors.joining(", "));
throw new IllegalArgumentException(formatWithLocale(
"Node property steps added to a pipeline may not have different non-null values for `%s`. " +
"Pipeline already contains tasks %s which use the value `%s`.",
RELATIONSHIP_WEIGHT_PROPERTY,
tasks,
relationshipProperty
));
return Stream.of(PipelineInfoResult.create(pipelineName, pipeline));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.neo4j.gds.ml.models.randomforest.RandomForestClassifierTrainerConfig;
import org.neo4j.gds.ml.pipeline.PipelineCatalog;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionTrainingPipeline;
import org.neo4j.gds.procedures.pipelines.PipelineInfoResult;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Internal;
import org.neo4j.procedure.Name;
Expand Down Expand Up @@ -56,7 +57,7 @@ public Stream<PipelineInfoResult> addLogisticRegression(
tunableTrainerConfig
);

return Stream.of(new PipelineInfoResult(pipelineName, pipeline));
return Stream.of(PipelineInfoResult.create(pipelineName, pipeline));
}

@Procedure(name = "gds.beta.pipeline.linkPrediction.addRandomForest", mode = READ)
Expand All @@ -75,7 +76,7 @@ public Stream<PipelineInfoResult> addRandomForest(
tunableTrainerConfig
);

return Stream.of(new PipelineInfoResult(pipelineName, pipeline));
return Stream.of(PipelineInfoResult.create(pipelineName, pipeline));
}

@Procedure(name = "gds.alpha.pipeline.linkPrediction.addRandomForest", mode = READ, deprecatedBy = "gds.beta.pipeline.linkPrediction.addRandomForest")
Expand Down Expand Up @@ -109,6 +110,6 @@ public Stream<PipelineInfoResult> addMLP(

pipeline.addTrainerConfig(TunableTrainerConfig.of(mlpClassifierConfig, TrainingMethod.MLPClassification));

return Stream.of(new PipelineInfoResult(pipelineName, pipeline));
return Stream.of(PipelineInfoResult.create(pipelineName, pipeline));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.neo4j.gds.ml.pipeline.PipelineCompanion;
import org.neo4j.gds.ml.pipeline.PipelineCatalog;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionTrainingPipeline;
import org.neo4j.gds.procedures.pipelines.PipelineInfoResult;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.Procedure;
Expand All @@ -42,7 +43,7 @@ public Stream<PipelineInfoResult> configureAutoTuning(@Name("pipelineName") Stri
username(),
pipelineName,
configMap,
pipeline -> new PipelineInfoResult(pipelineName, (LinkPredictionTrainingPipeline) pipeline)
pipeline -> PipelineInfoResult.create(pipelineName, (LinkPredictionTrainingPipeline) pipeline)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.neo4j.gds.ml.pipeline.PipelineCatalog;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionSplitConfig;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionTrainingPipeline;
import org.neo4j.gds.procedures.pipelines.PipelineInfoResult;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.Procedure;
Expand All @@ -47,6 +48,6 @@ public Stream<PipelineInfoResult> configureSplit(@Name("pipelineName") String pi

pipeline.setSplitConfig(config);

return Stream.of(new PipelineInfoResult(pipelineName, pipeline));
return Stream.of(PipelineInfoResult.create(pipelineName, pipeline));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.neo4j.gds.core.StringIdentifierValidations;
import org.neo4j.gds.ml.pipeline.PipelineCatalog;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionTrainingPipeline;
import org.neo4j.gds.procedures.pipelines.PipelineInfoResult;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.Procedure;
Expand All @@ -42,7 +43,7 @@ public Stream<PipelineInfoResult> create(@Name("pipelineName") String input) {
LinkPredictionTrainingPipeline pipeline = new LinkPredictionTrainingPipeline();
PipelineCatalog.set(username(), pipelineName, pipeline);

return Stream.of(new PipelineInfoResult(pipelineName, pipeline));
return Stream.of(PipelineInfoResult.create(pipelineName, pipeline));
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Copyright (c) "Neo4j"
* Neo4j Sweden AB [http://neo4j.com]
*
* This file is part of Neo4j.
*
* Neo4j is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.neo4j.gds.procedures.pipelines;

import java.util.Map;
import java.util.stream.Stream;

public class LinkPredictionFacade {
private final PipelineApplications pipelineApplications;

LinkPredictionFacade(PipelineApplications pipelineApplications) {
this.pipelineApplications = pipelineApplications;
}

public Stream<PipelineInfoResult> addNodeProperty(
String pipelineNameAsString,
String taskName,
Map<String, Object> procedureConfig
) {
var pipelineName = PipelineName.parse(pipelineNameAsString);

var pipeline = pipelineApplications.addNodePropertyToLinkPredictionPipeline(
pipelineName,
taskName,
procedureConfig
);

var result = PipelineInfoResult.create(pipelineName.value, pipeline);

return Stream.of(result);
}
}
Loading

0 comments on commit e7a31c6

Please sign in to comment.