diff --git a/algo/src/main/java/org/neo4j/gds/algorithms/similarity/SimilarityAlgorithmsFacade.java b/algo/src/main/java/org/neo4j/gds/algorithms/similarity/SimilarityAlgorithmsFacade.java index e873899e08..d4068f1fd0 100644 --- a/algo/src/main/java/org/neo4j/gds/algorithms/similarity/SimilarityAlgorithmsFacade.java +++ b/algo/src/main/java/org/neo4j/gds/algorithms/similarity/SimilarityAlgorithmsFacade.java @@ -23,6 +23,8 @@ import org.neo4j.gds.algorithms.runner.AlgorithmRunner; import org.neo4j.gds.api.DatabaseId; import org.neo4j.gds.api.User; +import org.neo4j.gds.similarity.filterednodesim.FilteredNodeSimilarityBaseConfig; +import org.neo4j.gds.similarity.filterednodesim.FilteredNodeSimilarityFactory; import org.neo4j.gds.similarity.nodesim.NodeSimilarityBaseConfig; import org.neo4j.gds.similarity.nodesim.NodeSimilarityFactory; import org.neo4j.gds.similarity.nodesim.NodeSimilarityResult; @@ -50,4 +52,20 @@ AlgorithmComputationResult nodeSimilarity( databaseId ); } + + AlgorithmComputationResult filteredNodeSimilarity( + String graphName, + FilteredNodeSimilarityBaseConfig config, + User user, + DatabaseId databaseId + ) { + return algorithmRunner.run( + graphName, + config, + config.relationshipWeightProperty(), + new FilteredNodeSimilarityFactory<>(), + user, + databaseId + ); + } } diff --git a/algo/src/main/java/org/neo4j/gds/algorithms/similarity/SimilarityAlgorithmsStatsBusinessFacade.java b/algo/src/main/java/org/neo4j/gds/algorithms/similarity/SimilarityAlgorithmsStatsBusinessFacade.java index 4509432ed7..0f57f46630 100644 --- a/algo/src/main/java/org/neo4j/gds/algorithms/similarity/SimilarityAlgorithmsStatsBusinessFacade.java +++ b/algo/src/main/java/org/neo4j/gds/algorithms/similarity/SimilarityAlgorithmsStatsBusinessFacade.java @@ -28,6 +28,7 @@ import org.neo4j.gds.api.User; import org.neo4j.gds.result.SimilarityStatistics; import org.neo4j.gds.similarity.SimilarityGraphResult; +import org.neo4j.gds.similarity.filterednodesim.FilteredNodeSimilarityStatsConfig; import org.neo4j.gds.similarity.nodesim.NodeSimilarityStatsConfig; import java.util.function.Function; @@ -71,6 +72,36 @@ public StatsResult nodeSimilarity( ); } + public StatsResult filteredNodeSimilarity( + String graphName, + FilteredNodeSimilarityStatsConfig configuration, + User user, + DatabaseId databaseId, + boolean computeSimilarityDistribution + ) { + // 1. Run the algorithm and time the execution + var intermediateResult = AlgorithmRunner.runWithTiming( + () -> similarityAlgorithmsFacade.filteredNodeSimilarity(graphName, configuration, user, databaseId) + ); + var algorithmResult = intermediateResult.algorithmResult; + + return statsResult( + algorithmResult, + result -> result.graphResult(), + ((result, similarityDistribution) -> { + var graphResult = result.graphResult(); + return new SimilaritySpecificFieldsWithDistribution( + graphResult.comparedNodes(), + graphResult.similarityGraph().relationshipCount(), + similarityDistribution + ); + }), + intermediateResult.computeMilliseconds, + () -> SimilaritySpecificFieldsWithDistribution.EMPTY, + computeSimilarityDistribution + ); + } + StatsResult statsResult( AlgorithmComputationResult algorithmResult, Function similarityGraphResultSupplier, diff --git a/algo/src/main/java/org/neo4j/gds/algorithms/similarity/SimilarityAlgorithmsStreamBusinessFacade.java b/algo/src/main/java/org/neo4j/gds/algorithms/similarity/SimilarityAlgorithmsStreamBusinessFacade.java index b5fd61ea0c..aa6561c716 100644 --- a/algo/src/main/java/org/neo4j/gds/algorithms/similarity/SimilarityAlgorithmsStreamBusinessFacade.java +++ b/algo/src/main/java/org/neo4j/gds/algorithms/similarity/SimilarityAlgorithmsStreamBusinessFacade.java @@ -23,6 +23,7 @@ import org.neo4j.gds.algorithms.StreamComputationResult; import org.neo4j.gds.api.DatabaseId; import org.neo4j.gds.api.User; +import org.neo4j.gds.similarity.filterednodesim.FilteredNodeSimilarityStreamConfig; import org.neo4j.gds.similarity.nodesim.NodeSimilarityResult; import org.neo4j.gds.similarity.nodesim.NodeSimilarityStreamConfig; @@ -46,6 +47,18 @@ public StreamComputationResult nodeSimilarity( return createStreamComputationResult(result); } + public StreamComputationResult filteredNodeSimilarity( + String graphName, + FilteredNodeSimilarityStreamConfig config, + User user, + DatabaseId databaseId + + ) { + var result = similarityAlgorithmsFacade.filteredNodeSimilarity(graphName, config, user, databaseId); + + return createStreamComputationResult(result); + } + // FIXME: the following method is duplicate, find a good place for it. private StreamComputationResult createStreamComputationResult(AlgorithmComputationResult result) { return StreamComputationResult.of( diff --git a/proc/similarity/src/main/java/org/neo4j/gds/similarity/filterednodesim/FilteredNodeSimilarityStatsProc.java b/proc/similarity/src/main/java/org/neo4j/gds/similarity/filterednodesim/FilteredNodeSimilarityStatsProc.java index e9656bfbdc..c9bbce0455 100644 --- a/proc/similarity/src/main/java/org/neo4j/gds/similarity/filterednodesim/FilteredNodeSimilarityStatsProc.java +++ b/proc/similarity/src/main/java/org/neo4j/gds/similarity/filterednodesim/FilteredNodeSimilarityStatsProc.java @@ -20,10 +20,10 @@ package org.neo4j.gds.similarity.filterednodesim; import org.neo4j.gds.BaseProc; -import org.neo4j.gds.executor.MemoryEstimationExecutor; -import org.neo4j.gds.executor.ProcedureExecutor; +import org.neo4j.gds.procedures.GraphDataScience; import org.neo4j.gds.procedures.similarity.SimilarityStatsResult; import org.neo4j.gds.results.MemoryEstimateResult; +import org.neo4j.procedure.Context; import org.neo4j.procedure.Description; import org.neo4j.procedure.Internal; import org.neo4j.procedure.Name; @@ -37,16 +37,16 @@ public class FilteredNodeSimilarityStatsProc extends BaseProc { + @Context + public GraphDataScience facade; + @Procedure(value = "gds.nodeSimilarity.filtered.stats", mode = READ) @Description(DESCRIPTION) public Stream stats( @Name(value = "graphName") String graphName, @Name(value = "configuration", defaultValue = "{}") Map configuration ){ - return new ProcedureExecutor<>( - new FilteredNodeSimilarityStatsSpec(), - executionContext() - ).compute(graphName, configuration); + return facade.similarity().filteredNodeSimilarityStats(graphName, configuration); } @Procedure(value = "gds.nodeSimilarity.filtered.stats.estimate", mode = READ) @@ -55,11 +55,7 @@ public Stream estimate( @Name(value = "graphNameOrConfiguration") Object graphNameOrConfiguration, @Name(value = "algoConfiguration") Map algoConfiguration ) { - return new MemoryEstimationExecutor<>( - new FilteredNodeSimilarityStatsSpec(), - executionContext(), - transactionContext() - ).computeEstimate(graphNameOrConfiguration, algoConfiguration); + return facade.similarity().filteredNodeSimilarityEstimateStats(graphNameOrConfiguration, algoConfiguration); } @Deprecated(forRemoval = true) diff --git a/proc/similarity/src/main/java/org/neo4j/gds/similarity/filterednodesim/FilteredNodeSimilarityStreamProc.java b/proc/similarity/src/main/java/org/neo4j/gds/similarity/filterednodesim/FilteredNodeSimilarityStreamProc.java index 0519a64491..2aceca7ec9 100644 --- a/proc/similarity/src/main/java/org/neo4j/gds/similarity/filterednodesim/FilteredNodeSimilarityStreamProc.java +++ b/proc/similarity/src/main/java/org/neo4j/gds/similarity/filterednodesim/FilteredNodeSimilarityStreamProc.java @@ -20,10 +20,10 @@ package org.neo4j.gds.similarity.filterednodesim; import org.neo4j.gds.BaseProc; -import org.neo4j.gds.executor.MemoryEstimationExecutor; -import org.neo4j.gds.executor.ProcedureExecutor; +import org.neo4j.gds.procedures.GraphDataScience; import org.neo4j.gds.results.MemoryEstimateResult; import org.neo4j.gds.similarity.SimilarityResult; +import org.neo4j.procedure.Context; import org.neo4j.procedure.Description; import org.neo4j.procedure.Internal; import org.neo4j.procedure.Name; @@ -36,6 +36,10 @@ public class FilteredNodeSimilarityStreamProc extends BaseProc { + @Context + public GraphDataScience facade; + + static final String DESCRIPTION = "The Filtered Node Similarity algorithm compares a set of nodes based on the nodes they are connected to. " + "Two nodes are considered similar if they share many of the same neighbors. " + @@ -48,10 +52,7 @@ public Stream stream( @Name(value = "graphName") String graphName, @Name(value = "configuration", defaultValue = "{}") Map configuration ){ - return new ProcedureExecutor<>( - new FilteredNodeSimilarityStreamSpec(), - executionContext() - ).compute(graphName, configuration); + return facade.similarity().filteredNodeSimilarityStream(graphName, configuration); } @Procedure(value = "gds.nodeSimilarity.filtered.stream.estimate", mode = READ) @@ -60,11 +61,7 @@ public Stream estimate( @Name(value = "graphNameOrConfiguration") Object graphNameOrConfiguration, @Name(value = "algoConfiguration") Map algoConfiguration ) { - return new MemoryEstimationExecutor<>( - new FilteredNodeSimilarityStreamSpec(), - executionContext(), - transactionContext() - ).computeEstimate(graphNameOrConfiguration, algoConfiguration); + return facade.similarity().filteredNodeSimilarityEstimateStream(graphNameOrConfiguration, algoConfiguration); } @Deprecated(forRemoval = true) diff --git a/procedures/facade/src/main/java/org/neo4j/gds/procedures/similarity/SimilarityProcedureFacade.java b/procedures/facade/src/main/java/org/neo4j/gds/procedures/similarity/SimilarityProcedureFacade.java index d042fc7c4f..149bb0a401 100644 --- a/procedures/facade/src/main/java/org/neo4j/gds/procedures/similarity/SimilarityProcedureFacade.java +++ b/procedures/facade/src/main/java/org/neo4j/gds/procedures/similarity/SimilarityProcedureFacade.java @@ -33,6 +33,8 @@ import org.neo4j.gds.procedures.community.ConfigurationParser; import org.neo4j.gds.results.MemoryEstimateResult; import org.neo4j.gds.similarity.SimilarityResult; +import org.neo4j.gds.similarity.filterednodesim.FilteredNodeSimilarityStatsConfig; +import org.neo4j.gds.similarity.filterednodesim.FilteredNodeSimilarityStreamConfig; import org.neo4j.gds.similarity.nodesim.NodeSimilarityStatsConfig; import org.neo4j.gds.similarity.nodesim.NodeSimilarityStreamConfig; @@ -127,6 +129,56 @@ public Stream nodeSimilarityEstimateStats( return Stream.of(estimateBusinessFacade.nodeSimilarity(graphNameOrConfiguration, config)); } + //filtered + public Stream filteredNodeSimilarityStream( + String graphName, + Map configuration + ) { + var streamConfig = createStreamConfig(configuration, FilteredNodeSimilarityStreamConfig::of); + + var computationResult = streamBusinessFacade.filteredNodeSimilarity( + graphName, + streamConfig, + user, + databaseId + ); + + return NodeSimilarityComputationResultTransformer.toStreamResult(computationResult); + } + + public Stream filteredNodeSimilarityStats( + String graphName, + Map configuration + ) { + var statsConfig = createConfig(configuration, FilteredNodeSimilarityStatsConfig::of); + + var computationResult = statsBusinessFacade.filteredNodeSimilarity( + graphName, + statsConfig, + user, + databaseId, + procedureReturnColumns.contains("similarityDistribution") + ); + + return Stream.of(NodeSimilarityComputationResultTransformer.toStatsResult(computationResult, statsConfig)); + } + + public Stream filteredNodeSimilarityEstimateStream( + Object graphNameOrConfiguration, + Map algoConfiguration + ) { + var config = createConfig(algoConfiguration, FilteredNodeSimilarityStreamConfig::of); + return Stream.of(estimateBusinessFacade.nodeSimilarity(graphNameOrConfiguration, config)); + } + + public Stream filteredNodeSimilarityEstimateStats( + Object graphNameOrConfiguration, + Map algoConfiguration + ) { + var config = createConfig(algoConfiguration, FilteredNodeSimilarityStatsConfig::of); + return Stream.of(estimateBusinessFacade.nodeSimilarity(graphNameOrConfiguration, config)); + } + // FIXME: the following two methods are duplicate, find a good place for them. private C createStreamConfig( Map configuration,