Skip to content

Commit

Permalink
Merge pull request #8258 from IoannisPanagiotas/filtered-node-sim-ont…
Browse files Browse the repository at this point in the history
…o-facade

Filtered node similarity onto  stream/stats facade
  • Loading branch information
IoannisPanagiotas authored Oct 12, 2023
2 parents b461ab7 + 1212c1a commit 4ab6dc4
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import org.neo4j.gds.algorithms.runner.AlgorithmRunner;
import org.neo4j.gds.api.DatabaseId;
import org.neo4j.gds.api.User;
import org.neo4j.gds.similarity.filterednodesim.FilteredNodeSimilarityBaseConfig;
import org.neo4j.gds.similarity.filterednodesim.FilteredNodeSimilarityFactory;
import org.neo4j.gds.similarity.nodesim.NodeSimilarityBaseConfig;
import org.neo4j.gds.similarity.nodesim.NodeSimilarityFactory;
import org.neo4j.gds.similarity.nodesim.NodeSimilarityResult;
Expand Down Expand Up @@ -50,4 +52,20 @@ AlgorithmComputationResult<NodeSimilarityResult> nodeSimilarity(
databaseId
);
}

AlgorithmComputationResult<NodeSimilarityResult> filteredNodeSimilarity(
String graphName,
FilteredNodeSimilarityBaseConfig config,
User user,
DatabaseId databaseId
) {
return algorithmRunner.run(
graphName,
config,
config.relationshipWeightProperty(),
new FilteredNodeSimilarityFactory<>(),
user,
databaseId
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.neo4j.gds.api.User;
import org.neo4j.gds.result.SimilarityStatistics;
import org.neo4j.gds.similarity.SimilarityGraphResult;
import org.neo4j.gds.similarity.filterednodesim.FilteredNodeSimilarityStatsConfig;
import org.neo4j.gds.similarity.nodesim.NodeSimilarityStatsConfig;

import java.util.function.Function;
Expand Down Expand Up @@ -71,6 +72,36 @@ public StatsResult<SimilaritySpecificFieldsWithDistribution> nodeSimilarity(
);
}

public StatsResult<SimilaritySpecificFieldsWithDistribution> filteredNodeSimilarity(
String graphName,
FilteredNodeSimilarityStatsConfig configuration,
User user,
DatabaseId databaseId,
boolean computeSimilarityDistribution
) {
// 1. Run the algorithm and time the execution
var intermediateResult = AlgorithmRunner.runWithTiming(
() -> similarityAlgorithmsFacade.filteredNodeSimilarity(graphName, configuration, user, databaseId)
);
var algorithmResult = intermediateResult.algorithmResult;

return statsResult(
algorithmResult,
result -> result.graphResult(),
((result, similarityDistribution) -> {
var graphResult = result.graphResult();
return new SimilaritySpecificFieldsWithDistribution(
graphResult.comparedNodes(),
graphResult.similarityGraph().relationshipCount(),
similarityDistribution
);
}),
intermediateResult.computeMilliseconds,
() -> SimilaritySpecificFieldsWithDistribution.EMPTY,
computeSimilarityDistribution
);
}

<RESULT, ASF extends SimilaritySpecificFields> StatsResult<ASF> statsResult(
AlgorithmComputationResult<RESULT> algorithmResult,
Function<RESULT, SimilarityGraphResult> similarityGraphResultSupplier,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.neo4j.gds.algorithms.StreamComputationResult;
import org.neo4j.gds.api.DatabaseId;
import org.neo4j.gds.api.User;
import org.neo4j.gds.similarity.filterednodesim.FilteredNodeSimilarityStreamConfig;
import org.neo4j.gds.similarity.nodesim.NodeSimilarityResult;
import org.neo4j.gds.similarity.nodesim.NodeSimilarityStreamConfig;

Expand All @@ -46,6 +47,18 @@ public StreamComputationResult<NodeSimilarityResult> nodeSimilarity(
return createStreamComputationResult(result);
}

public StreamComputationResult<NodeSimilarityResult> filteredNodeSimilarity(
String graphName,
FilteredNodeSimilarityStreamConfig config,
User user,
DatabaseId databaseId

) {
var result = similarityAlgorithmsFacade.filteredNodeSimilarity(graphName, config, user, databaseId);

return createStreamComputationResult(result);
}

// FIXME: the following method is duplicate, find a good place for it.
private <RESULT> StreamComputationResult<RESULT> createStreamComputationResult(AlgorithmComputationResult<RESULT> result) {
return StreamComputationResult.of(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
package org.neo4j.gds.similarity.filterednodesim;

import org.neo4j.gds.BaseProc;
import org.neo4j.gds.executor.MemoryEstimationExecutor;
import org.neo4j.gds.executor.ProcedureExecutor;
import org.neo4j.gds.procedures.GraphDataScience;
import org.neo4j.gds.procedures.similarity.SimilarityStatsResult;
import org.neo4j.gds.results.MemoryEstimateResult;
import org.neo4j.procedure.Context;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Internal;
import org.neo4j.procedure.Name;
Expand All @@ -37,16 +37,16 @@

public class FilteredNodeSimilarityStatsProc extends BaseProc {

@Context
public GraphDataScience facade;

@Procedure(value = "gds.nodeSimilarity.filtered.stats", mode = READ)
@Description(DESCRIPTION)
public Stream<SimilarityStatsResult> stats(
@Name(value = "graphName") String graphName,
@Name(value = "configuration", defaultValue = "{}") Map<String, Object> configuration
){
return new ProcedureExecutor<>(
new FilteredNodeSimilarityStatsSpec(),
executionContext()
).compute(graphName, configuration);
return facade.similarity().filteredNodeSimilarityStats(graphName, configuration);
}

@Procedure(value = "gds.nodeSimilarity.filtered.stats.estimate", mode = READ)
Expand All @@ -55,11 +55,7 @@ public Stream<MemoryEstimateResult> estimate(
@Name(value = "graphNameOrConfiguration") Object graphNameOrConfiguration,
@Name(value = "algoConfiguration") Map<String, Object> algoConfiguration
) {
return new MemoryEstimationExecutor<>(
new FilteredNodeSimilarityStatsSpec(),
executionContext(),
transactionContext()
).computeEstimate(graphNameOrConfiguration, algoConfiguration);
return facade.similarity().filteredNodeSimilarityEstimateStats(graphNameOrConfiguration, algoConfiguration);
}

@Deprecated(forRemoval = true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
package org.neo4j.gds.similarity.filterednodesim;

import org.neo4j.gds.BaseProc;
import org.neo4j.gds.executor.MemoryEstimationExecutor;
import org.neo4j.gds.executor.ProcedureExecutor;
import org.neo4j.gds.procedures.GraphDataScience;
import org.neo4j.gds.results.MemoryEstimateResult;
import org.neo4j.gds.similarity.SimilarityResult;
import org.neo4j.procedure.Context;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Internal;
import org.neo4j.procedure.Name;
Expand All @@ -36,6 +36,10 @@

public class FilteredNodeSimilarityStreamProc extends BaseProc {

@Context
public GraphDataScience facade;


static final String DESCRIPTION =
"The Filtered Node Similarity algorithm compares a set of nodes based on the nodes they are connected to. " +
"Two nodes are considered similar if they share many of the same neighbors. " +
Expand All @@ -48,10 +52,7 @@ public Stream<SimilarityResult> stream(
@Name(value = "graphName") String graphName,
@Name(value = "configuration", defaultValue = "{}") Map<String, Object> configuration
){
return new ProcedureExecutor<>(
new FilteredNodeSimilarityStreamSpec(),
executionContext()
).compute(graphName, configuration);
return facade.similarity().filteredNodeSimilarityStream(graphName, configuration);
}

@Procedure(value = "gds.nodeSimilarity.filtered.stream.estimate", mode = READ)
Expand All @@ -60,11 +61,7 @@ public Stream<MemoryEstimateResult> estimate(
@Name(value = "graphNameOrConfiguration") Object graphNameOrConfiguration,
@Name(value = "algoConfiguration") Map<String, Object> algoConfiguration
) {
return new MemoryEstimationExecutor<>(
new FilteredNodeSimilarityStreamSpec(),
executionContext(),
transactionContext()
).computeEstimate(graphNameOrConfiguration, algoConfiguration);
return facade.similarity().filteredNodeSimilarityEstimateStream(graphNameOrConfiguration, algoConfiguration);
}

@Deprecated(forRemoval = true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
import org.neo4j.gds.procedures.community.ConfigurationParser;
import org.neo4j.gds.results.MemoryEstimateResult;
import org.neo4j.gds.similarity.SimilarityResult;
import org.neo4j.gds.similarity.filterednodesim.FilteredNodeSimilarityStatsConfig;
import org.neo4j.gds.similarity.filterednodesim.FilteredNodeSimilarityStreamConfig;
import org.neo4j.gds.similarity.nodesim.NodeSimilarityStatsConfig;
import org.neo4j.gds.similarity.nodesim.NodeSimilarityStreamConfig;

Expand Down Expand Up @@ -127,6 +129,56 @@ public Stream<MemoryEstimateResult> nodeSimilarityEstimateStats(
return Stream.of(estimateBusinessFacade.nodeSimilarity(graphNameOrConfiguration, config));
}

//filtered
public Stream<SimilarityResult> filteredNodeSimilarityStream(
String graphName,
Map<String, Object> configuration
) {
var streamConfig = createStreamConfig(configuration, FilteredNodeSimilarityStreamConfig::of);

var computationResult = streamBusinessFacade.filteredNodeSimilarity(
graphName,
streamConfig,
user,
databaseId
);

return NodeSimilarityComputationResultTransformer.toStreamResult(computationResult);
}

public Stream<SimilarityStatsResult> filteredNodeSimilarityStats(
String graphName,
Map<String, Object> configuration
) {
var statsConfig = createConfig(configuration, FilteredNodeSimilarityStatsConfig::of);

var computationResult = statsBusinessFacade.filteredNodeSimilarity(
graphName,
statsConfig,
user,
databaseId,
procedureReturnColumns.contains("similarityDistribution")
);

return Stream.of(NodeSimilarityComputationResultTransformer.toStatsResult(computationResult, statsConfig));
}

public Stream<MemoryEstimateResult> filteredNodeSimilarityEstimateStream(
Object graphNameOrConfiguration,
Map<String, Object> algoConfiguration
) {
var config = createConfig(algoConfiguration, FilteredNodeSimilarityStreamConfig::of);
return Stream.of(estimateBusinessFacade.nodeSimilarity(graphNameOrConfiguration, config));
}

public Stream<MemoryEstimateResult> filteredNodeSimilarityEstimateStats(
Object graphNameOrConfiguration,
Map<String, Object> algoConfiguration
) {
var config = createConfig(algoConfiguration, FilteredNodeSimilarityStatsConfig::of);
return Stream.of(estimateBusinessFacade.nodeSimilarity(graphNameOrConfiguration, config));
}

// FIXME: the following two methods are duplicate, find a good place for them.
private <C extends AlgoBaseConfig> C createStreamConfig(
Map<String, Object> configuration,
Expand Down

0 comments on commit 4ab6dc4

Please sign in to comment.