Skip to content

Commit

Permalink
migrate node classification stream
Browse files Browse the repository at this point in the history
  • Loading branch information
lassewesth committed Oct 2, 2024
1 parent ffc9497 commit 6532188
Show file tree
Hide file tree
Showing 10 changed files with 219 additions and 82 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,9 @@
*/
package org.neo4j.gds.ml.pipeline.node.classification.predict;

import org.neo4j.gds.BaseProc;
import org.neo4j.gds.core.model.ModelCatalog;
import org.neo4j.gds.executor.ExecutionContext;
import org.neo4j.gds.executor.MemoryEstimationExecutor;
import org.neo4j.gds.executor.ProcedureExecutor;
import org.neo4j.gds.applications.algorithms.machinery.MemoryEstimateResult;
import org.neo4j.gds.procedures.GraphDataScienceProcedures;
import org.neo4j.gds.procedures.pipelines.NodeClassificationStreamResult;
import org.neo4j.procedure.Context;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Mode;
Expand All @@ -34,25 +31,20 @@
import java.util.Map;
import java.util.stream.Stream;

import static org.neo4j.gds.ml.pipeline.PipelineCompanion.preparePipelineConfig;
import static org.neo4j.gds.ml.pipeline.node.classification.predict.NodeClassificationPipelineConstants.ESTIMATE_PREDICT_DESCRIPTION;
import static org.neo4j.gds.ml.pipeline.node.classification.predict.NodeClassificationPipelineConstants.PREDICT_DESCRIPTION;

