From 0d5e1c09eb65496a9e11fa50338aed15beb2dcd0 Mon Sep 17 00:00:00 2001 From: Lasse Westh-Nielsen Date: Wed, 4 Dec 2024 14:30:42 +0100 Subject: [PATCH 1/2] cleanup --- .../java/org/neo4j/gds/procedures/pipelines/Configurer.java | 2 +- .../neo4j/gds/procedures/pipelines/PipelineApplications.java | 2 +- .../gds/procedures/pipelines/PipelineConfigurationParser.java | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/Configurer.java b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/Configurer.java index 4357de4347..644860d97d 100644 --- a/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/Configurer.java +++ b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/Configurer.java @@ -30,7 +30,7 @@ import java.util.function.Supplier; import java.util.stream.Stream; -class Configurer { +public class Configurer { private final PipelineRepository pipelineRepository; private final User user; diff --git a/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelineApplications.java b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelineApplications.java index 52c2116209..a68cfaeb53 100644 --- a/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelineApplications.java +++ b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelineApplications.java @@ -78,7 +78,7 @@ import static org.neo4j.gds.config.RelationshipWeightConfig.RELATIONSHIP_WEIGHT_PROPERTY; import static org.neo4j.gds.utils.StringFormatting.formatWithLocale; -class PipelineApplications { +public class PipelineApplications { private final Log log; private final GraphStoreCatalogService graphStoreCatalogService; private final GraphStoreService graphStoreService; diff --git a/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelineConfigurationParser.java b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelineConfigurationParser.java index 9e2c619d03..8d46958ea2 100644 --- a/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelineConfigurationParser.java +++ b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelineConfigurationParser.java @@ -42,7 +42,7 @@ import java.util.function.BiFunction; import java.util.function.Function; -class PipelineConfigurationParser { +public class PipelineConfigurationParser { private final User user; PipelineConfigurationParser(User user) { From df5a90dee856cd39c235840c3c449c88bce79dbc Mon Sep 17 00:00:00 2001 From: Lasse Westh-Nielsen Date: Thu, 5 Dec 2024 11:52:30 +0100 Subject: [PATCH 2/2] some Sherlock Holmesing reveals that, while pipeline estimations are _reachable_, they are not _usable_ and therefore they are safe to delete. Their transitive closure is ginormous, there is a whole facility we can get rid of (private preconditions), it is just beautiful. a great day at the office! --- .../neo4j/gds/core/model/ModelCatalog.java | 11 -- .../gds/core/model/OpenModelCatalog.java | 13 -- .../neo4j/gds/VerifyThatModelCanBeStored.java | 50 ------- ...redictionPipelineMutateResultConsumer.java | 99 ------------- .../LinkPredictionPipelineMutateSpec.java | 81 ----------- .../LinkPredictionPipelineStreamSpec.java | 87 ------------ ...ictionPredictPipelineAlgorithmFactory.java | 134 ------------------ .../LinkPredictionPipelineTrainSpec.java | 112 --------------- ...edictionTrainPipelineAlgorithmFactory.java | 115 --------------- .../NodeClassificationPipelineMutateSpec.java | 88 ------------ .../NodeClassificationPipelineStreamSpec.java | 120 ---------------- .../NodeClassificationPipelineTrainSpec.java | 114 --------------- .../NodeClassificationPipelineWriteSpec.java | 86 ----------- ...cationPredictPipelineAlgorithmFactory.java | 114 --------------- .../NodeRegressionPipelineTrainSpec.java | 110 -------------- .../NodeRegressionPipelineMutateSpec.java | 106 -------------- .../NodeRegressionPipelineStreamSpec.java | 98 ------------- ...essionPredictPipelineAlgorithmFactory.java | 119 ---------------- ...onPredictPipelineAlgorithmFactoryTest.java | 115 --------------- .../predict/PathFindingMutateResultTest.java | 39 ----- ...tatePropertyComputationResultConsumer.java | 9 +- .../org/neo4j/gds/test/TestMutateSpec.java | 1 - .../procedures/pipelines/MutateResult.java | 48 ------- .../pipelines/PredictMutateResult.java | 14 -- .../gds/procedures/pipelines/WriteResult.java | 14 -- ...assificationPredictConfigPreProcessor.java | 15 +- ...deRegressionPredictConfigPreProcessor.java | 13 -- 27 files changed, 8 insertions(+), 1917 deletions(-) delete mode 100644 proc/common/src/main/java/org/neo4j/gds/VerifyThatModelCanBeStored.java delete mode 100644 proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/predict/LinkPredictionPipelineMutateResultConsumer.java delete mode 100644 proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/predict/LinkPredictionPipelineMutateSpec.java delete mode 100644 proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/predict/LinkPredictionPipelineStreamSpec.java delete mode 100644 proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/predict/LinkPredictionPredictPipelineAlgorithmFactory.java delete mode 100644 proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/train/LinkPredictionPipelineTrainSpec.java delete mode 100644 proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/train/LinkPredictionTrainPipelineAlgorithmFactory.java delete mode 100644 proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/predict/NodeClassificationPipelineMutateSpec.java delete mode 100644 proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/predict/NodeClassificationPipelineStreamSpec.java delete mode 100644 proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/predict/NodeClassificationPipelineTrainSpec.java delete mode 100644 proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/predict/NodeClassificationPipelineWriteSpec.java delete mode 100644 proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/predict/NodeClassificationPredictPipelineAlgorithmFactory.java delete mode 100644 proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/regression/NodeRegressionPipelineTrainSpec.java delete mode 100644 proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/regression/predict/NodeRegressionPipelineMutateSpec.java delete mode 100644 proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/regression/predict/NodeRegressionPipelineStreamSpec.java delete mode 100644 proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/regression/predict/NodeRegressionPredictPipelineAlgorithmFactory.java delete mode 100644 proc/machine-learning/src/test/java/org/neo4j/gds/ml/linkmodels/pipeline/predict/LinkPredictionPredictPipelineAlgorithmFactoryTest.java delete mode 100644 proc/machine-learning/src/test/java/org/neo4j/gds/ml/linkmodels/pipeline/predict/PathFindingMutateResultTest.java rename proc/{common/src/main/java/org/neo4j/gds => test/src/main/java/org/neo4j/gds/test}/MutatePropertyComputationResultConsumer.java (87%) diff --git a/model-catalog-api/src/main/java/org/neo4j/gds/core/model/ModelCatalog.java b/model-catalog-api/src/main/java/org/neo4j/gds/core/model/ModelCatalog.java index bc893371c5..c6dabeaccd 100644 --- a/model-catalog-api/src/main/java/org/neo4j/gds/core/model/ModelCatalog.java +++ b/model-catalog-api/src/main/java/org/neo4j/gds/core/model/ModelCatalog.java @@ -22,7 +22,6 @@ import org.jetbrains.annotations.Nullable; import org.neo4j.gds.core.model.Model.CustomInfo; import org.neo4j.gds.model.ModelConfig; -import org.neo4j.graphdb.GraphDatabaseService; import java.nio.file.Path; import java.util.Collection; @@ -62,10 +61,6 @@ Model get( Model publish(String username, String modelName); - void checkLicenseBeforeStoreModel(GraphDatabaseService db, String detail); - - Path getModelDirectory(GraphDatabaseService db); - Model store(String username, String modelName, Path modelDir); boolean isEmpty(); @@ -146,12 +141,6 @@ public boolean exists(String username, String modelName) { return null; } - @Override - public void checkLicenseBeforeStoreModel(GraphDatabaseService db, String detail) { } - - @Override - public Path getModelDirectory(GraphDatabaseService db) { return null; } - @Override public Model store(String username, String modelName, Path modelDir) { return null; } diff --git a/open-model-catalog/src/main/java/org/neo4j/gds/core/model/OpenModelCatalog.java b/open-model-catalog/src/main/java/org/neo4j/gds/core/model/OpenModelCatalog.java index 958556e1ef..cdc6d3cc93 100644 --- a/open-model-catalog/src/main/java/org/neo4j/gds/core/model/OpenModelCatalog.java +++ b/open-model-catalog/src/main/java/org/neo4j/gds/core/model/OpenModelCatalog.java @@ -23,7 +23,6 @@ import org.jetbrains.annotations.Nullable; import org.neo4j.gds.core.model.Model.CustomInfo; import org.neo4j.gds.model.ModelConfig; -import org.neo4j.graphdb.GraphDatabaseService; import java.nio.file.Path; import java.util.ArrayList; @@ -145,18 +144,6 @@ public boolean exists(String username, String modelName) { ); } - @Override - public void checkLicenseBeforeStoreModel(GraphDatabaseService db, String detail) { } - - @Override - public Path getModelDirectory(org.neo4j.graphdb.GraphDatabaseService db) { - throw new IllegalStateException( - "There is no model directory path. Storing models is not available in openGDS. " + - "Please consider licensing the Graph Data Science library. " + - "See documentation at https://neo4j.com/docs/graph-data-science/" - ); - } - @Override public Model store(String username, String modelName, Path modelDir) { throw new IllegalStateException( diff --git a/proc/common/src/main/java/org/neo4j/gds/VerifyThatModelCanBeStored.java b/proc/common/src/main/java/org/neo4j/gds/VerifyThatModelCanBeStored.java deleted file mode 100644 index bc1f2fb9ee..0000000000 --- a/proc/common/src/main/java/org/neo4j/gds/VerifyThatModelCanBeStored.java +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Copyright (c) "Neo4j" - * Neo4j Sweden AB [http://neo4j.com] - * - * This file is part of Neo4j. - * - * Neo4j is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - */ -package org.neo4j.gds; - -import org.neo4j.gds.config.AlgoBaseConfig; -import org.neo4j.gds.config.GraphProjectConfig; -import org.neo4j.gds.core.model.ModelCatalog; -import org.neo4j.gds.executor.validation.BeforeLoadValidation; -import org.neo4j.gds.model.ModelConfig; - -public final class VerifyThatModelCanBeStored implements BeforeLoadValidation { - private final ModelCatalog modelCatalog; - private final String username; - private final String modelType; - - public VerifyThatModelCanBeStored(ModelCatalog modelCatalog, String username, String modelType) { - this.modelCatalog = modelCatalog; - this.username = username; - this.modelType = modelType; - } - - @Override - public void validateConfigsBeforeLoad( - GraphProjectConfig graphProjectConfig, - TRAIN_CONFIG config - ) { - modelCatalog.verifyModelCanBeStored( - username, - config.modelName(), - modelType - ); - } -} diff --git a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/predict/LinkPredictionPipelineMutateResultConsumer.java b/proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/predict/LinkPredictionPipelineMutateResultConsumer.java deleted file mode 100644 index 4743f57878..0000000000 --- a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/predict/LinkPredictionPipelineMutateResultConsumer.java +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Copyright (c) "Neo4j" - * Neo4j Sweden AB [http://neo4j.com] - * - * This file is part of Neo4j. - * - * Neo4j is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - */ -package org.neo4j.gds.ml.linkmodels.pipeline.predict; - -import org.neo4j.gds.MutateComputationResultConsumer; -import org.neo4j.gds.NodeLabel; -import org.neo4j.gds.Orientation; -import org.neo4j.gds.RelationshipType; -import org.neo4j.gds.ResultBuilderFunction; -import org.neo4j.gds.core.Aggregation; -import org.neo4j.gds.core.concurrency.DefaultPool; -import org.neo4j.gds.core.concurrency.ParallelUtil; -import org.neo4j.gds.core.loading.construction.GraphFactory; -import org.neo4j.gds.executor.ComputationResult; -import org.neo4j.gds.executor.ExecutionContext; -import org.neo4j.gds.ml.linkmodels.LinkPredictionResult; -import org.neo4j.gds.procedures.pipelines.LinkPredictionPredictPipelineExecutor; -import org.neo4j.gds.procedures.pipelines.LinkPredictionPredictPipelineMutateConfig; -import org.neo4j.gds.procedures.pipelines.MutateResult; -import org.neo4j.gds.result.AbstractResultBuilder; -import org.neo4j.gds.termination.TerminationFlag; - -import java.util.Collection; -import java.util.stream.Stream; - -class LinkPredictionPipelineMutateResultConsumer extends MutateComputationResultConsumer { - LinkPredictionPipelineMutateResultConsumer( - ResultBuilderFunction resultBuilderFunction - ) { - super(resultBuilderFunction); - } - - @Override - protected void updateGraphStore( - AbstractResultBuilder resultBuilder, - ComputationResult computationResult, - ExecutionContext executionContext - ) { - var graphStore = computationResult.graphStore(); - Collection labelFilter = computationResult.algorithm().labelFilter().predictNodeLabels(); - var graph = graphStore.getGraph(labelFilter); - - var config = computationResult.config(); - var concurrency = config.concurrency(); - var mutateRelationshipType = RelationshipType.of(config.mutateRelationshipType()); - - var relationshipsBuilder = GraphFactory.initRelationshipsBuilder() - .aggregation(Aggregation.SINGLE) - .nodes(graph) - .relationshipType(mutateRelationshipType) - .orientation(Orientation.UNDIRECTED) - .addPropertyConfig(GraphFactory.PropertyConfig.of(config.mutateProperty())) - .concurrency(concurrency) - .executorService(DefaultPool.INSTANCE) - .build(); - - var resultWithHistogramBuilder = (MutateResult.Builder) resultBuilder; - var predictedLinkStream = computationResult.result() - .map(LinkPredictionResult::stream) - .orElseGet(Stream::empty); - ParallelUtil.parallelStreamConsume( - predictedLinkStream, - concurrency, - TerminationFlag.wrap(executionContext.terminationMonitor()), - stream -> stream.forEach(predictedLink -> { - relationshipsBuilder.addFromInternal( - graph.toRootNodeId(predictedLink.sourceId()), - graph.toRootNodeId(predictedLink.targetId()), - predictedLink.probability() - ); - resultWithHistogramBuilder.recordHistogramValue(predictedLink.probability()); - })); - - var relationships = relationshipsBuilder.build(); - - - computationResult - .graphStore() - .addRelationshipType(relationships); - resultBuilder.withRelationshipsWritten(relationships.topology().elementCount()); - } -} diff --git a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/predict/LinkPredictionPipelineMutateSpec.java b/proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/predict/LinkPredictionPipelineMutateSpec.java deleted file mode 100644 index 53634a4e3c..0000000000 --- a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/predict/LinkPredictionPipelineMutateSpec.java +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Copyright (c) "Neo4j" - * Neo4j Sweden AB [http://neo4j.com] - * - * This file is part of Neo4j. - * - * Neo4j is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - */ -package org.neo4j.gds.ml.linkmodels.pipeline.predict; - -import org.neo4j.gds.executor.AlgorithmSpec; -import org.neo4j.gds.executor.ComputationResult; -import org.neo4j.gds.executor.ComputationResultConsumer; -import org.neo4j.gds.executor.ExecutionContext; -import org.neo4j.gds.executor.GdsCallable; -import org.neo4j.gds.procedures.algorithms.configuration.NewConfigFunction; -import org.neo4j.gds.ml.linkmodels.LinkPredictionResult; -import org.neo4j.gds.procedures.pipelines.LinkPredictionPredictPipelineExecutor; -import org.neo4j.gds.procedures.pipelines.LinkPredictionPredictPipelineMutateConfig; -import org.neo4j.gds.procedures.pipelines.MutateResult; - -import java.util.Collections; -import java.util.stream.Stream; - -import static org.neo4j.gds.executor.ExecutionMode.MUTATE_RELATIONSHIP; -import static org.neo4j.gds.ml.linkmodels.pipeline.LinkPredictionPipelineCompanion.PREDICT_DESCRIPTION; - -@GdsCallable(name = "gds.beta.pipeline.linkPrediction.predict.mutate", description = PREDICT_DESCRIPTION, executionMode = MUTATE_RELATIONSHIP) -public class LinkPredictionPipelineMutateSpec implements AlgorithmSpec< - LinkPredictionPredictPipelineExecutor, - LinkPredictionResult, - LinkPredictionPredictPipelineMutateConfig, - Stream, - LinkPredictionPredictPipelineAlgorithmFactory> { - @Override - public String name() { - return "LinkPredictionPipelineMutate"; - } - - @Override - public LinkPredictionPredictPipelineAlgorithmFactory algorithmFactory( - ExecutionContext executionContext) { - return new LinkPredictionPredictPipelineAlgorithmFactory<>(executionContext); - } - - @Override - public NewConfigFunction newConfigFunction() { - return LinkPredictionPredictPipelineMutateConfig::of; - } - - @Override - public ComputationResultConsumer> computationResultConsumer() { - return new LinkPredictionPipelineMutateResultConsumer(this::resultBuilder); - } - - private MutateResult.Builder resultBuilder( - ComputationResult computeResult, - ExecutionContext executionContext - ) { - var builder = new MutateResult.Builder() - .withSamplingStats(computeResult.result() - .map(LinkPredictionResult::samplingStats) - .orElseGet(Collections::emptyMap)); - - if (executionContext.returnColumns().contains("probabilityDistribution")) { - builder.withHistogram(); - } - return builder; - } -} diff --git a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/predict/LinkPredictionPipelineStreamSpec.java b/proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/predict/LinkPredictionPipelineStreamSpec.java deleted file mode 100644 index 5daa1d9048..0000000000 --- a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/predict/LinkPredictionPipelineStreamSpec.java +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Copyright (c) "Neo4j" - * Neo4j Sweden AB [http://neo4j.com] - * - * This file is part of Neo4j. - * - * Neo4j is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - */ -package org.neo4j.gds.ml.linkmodels.pipeline.predict; - -import org.neo4j.gds.NodeLabel; -import org.neo4j.gds.executor.AlgorithmSpec; -import org.neo4j.gds.executor.ComputationResultConsumer; -import org.neo4j.gds.executor.ExecutionContext; -import org.neo4j.gds.executor.GdsCallable; -import org.neo4j.gds.procedures.algorithms.configuration.NewConfigFunction; -import org.neo4j.gds.ml.linkmodels.LinkPredictionResult; -import org.neo4j.gds.procedures.pipelines.LinkPredictionPredictPipelineExecutor; -import org.neo4j.gds.procedures.pipelines.LinkPredictionPredictPipelineStreamConfig; -import org.neo4j.gds.procedures.pipelines.StreamResult; - -import java.util.Collection; -import java.util.stream.Stream; - -import static org.neo4j.gds.LoggingUtil.runWithExceptionLogging; -import static org.neo4j.gds.executor.ExecutionMode.STREAM; -import static org.neo4j.gds.ml.linkmodels.pipeline.LinkPredictionPipelineCompanion.PREDICT_DESCRIPTION; - -@GdsCallable( - name = "gds.beta.pipeline.linkPrediction.predict.stream", description = PREDICT_DESCRIPTION, executionMode = STREAM -) -public class LinkPredictionPipelineStreamSpec implements - AlgorithmSpec< - LinkPredictionPredictPipelineExecutor, - LinkPredictionResult, - LinkPredictionPredictPipelineStreamConfig, - Stream, - LinkPredictionPredictPipelineAlgorithmFactory> { - @Override - public String name() { - return "LinkPredictionPipelineStream"; - } - - @Override - public LinkPredictionPredictPipelineAlgorithmFactory algorithmFactory( - ExecutionContext executionContext - ) { - return new LinkPredictionPredictPipelineAlgorithmFactory<>(executionContext); - } - - @Override - public NewConfigFunction newConfigFunction() { - return LinkPredictionPredictPipelineStreamConfig::of; - } - - @Override - public ComputationResultConsumer> computationResultConsumer() { - return (computationResult, executionContext) -> runWithExceptionLogging( - "Result streaming failed", - executionContext.log(), - () -> computationResult.result() - .map(result -> { - var graphStore = computationResult.graphStore(); - Collection labelFilter = computationResult.algorithm().labelFilter().predictNodeLabels(); - var graph = graphStore.getGraph(labelFilter); - - return result.stream() - .map(predictedLink -> new StreamResult( - graph.toOriginalNodeId(predictedLink.sourceId()), - graph.toOriginalNodeId(predictedLink.targetId()), - predictedLink.probability() - )); - }).orElseGet(Stream::empty) - ); - } -} diff --git a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/predict/LinkPredictionPredictPipelineAlgorithmFactory.java b/proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/predict/LinkPredictionPredictPipelineAlgorithmFactory.java deleted file mode 100644 index 15d62fde0e..0000000000 --- a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/predict/LinkPredictionPredictPipelineAlgorithmFactory.java +++ /dev/null @@ -1,134 +0,0 @@ -/* - * Copyright (c) "Neo4j" - * Neo4j Sweden AB [http://neo4j.com] - * - * This file is part of Neo4j. - * - * Neo4j is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - */ -package org.neo4j.gds.ml.linkmodels.pipeline.predict; - -import org.neo4j.gds.GraphStoreAlgorithmFactory; -import org.neo4j.gds.api.GraphStore; -import org.neo4j.gds.core.GraphDimensions; -import org.neo4j.gds.core.loading.CatalogRequest; -import org.neo4j.gds.core.loading.GraphStoreCatalog; -import org.neo4j.gds.core.model.ModelCatalog; -import org.neo4j.gds.mem.MemoryEstimation; -import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker; -import org.neo4j.gds.core.utils.progress.tasks.Task; -import org.neo4j.gds.executor.ExecutionContext; -import org.neo4j.gds.ml.models.ClassifierFactory; -import org.neo4j.gds.procedures.pipelines.LPGraphStoreFilterFactory; -import org.neo4j.gds.procedures.pipelines.LinkPredictionPredictPipelineBaseConfig; -import org.neo4j.gds.procedures.pipelines.LinkPredictionPredictPipelineExecutor; - -import static org.neo4j.gds.ml.linkmodels.pipeline.LinkPredictionPipelineCompanion.getTrainedLPPipelineModel; -import static org.neo4j.gds.ml.pipeline.PipelineCompanion.ANONYMOUS_GRAPH; - -public class LinkPredictionPredictPipelineAlgorithmFactory extends GraphStoreAlgorithmFactory { - private final ExecutionContext executionContext; - private final ModelCatalog modelCatalog; - - LinkPredictionPredictPipelineAlgorithmFactory(ExecutionContext executionContext) { - super(); - this.executionContext = executionContext; - this.modelCatalog = executionContext.modelCatalog(); - } - - @Override - public Task progressTask(GraphStore graphStore, CONFIG config) { - var pipeline = getTrainedLPPipelineModel(modelCatalog, config.modelName(), config.username()) - .customInfo() - .pipeline(); - - return LinkPredictionPredictPipelineExecutor.progressTask(taskName(), pipeline, graphStore, config); - } - - @Override - public String taskName() { - return "Link Prediction Predict Pipeline"; - } - - @Override - public LinkPredictionPredictPipelineExecutor build( - GraphStore graphStore, - CONFIG configuration, - ProgressTracker progressTracker - ) { - var model = getTrainedLPPipelineModel( - modelCatalog, - configuration.modelName(), - configuration.username() - ); - - var trainConfig = model.trainConfig(); - var lpGraphStoreFilter = LPGraphStoreFilterFactory.generate(executionContext.log(), trainConfig, configuration, graphStore); - - return new LinkPredictionPredictPipelineExecutor( - model.customInfo().pipeline(), - ClassifierFactory.create(model.data()), - lpGraphStoreFilter, - configuration, - executionContext, - graphStore, - progressTracker - ); - } - - @Override - public MemoryEstimation memoryEstimation(CONFIG configuration) { - var model = getTrainedLPPipelineModel( - modelCatalog, - configuration.modelName(), - configuration.username() - ); - var linkPredictionPipeline = model.customInfo().pipeline(); - - return LinkPredictionPredictPipelineExecutor.estimate( - modelCatalog, - linkPredictionPipeline, - configuration, - model.data(), - executionContext.algorithmsProcedureFacade() - ); - } - - @Override - public GraphDimensions estimatedGraphDimensionTransformer(GraphDimensions graphDimensions, CONFIG config) { - var model = getTrainedLPPipelineModel( - modelCatalog, - config.modelName(), - config.username() - ); - - //Don't have nodeLabel information for filtering to give better estimation - if (config.graphName().equals(ANONYMOUS_GRAPH)) return graphDimensions; - - var graphStore = GraphStoreCatalog - .get(CatalogRequest.of(config.username(), executionContext.databaseId()), config.graphName()) - .graphStore(); - - var lpNodeLabelFilter = LPGraphStoreFilterFactory.generate(executionContext.log(), model.trainConfig(), config, graphStore); - - //Taking nodePropertyStepsLabels since they are superset of source&target nodeLabels, to give the upper bound estimation - //In the future we can add nodeCount per label info to GraphDimensions to make more exact estimations - return GraphDimensions - .builder() - .from(graphDimensions) - .nodeCount(graphStore.getGraph(lpNodeLabelFilter.nodePropertyStepsBaseLabels()).nodeCount()) - .build(); - } - -} diff --git a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/train/LinkPredictionPipelineTrainSpec.java b/proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/train/LinkPredictionPipelineTrainSpec.java deleted file mode 100644 index b2a54b059b..0000000000 --- a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/train/LinkPredictionPipelineTrainSpec.java +++ /dev/null @@ -1,112 +0,0 @@ -/* - * Copyright (c) "Neo4j" - * Neo4j Sweden AB [http://neo4j.com] - * - * This file is part of Neo4j. - * - * Neo4j is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - */ -package org.neo4j.gds.ml.linkmodels.pipeline.train; - -import org.neo4j.gds.VerifyThatModelCanBeStored; -import org.neo4j.gds.compat.GdsVersionInfoProvider; -import org.neo4j.gds.executor.AlgorithmSpec; -import org.neo4j.gds.executor.ComputationResultConsumer; -import org.neo4j.gds.executor.ExecutionContext; -import org.neo4j.gds.executor.GdsCallable; -import org.neo4j.gds.procedures.algorithms.configuration.NewConfigFunction; -import org.neo4j.gds.executor.validation.BeforeLoadValidation; -import org.neo4j.gds.executor.validation.ValidationConfiguration; -import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionTrainingPipeline; -import org.neo4j.gds.ml.pipeline.linkPipeline.train.LinkPredictionTrainConfig; -import org.neo4j.gds.ml.pipeline.linkPipeline.train.LinkPredictionTrainPipelineExecutor; -import org.neo4j.gds.procedures.pipelines.LinkPredictionTrainResult; -import org.neo4j.graphdb.GraphDatabaseService; - -import java.util.List; -import java.util.stream.Stream; - -import static org.neo4j.gds.executor.ExecutionMode.TRAIN; - -@GdsCallable(name = "gds.beta.pipeline.linkPrediction.train", description = "Trains a link prediction model based on a pipeline", executionMode = TRAIN) -public class LinkPredictionPipelineTrainSpec implements AlgorithmSpec< - LinkPredictionTrainPipelineExecutor, - LinkPredictionTrainPipelineExecutor.LinkPredictionTrainPipelineResult, - LinkPredictionTrainConfig, - Stream, - LinkPredictionTrainPipelineAlgorithmFactory - > { - @Override - public String name() { - return "LinkPredictionPipelineTrain"; - } - - @Override - public LinkPredictionTrainPipelineAlgorithmFactory algorithmFactory(ExecutionContext executionContext) { - var gdsVersion = GdsVersionInfoProvider.GDS_VERSION_INFO.gdsVersion(); - return new LinkPredictionTrainPipelineAlgorithmFactory(executionContext, gdsVersion); - } - - @Override - public NewConfigFunction newConfigFunction() { - return LinkPredictionTrainConfig::of; - } - - @Override - public ComputationResultConsumer> computationResultConsumer() { - return (computationResult, executionContext) -> { - return computationResult.result().map(result -> { - var model = result.model(); - var modelCatalog = executionContext.modelCatalog(); - assert modelCatalog != null : "ModelCatalog should have been set in the ExecutionContext by this point!!!"; - modelCatalog.set(model); - - if (computationResult.config().storeModelToDisk()) { - try { - // FIXME: This works but is not what we want to do! - var databaseService = executionContext.dependencyResolver() - .resolveDependency(GraphDatabaseService.class); - modelCatalog.checkLicenseBeforeStoreModel(databaseService, "Store a model"); - var modelDir = modelCatalog.getModelDirectory(databaseService); - modelCatalog.store(model.creator(), model.name(), modelDir); - } catch (Exception e) { - executionContext.log().error("Failed to store model to disk after training.", e.getMessage()); - throw e; - } - } - return Stream.of(new LinkPredictionTrainResult(model, result.trainingStatistics(), computationResult.computeMillis() - )); - }).orElseGet(Stream::empty); - }; - } - - @Override - public ValidationConfiguration validationConfig(ExecutionContext executionContext) { - return new ValidationConfiguration<>() { - @Override - public List> beforeLoadValidations() { - var modelCatalog = executionContext.modelCatalog(); - assert modelCatalog != null : "ModelCatalog should have been set in the ExecutionContext by this point!!!"; - return List.of( - new VerifyThatModelCanBeStored<>( - modelCatalog, - executionContext.username(), - LinkPredictionTrainingPipeline.MODEL_TYPE - ) - ); - } - }; - } - -} diff --git a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/train/LinkPredictionTrainPipelineAlgorithmFactory.java b/proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/train/LinkPredictionTrainPipelineAlgorithmFactory.java deleted file mode 100644 index 15379907d1..0000000000 --- a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/train/LinkPredictionTrainPipelineAlgorithmFactory.java +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Copyright (c) "Neo4j" - * Neo4j Sweden AB [http://neo4j.com] - * - * This file is part of Neo4j. - * - * Neo4j is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - */ -package org.neo4j.gds.ml.linkmodels.pipeline.train; - -import org.neo4j.gds.GraphStoreAlgorithmFactory; -import org.neo4j.gds.api.GraphStore; -import org.neo4j.gds.core.GraphDimensions; -import org.neo4j.gds.mem.MemoryEstimation; -import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker; -import org.neo4j.gds.core.utils.progress.tasks.Task; -import org.neo4j.gds.executor.ExecutionContext; -import org.neo4j.gds.ml.pipeline.PipelineCatalog; -import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionTrainingPipeline; -import org.neo4j.gds.ml.pipeline.linkPipeline.train.LinkPredictionTrainConfig; -import org.neo4j.gds.ml.pipeline.linkPipeline.train.LinkPredictionTrainPipelineExecutor; - -import static org.neo4j.gds.ml.pipeline.PipelineCompanion.validateMainMetric; - -public class LinkPredictionTrainPipelineAlgorithmFactory extends GraphStoreAlgorithmFactory { - private final ExecutionContext executionContext; - - private final String gdsVersion; - - LinkPredictionTrainPipelineAlgorithmFactory(ExecutionContext executionContext, String gdsVersion) { - this.executionContext = executionContext; - this.gdsVersion = gdsVersion; - } - - @Override - public LinkPredictionTrainPipelineExecutor build( - GraphStore graphStore, - LinkPredictionTrainConfig trainConfig, - ProgressTracker progressTracker - ) { - var pipeline = PipelineCatalog.getTyped( - trainConfig.username(), - trainConfig.pipeline(), - LinkPredictionTrainingPipeline.class - ); - - validateMainMetric(pipeline, trainConfig.mainMetric().name()); - - return new LinkPredictionTrainPipelineExecutor( - pipeline, - trainConfig, - executionContext, - graphStore, - progressTracker - ); - } - - @Override - public String taskName() { - return "Link Prediction Train Pipeline"; - } - - @Override - public Task progressTask(GraphStore graphStore, LinkPredictionTrainConfig config) { - var relationshipCount = config - .internalRelationshipTypes(graphStore) - .stream() - .mapToLong(graphStore::relationshipCount) - .sum(); - return LinkPredictionTrainPipelineExecutor.progressTask( - taskName(), - PipelineCatalog.getTyped(config.username(), config.pipeline(), LinkPredictionTrainingPipeline.class), - relationshipCount - ); - } - - @Override - public MemoryEstimation memoryEstimation(LinkPredictionTrainConfig configuration) { - var pipeline = PipelineCatalog.getTyped( - configuration.username(), - configuration.pipeline(), - LinkPredictionTrainingPipeline.class - ); - - return LinkPredictionTrainPipelineExecutor.estimate( - pipeline, - configuration, - executionContext.modelCatalog(), - executionContext.algorithmsProcedureFacade(), - executionContext.username() - ); - } - - @Override - public GraphDimensions estimatedGraphDimensionTransformer(GraphDimensions graphDimensions, LinkPredictionTrainConfig config) { - // inject expected relationship set sizes which are used in the estimation of the TrainPipelineExecutor - // this allows to compute the MemoryTree over a single graphDimension - var splitConfig = PipelineCatalog - .getTyped(config.username(), config.pipeline(), LinkPredictionTrainingPipeline.class) - .splitConfig(); - - return splitConfig.expectedGraphDimensions(graphDimensions, config.targetRelationshipType()); - } -} diff --git a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/predict/NodeClassificationPipelineMutateSpec.java b/proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/predict/NodeClassificationPipelineMutateSpec.java deleted file mode 100644 index 24389f4b74..0000000000 --- a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/predict/NodeClassificationPipelineMutateSpec.java +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Copyright (c) "Neo4j" - * Neo4j Sweden AB [http://neo4j.com] - * - * This file is part of Neo4j. - * - * Neo4j is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - */ -package org.neo4j.gds.ml.pipeline.node.classification.predict; - -import org.neo4j.gds.GraphStoreUpdater; -import org.neo4j.gds.MutateComputationResultConsumer; -import org.neo4j.gds.executor.AlgorithmSpec; -import org.neo4j.gds.executor.ComputationResult; -import org.neo4j.gds.executor.ComputationResultConsumer; -import org.neo4j.gds.executor.ExecutionContext; -import org.neo4j.gds.executor.GdsCallable; -import org.neo4j.gds.procedures.algorithms.configuration.NewConfigFunction; -import org.neo4j.gds.procedures.pipelines.NodeClassificationPipelineResult; -import org.neo4j.gds.procedures.pipelines.NodeClassificationPredictPipelineExecutor; -import org.neo4j.gds.procedures.pipelines.PredictMutateResult; -import org.neo4j.gds.procedures.pipelines.NodeClassificationPredictConfigPreProcessor; -import org.neo4j.gds.procedures.pipelines.NodeClassificationPredictPipelineMutateConfig; -import org.neo4j.gds.procedures.pipelines.PredictedProbabilities; -import org.neo4j.gds.result.AbstractResultBuilder; - -import java.util.Map; -import java.util.stream.Stream; - -import static org.neo4j.gds.executor.ExecutionMode.MUTATE_NODE_PROPERTY; -import static org.neo4j.gds.ml.pipeline.node.classification.predict.NodeClassificationPipelineConstants.PREDICT_DESCRIPTION; - -@GdsCallable(name = "gds.beta.pipeline.nodeClassification.predict.mutate", description = PREDICT_DESCRIPTION, executionMode = MUTATE_NODE_PROPERTY) -public class NodeClassificationPipelineMutateSpec implements AlgorithmSpec, NodeClassificationPredictPipelineAlgorithmFactory> { - @Override - public String name() { - return "NodeClassificationPipelineMutate"; - } - - @Override - public NodeClassificationPredictPipelineAlgorithmFactory algorithmFactory(ExecutionContext executionContext) { - return new NodeClassificationPredictPipelineAlgorithmFactory<>(executionContext); - } - - @Override - public NewConfigFunction newConfigFunction() { - return NodeClassificationPredictPipelineMutateConfig::of; - } - - @Override - public ComputationResultConsumer> computationResultConsumer() { - return new MutateComputationResultConsumer<>((computationResult, executionContext) -> new PredictMutateResult.Builder()) { - @Override - protected void updateGraphStore( - AbstractResultBuilder resultBuilder, - ComputationResult computationResult, - ExecutionContext executionContext - ) { - GraphStoreUpdater.UpdateGraphStore( - resultBuilder, - computationResult, - executionContext, - PredictedProbabilities.asProperties( - computationResult.result(), - computationResult.config().mutateProperty(), - computationResult.config().predictedProbabilityProperty() - ) - ); - } - }; - } - - @Override - public void preProcessConfig(Map userInput, ExecutionContext executionContext) { - NodeClassificationPredictConfigPreProcessor.enhanceInputWithPipelineParameters(userInput, executionContext); - } -} diff --git a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/predict/NodeClassificationPipelineStreamSpec.java b/proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/predict/NodeClassificationPipelineStreamSpec.java deleted file mode 100644 index 3dca4523dc..0000000000 --- a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/predict/NodeClassificationPipelineStreamSpec.java +++ /dev/null @@ -1,120 +0,0 @@ -/* - * Copyright (c) "Neo4j" - * Neo4j Sweden AB [http://neo4j.com] - * - * This file is part of Neo4j. - * - * Neo4j is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - */ -package org.neo4j.gds.ml.pipeline.node.classification.predict; - -import org.neo4j.gds.api.IdMap; -import org.neo4j.gds.core.model.Model; -import org.neo4j.gds.collections.ha.HugeObjectArray; -import org.neo4j.gds.executor.AlgorithmSpec; -import org.neo4j.gds.executor.ComputationResultConsumer; -import org.neo4j.gds.executor.ExecutionContext; -import org.neo4j.gds.executor.GdsCallable; -import org.neo4j.gds.procedures.algorithms.configuration.NewConfigFunction; -import org.neo4j.gds.ml.models.BaseModelData; -import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationPipelineTrainConfig; -import org.neo4j.gds.procedures.pipelines.NodeClassificationPipelineResult; -import org.neo4j.gds.procedures.pipelines.NodeClassificationPredictPipelineExecutor; -import org.neo4j.gds.procedures.pipelines.NodeClassificationPredictPipelineStreamConfig; -import org.neo4j.gds.procedures.pipelines.NodeClassificationStreamResult; - -import java.util.Arrays; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.stream.Collectors; -import java.util.stream.LongStream; -import java.util.stream.Stream; - -import static org.neo4j.gds.LoggingUtil.runWithExceptionLogging; -import static org.neo4j.gds.executor.ExecutionMode.STREAM; -import static org.neo4j.gds.ml.pipeline.node.classification.predict.NodeClassificationPipelineConstants.PREDICT_DESCRIPTION; - -@GdsCallable(name = "gds.beta.pipeline.nodeClassification.predict.stream", description = PREDICT_DESCRIPTION, executionMode = STREAM) -public class NodeClassificationPipelineStreamSpec implements AlgorithmSpec, NodeClassificationPredictPipelineAlgorithmFactory> { - @Override - public String name() { - return "NodeClassificationPipelineStream"; - } - - @Override - public NodeClassificationPredictPipelineAlgorithmFactory algorithmFactory( - ExecutionContext executionContext - ) { - return new NodeClassificationPredictPipelineAlgorithmFactory<>(executionContext); - } - - @Override - public NewConfigFunction newConfigFunction() { - return NodeClassificationPredictPipelineStreamConfig::of; - } - - @Override - public ComputationResultConsumer> computationResultConsumer() { - return (computationResult, executionContext) -> runWithExceptionLogging( - "Result streaming failed", - executionContext.log(), - () -> computationResult.result() - .map(result -> { - var pipelineGraphFilter = computationResult.algorithm().nodePropertyStepFilter(); - var graph = computationResult.graphStore().getGraph(pipelineGraphFilter.nodeLabels()); - - var predictedClasses = result.predictedClasses(); - var predictedProbabilities = result.predictedProbabilities(); - return LongStream - .range(IdMap.START_NODE_ID, graph.nodeCount()) - .mapToObj(nodeId -> - new NodeClassificationStreamResult( - graph.toOriginalNodeId(nodeId), - predictedClasses.get(nodeId), - nodePropertiesAsList(predictedProbabilities, nodeId) - ) - ); - }).orElseGet(Stream::empty) - ); - } - - private static List nodePropertiesAsList( - Optional> predictedProbabilities, - long nodeId - ) { - return predictedProbabilities.map(p -> { - var values = p.get(nodeId); - return Arrays.stream(values).boxed().collect(Collectors.toList()); - }).orElse(null); - } - - @Override - public void preProcessConfig(Map userInput, ExecutionContext executionContext) { - if (!userInput.containsKey("modelName")) return; - - var modelName = userInput.get("modelName"); - - var model = executionContext.modelCatalog().get( - executionContext.username(), - (String) modelName, - BaseModelData.class, - NodeClassificationPipelineTrainConfig.class, - Model.CustomInfo.class - ); - - if (!userInput.containsKey("targetNodeLabels")) userInput.put("targetNodeLabels", model.trainConfig().targetNodeLabels()); - if (!userInput.containsKey("relationshipTypes")) userInput.put("relationshipTypes", model.trainConfig().relationshipTypes()); - } -} diff --git a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/predict/NodeClassificationPipelineTrainSpec.java b/proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/predict/NodeClassificationPipelineTrainSpec.java deleted file mode 100644 index ce80d13e3d..0000000000 --- a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/predict/NodeClassificationPipelineTrainSpec.java +++ /dev/null @@ -1,114 +0,0 @@ -/* - * Copyright (c) "Neo4j" - * Neo4j Sweden AB [http://neo4j.com] - * - * This file is part of Neo4j. - * - * Neo4j is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - */ -package org.neo4j.gds.ml.pipeline.node.classification.predict; - -import org.neo4j.gds.VerifyThatModelCanBeStored; -import org.neo4j.gds.compat.GdsVersionInfoProvider; -import org.neo4j.gds.core.model.Model; -import org.neo4j.gds.executor.AlgorithmSpec; -import org.neo4j.gds.executor.ComputationResult; -import org.neo4j.gds.executor.ComputationResultConsumer; -import org.neo4j.gds.executor.ExecutionContext; -import org.neo4j.gds.executor.GdsCallable; -import org.neo4j.gds.procedures.algorithms.configuration.NewConfigFunction; -import org.neo4j.gds.executor.validation.BeforeLoadValidation; -import org.neo4j.gds.executor.validation.ValidationConfiguration; -import org.neo4j.gds.ml.pipeline.nodePipeline.classification.NodeClassificationTrainingPipeline; -import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationModelResult; -import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationPipelineTrainConfig; -import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationTrainAlgorithm; -import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationTrainPipelineAlgorithmFactory; -import org.neo4j.gds.procedures.pipelines.NodeClassificationPipelineTrainResult; -import org.neo4j.graphdb.GraphDatabaseService; - -import java.util.List; -import java.util.stream.Stream; - -import static org.neo4j.gds.executor.ExecutionMode.TRAIN; - -@GdsCallable(name = "gds.beta.pipeline.nodeClassification.train", description = "Trains a node classification model based on a pipeline", executionMode = TRAIN) -public class NodeClassificationPipelineTrainSpec implements AlgorithmSpec, NodeClassificationTrainPipelineAlgorithmFactory> { - - @Override - public String name() { - return "NodeClassificationPipelineTrain"; - } - - @Override - public NodeClassificationTrainPipelineAlgorithmFactory algorithmFactory(ExecutionContext executionContext) { - return new NodeClassificationTrainPipelineAlgorithmFactory(executionContext, GdsVersionInfoProvider.GDS_VERSION_INFO.gdsVersion()); - } - - @Override - public NewConfigFunction newConfigFunction() { - return NodeClassificationPipelineTrainConfig::of; - } - - @Override - public ComputationResultConsumer> computationResultConsumer() { - return (computationResult, executionContext) -> { - if (computationResult.result().isPresent()) { - var model = (Model) computationResult.result().get().model(); - var modelCatalog = executionContext.modelCatalog(); - modelCatalog.set(model); - - if (computationResult.config().storeModelToDisk()) { - try { - // FIXME: This works but is not what we want to do! - var databaseService = executionContext.dependencyResolver() - .resolveDependency(GraphDatabaseService.class); - - modelCatalog.checkLicenseBeforeStoreModel(databaseService, "Store a model"); - var modelDir = modelCatalog.getModelDirectory(databaseService); - modelCatalog.store(model.creator(), model.name(), modelDir); - } catch (Exception e) { - executionContext.log().error("Failed to store model to disk after training.", e.getMessage()); - throw e; - } - } - return Stream.of(constructProcResult(computationResult)); - } - - return Stream.empty(); - }; - } - - @Override - public ValidationConfiguration validationConfig(ExecutionContext executionContext) { - return new ValidationConfiguration<>() { - @Override - public List> beforeLoadValidations() { - return List.of( - new VerifyThatModelCanBeStored<>(executionContext.modelCatalog(), executionContext.username(), NodeClassificationTrainingPipeline.MODEL_TYPE) - ); - } - }; - } - - private NodeClassificationPipelineTrainResult constructProcResult( - ComputationResult< - NodeClassificationTrainAlgorithm, - NodeClassificationModelResult, - NodeClassificationPipelineTrainConfig> computationResult - ) { - var transformedResult = computationResult.result(); - return new NodeClassificationPipelineTrainResult(transformedResult, computationResult.computeMillis()); - } -} diff --git a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/predict/NodeClassificationPipelineWriteSpec.java b/proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/predict/NodeClassificationPipelineWriteSpec.java deleted file mode 100644 index 3398427721..0000000000 --- a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/predict/NodeClassificationPipelineWriteSpec.java +++ /dev/null @@ -1,86 +0,0 @@ -/* - * Copyright (c) "Neo4j" - * Neo4j Sweden AB [http://neo4j.com] - * - * This file is part of Neo4j. - * - * Neo4j is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - */ -package org.neo4j.gds.ml.pipeline.node.classification.predict; - -import org.neo4j.gds.WriteNodePropertiesComputationResultConsumer; -import org.neo4j.gds.WriteNodePropertyListFunction; -import org.neo4j.gds.core.write.NodeProperty; -import org.neo4j.gds.executor.AlgorithmSpec; -import org.neo4j.gds.executor.ComputationResult; -import org.neo4j.gds.executor.ComputationResultConsumer; -import org.neo4j.gds.executor.ExecutionContext; -import org.neo4j.gds.executor.GdsCallable; -import org.neo4j.gds.procedures.algorithms.configuration.NewConfigFunction; -import org.neo4j.gds.procedures.pipelines.NodeClassificationPipelineResult; -import org.neo4j.gds.procedures.pipelines.NodeClassificationPredictConfigPreProcessor; -import org.neo4j.gds.procedures.pipelines.NodeClassificationPredictPipelineExecutor; -import org.neo4j.gds.procedures.pipelines.NodeClassificationPredictPipelineWriteConfig; -import org.neo4j.gds.procedures.pipelines.PredictedProbabilities; -import org.neo4j.gds.procedures.pipelines.WriteResult; - -import java.util.List; -import java.util.Map; -import java.util.stream.Stream; - -import static org.neo4j.gds.executor.ExecutionMode.WRITE_NODE_PROPERTY; -import static org.neo4j.gds.ml.pipeline.node.classification.predict.NodeClassificationPipelineConstants.PREDICT_DESCRIPTION; - -@GdsCallable(name = "gds.beta.pipeline.nodeClassification.predict.write", description = PREDICT_DESCRIPTION, executionMode = WRITE_NODE_PROPERTY) -public class NodeClassificationPipelineWriteSpec implements AlgorithmSpec, NodeClassificationPredictPipelineAlgorithmFactory> { - @Override - public String name() { - return "NodeClassificationPipelineWrite"; - } - - @Override - public NodeClassificationPredictPipelineAlgorithmFactory algorithmFactory(ExecutionContext executionContext) { - return new NodeClassificationPredictPipelineAlgorithmFactory<>(executionContext); - } - - @Override - public NewConfigFunction newConfigFunction() { - return NodeClassificationPredictPipelineWriteConfig::of; - } - - @Override - public ComputationResultConsumer> computationResultConsumer() { - var writeNodePropertyListFunction = new WriteNodePropertyListFunction() { - @Override - public List apply(ComputationResult computationResult) { - return PredictedProbabilities.asProperties( - computationResult.result(), - computationResult.config().writeProperty(), - computationResult.config().predictedProbabilityProperty() - ); - } - }; - - return new WriteNodePropertiesComputationResultConsumer<>( - (computationResult, executionContext) -> new WriteResult.Builder(), - writeNodePropertyListFunction, - name() - ); - } - - @Override - public void preProcessConfig(Map userInput, ExecutionContext executionContext) { - NodeClassificationPredictConfigPreProcessor.enhanceInputWithPipelineParameters(userInput, executionContext); - } -} diff --git a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/predict/NodeClassificationPredictPipelineAlgorithmFactory.java b/proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/predict/NodeClassificationPredictPipelineAlgorithmFactory.java deleted file mode 100644 index 5c100c26c7..0000000000 --- a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/predict/NodeClassificationPredictPipelineAlgorithmFactory.java +++ /dev/null @@ -1,114 +0,0 @@ -/* - * Copyright (c) "Neo4j" - * Neo4j Sweden AB [http://neo4j.com] - * - * This file is part of Neo4j. - * - * Neo4j is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - */ -package org.neo4j.gds.ml.pipeline.node.classification.predict; - -import org.neo4j.gds.GraphStoreAlgorithmFactory; -import org.neo4j.gds.api.GraphStore; -import org.neo4j.gds.core.model.Model; -import org.neo4j.gds.core.model.ModelCatalog; -import org.neo4j.gds.mem.MemoryEstimation; -import org.neo4j.gds.mem.MemoryEstimations; -import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker; -import org.neo4j.gds.core.utils.progress.tasks.Task; -import org.neo4j.gds.executor.ExecutionContext; -import org.neo4j.gds.ml.core.subgraph.LocalIdMap; -import org.neo4j.gds.ml.models.Classifier; -import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationPipelineModelInfo; -import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationPipelineTrainConfig; -import org.neo4j.gds.procedures.pipelines.NodeClassificationPredictPipelineBaseConfig; -import org.neo4j.gds.procedures.pipelines.NodeClassificationPredictPipelineExecutor; -import org.neo4j.gds.procedures.pipelines.TrainedNCPipelineModel; - -public class NodeClassificationPredictPipelineAlgorithmFactory - - extends GraphStoreAlgorithmFactory -{ - - private final ModelCatalog modelCatalog; - private final ExecutionContext executionContext; - - NodeClassificationPredictPipelineAlgorithmFactory(ExecutionContext executionContext) { - super(); - this.modelCatalog = executionContext.modelCatalog(); - this.executionContext = executionContext; - } - - @Override - public Task progressTask(GraphStore graphStore, CONFIG config) { - var trainingPipeline = getTrainedNCPipelineModel( - modelCatalog, - config.modelName(), - config.username() - ).customInfo().pipeline(); - - return NodeClassificationPredictPipelineExecutor.progressTask(taskName(), trainingPipeline, graphStore); - } - - @Override - public String taskName() { - return "Node Classification Predict Pipeline"; - } - - @Override - public NodeClassificationPredictPipelineExecutor build( - GraphStore graphStore, - CONFIG configuration, - ProgressTracker progressTracker - ) { - var model = getTrainedNCPipelineModel( - modelCatalog, - configuration.modelName(), - configuration.username() - ); - var nodeClassificationPipeline = model.customInfo().pipeline(); - var classIdMap = LocalIdMap.of(model.customInfo().classes()); - - return new NodeClassificationPredictPipelineExecutor( - nodeClassificationPipeline, - configuration, - executionContext, - graphStore, - progressTracker, - model.data(), - classIdMap - ); - } - - @Override - public MemoryEstimation memoryEstimation(CONFIG configuration) { - var trainedNCPipelineModel = new TrainedNCPipelineModel(modelCatalog); - - var model = trainedNCPipelineModel.get(configuration.modelName(), configuration.username()); - - return MemoryEstimations.builder(NodeClassificationPredictPipelineExecutor.class.getSimpleName()) - .add("Pipeline executor", NodeClassificationPredictPipelineExecutor.estimate(model, configuration, modelCatalog, executionContext.algorithmsProcedureFacade())) - .build(); - } - - private static Model getTrainedNCPipelineModel( - ModelCatalog modelCatalog, - String modelName, - String username - ) { - var trainedNCPipelineModel = new TrainedNCPipelineModel(modelCatalog); - - return trainedNCPipelineModel.get(modelName, username); - } -} diff --git a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/regression/NodeRegressionPipelineTrainSpec.java b/proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/regression/NodeRegressionPipelineTrainSpec.java deleted file mode 100644 index b39d276c22..0000000000 --- a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/regression/NodeRegressionPipelineTrainSpec.java +++ /dev/null @@ -1,110 +0,0 @@ -/* - * Copyright (c) "Neo4j" - * Neo4j Sweden AB [http://neo4j.com] - * - * This file is part of Neo4j. - * - * Neo4j is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - */ -package org.neo4j.gds.ml.pipeline.node.regression; - -import org.neo4j.gds.VerifyThatModelCanBeStored; -import org.neo4j.gds.executor.AlgorithmSpec; -import org.neo4j.gds.executor.ComputationResultConsumer; -import org.neo4j.gds.executor.ExecutionContext; -import org.neo4j.gds.executor.GdsCallable; -import org.neo4j.gds.procedures.algorithms.configuration.NewConfigFunction; -import org.neo4j.gds.executor.validation.BeforeLoadValidation; -import org.neo4j.gds.executor.validation.ValidationConfiguration; -import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionTrainingPipeline; -import org.neo4j.gds.ml.pipeline.nodePipeline.regression.NodeRegressionPipelineTrainConfig; -import org.neo4j.gds.ml.pipeline.nodePipeline.regression.NodeRegressionTrainAlgorithm; -import org.neo4j.gds.ml.pipeline.nodePipeline.regression.NodeRegressionTrainPipelineAlgorithmFactory; -import org.neo4j.gds.procedures.pipelines.NodeRegressionPipelineTrainResult; -import org.neo4j.graphdb.GraphDatabaseService; - -import java.util.List; -import java.util.stream.Stream; - -import static org.neo4j.gds.executor.ExecutionMode.TRAIN; - -@GdsCallable(name = "gds.alpha.pipeline.nodeRegression.train", description = "Trains a node regression model based on a pipeline", executionMode = TRAIN) -public class NodeRegressionPipelineTrainSpec implements AlgorithmSpec< - NodeRegressionTrainAlgorithm, - org.neo4j.gds.ml.pipeline.nodePipeline.regression.NodeRegressionTrainResult.NodeRegressionTrainPipelineResult, - NodeRegressionPipelineTrainConfig, - Stream, - NodeRegressionTrainPipelineAlgorithmFactory> { - @Override - public String name() { - return "NodeRegressionPipelineTrain"; - } - - @Override - public NodeRegressionTrainPipelineAlgorithmFactory algorithmFactory(ExecutionContext executionContext) { - return new NodeRegressionTrainPipelineAlgorithmFactory(executionContext); - } - - @Override - public NewConfigFunction newConfigFunction() { - return NodeRegressionPipelineTrainConfig::of; - } - - @Override - public ComputationResultConsumer> computationResultConsumer() { - return (computationResult, executionContext) -> { - return computationResult.result().map(result -> { - var model = result.model(); - var modelCatalog = executionContext.modelCatalog(); - assert modelCatalog != null : "ModelCatalog should have been set in the ExecutionContext by this point!!!"; - modelCatalog.set(model); - - if (computationResult.config().storeModelToDisk()) { - try { - // FIXME: This works but is not what we want to do! - var databaseService = executionContext.dependencyResolver() - .resolveDependency(GraphDatabaseService.class); - modelCatalog.checkLicenseBeforeStoreModel(databaseService, "Store a model"); - var modelDir = modelCatalog.getModelDirectory(databaseService); - modelCatalog.store(model.creator(), model.name(), modelDir); - } catch (Exception e) { - executionContext.log().error("Failed to store model to disk after training.", e.getMessage()); - throw e; - } - } - return Stream.of(new NodeRegressionPipelineTrainResult(model, result.trainingStatistics(), computationResult.computeMillis() - )); - }).orElseGet(Stream::empty); - }; - } - - @Override - public ValidationConfiguration validationConfig(ExecutionContext executionContext) { - return new ValidationConfiguration<>() { - @Override - public List> beforeLoadValidations() { - var modelCatalog = executionContext.modelCatalog(); - assert modelCatalog != null : "ModelCatalog should have been set in the ExecutionContext by this point!!!"; - return List.of( - new VerifyThatModelCanBeStored<>( - modelCatalog, - executionContext.username(), - LinkPredictionTrainingPipeline.MODEL_TYPE - ) - ); - } - }; - } - -} diff --git a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/regression/predict/NodeRegressionPipelineMutateSpec.java b/proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/regression/predict/NodeRegressionPipelineMutateSpec.java deleted file mode 100644 index a14115d412..0000000000 --- a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/regression/predict/NodeRegressionPipelineMutateSpec.java +++ /dev/null @@ -1,106 +0,0 @@ -/* - * Copyright (c) "Neo4j" - * Neo4j Sweden AB [http://neo4j.com] - * - * This file is part of Neo4j. - * - * Neo4j is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - */ -package org.neo4j.gds.ml.pipeline.node.regression.predict; - -import org.neo4j.gds.MutatePropertyComputationResultConsumer; -import org.neo4j.gds.api.properties.nodes.EmptyDoubleNodePropertyValues; -import org.neo4j.gds.api.properties.nodes.NodePropertyValues; -import org.neo4j.gds.api.properties.nodes.NodePropertyValuesAdapter; -import org.neo4j.gds.collections.ha.HugeDoubleArray; -import org.neo4j.gds.core.write.NodeProperty; -import org.neo4j.gds.executor.AlgorithmSpec; -import org.neo4j.gds.executor.ComputationResult; -import org.neo4j.gds.executor.ComputationResultConsumer; -import org.neo4j.gds.executor.ExecutionContext; -import org.neo4j.gds.executor.GdsCallable; -import org.neo4j.gds.procedures.algorithms.configuration.NewConfigFunction; -import org.neo4j.gds.procedures.pipelines.NodeRegressionPredictConfigPreProcessor; -import org.neo4j.gds.procedures.pipelines.NodeRegressionPredictPipelineExecutor; -import org.neo4j.gds.procedures.pipelines.NodeRegressionPredictPipelineMutateConfig; -import org.neo4j.gds.procedures.pipelines.PredictMutateResult; - -import java.util.List; -import java.util.Map; -import java.util.stream.Stream; - -import static org.neo4j.gds.executor.ExecutionMode.MUTATE_NODE_PROPERTY; -import static org.neo4j.gds.ml.pipeline.node.regression.NodeRegressionProcCompanion.PREDICT_DESCRIPTION; - -@GdsCallable( - name = "gds.alpha.pipeline.nodeRegression.predict.mutate", description = PREDICT_DESCRIPTION, - executionMode = MUTATE_NODE_PROPERTY -) -public class NodeRegressionPipelineMutateSpec - implements AlgorithmSpec< - NodeRegressionPredictPipelineExecutor, - HugeDoubleArray, - NodeRegressionPredictPipelineMutateConfig, - Stream, - NodeRegressionPredictPipelineAlgorithmFactory> { - @Override - public String name() { - return "NodeRegressionPipelineMutate"; - } - - @Override - public NodeRegressionPredictPipelineAlgorithmFactory algorithmFactory( - ExecutionContext executionContext - ) { - return new NodeRegressionPredictPipelineAlgorithmFactory<>(executionContext); - } - - @Override - public NewConfigFunction newConfigFunction() { - return NodeRegressionPredictPipelineMutateConfig::of; - } - - @Override - public void preProcessConfig(Map userInput, ExecutionContext executionContext) { - NodeRegressionPredictConfigPreProcessor.enhanceInputWithPipelineParameters(userInput, executionContext); - } - - @Override - public ComputationResultConsumer> computationResultConsumer() { - return new MutatePropertyComputationResultConsumer<>( - this::nodePropertyList, - this::resultBuilder - ); - } - - private List nodePropertyList(ComputationResult computationResult) { - return List.of(NodeProperty.of( - computationResult.config().mutateProperty(), - nodeProperties(computationResult) - )); - } - - private NodePropertyValues nodeProperties(ComputationResult computationResult) { - return computationResult.result() - .map(NodePropertyValuesAdapter::adapt) - .orElse(EmptyDoubleNodePropertyValues.INSTANCE); - } - - private PredictMutateResult.Builder resultBuilder( - ComputationResult computeResult, - ExecutionContext executionContext - ) { - return new PredictMutateResult.Builder(); - } -} diff --git a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/regression/predict/NodeRegressionPipelineStreamSpec.java b/proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/regression/predict/NodeRegressionPipelineStreamSpec.java deleted file mode 100644 index 72c47f00cf..0000000000 --- a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/regression/predict/NodeRegressionPipelineStreamSpec.java +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Copyright (c) "Neo4j" - * Neo4j Sweden AB [http://neo4j.com] - * - * This file is part of Neo4j. - * - * Neo4j is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - */ -package org.neo4j.gds.ml.pipeline.node.regression.predict; - -import org.neo4j.gds.api.Graph; -import org.neo4j.gds.api.IdMap; -import org.neo4j.gds.api.properties.nodes.NodePropertyValues; -import org.neo4j.gds.api.properties.nodes.NodePropertyValuesAdapter; -import org.neo4j.gds.collections.ha.HugeDoubleArray; -import org.neo4j.gds.executor.AlgorithmSpec; -import org.neo4j.gds.executor.ComputationResultConsumer; -import org.neo4j.gds.executor.ExecutionContext; -import org.neo4j.gds.executor.GdsCallable; -import org.neo4j.gds.procedures.algorithms.configuration.NewConfigFunction; -import org.neo4j.gds.procedures.pipelines.NodeRegressionPredictConfigPreProcessor; -import org.neo4j.gds.procedures.pipelines.NodeRegressionPredictPipelineBaseConfig; -import org.neo4j.gds.procedures.pipelines.NodeRegressionPredictPipelineExecutor; -import org.neo4j.gds.procedures.pipelines.NodeRegressionStreamResult; - -import java.util.Map; -import java.util.stream.LongStream; -import java.util.stream.Stream; - -import static org.neo4j.gds.LoggingUtil.runWithExceptionLogging; -import static org.neo4j.gds.executor.ExecutionMode.STREAM; -import static org.neo4j.gds.ml.pipeline.node.regression.NodeRegressionProcCompanion.PREDICT_DESCRIPTION; - -@GdsCallable( - name = "gds.alpha.pipeline.nodeRegression.predict.stream", description = PREDICT_DESCRIPTION, - executionMode = STREAM -) -public class NodeRegressionPipelineStreamSpec - implements AlgorithmSpec< - NodeRegressionPredictPipelineExecutor, - HugeDoubleArray, - NodeRegressionPredictPipelineBaseConfig, - Stream, - NodeRegressionPredictPipelineAlgorithmFactory> { - @Override - public String name() { - return "NodeRegressionPipelineStream"; - } - - @Override - public NodeRegressionPredictPipelineAlgorithmFactory algorithmFactory( - ExecutionContext executionContext - ) { - return new NodeRegressionPredictPipelineAlgorithmFactory<>(executionContext); - } - - @Override - public NewConfigFunction newConfigFunction() { - return NodeRegressionPredictPipelineBaseConfig::of; - } - - @Override - public void preProcessConfig(Map userInput, ExecutionContext executionContext) { - NodeRegressionPredictConfigPreProcessor.enhanceInputWithPipelineParameters(userInput, executionContext); - } - - @Override - public ComputationResultConsumer> computationResultConsumer() { - return (computationResult, executionContext) -> - runWithExceptionLogging( - "Result streaming failed", - executionContext.log(), - () -> computationResult.result() - .map(result -> { - Graph graph = computationResult.graph(); - NodePropertyValues nodePropertyValues = NodePropertyValuesAdapter.adapt(result); - return LongStream - .range(IdMap.START_NODE_ID, graph.nodeCount()) - .filter(nodePropertyValues::hasValue) - .mapToObj(nodeId -> new NodeRegressionStreamResult( - graph.toOriginalNodeId(nodeId), - nodePropertyValues.doubleValue(nodeId) - )); - }).orElseGet(Stream::empty) - ); - } -} diff --git a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/regression/predict/NodeRegressionPredictPipelineAlgorithmFactory.java b/proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/regression/predict/NodeRegressionPredictPipelineAlgorithmFactory.java deleted file mode 100644 index e1fd1bf860..0000000000 --- a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/regression/predict/NodeRegressionPredictPipelineAlgorithmFactory.java +++ /dev/null @@ -1,119 +0,0 @@ -/* - * Copyright (c) "Neo4j" - * Neo4j Sweden AB [http://neo4j.com] - * - * This file is part of Neo4j. - * - * Neo4j is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - */ -package org.neo4j.gds.ml.pipeline.node.regression.predict; - -import org.neo4j.gds.GraphStoreAlgorithmFactory; -import org.neo4j.gds.api.GraphStore; -import org.neo4j.gds.core.model.Model; -import org.neo4j.gds.core.model.ModelCatalog; -import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker; -import org.neo4j.gds.core.utils.progress.tasks.Task; -import org.neo4j.gds.executor.ExecutionContext; -import org.neo4j.gds.ml.models.Regressor; -import org.neo4j.gds.ml.models.linearregression.LinearRegressionData; -import org.neo4j.gds.ml.models.linearregression.LinearRegressor; -import org.neo4j.gds.ml.models.randomforest.RandomForestRegressor; -import org.neo4j.gds.ml.models.randomforest.RandomForestRegressorData; -import org.neo4j.gds.ml.pipeline.nodePipeline.regression.NodeRegressionPipelineModelInfo; -import org.neo4j.gds.ml.pipeline.nodePipeline.regression.NodeRegressionPipelineTrainConfig; -import org.neo4j.gds.procedures.pipelines.NodeRegressionPredictPipelineBaseConfig; -import org.neo4j.gds.procedures.pipelines.NodeRegressionPredictPipelineExecutor; - -public class NodeRegressionPredictPipelineAlgorithmFactory - - extends GraphStoreAlgorithmFactory -{ - - private final ModelCatalog modelCatalog; - private final ExecutionContext executionContext; - - NodeRegressionPredictPipelineAlgorithmFactory(ExecutionContext executionContext) { - super(); - this.modelCatalog = executionContext.modelCatalog(); - this.executionContext = executionContext; - } - - @Override - public Task progressTask(GraphStore graphStore, CONFIG config) { - var trainingPipeline = getTrainedNRPipelineModel( - modelCatalog, - config.modelName(), - config.username() - ).customInfo() - .pipeline(); - - return NodeRegressionPredictPipelineExecutor.progressTask(taskName(), trainingPipeline, graphStore); - } - - @Override - public String taskName() { - return "Node Classification Predict Pipeline"; - } - - @Override - public NodeRegressionPredictPipelineExecutor build( - GraphStore graphStore, - CONFIG configuration, - ProgressTracker progressTracker - ) { - var model = getTrainedNRPipelineModel( - modelCatalog, - configuration.modelName(), - configuration.username() - ); - - return new NodeRegressionPredictPipelineExecutor( - model.customInfo().pipeline(), - configuration, - executionContext, - graphStore, - progressTracker, - regressorFrom(model.data()) - ); - } - - private static Regressor regressorFrom( - Regressor.RegressorData regressorData - ) { - switch (regressorData.trainerMethod()) { - case LinearRegression: - return new LinearRegressor((LinearRegressionData) regressorData); - case RandomForestRegression: - return new RandomForestRegressor((RandomForestRegressorData) regressorData); - default: - throw new IllegalStateException("No such regressor: " + regressorData.trainerMethod().name()); - } - } - - - private static Model getTrainedNRPipelineModel( - ModelCatalog modelCatalog, - String modelName, - String username - ) { - return modelCatalog.get( - username, - modelName, - Regressor.RegressorData.class, - NodeRegressionPipelineTrainConfig.class, - NodeRegressionPipelineModelInfo.class - ); - } -} diff --git a/proc/machine-learning/src/test/java/org/neo4j/gds/ml/linkmodels/pipeline/predict/LinkPredictionPredictPipelineAlgorithmFactoryTest.java b/proc/machine-learning/src/test/java/org/neo4j/gds/ml/linkmodels/pipeline/predict/LinkPredictionPredictPipelineAlgorithmFactoryTest.java deleted file mode 100644 index e9619f80b1..0000000000 --- a/proc/machine-learning/src/test/java/org/neo4j/gds/ml/linkmodels/pipeline/predict/LinkPredictionPredictPipelineAlgorithmFactoryTest.java +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Copyright (c) "Neo4j" - * Neo4j Sweden AB [http://neo4j.com] - * - * This file is part of Neo4j. - * - * Neo4j is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - */ -package org.neo4j.gds.ml.linkmodels.pipeline.predict; - -import org.junit.jupiter.api.Test; -import org.neo4j.gds.api.schema.GraphSchema; -import org.neo4j.gds.core.CypherMapWrapper; -import org.neo4j.gds.core.GraphDimensions; -import org.neo4j.gds.core.concurrency.Concurrency; -import org.neo4j.gds.core.model.InjectModelCatalog; -import org.neo4j.gds.core.model.Model; -import org.neo4j.gds.core.model.ModelCatalog; -import org.neo4j.gds.core.model.ModelCatalogExtension; -import org.neo4j.gds.executor.ExecutionContext; -import org.neo4j.gds.ml.core.functions.Weights; -import org.neo4j.gds.ml.core.tensor.Matrix; -import org.neo4j.gds.ml.metrics.ModelCandidateStats; -import org.neo4j.gds.ml.models.logisticregression.ImmutableLogisticRegressionData; -import org.neo4j.gds.ml.models.logisticregression.LogisticRegressionTrainConfig; -import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionModelInfo; -import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionPredictPipeline; -import org.neo4j.gds.ml.pipeline.linkPipeline.linkfunctions.L2FeatureStep; -import org.neo4j.gds.ml.pipeline.linkPipeline.train.LinkPredictionTrainConfigImpl; -import org.neo4j.gds.procedures.pipelines.LinkPredictionPredictPipelineStreamConfig; - -import java.util.List; -import java.util.Map; -import java.util.stream.Stream; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionTrainingPipeline.MODEL_TYPE; - -@ModelCatalogExtension -class LinkPredictionPredictPipelineAlgorithmFactoryTest { - - @InjectModelCatalog - ModelCatalog modelCatalog; - - @Test - void estimate() { - withModelInCatalog(); - var factory = new LinkPredictionPredictPipelineAlgorithmFactory<>(ExecutionContext.EMPTY.withModelCatalog(modelCatalog)); - var config = LinkPredictionPredictPipelineStreamConfig.of( - "testUser", - CypherMapWrapper.create( - Map.of( - "graphName", "g", - "modelName", "model", - "threshold", 0L, - "mutateRelationshipType", "PREDICTED", - "topN", 3 - ) - ) - ); - var estimate = factory - .memoryEstimation(config) - .estimate(GraphDimensions.of(10), new Concurrency(4)); - - assertThat(estimate.memoryUsage().toString()).isEqualTo("548 Bytes"); - } - - private void withModelInCatalog() { - var weights = new double[]{2.0, 1.0, -3.0}; - var pipeline = LinkPredictionPredictPipeline.from(Stream.of(), Stream.of(new L2FeatureStep(List.of("a", "b", "c")))); - - var modelData = ImmutableLogisticRegressionData.of( - 2, - new Weights<>(new Matrix( - weights, - 1, - weights.length - )), - Weights.ofVector(0.0) - ); - - modelCatalog.set(Model.of( - MODEL_TYPE, - GraphSchema.empty(), - modelData, - LinkPredictionTrainConfigImpl.builder() - .modelUser("testUser") - .modelName("model") - .pipeline("DUMMY") - .sourceNodeLabel("N") - .targetNodeLabel("N") - .targetRelationshipType("T") - .graphName("g") - .negativeClassWeight(1.0) - .build(), - LinkPredictionModelInfo.of( - Map.of(), - Map.of(), - ModelCandidateStats.of(LogisticRegressionTrainConfig.DEFAULT, Map.of(), Map.of()), - pipeline - ) - )); - } -} diff --git a/proc/machine-learning/src/test/java/org/neo4j/gds/ml/linkmodels/pipeline/predict/PathFindingMutateResultTest.java b/proc/machine-learning/src/test/java/org/neo4j/gds/ml/linkmodels/pipeline/predict/PathFindingMutateResultTest.java deleted file mode 100644 index 6c202eb2e1..0000000000 --- a/proc/machine-learning/src/test/java/org/neo4j/gds/ml/linkmodels/pipeline/predict/PathFindingMutateResultTest.java +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright (c) "Neo4j" - * Neo4j Sweden AB [http://neo4j.com] - * - * This file is part of Neo4j. - * - * Neo4j is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - */ -package org.neo4j.gds.ml.linkmodels.pipeline.predict; - -import org.junit.jupiter.api.Test; -import org.neo4j.gds.procedures.pipelines.MutateResult; - -import static org.assertj.core.api.Assertions.assertThatNoException; - -class PathFindingMutateResultTest { - - @Test - void shouldRecordResults() { - var resultWithHistogramBuilder = new MutateResult.Builder().withHistogram(); - assertThatNoException().isThrownBy(() -> { - resultWithHistogramBuilder.recordHistogramValue(0.9); - resultWithHistogramBuilder.recordHistogramValue(5E-10); - resultWithHistogramBuilder.recordHistogramValue(5E-20); - }); - } - -} diff --git a/proc/common/src/main/java/org/neo4j/gds/MutatePropertyComputationResultConsumer.java b/proc/test/src/main/java/org/neo4j/gds/test/MutatePropertyComputationResultConsumer.java similarity index 87% rename from proc/common/src/main/java/org/neo4j/gds/MutatePropertyComputationResultConsumer.java rename to proc/test/src/main/java/org/neo4j/gds/test/MutatePropertyComputationResultConsumer.java index 34d0205365..b66d56b14c 100644 --- a/proc/common/src/main/java/org/neo4j/gds/MutatePropertyComputationResultConsumer.java +++ b/proc/test/src/main/java/org/neo4j/gds/test/MutatePropertyComputationResultConsumer.java @@ -17,8 +17,13 @@ * You should have received a copy of the GNU General Public License * along with this program. If not, see . */ -package org.neo4j.gds; +package org.neo4j.gds.test; +import org.neo4j.gds.Algorithm; +import org.neo4j.gds.GraphStoreUpdater; +import org.neo4j.gds.MutateComputationResultConsumer; +import org.neo4j.gds.MutateNodePropertyListFunction; +import org.neo4j.gds.ResultBuilderFunction; import org.neo4j.gds.config.MutateNodePropertyConfig; import org.neo4j.gds.executor.ComputationResult; import org.neo4j.gds.executor.ExecutionContext; @@ -28,7 +33,7 @@ public class MutatePropertyComputationResultConsumer { private final MutateNodePropertyListFunction nodePropertyListFunction; - public MutatePropertyComputationResultConsumer( + MutatePropertyComputationResultConsumer( MutateNodePropertyListFunction nodePropertyListFunction, ResultBuilderFunction resultBuilderFunction ) { diff --git a/proc/test/src/main/java/org/neo4j/gds/test/TestMutateSpec.java b/proc/test/src/main/java/org/neo4j/gds/test/TestMutateSpec.java index 774d1bd7fd..a44fa31a70 100644 --- a/proc/test/src/main/java/org/neo4j/gds/test/TestMutateSpec.java +++ b/proc/test/src/main/java/org/neo4j/gds/test/TestMutateSpec.java @@ -20,7 +20,6 @@ package org.neo4j.gds.test; import org.neo4j.gds.GraphAlgorithmFactory; -import org.neo4j.gds.MutatePropertyComputationResultConsumer; import org.neo4j.gds.api.Graph; import org.neo4j.gds.api.properties.nodes.LongNodePropertyValues; import org.neo4j.gds.mem.MemoryEstimation; diff --git a/procedures/pipelines-facade-api/src/main/java/org/neo4j/gds/procedures/pipelines/MutateResult.java b/procedures/pipelines-facade-api/src/main/java/org/neo4j/gds/procedures/pipelines/MutateResult.java index 2d111e3dad..6d4a254433 100644 --- a/procedures/pipelines-facade-api/src/main/java/org/neo4j/gds/procedures/pipelines/MutateResult.java +++ b/procedures/pipelines-facade-api/src/main/java/org/neo4j/gds/procedures/pipelines/MutateResult.java @@ -19,13 +19,8 @@ */ package org.neo4j.gds.procedures.pipelines; -import org.HdrHistogram.ConcurrentDoubleHistogram; -import org.jetbrains.annotations.Nullable; import org.neo4j.gds.applications.algorithms.machinery.AlgorithmProcessingTimings; -import org.neo4j.gds.core.ProcedureConstants; import org.neo4j.gds.procedures.algorithms.results.StandardMutateResult; -import org.neo4j.gds.result.AbstractResultBuilder; -import org.neo4j.gds.result.HistogramUtils; import java.util.Collections; import java.util.Map; @@ -68,47 +63,4 @@ static MutateResult emptyFrom(AlgorithmProcessingTimings timings, Map { - private Map samplingStats = null; - @Nullable - private ConcurrentDoubleHistogram histogram = null; - - @Override - public MutateResult build() { - return new MutateResult( - preProcessingMillis, - computeMillis, - mutateMillis, - relationshipsWritten, - config.toMap(), - histogram == null ? Map.of() : HistogramUtils.similaritySummary(histogram), - samplingStats - ); - } - - public Builder withHistogram() { - if (histogram != null) { - return this; - } - - this.histogram = new ConcurrentDoubleHistogram(ProcedureConstants.HISTOGRAM_PRECISION_DEFAULT); - return this; - } - - public void recordHistogramValue(double value) { - if (histogram == null) { - return; - } - - //HISTOGRAM_PRECISION_DEFAULT hence numberOfSignificantValueDigits is 1E-5, so it can't separate 0 and 1E-5 - //Therefore we can floor at 1E-6 and smaller probabilities between 0 and 1E-6 is unnecessary. - if (value >= 1E-6) histogram.recordValue(value); else histogram.recordValue(1E-6); - } - - public Builder withSamplingStats(Map samplingStats) { - this.samplingStats = samplingStats; - return this; - } - } } diff --git a/procedures/pipelines-facade-api/src/main/java/org/neo4j/gds/procedures/pipelines/PredictMutateResult.java b/procedures/pipelines-facade-api/src/main/java/org/neo4j/gds/procedures/pipelines/PredictMutateResult.java index 55f59e3ffc..cc9933223b 100644 --- a/procedures/pipelines-facade-api/src/main/java/org/neo4j/gds/procedures/pipelines/PredictMutateResult.java +++ b/procedures/pipelines-facade-api/src/main/java/org/neo4j/gds/procedures/pipelines/PredictMutateResult.java @@ -20,7 +20,6 @@ package org.neo4j.gds.procedures.pipelines; import org.neo4j.gds.applications.algorithms.machinery.AlgorithmProcessingTimings; -import org.neo4j.gds.result.AbstractResultBuilder; import org.neo4j.gds.procedures.algorithms.results.StandardMutateResult; import java.util.Map; @@ -54,17 +53,4 @@ static PredictMutateResult emptyFrom(AlgorithmProcessingTimings timings, Map { - @Override - public PredictMutateResult build() { - return new PredictMutateResult( - preProcessingMillis, - computeMillis, - mutateMillis, - nodePropertiesWritten, - config.toMap() - ); - } - } } diff --git a/procedures/pipelines-facade-api/src/main/java/org/neo4j/gds/procedures/pipelines/WriteResult.java b/procedures/pipelines-facade-api/src/main/java/org/neo4j/gds/procedures/pipelines/WriteResult.java index 20197d5fc1..0c0e5e0e2a 100644 --- a/procedures/pipelines-facade-api/src/main/java/org/neo4j/gds/procedures/pipelines/WriteResult.java +++ b/procedures/pipelines-facade-api/src/main/java/org/neo4j/gds/procedures/pipelines/WriteResult.java @@ -21,7 +21,6 @@ import org.neo4j.gds.applications.algorithms.machinery.AlgorithmProcessingTimings; import org.neo4j.gds.procedures.algorithms.results.StandardWriteResult; -import org.neo4j.gds.result.AbstractResultBuilder; import java.util.Map; @@ -54,17 +53,4 @@ static WriteResult emptyFrom(AlgorithmProcessingTimings timings, Map { - @Override - public WriteResult build() { - return new WriteResult( - preProcessingMillis, - computeMillis, - writeMillis, - nodePropertiesWritten, - config.toMap() - ); - } - } } diff --git a/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/NodeClassificationPredictConfigPreProcessor.java b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/NodeClassificationPredictConfigPreProcessor.java index f88945c5fa..ae919d4c8e 100644 --- a/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/NodeClassificationPredictConfigPreProcessor.java +++ b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/NodeClassificationPredictConfigPreProcessor.java @@ -22,7 +22,6 @@ import org.neo4j.gds.api.User; import org.neo4j.gds.core.model.Model; import org.neo4j.gds.core.model.ModelCatalog; -import org.neo4j.gds.executor.ExecutionContext; import org.neo4j.gds.ml.models.BaseModelData; import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationPipelineTrainConfig; @@ -31,7 +30,7 @@ /** * Enhance user input by adding targetNodeLabels and relationshipTypes from training parameters if appropriate */ -public final class NodeClassificationPredictConfigPreProcessor { +final class NodeClassificationPredictConfigPreProcessor { private final ModelCatalog modelCatalog; private final User user; @@ -40,18 +39,6 @@ public final class NodeClassificationPredictConfigPreProcessor { this.user = user; } - public static void enhanceInputWithPipelineParameters( - Map userInput, - ExecutionContext executionContext - ) { - var modelCatalog = executionContext.modelCatalog(); - var user = new User(executionContext.username(), executionContext.isGdsAdmin()); - - var preProcessor = new NodeClassificationPredictConfigPreProcessor(modelCatalog, user); - - preProcessor.enhanceInputWithPipelineParameters(userInput); - } - void enhanceInputWithPipelineParameters(Map userInput) { if (!userInput.containsKey("modelName")) return; diff --git a/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/NodeRegressionPredictConfigPreProcessor.java b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/NodeRegressionPredictConfigPreProcessor.java index 16cd43a82f..74cdc7cc38 100644 --- a/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/NodeRegressionPredictConfigPreProcessor.java +++ b/procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/NodeRegressionPredictConfigPreProcessor.java @@ -22,7 +22,6 @@ import org.neo4j.gds.api.User; import org.neo4j.gds.core.model.Model; import org.neo4j.gds.core.model.ModelCatalog; -import org.neo4j.gds.executor.ExecutionContext; import org.neo4j.gds.ml.models.BaseModelData; import org.neo4j.gds.ml.pipeline.nodePipeline.NodePropertyPipelineBaseTrainConfig; @@ -39,18 +38,6 @@ public class NodeRegressionPredictConfigPreProcessor { this.user = user; } - public static void enhanceInputWithPipelineParameters( - Map userInput, - ExecutionContext executionContext - ) { - var modelCatalog = executionContext.modelCatalog(); - var user = new User(executionContext.username(), executionContext.isGdsAdmin()); - - var preProcessor = new NodeRegressionPredictConfigPreProcessor(modelCatalog, user); - - preProcessor.enhanceInputWithPipelineParameters(userInput); - } - void enhanceInputWithPipelineParameters(Map userInput) { Optional.ofNullable(userInput.get("modelName")) .map(modelName -> {