diff --git a/applications/algorithms/machinery/src/main/java/org/neo4j/gds/applications/algorithms/machinery/NodePropertyWriter.java b/applications/algorithms/machinery/src/main/java/org/neo4j/gds/applications/algorithms/machinery/NodePropertyWriter.java
new file mode 100644
index 0000000000..f6529362ff
--- /dev/null
+++ b/applications/algorithms/machinery/src/main/java/org/neo4j/gds/applications/algorithms/machinery/NodePropertyWriter.java
@@ -0,0 +1,186 @@
+/*
+ * Copyright (c) "Neo4j"
+ * Neo4j Sweden AB [http://neo4j.com]
+ *
+ * This file is part of Neo4j.
+ *
+ * Neo4j is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ */
+package org.neo4j.gds.applications.algorithms.machinery;
+
+import org.neo4j.gds.api.Graph;
+import org.neo4j.gds.api.GraphStore;
+import org.neo4j.gds.api.PropertyState;
+import org.neo4j.gds.api.ResultStore;
+import org.neo4j.gds.api.schema.PropertySchema;
+import org.neo4j.gds.applications.algorithms.metadata.NodePropertiesWritten;
+import org.neo4j.gds.config.WriteConfig;
+import org.neo4j.gds.core.concurrency.Concurrency;
+import org.neo4j.gds.core.concurrency.DefaultPool;
+import org.neo4j.gds.core.loading.Capabilities;
+import org.neo4j.gds.core.utils.progress.JobId;
+import org.neo4j.gds.core.utils.progress.TaskRegistryFactory;
+import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
+import org.neo4j.gds.core.utils.progress.tasks.TaskProgressTracker;
+import org.neo4j.gds.core.write.NodeProperty;
+import org.neo4j.gds.core.write.NodePropertyExporter;
+import org.neo4j.gds.core.write.NodePropertyExporterBuilder;
+import org.neo4j.gds.logging.Log;
+import org.neo4j.gds.termination.TerminationFlag;
+
+import java.util.Collection;
+import java.util.Map;
+import java.util.Optional;
+import java.util.function.Predicate;
+
+import static org.neo4j.gds.utils.StringFormatting.formatWithLocale;
+
+/**
+ * Common bits of node property writes, squirrelled away in one place
+ */
+public class NodePropertyWriter {
+ private final Log log;
+
+ private final NodePropertyExporterBuilder nodePropertyExporterBuilder;
+ private final TaskRegistryFactory taskRegistryFactory;
+ private final TerminationFlag terminationFlag;
+
+ public NodePropertyWriter(
+ Log log,
+ NodePropertyExporterBuilder nodePropertyExporterBuilder,
+ TaskRegistryFactory taskRegistryFactory,
+ TerminationFlag terminationFlag
+ ) {
+ this.log = log;
+ this.nodePropertyExporterBuilder = nodePropertyExporterBuilder;
+ this.taskRegistryFactory = taskRegistryFactory;
+ this.terminationFlag = terminationFlag;
+ }
+
+ public NodePropertiesWritten writeNodeProperties(
+ Graph graph,
+ GraphStore graphStore,
+ Optional resultStore,
+ Collection nodeProperties,
+ JobId jobId,
+ Label label,
+ WriteConfig writeConfig
+ ) {
+ preFlightCheck(
+ graphStore.capabilities().writeMode(),
+ graph.schema().nodeSchema().unionProperties(),
+ nodeProperties
+ );
+
+ var progressTracker = createProgressTracker(graph.nodeCount(), writeConfig.writeConcurrency(), label);
+
+ var nodePropertyExporter = nodePropertyExporterBuilder
+ .parallel(DefaultPool.INSTANCE, writeConfig.writeConcurrency())
+ .withIdMap(graph)
+ .withJobId(jobId)
+ .withProgressTracker(progressTracker)
+ .withResultStore(resultStore)
+ .withTerminationFlag(terminationFlag)
+ .build();
+
+ try {
+ return writeNodeProperties(nodePropertyExporter, nodeProperties);
+ } finally {
+ progressTracker.release();
+ }
+ }
+
+ private ProgressTracker createProgressTracker(
+ long taskVolume,
+ Concurrency writeConcurrency,
+ Label label
+ ) {
+ var task = NodePropertyExporter.baseTask(label.asString(), taskVolume);
+
+ return new TaskProgressTracker(
+ task,
+ log,
+ writeConcurrency,
+ taskRegistryFactory
+ );
+ }
+
+ private Predicate expectedPropertyStateForWriteMode(Capabilities.WriteMode writeMode) {
+ return switch (writeMode) {
+ case LOCAL ->
+ // We need to allow persistent and transient as for example algorithms that support seeding will reuse a
+ // mutated (transient) property to write back properties that are in fact backed by a database
+ state -> state == PropertyState.PERSISTENT || state == PropertyState.TRANSIENT;
+ case REMOTE ->
+ // We allow transient properties for the same reason as above
+ state -> state == PropertyState.REMOTE || state == PropertyState.TRANSIENT;
+ default -> throw new IllegalStateException(
+ formatWithLocale(
+ "Graph with write mode `%s` cannot write back to a database",
+ writeMode
+ )
+ );
+ };
+ }
+
+ private void preFlightCheck(
+ Capabilities.WriteMode writeMode,
+ Map propertySchemas,
+ Collection nodeProperties
+ ) {
+ if (writeMode == Capabilities.WriteMode.REMOTE) throw new IllegalArgumentException(
+ "Missing arrow connection information");
+
+ var expectedPropertyState = expectedPropertyStateForWriteMode(writeMode);
+
+ var unexpectedProperties = nodeProperties.stream()
+ .filter(nodeProperty -> {
+ var propertySchema = propertySchemas.get(nodeProperty.key());
+ if (propertySchema == null) {
+ // We are executing an algorithm write mode and the property we are writing is
+ // not in the GraphStore, therefore we do not perform any more checks
+ return false;
+ }
+ var propertyState = propertySchema.state();
+ return !expectedPropertyState.test(propertyState);
+ })
+ .map(
+ nodeProperty -> formatWithLocale(
+ "NodeProperty{propertyKey=%s, propertyState=%s}",
+ nodeProperty.key(),
+ propertySchemas.get(nodeProperty.key()).state()
+ )
+ )
+ .toList();
+
+ if (!unexpectedProperties.isEmpty()) {
+ throw new IllegalStateException(
+ formatWithLocale(
+ "Expected all properties to be of state `%s` but some properties differ: %s",
+ expectedPropertyState,
+ unexpectedProperties
+ )
+ );
+ }
+ }
+
+ private NodePropertiesWritten writeNodeProperties(
+ NodePropertyExporter nodePropertyExporter,
+ Collection nodeProperties
+ ) {
+ nodePropertyExporter.write(nodeProperties);
+
+ return new NodePropertiesWritten(nodePropertyExporter.propertiesWritten());
+ }
+}
diff --git a/proc/common/src/main/java/org/neo4j/gds/WriteNodePropertiesComputationResultConsumer.java b/proc/common/src/main/java/org/neo4j/gds/WriteNodePropertiesComputationResultConsumer.java
index f3987a78a4..9a9162c440 100644
--- a/proc/common/src/main/java/org/neo4j/gds/WriteNodePropertiesComputationResultConsumer.java
+++ b/proc/common/src/main/java/org/neo4j/gds/WriteNodePropertiesComputationResultConsumer.java
@@ -19,35 +19,21 @@
*/
package org.neo4j.gds;
-import org.neo4j.gds.api.PropertyState;
-import org.neo4j.gds.api.schema.PropertySchema;
+import org.neo4j.gds.applications.algorithms.machinery.NodePropertyWriter;
+import org.neo4j.gds.applications.algorithms.machinery.StandardLabel;
import org.neo4j.gds.config.AlgoBaseConfig;
import org.neo4j.gds.config.WritePropertyConfig;
-import org.neo4j.gds.core.concurrency.Concurrency;
-import org.neo4j.gds.core.concurrency.DefaultPool;
-import org.neo4j.gds.core.loading.Capabilities.WriteMode;
import org.neo4j.gds.core.utils.ProgressTimer;
-import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
-import org.neo4j.gds.core.utils.progress.tasks.TaskProgressTracker;
-import org.neo4j.gds.core.write.NodeProperty;
-import org.neo4j.gds.core.write.NodePropertyExporter;
import org.neo4j.gds.executor.ComputationResult;
import org.neo4j.gds.executor.ComputationResultConsumer;
import org.neo4j.gds.executor.ExecutionContext;
import org.neo4j.gds.result.AbstractResultBuilder;
-import java.util.Collection;
-import java.util.Map;
-import java.util.function.Predicate;
import java.util.stream.Stream;
import static org.neo4j.gds.LoggingUtil.runWithExceptionLogging;
-import static org.neo4j.gds.utils.StringFormatting.formatWithLocale;
-
-public class WriteNodePropertiesComputationResultConsumer, ALGO_RESULT, CONFIG extends WritePropertyConfig & AlgoBaseConfig, RESULT>
- implements
- ComputationResultConsumer> {
+public class WriteNodePropertiesComputationResultConsumer, ALGO_RESULT, CONFIG extends WritePropertyConfig & AlgoBaseConfig, RESULT> implements ComputationResultConsumer> {
private final ResultBuilderFunction resultBuilderFunction;
private final WriteNodePropertyListFunction nodePropertyListFunction;
private final String procedureName;
@@ -62,81 +48,6 @@ public WriteNodePropertiesComputationResultConsumer(
this.procedureName = procedureName;
}
- private static void validatePropertiesCanBeWritten(
- WriteMode writeMode,
- Map propertySchemas,
- Collection nodeProperties
- ) {
- if (writeMode == WriteMode.REMOTE) {
- throw new IllegalArgumentException("Missing arrow connection information");
- }
-
- var expectedPropertyState = expectedPropertyStateForWriteMode(writeMode);
-
- var unexpectedProperties = nodeProperties
- .stream()
- .filter(nodeProperty -> {
- var propertySchema = propertySchemas.get(nodeProperty.key());
- if (propertySchema == null) {
- // We are executing an algorithm write mode and the property we are writing is
- // not in the GraphStore, therefore we do not perform any more checks
- return false;
- }
- var propertyState = propertySchema.state();
- return !expectedPropertyState.test(propertyState);
- })
- .map(
- nodeProperty -> formatWithLocale(
- "NodeProperty{propertyKey=%s, propertyState=%s}",
- nodeProperty.key(),
- propertySchemas.get(nodeProperty.key()).state()
- )
- )
- .toList();
-
- if (!unexpectedProperties.isEmpty()) {
- throw new IllegalStateException(
- formatWithLocale(
- "Expected all properties to be of state `%s` but some properties differ: %s",
- expectedPropertyState,
- unexpectedProperties
- )
- );
- }
- }
-
- private static Predicate expectedPropertyStateForWriteMode(WriteMode writeMode) {
- switch (writeMode) {
- case LOCAL:
- // We need to allow persistent and transient as for example algorithms that support seeding will reuse a
- // mutated (transient) property to write back properties that are in fact backed by a database
- return state -> state == PropertyState.PERSISTENT || state == PropertyState.TRANSIENT;
- case REMOTE:
- // We allow transient properties for the same reason as above
- return state -> state == PropertyState.REMOTE || state == PropertyState.TRANSIENT;
- default:
- throw new IllegalStateException(
- formatWithLocale(
- "Graph with write mode `%s` cannot write back to a database",
- writeMode
- )
- );
- }
- }
-
- ProgressTracker createProgressTracker(
- long taskVolume,
- Concurrency writeConcurrency,
- ExecutionContext executionContext
- ) {
- return new TaskProgressTracker(
- NodePropertyExporter.baseTask(this.procedureName, taskVolume),
- executionContext.log(),
- writeConcurrency,
- executionContext.taskRegistryFactory()
- );
- }
-
@Override
public Stream consume(
ComputationResult computationResult,
@@ -158,48 +69,41 @@ public Stream consume(
});
}
- void writeToNeo(
+ private void writeToNeo(
AbstractResultBuilder> resultBuilder,
ComputationResult computationResult,
ExecutionContext executionContext
) {
try (ProgressTimer ignored = ProgressTimer.start(resultBuilder::withWriteMillis)) {
+ var log = executionContext.log();
+ var nodePropertyExporterBuilder = executionContext.nodePropertyExporterBuilder();
+ var taskRegistryFactory = executionContext.taskRegistryFactory();
+ var terminationFlag = computationResult.algorithm().terminationFlag;
+ var nodePropertyWriter = new NodePropertyWriter(
+ log,
+ nodePropertyExporterBuilder,
+ taskRegistryFactory,
+ terminationFlag
+ );
+
var graph = computationResult.graph();
+ var graphStore = computationResult.graphStore();
var config = computationResult.config();
- var progressTracker = createProgressTracker(
- graph.nodeCount(),
- config.writeConcurrency(),
- executionContext
- );
- var writeMode = computationResult.graphStore().capabilities().writeMode();
- var nodePropertySchema = graph.schema().nodeSchema().unionProperties();
+ var resultStore = config.resolveResultStore(computationResult.resultStore());
var nodeProperties = nodePropertyListFunction.apply(computationResult);
- validatePropertiesCanBeWritten(
- writeMode,
- nodePropertySchema,
- nodeProperties
+ var nodePropertiesWritten = nodePropertyWriter.writeNodeProperties(
+ graph,
+ graphStore,
+ resultStore,
+ nodeProperties,
+ config.jobId(),
+ new StandardLabel(procedureName),
+ config
);
- var resultStore = config.resolveResultStore(computationResult.resultStore());
- var exporter = executionContext
- .nodePropertyExporterBuilder()
- .withIdMap(graph)
- .withTerminationFlag(computationResult.algorithm().terminationFlag)
- .withProgressTracker(progressTracker)
- .withResultStore(resultStore)
- .withJobId(config.jobId())
- .parallel(DefaultPool.INSTANCE, config.writeConcurrency())
- .build();
-
- try {
- exporter.write(nodeProperties);
- } finally {
- progressTracker.release();
- }
-
resultBuilder.withNodeCount(computationResult.graph().nodeCount());
- resultBuilder.withNodePropertiesWritten(exporter.propertiesWritten());
+ resultBuilder.withNodePropertiesWritten(nodePropertiesWritten.value());
}
}
}
diff --git a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/predict/NodeClassificationPipelineWriteProc.java b/proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/predict/NodeClassificationPipelineWriteProc.java
index fd7255f536..f7db944e51 100644
--- a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/predict/NodeClassificationPipelineWriteProc.java
+++ b/proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/predict/NodeClassificationPipelineWriteProc.java
@@ -19,13 +19,9 @@
*/
package org.neo4j.gds.ml.pipeline.node.classification.predict;
-import org.neo4j.gds.BaseProc;
import org.neo4j.gds.applications.algorithms.machinery.MemoryEstimateResult;
-import org.neo4j.gds.core.model.ModelCatalog;
-import org.neo4j.gds.core.write.NodePropertyExporterBuilder;
-import org.neo4j.gds.executor.ExecutionContext;
-import org.neo4j.gds.executor.ProcedureExecutor;
import org.neo4j.gds.procedures.GraphDataScienceProcedures;
+import org.neo4j.gds.procedures.pipelines.WriteResult;
import org.neo4j.procedure.Context;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Mode;
@@ -35,31 +31,20 @@
import java.util.Map;
import java.util.stream.Stream;
-import static org.neo4j.gds.ml.pipeline.PipelineCompanion.preparePipelineConfig;
import static org.neo4j.gds.ml.pipeline.node.classification.predict.NodeClassificationPipelineConstants.ESTIMATE_PREDICT_DESCRIPTION;
import static org.neo4j.gds.ml.pipeline.node.classification.predict.NodeClassificationPipelineConstants.PREDICT_DESCRIPTION;
-public class NodeClassificationPipelineWriteProc extends BaseProc {
+public class NodeClassificationPipelineWriteProc {
@Context
public GraphDataScienceProcedures facade;
- @Context
- public ModelCatalog internalModelCatalog;
-
- @Context
- public NodePropertyExporterBuilder nodePropertyExporterBuilder;
-
@Procedure(name = "gds.beta.pipeline.nodeClassification.predict.write", mode = Mode.WRITE)
@Description(PREDICT_DESCRIPTION)
public Stream write(
@Name(value = "graphName") String graphName,
@Name(value = "configuration", defaultValue = "{}") Map configuration
) {
- preparePipelineConfig(graphName, configuration);
- return new ProcedureExecutor<>(
- new NodeClassificationPipelineWriteSpec(),
- executionContext()
- ).compute(graphName, configuration);
+ return facade.pipelines().nodeClassificationWrite(graphName, configuration);
}
@Procedure(name = "gds.beta.pipeline.nodeClassification.predict.write.estimate", mode = Mode.READ)
@@ -70,9 +55,4 @@ public Stream estimate(
) {
return facade.pipelines().nodeClassificationWriteEstimate(graphNameOrConfiguration, algoConfiguration);
}
-
- @Override
- public ExecutionContext executionContext() {
- return super.executionContext().withNodePropertyExporterBuilder(nodePropertyExporterBuilder).withModelCatalog(internalModelCatalog);
- }
}
diff --git a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/predict/NodeClassificationPipelineWriteSpec.java b/proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/predict/NodeClassificationPipelineWriteSpec.java
index 6ae2963b01..3398427721 100644
--- a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/predict/NodeClassificationPipelineWriteSpec.java
+++ b/proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/predict/NodeClassificationPipelineWriteSpec.java
@@ -33,6 +33,7 @@
import org.neo4j.gds.procedures.pipelines.NodeClassificationPredictPipelineExecutor;
import org.neo4j.gds.procedures.pipelines.NodeClassificationPredictPipelineWriteConfig;
import org.neo4j.gds.procedures.pipelines.PredictedProbabilities;
+import org.neo4j.gds.procedures.pipelines.WriteResult;
import java.util.List;
import java.util.Map;
diff --git a/proc/machine-learning/src/test/java/org/neo4j/gds/ml/pipeline/node/classification/NodeClassificationPipelineAddStepProcsTest.java b/proc/machine-learning/src/test/java/org/neo4j/gds/ml/pipeline/node/classification/NodeClassificationPipelineAddStepProcsTest.java
index 887abfcd87..a30b88663b 100644
--- a/proc/machine-learning/src/test/java/org/neo4j/gds/ml/pipeline/node/classification/NodeClassificationPipelineAddStepProcsTest.java
+++ b/proc/machine-learning/src/test/java/org/neo4j/gds/ml/pipeline/node/classification/NodeClassificationPipelineAddStepProcsTest.java
@@ -59,6 +59,7 @@ void setUp() {
null,
null,
null,
+ null,
new User(getUsername(), false),
null,
null,
@@ -334,6 +335,7 @@ private GraphDataScienceProcedures buildFacade() {
null,
null,
null,
+ null,
new User(getUsername(), false),
null,
null,
diff --git a/procedures/facade/src/main/java/org/neo4j/gds/procedures/LocalGraphDataScienceProcedures.java b/procedures/facade/src/main/java/org/neo4j/gds/procedures/LocalGraphDataScienceProcedures.java
index 2a26e51b17..0a0e626d0e 100644
--- a/procedures/facade/src/main/java/org/neo4j/gds/procedures/LocalGraphDataScienceProcedures.java
+++ b/procedures/facade/src/main/java/org/neo4j/gds/procedures/LocalGraphDataScienceProcedures.java
@@ -206,6 +206,7 @@ public static GraphDataScienceProcedures create(
writeContext.relationshipExporterBuilder(),
requestScopedDependencies.getTaskRegistryFactory(),
terminationMonitor,
+ requestScopedDependencies.getTerminationFlag(),
requestScopedDependencies.getUser(),
requestScopedDependencies.getUserLogRegistryFactory(),
progressTrackerCreator,
diff --git a/procedures/pipelines-facade/build.gradle b/procedures/pipelines-facade/build.gradle
index 4d1c233d3c..c9f036ac92 100644
--- a/procedures/pipelines-facade/build.gradle
+++ b/procedures/pipelines-facade/build.gradle
@@ -31,6 +31,7 @@ dependencies {
implementation project(':core')
implementation project(':core-write')
implementation project(':executor')
+ implementation project(':graph-schema-api')
implementation project(':logging')
implementation project(':memory-usage')
implementation project(':metrics-api')
diff --git a/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/NodeClassificationPredictPipelineWriteResultBuilder.java b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/NodeClassificationPredictPipelineWriteResultBuilder.java
new file mode 100644
index 0000000000..874ec3cde9
--- /dev/null
+++ b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/NodeClassificationPredictPipelineWriteResultBuilder.java
@@ -0,0 +1,54 @@
+/*
+ * Copyright (c) "Neo4j"
+ * Neo4j Sweden AB [http://neo4j.com]
+ *
+ * This file is part of Neo4j.
+ *
+ * Neo4j is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ */
+package org.neo4j.gds.procedures.pipelines;
+
+import org.neo4j.gds.api.Graph;
+import org.neo4j.gds.applications.algorithms.machinery.AlgorithmProcessingTimings;
+import org.neo4j.gds.applications.algorithms.machinery.ResultBuilder;
+import org.neo4j.gds.applications.algorithms.metadata.NodePropertiesWritten;
+
+import java.util.Optional;
+
+class NodeClassificationPredictPipelineWriteResultBuilder implements ResultBuilder {
+ private final NodeClassificationPredictPipelineWriteConfig configuration;
+
+ NodeClassificationPredictPipelineWriteResultBuilder(NodeClassificationPredictPipelineWriteConfig configuration) {
+ this.configuration = configuration;
+ }
+
+ @Override
+ public WriteResult build(
+ Graph graph,
+ NodeClassificationPredictPipelineWriteConfig configuration,
+ Optional result,
+ AlgorithmProcessingTimings timings,
+ Optional metadata
+ ) {
+ if (result.isEmpty()) return WriteResult.emptyFrom(timings, this.configuration.toMap());
+
+ return new WriteResult(
+ timings.preProcessingMillis,
+ timings.computeMillis,
+ timings.mutateOrWriteMillis,
+ metadata.orElseThrow().value(),
+ this.configuration.toMap()
+ );
+ }
+}
diff --git a/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/NodeClassificationPredictPipelineWriteStep.java b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/NodeClassificationPredictPipelineWriteStep.java
new file mode 100644
index 0000000000..59c6006fa1
--- /dev/null
+++ b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/NodeClassificationPredictPipelineWriteStep.java
@@ -0,0 +1,68 @@
+/*
+ * Copyright (c) "Neo4j"
+ * Neo4j Sweden AB [http://neo4j.com]
+ *
+ * This file is part of Neo4j.
+ *
+ * Neo4j is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ */
+package org.neo4j.gds.procedures.pipelines;
+
+import org.neo4j.gds.api.Graph;
+import org.neo4j.gds.api.GraphStore;
+import org.neo4j.gds.api.ResultStore;
+import org.neo4j.gds.applications.algorithms.machinery.NodePropertyWriter;
+import org.neo4j.gds.applications.algorithms.machinery.StandardLabel;
+import org.neo4j.gds.applications.algorithms.machinery.WriteStep;
+import org.neo4j.gds.applications.algorithms.metadata.NodePropertiesWritten;
+import org.neo4j.gds.core.utils.progress.JobId;
+
+import java.util.Optional;
+
+class NodeClassificationPredictPipelineWriteStep implements WriteStep {
+ private final NodePropertyWriter nodePropertyWriter;
+ private final NodeClassificationPredictPipelineWriteConfig configuration;
+
+ NodeClassificationPredictPipelineWriteStep(
+ NodePropertyWriter nodePropertyWriter,
+ NodeClassificationPredictPipelineWriteConfig configuration
+ ) {
+ this.nodePropertyWriter = nodePropertyWriter;
+ this.configuration = configuration;
+ }
+
+ @Override
+ public NodePropertiesWritten execute(
+ Graph graph,
+ GraphStore graphStore,
+ ResultStore resultStore,
+ NodeClassificationPipelineResult result,
+ JobId jobId
+ ) {
+ var nodeProperties = PredictedProbabilities.asProperties(
+ Optional.of(result),
+ configuration.writeProperty(),
+ configuration.predictedProbabilityProperty()
+ );
+
+ return nodePropertyWriter.writeNodeProperties(
+ graph,
+ graphStore,
+ Optional.of(resultStore), nodeProperties,
+ jobId,
+ new StandardLabel("NodeClassificationPipelineWrite"),
+ configuration
+ );
+ }
+}
diff --git a/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelineApplications.java b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelineApplications.java
index 8d705b7d69..16293b0e80 100644
--- a/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelineApplications.java
+++ b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelineApplications.java
@@ -32,6 +32,7 @@
import org.neo4j.gds.applications.algorithms.machinery.GraphStoreService;
import org.neo4j.gds.applications.algorithms.machinery.Label;
import org.neo4j.gds.applications.algorithms.machinery.MemoryEstimateResult;
+import org.neo4j.gds.applications.algorithms.machinery.NodePropertyWriter;
import org.neo4j.gds.applications.algorithms.machinery.ProgressTrackerCreator;
import org.neo4j.gds.applications.algorithms.machinery.StandardLabel;
import org.neo4j.gds.applications.modelcatalog.ModelRepository;
@@ -60,6 +61,7 @@
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationTrain;
import org.neo4j.gds.model.ModelConfig;
import org.neo4j.gds.procedures.algorithms.AlgorithmsProcedureFacade;
+import org.neo4j.gds.termination.TerminationFlag;
import org.neo4j.gds.termination.TerminationMonitor;
import java.util.Map;
@@ -69,7 +71,7 @@
class PipelineApplications {
private final Log log;
- private final GraphStoreService gss;
+ private final GraphStoreService graphStoreService;
private final ModelCatalog modelCatalog;
private final PipelineRepository pipelineRepository;
@@ -92,6 +94,8 @@ class PipelineApplications {
private final NodeClassificationPredictPipelineEstimator nodeClassificationPredictPipelineEstimator;
private final NodeClassificationTrainSideEffectsFactory nodeClassificationTrainSideEffectsFactory;
+ private final NodePropertyWriter nodePropertyWriter;
+
private final AlgorithmsProcedureFacade algorithmsProcedureFacade;
private final AlgorithmEstimationTemplate algorithmEstimationTemplate;
private final AlgorithmProcessingTemplate algorithmProcessingTemplate;
@@ -117,12 +121,13 @@ class PipelineApplications {
ProgressTrackerCreator progressTrackerCreator,
NodeClassificationPredictPipelineEstimator nodeClassificationPredictPipelineEstimator,
NodeClassificationTrainSideEffectsFactory nodeClassificationTrainSideEffectsFactory,
+ NodePropertyWriter nodePropertyWriter,
AlgorithmsProcedureFacade algorithmsProcedureFacade,
AlgorithmEstimationTemplate algorithmEstimationTemplate,
AlgorithmProcessingTemplate algorithmProcessingTemplate
) {
this.log = log;
- this.gss = graphStoreService;
+ this.graphStoreService = graphStoreService;
this.modelCatalog = modelCatalog;
this.pipelineRepository = pipelineRepository;
this.closeableResourceRegistry = closeableResourceRegistry;
@@ -141,6 +146,7 @@ class PipelineApplications {
this.progressTrackerCreator = progressTrackerCreator;
this.nodeClassificationPredictPipelineEstimator = nodeClassificationPredictPipelineEstimator;
this.nodeClassificationTrainSideEffectsFactory = nodeClassificationTrainSideEffectsFactory;
+ this.nodePropertyWriter = nodePropertyWriter;
this.algorithmsProcedureFacade = algorithmsProcedureFacade;
this.algorithmEstimationTemplate = algorithmEstimationTemplate;
this.algorithmProcessingTemplate = algorithmProcessingTemplate;
@@ -161,6 +167,7 @@ static PipelineApplications create(
RelationshipExporterBuilder relationshipExporterBuilder,
TaskRegistryFactory taskRegistryFactory,
TerminationMonitor terminationMonitor,
+ TerminationFlag terminationFlag,
User user,
UserLogRegistryFactory userLogRegistryFactory,
PipelineConfigurationParser pipelineConfigurationParser,
@@ -182,6 +189,13 @@ static PipelineApplications create(
modelRepository
);
+ var nodePropertyWriter = new NodePropertyWriter(
+ log,
+ nodePropertyExporterBuilder,
+ taskRegistryFactory,
+ terminationFlag
+ );
+
return new PipelineApplications(
log,
graphStoreService,
@@ -203,6 +217,7 @@ static PipelineApplications create(
progressTrackerCreator,
nodeClassificationPredictPipelineEstimator,
nodeClassificationTrainSideEffectsFactory,
+ nodePropertyWriter,
algorithmsProcedureFacade,
algorithmEstimationTemplate,
algorithmProcessingTemplate
@@ -298,7 +313,7 @@ PredictMutateResult nodeClassificationPredictMutate(GraphName graphName, Map nodeClassificationPredictStream(
);
}
+ WriteResult nodeClassificationPredictWrite(GraphName graphName, Map rawConfiguration) {
+ var configuration = pipelineConfigurationParser.parseNodeClassificationPredictWriteConfig(rawConfiguration);
+ var label = new StandardLabel("NodeClassificationPredictPipelineWrite");
+ var computation = constructPredictComputation(configuration, label);
+ var writeStep = new NodeClassificationPredictPipelineWriteStep(nodePropertyWriter, configuration);
+ var resultBuilder = new NodeClassificationPredictPipelineWriteResultBuilder(configuration);
+
+ return algorithmProcessingTemplate.processAlgorithmForWrite(
+ Optional.empty(),
+ graphName,
+ configuration,
+ Optional.empty(),
+ label,
+ () -> nodeClassificationPredictMemoryEstimation(configuration),
+ computation,
+ writeStep,
+ resultBuilder
+ );
+ }
+
Stream nodeClassificationTrain(
GraphName graphName,
Map rawConfiguration
diff --git a/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelineConfigurationParser.java b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelineConfigurationParser.java
index 8e981061d5..21cb9d5299 100644
--- a/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelineConfigurationParser.java
+++ b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelineConfigurationParser.java
@@ -79,7 +79,7 @@ NodeClassificationPredictPipelineStreamConfig parseNodeClassificationPredictStre
return parseNodeClassificationPipelineConfig(NodeClassificationPredictPipelineStreamConfig::of, configuration);
}
- NodeClassificationPredictPipelineWriteConfig parseNodeClassificationWriteConfig(Map configuration) {
+ NodeClassificationPredictPipelineWriteConfig parseNodeClassificationPredictWriteConfig(Map configuration) {
return parseNodeClassificationPipelineConfig(NodeClassificationPredictPipelineWriteConfig::of, configuration);
}
diff --git a/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelinesProcedureFacade.java b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelinesProcedureFacade.java
index b014aa4209..03df661f8d 100644
--- a/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelinesProcedureFacade.java
+++ b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelinesProcedureFacade.java
@@ -43,6 +43,7 @@
import org.neo4j.gds.ml.pipeline.nodePipeline.NodeFeatureStep;
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.NodeClassificationTrainingPipeline;
import org.neo4j.gds.procedures.algorithms.AlgorithmsProcedureFacade;
+import org.neo4j.gds.termination.TerminationFlag;
import org.neo4j.gds.termination.TerminationMonitor;
import java.util.ArrayList;
@@ -86,6 +87,7 @@ public static PipelinesProcedureFacade create(
RelationshipExporterBuilder relationshipExporterBuilder,
TaskRegistryFactory taskRegistryFactory,
TerminationMonitor terminationMonitor,
+ TerminationFlag terminationFlag,
User user,
UserLogRegistryFactory userLogRegistryFactory,
ProgressTrackerCreator progressTrackerCreator,
@@ -115,6 +117,7 @@ public static PipelinesProcedureFacade create(
relationshipExporterBuilder,
taskRegistryFactory,
terminationMonitor,
+ terminationFlag,
user,
userLogRegistryFactory,
pipelineConfigurationParser,
@@ -335,12 +338,26 @@ public Stream nodeClassificationTrainEstimate(
return Stream.of(result);
}
+ public Stream nodeClassificationWrite(
+ String graphNameAsString,
+ Map configuration
+ ) {
+ PipelineCompanion.preparePipelineConfig(graphNameAsString, configuration);
+ nodeClassificationPredictConfigPreProcessor.enhanceInputWithPipelineParameters(configuration);
+
+ var graphName = GraphName.parse(graphNameAsString);
+
+ var result = pipelineApplications.nodeClassificationPredictWrite(graphName, configuration);
+
+ return Stream.of(result);
+ }
+
public Stream nodeClassificationWriteEstimate(
Object graphNameOrConfiguration,
Map rawConfiguration
) {
PipelineCompanion.preparePipelineConfig(graphNameOrConfiguration, rawConfiguration);
- var configuration = pipelineConfigurationParser.parseNodeClassificationWriteConfig(rawConfiguration);
+ var configuration = pipelineConfigurationParser.parseNodeClassificationPredictWriteConfig(rawConfiguration);
var result = pipelineApplications.nodeClassificationPredictEstimate(
graphNameOrConfiguration,
diff --git a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/predict/WriteResult.java b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/WriteResult.java
similarity index 76%
rename from proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/predict/WriteResult.java
rename to procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/WriteResult.java
index 51c03f0956..0785f2e11d 100644
--- a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/predict/WriteResult.java
+++ b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/WriteResult.java
@@ -17,15 +17,15 @@
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
-package org.neo4j.gds.ml.pipeline.node.classification.predict;
+package org.neo4j.gds.procedures.pipelines;
-import org.neo4j.gds.result.AbstractResultBuilder;
+import org.neo4j.gds.applications.algorithms.machinery.AlgorithmProcessingTimings;
import org.neo4j.gds.procedures.algorithms.results.StandardWriteResult;
+import org.neo4j.gds.result.AbstractResultBuilder;
import java.util.Map;
public final class WriteResult extends StandardWriteResult {
-
public final long nodePropertiesWritten;
WriteResult(
@@ -45,8 +45,17 @@ public final class WriteResult extends StandardWriteResult {
this.nodePropertiesWritten = nodePropertiesWritten;
}
- static class Builder extends AbstractResultBuilder {
+ static WriteResult emptyFrom(AlgorithmProcessingTimings timings, Map configurationMap) {
+ return new WriteResult(
+ timings.preProcessingMillis,
+ timings.computeMillis,
+ timings.mutateOrWriteMillis,
+ 0,
+ configurationMap
+ );
+ }
+ public static class Builder extends AbstractResultBuilder {
@Override
public WriteResult build() {
return new WriteResult(
diff --git a/procedures/pipelines-facade/src/test/java/org/neo4j/gds/procedures/pipelines/PipelinesProcedureFacadeTest.java b/procedures/pipelines-facade/src/test/java/org/neo4j/gds/procedures/pipelines/PipelinesProcedureFacadeTest.java
index de8b67acd5..3ecd521de0 100644
--- a/procedures/pipelines-facade/src/test/java/org/neo4j/gds/procedures/pipelines/PipelinesProcedureFacadeTest.java
+++ b/procedures/pipelines-facade/src/test/java/org/neo4j/gds/procedures/pipelines/PipelinesProcedureFacadeTest.java
@@ -56,6 +56,7 @@ void createPipeline() {
null,
null,
null,
+ null,
null
);
var facade = new PipelinesProcedureFacade(null, null, applications);
@@ -98,6 +99,7 @@ void shouldNotCreatePipelineWhenOneExists() {
null,
null,
null,
+ null,
null
);
var facade = new PipelinesProcedureFacade(null, null, applications);