public class NodeClassificationPipelineStreamProc extends BaseProc {
public class NodeClassificationPipelineStreamProc {
@Context
public ModelCatalog internalModelCatalog;
public GraphDataScienceProcedures facade;

@Procedure(name = "gds.beta.pipeline.nodeClassification.predict.stream", mode = Mode.READ)
@Description(PREDICT_DESCRIPTION)
public Stream<NodeClassificationStreamResult> stream(
@Name(value = "graphName") String graphName,
@Name(value = "configuration", defaultValue = "{}") Map<String, Object> configuration
) {
preparePipelineConfig(graphName, configuration);
return new ProcedureExecutor<>(
new NodeClassificationPipelineStreamSpec(),
executionContext()
).compute(graphName, configuration);
return facade.pipelines().nodeClassificationStream(graphName, configuration);
}

@Procedure(name = "gds.beta.pipeline.nodeClassification.predict.stream.estimate", mode = Mode.READ)
Expand All @@ -61,16 +53,6 @@ public Stream<MemoryEstimateResult> estimate(
@Name(value = "graphName") Object graphName,
@Name(value = "configuration") Map<String, Object> configuration
) {
preparePipelineConfig(graphName, configuration);
return new MemoryEstimationExecutor<>(
new NodeClassificationPipelineStreamSpec(),
executionContext(),
transactionContext()
).computeEstimate(graphName, configuration);
}

@Override
public ExecutionContext executionContext() {
return super.executionContext().withModelCatalog(internalModelCatalog);
return facade.pipelines().nodeClassificationStreamEstimate(graphName, configuration);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationPipelineTrainConfig;
import org.neo4j.gds.procedures.pipelines.NodeClassificationPipelineResult;
import org.neo4j.gds.procedures.pipelines.NodeClassificationPredictPipelineExecutor;
import org.neo4j.gds.procedures.pipelines.NodeClassificationPredictPipelineStreamConfig;
import org.neo4j.gds.procedures.pipelines.NodeClassificationStreamResult;

import java.util.Arrays;
import java.util.List;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import org.neo4j.gds.core.CypherMapWrapper;
import org.neo4j.gds.core.model.ModelCatalog;
import org.neo4j.gds.procedures.algorithms.configuration.NewConfigFunction;
import org.neo4j.gds.procedures.pipelines.NodeClassificationPredictPipelineStreamConfig;
import org.neo4j.gds.procedures.pipelines.NodeClassificationPredictPipelineStreamConfigImpl;

import static org.neo4j.gds.ml.pipeline.node.NodePropertyPredictPipelineFilterUtil.generatePredictPipelineFilter;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
import org.neo4j.gds.procedures.algorithms.AlgorithmsProcedureFacade;
import org.neo4j.gds.termination.TerminationMonitor;

final class NodeClassificationPredictPipelineMutateComputation implements Computation<NodeClassificationPipelineResult> {
final class NodeClassificationPredictPipelineComputation implements Computation<NodeClassificationPipelineResult> {
private final Log log;
private final ModelCatalog modelCatalog;

Expand All @@ -64,10 +64,10 @@ final class NodeClassificationPredictPipelineMutateComputation implements Comput
private final AlgorithmsProcedureFacade algorithmsProcedureFacade;

private final TrainedNCPipelineModel trainedNCPipelineModel;
private final NodeClassificationPredictPipelineMutateConfig configuration;
private final NodeClassificationPredictPipelineBaseConfig configuration;
private final Label label;

private NodeClassificationPredictPipelineMutateComputation(
private NodeClassificationPredictPipelineComputation(
Log log,
ModelCatalog modelCatalog,
CloseableResourceRegistry closeableResourceRegistry,
Expand All @@ -84,7 +84,7 @@ private NodeClassificationPredictPipelineMutateComputation(
UserLogRegistryFactory userLogRegistryFactory,
ProgressTrackerCreator progressTrackerCreator,
AlgorithmsProcedureFacade algorithmsProcedureFacade,
NodeClassificationPredictPipelineMutateConfig configuration,
NodeClassificationPredictPipelineBaseConfig configuration,
Label label,
TrainedNCPipelineModel trainedNCPipelineModel
) {
Expand All @@ -109,7 +109,7 @@ private NodeClassificationPredictPipelineMutateComputation(
this.label = label;
}

static NodeClassificationPredictPipelineMutateComputation create(
static NodeClassificationPredictPipelineComputation create(
Log log,
ModelCatalog modelCatalog,
CloseableResourceRegistry closeableResourceRegistry,
Expand All @@ -126,12 +126,12 @@ static NodeClassificationPredictPipelineMutateComputation create(
UserLogRegistryFactory userLogRegistryFactory,
ProgressTrackerCreator progressTrackerCreator,
AlgorithmsProcedureFacade algorithmsProcedureFacade,
NodeClassificationPredictPipelineMutateConfig configuration,
NodeClassificationPredictPipelineBaseConfig configuration,
Label label
) {
var trainedNCPipelineModel = new TrainedNCPipelineModel(modelCatalog);

return new NodeClassificationPredictPipelineMutateComputation(
return new NodeClassificationPredictPipelineComputation(
log,
modelCatalog,
closeableResourceRegistry,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,13 @@
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.neo4j.gds.ml.pipeline.node.classification.predict;
package org.neo4j.gds.procedures.pipelines;

import org.neo4j.gds.annotation.Configuration;
import org.neo4j.gds.core.CypherMapWrapper;
import org.neo4j.gds.procedures.pipelines.NodeClassificationPredictPipelineBaseConfig;

@Configuration
public interface NodeClassificationPredictPipelineStreamConfig
extends NodeClassificationPredictPipelineBaseConfig
{
public interface NodeClassificationPredictPipelineStreamConfig extends NodeClassificationPredictPipelineBaseConfig {
@Override
default boolean includePredictedProbabilities() {
return false;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Copyright (c) "Neo4j"
* Neo4j Sweden AB [http://neo4j.com]
*
* This file is part of Neo4j.
*
* Neo4j is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.neo4j.gds.procedures.pipelines;

import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.api.IdMap;
import org.neo4j.gds.applications.algorithms.machinery.StreamResultBuilder;
import org.neo4j.gds.collections.ha.HugeObjectArray;
import org.neo4j.gds.ml.pipeline.ImmutablePipelineGraphFilter;

import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.LongStream;
import java.util.stream.Stream;

class NodeClassificationPredictPipelineStreamResultBuilder implements StreamResultBuilder<NodeClassificationPipelineResult, NodeClassificationStreamResult> {
private final NodeClassificationPredictPipelineStreamConfig configuration;

NodeClassificationPredictPipelineStreamResultBuilder(NodeClassificationPredictPipelineStreamConfig configuration) {this.configuration = configuration;}

@Override
public Stream<NodeClassificationStreamResult> build(
Graph unused,
GraphStore graphStore,
Optional<NodeClassificationPipelineResult> result
) {
if (result.isEmpty()) return Stream.empty();

var pipelineGraphFilter = ImmutablePipelineGraphFilter.builder()
.nodeLabels(configuration.nodeLabelIdentifiers(graphStore))
.relationshipTypes(configuration.internalRelationshipTypes(graphStore))
.build();

var graph = graphStore.getGraph(pipelineGraphFilter.nodeLabels());

var nodeClassificationPipelineResult = result.get();
var predictedClasses = nodeClassificationPipelineResult.predictedClasses();
var predictedProbabilities = nodeClassificationPipelineResult.predictedProbabilities();

return LongStream.range(IdMap.START_NODE_ID, graph.nodeCount())
.mapToObj(nodeId -> new NodeClassificationStreamResult(
graph.toOriginalNodeId(nodeId),
predictedClasses.get(nodeId),
nodePropertiesAsList(predictedProbabilities, nodeId)
));
}

private static List<Double> nodePropertiesAsList(
Optional<HugeObjectArray<double[]>> predictedProbabilities,
long nodeId
) {
return predictedProbabilities.map(p -> {
var values = p.get(nodeId);
return Arrays.stream(values).boxed().collect(Collectors.toList());
}).orElse(null);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,11 @@
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.neo4j.gds.ml.pipeline.node.classification.predict;
package org.neo4j.gds.procedures.pipelines;

import java.util.List;

public final class NodeClassificationStreamResult {

public long nodeId;
public long predictedClass;
public List<Double> predictedProbabilities;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.neo4j.gds.applications.algorithms.machinery.AlgorithmEstimationTemplate;
import org.neo4j.gds.applications.algorithms.machinery.AlgorithmProcessingTemplate;
import org.neo4j.gds.applications.algorithms.machinery.GraphStoreService;
import org.neo4j.gds.applications.algorithms.machinery.Label;
import org.neo4j.gds.applications.algorithms.machinery.MemoryEstimateResult;
import org.neo4j.gds.applications.algorithms.machinery.ProgressTrackerCreator;
import org.neo4j.gds.applications.algorithms.machinery.StandardLabel;
Expand Down Expand Up @@ -263,6 +264,62 @@ Optional<TrainingPipeline<?>> getSingle(PipelineName pipelineName) {
return pipelineRepository.getSingle(user, pipelineName);
}

MemoryEstimateResult nodeClassificationEstimate(
Object graphNameOrConfiguration,
NodeClassificationPredictPipelineBaseConfig configuration
) {
var estimate = nodeClassificationMemoryEstimation(configuration);

var memoryEstimation = MemoryEstimations.builder("Node Classification Predict Pipeline Executor")
.add("Pipeline executor", estimate)
.build();

return algorithmEstimationTemplate.estimate(configuration, graphNameOrConfiguration, memoryEstimation);
}

PredictMutateResult nodeClassificationMutate(GraphName graphName, Map<String, Object> rawConfiguration) {
var configuration = pipelineConfigurationParser.parseNodeClassificationPredictPipelineMutateConfig(
rawConfiguration);
var label = new StandardLabel("NodeClassificationPredictPipelineMutate");
var computation = constructComputation(configuration, label);
var mutateStep = new NodeClassificationPredictPipelineMutateStep(gss, configuration);
var resultBuilder = new NodeClassificationPredictPipelineMutateResultBuilder(configuration);

return algorithmProcessingTemplate.processAlgorithmForMutate(
Optional.empty(),
graphName,
configuration,
Optional.empty(),
label,
() -> nodeClassificationMemoryEstimation(configuration),
computation,
mutateStep,
resultBuilder
);
}

Stream<NodeClassificationStreamResult> nodeClassificationStream(
GraphName graphName,
Map<String, Object> rawConfiguration
) {
var configuration = pipelineConfigurationParser.parseNodeClassificationPredictPipelineStreamConfig(
rawConfiguration);
var label = new StandardLabel("NodeClassificationPredictPipelineStream");
var computation = constructComputation(configuration, label);
var resultBuilder = new NodeClassificationPredictPipelineStreamResultBuilder(configuration);

return algorithmProcessingTemplate.processAlgorithmForStream(
Optional.empty(),
graphName,
configuration,
Optional.empty(),
label,
() -> nodeClassificationMemoryEstimation(configuration),
computation,
resultBuilder
);
}

NodeClassificationTrainingPipeline selectFeatures(
PipelineName pipelineName,
Iterable<NodeFeatureStep> nodeFeatureSteps
Expand Down Expand Up @@ -290,12 +347,11 @@ private NodeClassificationTrainingPipeline configure(
return pipeline;
}

PredictMutateResult nodeClassificationMutate(GraphName graphName, Map<String, Object> rawConfiguration) {
var configuration = pipelineConfigurationParser.parseNodeClassificationPredictPipelineMutateConfig(rawConfiguration);

var label = new StandardLabel("NodeClassificationPredictPipelineMutate");

var computation = NodeClassificationPredictPipelineMutateComputation.create(
private NodeClassificationPredictPipelineComputation constructComputation(
NodeClassificationPredictPipelineBaseConfig configuration,
Label label
) {
return NodeClassificationPredictPipelineComputation.create(
log,
modelCatalog,
closeableResourceRegistry,
Expand All @@ -315,41 +371,6 @@ PredictMutateResult nodeClassificationMutate(GraphName graphName, Map<String, Ob
configuration,
label
);

var mutateStep = new NodeClassificationPredictPipelineMutateStep(gss, configuration);

var resultBuilder = new NodeClassificationPredictPipelineMutateResultBuilder(configuration);

return algorithmProcessingTemplate.processAlgorithmForMutate(
Optional.empty(),
graphName,
configuration,
Optional.empty(),
label,
() -> nodeClassificationMutateMemoryEstimation(configuration),
computation,
mutateStep,
resultBuilder
);
}

MemoryEstimateResult nodeClassificationMutateEstimate(
Object graphNameOrConfiguration,
NodeClassificationPredictPipelineMutateConfig configuration
) {
var estimate = nodeClassificationMutateMemoryEstimation(configuration);

var memoryEstimation = MemoryEstimations.builder("Node Classification Predict Pipeline Executor")
.add("Pipeline executor", estimate)
.build();

return algorithmEstimationTemplate.estimate(configuration, graphNameOrConfiguration, memoryEstimation);
}

private MemoryEstimation nodeClassificationMutateMemoryEstimation(NodeClassificationPredictPipelineBaseConfig configuration) {
var model = getTrainedNCPipelineModel(configuration.modelName(), configuration.username());

return nodeClassificationPredictPipelineEstimator.estimate(model, configuration);
}

private Model<Classifier.ClassifierData, NodeClassificationPipelineTrainConfig, NodeClassificationPipelineModelInfo> getTrainedNCPipelineModel(
Expand All @@ -364,4 +385,10 @@ private Model<Classifier.ClassifierData, NodeClassificationPipelineTrainConfig,
NodeClassificationPipelineModelInfo.class
);
}

private MemoryEstimation nodeClassificationMemoryEstimation(NodeClassificationPredictPipelineBaseConfig configuration) {
var model = getTrainedNCPipelineModel(configuration.modelName(), configuration.username());

return nodeClassificationPredictPipelineEstimator.estimate(model, configuration);
}
}
Loading

0 comments on commit 6532188

Please sign in to comment.