From 37afa7860e458b3f57d9688d5ad1204a1e9eea96 Mon Sep 17 00:00:00 2001 From: Lasse Westh-Nielsen Date: Tue, 8 Oct 2024 13:59:59 +0200 Subject: [PATCH] migrate node classification write --- .../machinery/NodePropertyWriter.java | 186 ++++++++++++++++++ ...dePropertiesComputationResultConsumer.java | 148 +++----------- .../NodeClassificationPipelineWriteProc.java | 26 +-- .../NodeClassificationPipelineWriteSpec.java | 1 + ...lassificationPipelineAddStepProcsTest.java | 2 + .../LocalGraphDataScienceProcedures.java | 1 + procedures/pipelines-facade/build.gradle | 1 + ...tionPredictPipelineWriteResultBuilder.java | 54 +++++ ...lassificationPredictPipelineWriteStep.java | 68 +++++++ .../pipelines/PipelineApplications.java | 41 +++- .../PipelineConfigurationParser.java | 2 +- .../pipelines/PipelinesProcedureFacade.java | 19 +- .../procedures/pipelines}/WriteResult.java | 17 +- .../PipelinesProcedureFacadeTest.java | 2 + 14 files changed, 414 insertions(+), 154 deletions(-) create mode 100644 applications/algorithms/machinery/src/main/java/org/neo4j/gds/applications/algorithms/machinery/NodePropertyWriter.java create mode 100644 procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/NodeClassificationPredictPipelineWriteResultBuilder.java create mode 100644 procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/NodeClassificationPredictPipelineWriteStep.java rename {proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/predict => procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines}/WriteResult.java (76%) diff --git a/applications/algorithms/machinery/src/main/java/org/neo4j/gds/applications/algorithms/machinery/NodePropertyWriter.java b/applications/algorithms/machinery/src/main/java/org/neo4j/gds/applications/algorithms/machinery/NodePropertyWriter.java new file mode 100644 index 0000000000..f6529362ff --- /dev/null +++ b/applications/algorithms/machinery/src/main/java/org/neo4j/gds/applications/algorithms/machinery/NodePropertyWriter.java @@ -0,0 +1,186 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.gds.applications.algorithms.machinery; + +import org.neo4j.gds.api.Graph; +import org.neo4j.gds.api.GraphStore; +import org.neo4j.gds.api.PropertyState; +import org.neo4j.gds.api.ResultStore; +import org.neo4j.gds.api.schema.PropertySchema; +import org.neo4j.gds.applications.algorithms.metadata.NodePropertiesWritten; +import org.neo4j.gds.config.WriteConfig; +import org.neo4j.gds.core.concurrency.Concurrency; +import org.neo4j.gds.core.concurrency.DefaultPool; +import org.neo4j.gds.core.loading.Capabilities; +import org.neo4j.gds.core.utils.progress.JobId; +import org.neo4j.gds.core.utils.progress.TaskRegistryFactory; +import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker; +import org.neo4j.gds.core.utils.progress.tasks.TaskProgressTracker; +import org.neo4j.gds.core.write.NodeProperty; +import org.neo4j.gds.core.write.NodePropertyExporter; +import org.neo4j.gds.core.write.NodePropertyExporterBuilder; +import org.neo4j.gds.logging.Log; +import org.neo4j.gds.termination.TerminationFlag; + +import java.util.Collection; +import java.util.Map; +import java.util.Optional; +import java.util.function.Predicate; + +import static org.neo4j.gds.utils.StringFormatting.formatWithLocale; + +/** + * Common bits of node property writes, squirrelled away in one place + */ +public class NodePropertyWriter { + private final Log log; + + private final NodePropertyExporterBuilder nodePropertyExporterBuilder; + private final TaskRegistryFactory taskRegistryFactory; + private final TerminationFlag terminationFlag; + + public NodePropertyWriter( + Log log, + NodePropertyExporterBuilder nodePropertyExporterBuilder, + TaskRegistryFactory taskRegistryFactory, + TerminationFlag terminationFlag + ) { + this.log = log; + this.nodePropertyExporterBuilder = nodePropertyExporterBuilder; + this.taskRegistryFactory = taskRegistryFactory; + this.terminationFlag = terminationFlag; + } + + public NodePropertiesWritten writeNodeProperties( + Graph graph, + GraphStore graphStore, + Optional resultStore, + Collection nodeProperties, + JobId jobId, + Label label, + WriteConfig writeConfig + ) { + preFlightCheck( + graphStore.capabilities().writeMode(), + graph.schema().nodeSchema().unionProperties(), + nodeProperties + ); + + var progressTracker = createProgressTracker(graph.nodeCount(), writeConfig.writeConcurrency(), label); + + var nodePropertyExporter = nodePropertyExporterBuilder + .parallel(DefaultPool.INSTANCE, writeConfig.writeConcurrency()) + .withIdMap(graph) + .withJobId(jobId) + .withProgressTracker(progressTracker) + .withResultStore(resultStore) + .withTerminationFlag(terminationFlag) + .build(); + + try { + return writeNodeProperties(nodePropertyExporter, nodeProperties); + } finally { + progressTracker.release(); + } + } + + private ProgressTracker createProgressTracker( + long taskVolume, + Concurrency writeConcurrency, + Label label + ) { + var task = NodePropertyExporter.baseTask(label.asString(), taskVolume); + + return new TaskProgressTracker( + task, + log, + writeConcurrency, + taskRegistryFactory + ); + } + + private Predicate expectedPropertyStateForWriteMode(Capabilities.WriteMode writeMode) { + return switch (writeMode) { + case LOCAL -> + // We need to allow persistent and transient as for example algorithms that support seeding will reuse a + // mutated (transient) property to write back properties that are in fact backed by a database + state -> state == PropertyState.PERSISTENT || state == PropertyState.TRANSIENT; + case REMOTE -> + // We allow transient properties for the same reason as above + state -> state == PropertyState.REMOTE || state == PropertyState.TRANSIENT; + default -> throw new IllegalStateException( + formatWithLocale( + "Graph with write mode `%s` cannot write back to a database", + writeMode + ) + ); + }; + } + + private void preFlightCheck( + Capabilities.WriteMode writeMode, + Map propertySchemas, + Collection nodeProperties + ) { + if (writeMode == Capabilities.WriteMode.REMOTE) throw new IllegalArgumentException( + "Missing arrow connection information"); + + var expectedPropertyState = expectedPropertyStateForWriteMode(writeMode); + + var unexpectedProperties = nodeProperties.stream() + .filter(nodeProperty -> { + var propertySchema = propertySchemas.get(nodeProperty.key()); + if (propertySchema == null) { + // We are executing an algorithm write mode and the property we are writing is + // not in the GraphStore, therefore we do not perform any more checks + return false; + } + var propertyState = propertySchema.state(); + return !expectedPropertyState.test(propertyState); + }) + .map( + nodeProperty -> formatWithLocale( + "NodeProperty{propertyKey=%s, propertyState=%s}", + nodeProperty.key(), + propertySchemas.get(nodeProperty.key()).state() + ) + ) + .toList(); + + if (!unexpectedProperties.isEmpty()) { + throw new IllegalStateException( + formatWithLocale( + "Expected all properties to be of state `%s` but some properties differ: %s", + expectedPropertyState, + unexpectedProperties + ) + ); + } + } + + private NodePropertiesWritten writeNodeProperties( + NodePropertyExporter nodePropertyExporter, + Collection nodeProperties + ) { + nodePropertyExporter.write(nodeProperties); + + return new NodePropertiesWritten(nodePropertyExporter.propertiesWritten()); + } +} diff --git a/proc/common/src/main/java/org/neo4j/gds/WriteNodePropertiesComputationResultConsumer.java b/proc/common/src/main/java/org/neo4j/gds/WriteNodePropertiesComputationResultConsumer.java index f3987a78a4..9a9162c440 100644 --- a/proc/common/src/main/java/org/neo4j/gds/WriteNodePropertiesComputationResultConsumer.java +++ b/proc/common/src/main/java/org/neo4j/gds/WriteNodePropertiesComputationResultConsumer.java @@ -19,35 +19,21 @@ */ package org.neo4j.gds; -import org.neo4j.gds.api.PropertyState; -import org.neo4j.gds.api.schema.PropertySchema; +import org.neo4j.gds.applications.algorithms.machinery.NodePropertyWriter; +import org.neo4j.gds.applications.algorithms.machinery.StandardLabel; import org.neo4j.gds.config.AlgoBaseConfig; import org.neo4j.gds.config.WritePropertyConfig; -import org.neo4j.gds.core.concurrency.Concurrency; -import org.neo4j.gds.core.concurrency.DefaultPool; -import org.neo4j.gds.core.loading.Capabilities.WriteMode; import org.neo4j.gds.core.utils.ProgressTimer; -import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker; -import org.neo4j.gds.core.utils.progress.tasks.TaskProgressTracker; -import org.neo4j.gds.core.write.NodeProperty; -import org.neo4j.gds.core.write.NodePropertyExporter; import org.neo4j.gds.executor.ComputationResult; import org.neo4j.gds.executor.ComputationResultConsumer; import org.neo4j.gds.executor.ExecutionContext; import org.neo4j.gds.result.AbstractResultBuilder; -import java.util.Collection; -import java.util.Map; -import java.util.function.Predicate; import java.util.stream.Stream; import static org.neo4j.gds.LoggingUtil.runWithExceptionLogging; -import static org.neo4j.gds.utils.StringFormatting.formatWithLocale; - -public class WriteNodePropertiesComputationResultConsumer, ALGO_RESULT, CONFIG extends WritePropertyConfig & AlgoBaseConfig, RESULT> - implements - ComputationResultConsumer> { +public class WriteNodePropertiesComputationResultConsumer, ALGO_RESULT, CONFIG extends WritePropertyConfig & AlgoBaseConfig, RESULT> implements ComputationResultConsumer> { private final ResultBuilderFunction resultBuilderFunction; private final WriteNodePropertyListFunction nodePropertyListFunction; private final String procedureName; @@ -62,81 +48,6 @@ public WriteNodePropertiesComputationResultConsumer( this.procedureName = procedureName; } - private static void validatePropertiesCanBeWritten( - WriteMode writeMode, - Map propertySchemas, - Collection nodeProperties - ) { - if (writeMode == WriteMode.REMOTE) { - throw new IllegalArgumentException("Missing arrow connection information"); - } - - var expectedPropertyState = expectedPropertyStateForWriteMode(writeMode); - - var unexpectedProperties = nodeProperties - .stream() - .filter(nodeProperty -> { - var propertySchema = propertySchemas.get(nodeProperty.key()); - if (propertySchema == null) { - // We are executing an algorithm write mode and the property we are writing is - // not in the GraphStore, therefore we do not perform any more checks - return false; - } - var propertyState = propertySchema.state(); - return !expectedPropertyState.test(propertyState); - }) - .map( - nodeProperty -> formatWithLocale( - "NodeProperty{propertyKey=%s, propertyState=%s}", - nodeProperty.key(), - propertySchemas.get(nodeProperty.key()).state() - ) - ) - .toList(); - - if (!unexpectedProperties.isEmpty()) { - throw new IllegalStateException( - formatWithLocale( - "Expected all properties to be of state `%s` but some properties differ: %s", - expectedPropertyState, - unexpectedProperties - ) - ); - } - } - - private static Predicate expectedPropertyStateForWriteMode(WriteMode writeMode) { - switch (writeMode) { - case LOCAL: - // We need to allow persistent and transient as for example algorithms that support seeding will reuse a - // mutated (transient) property to write back properties that are in fact backed by a database - return state -> state == PropertyState.PERSISTENT || state == PropertyState.TRANSIENT; - case REMOTE: - // We allow transient properties for the same reason as above - return state -> state == PropertyState.REMOTE || state == PropertyState.TRANSIENT; - default: - throw new IllegalStateException( - formatWithLocale( - "Graph with write mode `%s` cannot write back to a database", - writeMode - ) - ); - } - } - - ProgressTracker createProgressTracker( - long taskVolume, - Concurrency writeConcurrency, - ExecutionContext executionContext - ) { - return new TaskProgressTracker( - NodePropertyExporter.baseTask(this.procedureName, taskVolume), - executionContext.log(), - writeConcurrency, - executionContext.taskRegistryFactory() - ); - } - @Override public Stream consume( ComputationResult computationResult, @@ -158,48 +69,41 @@ public Stream consume( }); } - void writeToNeo( + private void writeToNeo( AbstractResultBuilder resultBuilder, ComputationResult computationResult, ExecutionContext executionContext ) { try (ProgressTimer ignored = ProgressTimer.start(resultBuilder::withWriteMillis)) { + var log = executionContext.log(); + var nodePropertyExporterBuilder = executionContext.nodePropertyExporterBuilder(); + var taskRegistryFactory = executionContext.taskRegistryFactory(); + var terminationFlag = computationResult.algorithm().terminationFlag; + var nodePropertyWriter = new NodePropertyWriter( + log, + nodePropertyExporterBuilder, + taskRegistryFactory, + terminationFlag + ); + var graph = computationResult.graph(); + var graphStore = computationResult.graphStore(); var config = computationResult.config(); - var progressTracker = createProgressTracker( - graph.nodeCount(), - config.writeConcurrency(), - executionContext - ); - var writeMode = computationResult.graphStore().capabilities().writeMode(); - var nodePropertySchema = graph.schema().nodeSchema().unionProperties(); + var resultStore = config.resolveResultStore(computationResult.resultStore()); var nodeProperties = nodePropertyListFunction.apply(computationResult); - validatePropertiesCanBeWritten( - writeMode, - nodePropertySchema, - nodeProperties + var nodePropertiesWritten = nodePropertyWriter.writeNodeProperties( + graph, + graphStore, + resultStore, + nodeProperties, + config.jobId(), + new StandardLabel(procedureName), + config ); - var resultStore = config.resolveResultStore(computationResult.resultStore()); - var exporter = executionContext - .nodePropertyExporterBuilder() - .withIdMap(graph) - .withTerminationFlag(computationResult.algorithm().terminationFlag) - .withProgressTracker(progressTracker) - .withResultStore(resultStore) - .withJobId(config.jobId()) - .parallel(DefaultPool.INSTANCE, config.writeConcurrency()) - .build(); - - try { - exporter.write(nodeProperties); - } finally { - progressTracker.release(); - } - resultBuilder.withNodeCount(computationResult.graph().nodeCount()); - resultBuilder.withNodePropertiesWritten(exporter.propertiesWritten()); + resultBuilder.withNodePropertiesWritten(nodePropertiesWritten.value()); } } } diff --git a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/predict/NodeClassificationPipelineWriteProc.java b/proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/predict/NodeClassificationPipelineWriteProc.java index fd7255f536..f7db944e51 100644 --- a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/predict/NodeClassificationPipelineWriteProc.java +++ b/proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/predict/NodeClassificationPipelineWriteProc.java @@ -19,13 +19,9 @@ */ package org.neo4j.gds.ml.pipeline.node.classification.predict; -import org.neo4j.gds.BaseProc; import org.neo4j.gds.applications.algorithms.machinery.MemoryEstimateResult; -import org.neo4j.gds.core.model.ModelCatalog; -import org.neo4j.gds.core.write.NodePropertyExporterBuilder; -import org.neo4j.gds.executor.ExecutionContext; -import org.neo4j.gds.executor.ProcedureExecutor; import org.neo4j.gds.procedures.GraphDataScienceProcedures; +import org.neo4j.gds.procedures.pipelines.WriteResult; import org.neo4j.procedure.Context; import org.neo4j.procedure.Description; import org.neo4j.procedure.Mode; @@ -35,31 +31,20 @@ import java.util.Map; import java.util.stream.Stream; -import static org.neo4j.gds.ml.pipeline.PipelineCompanion.preparePipelineConfig; import static org.neo4j.gds.ml.pipeline.node.classification.predict.NodeClassificationPipelineConstants.ESTIMATE_PREDICT_DESCRIPTION; import static org.neo4j.gds.ml.pipeline.node.classification.predict.NodeClassificationPipelineConstants.PREDICT_DESCRIPTION; -public class NodeClassificationPipelineWriteProc extends BaseProc { +public class NodeClassificationPipelineWriteProc { @Context public GraphDataScienceProcedures facade; - @Context - public ModelCatalog internalModelCatalog; - - @Context - public NodePropertyExporterBuilder nodePropertyExporterBuilder; - @Procedure(name = "gds.beta.pipeline.nodeClassification.predict.write", mode = Mode.WRITE) @Description(PREDICT_DESCRIPTION) public Stream write( @Name(value = "graphName") String graphName, @Name(value = "configuration", defaultValue = "{}") Map configuration ) { - preparePipelineConfig(graphName, configuration); - return new ProcedureExecutor<>( - new NodeClassificationPipelineWriteSpec(), - executionContext() - ).compute(graphName, configuration); + return facade.pipelines().nodeClassificationWrite(graphName, configuration); } @Procedure(name = "gds.beta.pipeline.nodeClassification.predict.write.estimate", mode = Mode.READ) @@ -70,9 +55,4 @@ public Stream estimate( ) { return facade.pipelines().nodeClassificationWriteEstimate(graphNameOrConfiguration, algoConfiguration); } - - @Override - public ExecutionContext executionContext() { - return super.executionContext().withNodePropertyExporterBuilder(nodePropertyExporterBuilder).withModelCatalog(internalModelCatalog); - } } diff --git a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/predict/NodeClassificationPipelineWriteSpec.java b/proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/predict/NodeClassificationPipelineWriteSpec.java index 6ae2963b01..3398427721 100644 --- a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/predict/NodeClassificationPipelineWriteSpec.java +++ b/proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/predict/NodeClassificationPipelineWriteSpec.java @@ -33,6 +33,7 @@ import org.neo4j.gds.procedures.pipelines.NodeClassificationPredictPipelineExecutor; import org.neo4j.gds.procedures.pipelines.NodeClassificationPredictPipelineWriteConfig; import org.neo4j.gds.procedures.pipelines.PredictedProbabilities; +import org.neo4j.gds.procedures.pipelines.WriteResult; import java.util.List; import java.util.Map; diff --git a/proc/machine-learning/src/test/java/org/neo4j/gds/ml/pipeline/node/classification/NodeClassificationPipelineAddStepProcsTest.java b/proc/machine-learning/src/test/java/org/neo4j/gds/ml/pipeline/node/classification/NodeClassificationPipelineAddStepProcsTest.java index 887abfcd87..a30b88663b 100644 --- a/proc/machine-learning/src/test/java/org/neo4j/gds/ml/pipeline/node/classification/NodeClassificationPipelineAddStepProcsTest.java +++ b/proc/machine-learning/src/test/java/org/neo4j/gds/ml/pipeline/node/classification/NodeClassificationPipelineAddStepProcsTest.java @@ -59,6 +59,7 @@ void setUp() { null, null, null, + null, new User(getUsername(), false), null, null, @@ -334,6 +335,7 @@ private GraphDataScienceProcedures buildFacade() { null, null, null, + null, new User(getUsername(), false), null, null, diff --git a/procedures/facade/src/main/java/org/neo4j/gds/procedures/LocalGraphDataScienceProcedures.java b/procedures/facade/src/main/java/org/neo4j/gds/procedures/LocalGraphDataScienceProcedures.java index 2a26e51b17..0a0e626d0e 100644 --- a/procedures/facade/src/main/java/org/neo4j/gds/procedures/LocalGraphDataScienceProcedures.java +++ b/procedures/facade/src/main/java/org/neo4j/gds/procedures/LocalGraphDataScienceProcedures.java @@ -206,6 +206,7 @@ public static GraphDataScienceProcedures create( writeContext.relationshipExporterBuilder(), requestScopedDependencies.getTaskRegistryFactory(), terminationMonitor, + requestScopedDependencies.getTerminationFlag(), requestScopedDependencies.getUser(), requestScopedDependencies.getUserLogRegistryFactory(), progressTrackerCreator, diff --git a/procedures/pipelines-facade/build.gradle b/procedures/pipelines-facade/build.gradle index 4d1c233d3c..c9f036ac92 100644 --- a/procedures/pipelines-facade/build.gradle +++ b/procedures/pipelines-facade/build.gradle @@ -31,6 +31,7 @@ dependencies { implementation project(':core') implementation project(':core-write') implementation project(':executor') + implementation project(':graph-schema-api') implementation project(':logging') implementation project(':memory-usage') implementation project(':metrics-api') diff --git a/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/NodeClassificationPredictPipelineWriteResultBuilder.java b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/NodeClassificationPredictPipelineWriteResultBuilder.java new file mode 100644 index 0000000000..874ec3cde9 --- /dev/null +++ b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/NodeClassificationPredictPipelineWriteResultBuilder.java @@ -0,0 +1,54 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.gds.procedures.pipelines; + +import org.neo4j.gds.api.Graph; +import org.neo4j.gds.applications.algorithms.machinery.AlgorithmProcessingTimings; +import org.neo4j.gds.applications.algorithms.machinery.ResultBuilder; +import org.neo4j.gds.applications.algorithms.metadata.NodePropertiesWritten; + +import java.util.Optional; + +class NodeClassificationPredictPipelineWriteResultBuilder implements ResultBuilder { + private final NodeClassificationPredictPipelineWriteConfig configuration; + + NodeClassificationPredictPipelineWriteResultBuilder(NodeClassificationPredictPipelineWriteConfig configuration) { + this.configuration = configuration; + } + + @Override + public WriteResult build( + Graph graph, + NodeClassificationPredictPipelineWriteConfig configuration, + Optional result, + AlgorithmProcessingTimings timings, + Optional metadata + ) { + if (result.isEmpty()) return WriteResult.emptyFrom(timings, this.configuration.toMap()); + + return new WriteResult( + timings.preProcessingMillis, + timings.computeMillis, + timings.mutateOrWriteMillis, + metadata.orElseThrow().value(), + this.configuration.toMap() + ); + } +} diff --git a/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/NodeClassificationPredictPipelineWriteStep.java b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/NodeClassificationPredictPipelineWriteStep.java new file mode 100644 index 0000000000..59c6006fa1 --- /dev/null +++ b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/NodeClassificationPredictPipelineWriteStep.java @@ -0,0 +1,68 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.gds.procedures.pipelines; + +import org.neo4j.gds.api.Graph; +import org.neo4j.gds.api.GraphStore; +import org.neo4j.gds.api.ResultStore; +import org.neo4j.gds.applications.algorithms.machinery.NodePropertyWriter; +import org.neo4j.gds.applications.algorithms.machinery.StandardLabel; +import org.neo4j.gds.applications.algorithms.machinery.WriteStep; +import org.neo4j.gds.applications.algorithms.metadata.NodePropertiesWritten; +import org.neo4j.gds.core.utils.progress.JobId; + +import java.util.Optional; + +class NodeClassificationPredictPipelineWriteStep implements WriteStep { + private final NodePropertyWriter nodePropertyWriter; + private final NodeClassificationPredictPipelineWriteConfig configuration; + + NodeClassificationPredictPipelineWriteStep( + NodePropertyWriter nodePropertyWriter, + NodeClassificationPredictPipelineWriteConfig configuration + ) { + this.nodePropertyWriter = nodePropertyWriter; + this.configuration = configuration; + } + + @Override + public NodePropertiesWritten execute( + Graph graph, + GraphStore graphStore, + ResultStore resultStore, + NodeClassificationPipelineResult result, + JobId jobId + ) { + var nodeProperties = PredictedProbabilities.asProperties( + Optional.of(result), + configuration.writeProperty(), + configuration.predictedProbabilityProperty() + ); + + return nodePropertyWriter.writeNodeProperties( + graph, + graphStore, + Optional.of(resultStore), nodeProperties, + jobId, + new StandardLabel("NodeClassificationPipelineWrite"), + configuration + ); + } +} diff --git a/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelineApplications.java b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelineApplications.java index 8d705b7d69..16293b0e80 100644 --- a/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelineApplications.java +++ b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelineApplications.java @@ -32,6 +32,7 @@ import org.neo4j.gds.applications.algorithms.machinery.GraphStoreService; import org.neo4j.gds.applications.algorithms.machinery.Label; import org.neo4j.gds.applications.algorithms.machinery.MemoryEstimateResult; +import org.neo4j.gds.applications.algorithms.machinery.NodePropertyWriter; import org.neo4j.gds.applications.algorithms.machinery.ProgressTrackerCreator; import org.neo4j.gds.applications.algorithms.machinery.StandardLabel; import org.neo4j.gds.applications.modelcatalog.ModelRepository; @@ -60,6 +61,7 @@ import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationTrain; import org.neo4j.gds.model.ModelConfig; import org.neo4j.gds.procedures.algorithms.AlgorithmsProcedureFacade; +import org.neo4j.gds.termination.TerminationFlag; import org.neo4j.gds.termination.TerminationMonitor; import java.util.Map; @@ -69,7 +71,7 @@ class PipelineApplications { private final Log log; - private final GraphStoreService gss; + private final GraphStoreService graphStoreService; private final ModelCatalog modelCatalog; private final PipelineRepository pipelineRepository; @@ -92,6 +94,8 @@ class PipelineApplications { private final NodeClassificationPredictPipelineEstimator nodeClassificationPredictPipelineEstimator; private final NodeClassificationTrainSideEffectsFactory nodeClassificationTrainSideEffectsFactory; + private final NodePropertyWriter nodePropertyWriter; + private final AlgorithmsProcedureFacade algorithmsProcedureFacade; private final AlgorithmEstimationTemplate algorithmEstimationTemplate; private final AlgorithmProcessingTemplate algorithmProcessingTemplate; @@ -117,12 +121,13 @@ class PipelineApplications { ProgressTrackerCreator progressTrackerCreator, NodeClassificationPredictPipelineEstimator nodeClassificationPredictPipelineEstimator, NodeClassificationTrainSideEffectsFactory nodeClassificationTrainSideEffectsFactory, + NodePropertyWriter nodePropertyWriter, AlgorithmsProcedureFacade algorithmsProcedureFacade, AlgorithmEstimationTemplate algorithmEstimationTemplate, AlgorithmProcessingTemplate algorithmProcessingTemplate ) { this.log = log; - this.gss = graphStoreService; + this.graphStoreService = graphStoreService; this.modelCatalog = modelCatalog; this.pipelineRepository = pipelineRepository; this.closeableResourceRegistry = closeableResourceRegistry; @@ -141,6 +146,7 @@ class PipelineApplications { this.progressTrackerCreator = progressTrackerCreator; this.nodeClassificationPredictPipelineEstimator = nodeClassificationPredictPipelineEstimator; this.nodeClassificationTrainSideEffectsFactory = nodeClassificationTrainSideEffectsFactory; + this.nodePropertyWriter = nodePropertyWriter; this.algorithmsProcedureFacade = algorithmsProcedureFacade; this.algorithmEstimationTemplate = algorithmEstimationTemplate; this.algorithmProcessingTemplate = algorithmProcessingTemplate; @@ -161,6 +167,7 @@ static PipelineApplications create( RelationshipExporterBuilder relationshipExporterBuilder, TaskRegistryFactory taskRegistryFactory, TerminationMonitor terminationMonitor, + TerminationFlag terminationFlag, User user, UserLogRegistryFactory userLogRegistryFactory, PipelineConfigurationParser pipelineConfigurationParser, @@ -182,6 +189,13 @@ static PipelineApplications create( modelRepository ); + var nodePropertyWriter = new NodePropertyWriter( + log, + nodePropertyExporterBuilder, + taskRegistryFactory, + terminationFlag + ); + return new PipelineApplications( log, graphStoreService, @@ -203,6 +217,7 @@ static PipelineApplications create( progressTrackerCreator, nodeClassificationPredictPipelineEstimator, nodeClassificationTrainSideEffectsFactory, + nodePropertyWriter, algorithmsProcedureFacade, algorithmEstimationTemplate, algorithmProcessingTemplate @@ -298,7 +313,7 @@ PredictMutateResult nodeClassificationPredictMutate(GraphName graphName, Map nodeClassificationPredictStream( ); } + WriteResult nodeClassificationPredictWrite(GraphName graphName, Map rawConfiguration) { + var configuration = pipelineConfigurationParser.parseNodeClassificationPredictWriteConfig(rawConfiguration); + var label = new StandardLabel("NodeClassificationPredictPipelineWrite"); + var computation = constructPredictComputation(configuration, label); + var writeStep = new NodeClassificationPredictPipelineWriteStep(nodePropertyWriter, configuration); + var resultBuilder = new NodeClassificationPredictPipelineWriteResultBuilder(configuration); + + return algorithmProcessingTemplate.processAlgorithmForWrite( + Optional.empty(), + graphName, + configuration, + Optional.empty(), + label, + () -> nodeClassificationPredictMemoryEstimation(configuration), + computation, + writeStep, + resultBuilder + ); + } + Stream nodeClassificationTrain( GraphName graphName, Map rawConfiguration diff --git a/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelineConfigurationParser.java b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelineConfigurationParser.java index 8e981061d5..21cb9d5299 100644 --- a/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelineConfigurationParser.java +++ b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelineConfigurationParser.java @@ -79,7 +79,7 @@ NodeClassificationPredictPipelineStreamConfig parseNodeClassificationPredictStre return parseNodeClassificationPipelineConfig(NodeClassificationPredictPipelineStreamConfig::of, configuration); } - NodeClassificationPredictPipelineWriteConfig parseNodeClassificationWriteConfig(Map configuration) { + NodeClassificationPredictPipelineWriteConfig parseNodeClassificationPredictWriteConfig(Map configuration) { return parseNodeClassificationPipelineConfig(NodeClassificationPredictPipelineWriteConfig::of, configuration); } diff --git a/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelinesProcedureFacade.java b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelinesProcedureFacade.java index b014aa4209..03df661f8d 100644 --- a/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelinesProcedureFacade.java +++ b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelinesProcedureFacade.java @@ -43,6 +43,7 @@ import org.neo4j.gds.ml.pipeline.nodePipeline.NodeFeatureStep; import org.neo4j.gds.ml.pipeline.nodePipeline.classification.NodeClassificationTrainingPipeline; import org.neo4j.gds.procedures.algorithms.AlgorithmsProcedureFacade; +import org.neo4j.gds.termination.TerminationFlag; import org.neo4j.gds.termination.TerminationMonitor; import java.util.ArrayList; @@ -86,6 +87,7 @@ public static PipelinesProcedureFacade create( RelationshipExporterBuilder relationshipExporterBuilder, TaskRegistryFactory taskRegistryFactory, TerminationMonitor terminationMonitor, + TerminationFlag terminationFlag, User user, UserLogRegistryFactory userLogRegistryFactory, ProgressTrackerCreator progressTrackerCreator, @@ -115,6 +117,7 @@ public static PipelinesProcedureFacade create( relationshipExporterBuilder, taskRegistryFactory, terminationMonitor, + terminationFlag, user, userLogRegistryFactory, pipelineConfigurationParser, @@ -335,12 +338,26 @@ public Stream nodeClassificationTrainEstimate( return Stream.of(result); } + public Stream nodeClassificationWrite( + String graphNameAsString, + Map configuration + ) { + PipelineCompanion.preparePipelineConfig(graphNameAsString, configuration); + nodeClassificationPredictConfigPreProcessor.enhanceInputWithPipelineParameters(configuration); + + var graphName = GraphName.parse(graphNameAsString); + + var result = pipelineApplications.nodeClassificationPredictWrite(graphName, configuration); + + return Stream.of(result); + } + public Stream nodeClassificationWriteEstimate( Object graphNameOrConfiguration, Map rawConfiguration ) { PipelineCompanion.preparePipelineConfig(graphNameOrConfiguration, rawConfiguration); - var configuration = pipelineConfigurationParser.parseNodeClassificationWriteConfig(rawConfiguration); + var configuration = pipelineConfigurationParser.parseNodeClassificationPredictWriteConfig(rawConfiguration); var result = pipelineApplications.nodeClassificationPredictEstimate( graphNameOrConfiguration, diff --git a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/predict/WriteResult.java b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/WriteResult.java similarity index 76% rename from proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/predict/WriteResult.java rename to procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/WriteResult.java index 51c03f0956..0785f2e11d 100644 --- a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/predict/WriteResult.java +++ b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/WriteResult.java @@ -17,15 +17,15 @@ * You should have received a copy of the GNU General Public License * along with this program. If not, see . */ -package org.neo4j.gds.ml.pipeline.node.classification.predict; +package org.neo4j.gds.procedures.pipelines; -import org.neo4j.gds.result.AbstractResultBuilder; +import org.neo4j.gds.applications.algorithms.machinery.AlgorithmProcessingTimings; import org.neo4j.gds.procedures.algorithms.results.StandardWriteResult; +import org.neo4j.gds.result.AbstractResultBuilder; import java.util.Map; public final class WriteResult extends StandardWriteResult { - public final long nodePropertiesWritten; WriteResult( @@ -45,8 +45,17 @@ public final class WriteResult extends StandardWriteResult { this.nodePropertiesWritten = nodePropertiesWritten; } - static class Builder extends AbstractResultBuilder { + static WriteResult emptyFrom(AlgorithmProcessingTimings timings, Map configurationMap) { + return new WriteResult( + timings.preProcessingMillis, + timings.computeMillis, + timings.mutateOrWriteMillis, + 0, + configurationMap + ); + } + public static class Builder extends AbstractResultBuilder { @Override public WriteResult build() { return new WriteResult( diff --git a/procedures/pipelines-facade/src/test/java/org/neo4j/gds/procedures/pipelines/PipelinesProcedureFacadeTest.java b/procedures/pipelines-facade/src/test/java/org/neo4j/gds/procedures/pipelines/PipelinesProcedureFacadeTest.java index de8b67acd5..3ecd521de0 100644 --- a/procedures/pipelines-facade/src/test/java/org/neo4j/gds/procedures/pipelines/PipelinesProcedureFacadeTest.java +++ b/procedures/pipelines-facade/src/test/java/org/neo4j/gds/procedures/pipelines/PipelinesProcedureFacadeTest.java @@ -56,6 +56,7 @@ void createPipeline() { null, null, null, + null, null ); var facade = new PipelinesProcedureFacade(null, null, applications); @@ -98,6 +99,7 @@ void shouldNotCreatePipelineWhenOneExists() { null, null, null, + null, null ); var facade = new PipelinesProcedureFacade(null, null, applications);