Skip to content

Commit

Permalink
fix injection of termination flag into similarity algorithms
Browse files Browse the repository at this point in the history
  • Loading branch information
lassewesth committed May 22, 2024
1 parent 4f470f4 commit 2e5367f
Show file tree
Hide file tree
Showing 14 changed files with 155 additions and 413 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,16 @@
import org.neo4j.gds.similarity.knn.NeighborList;
import org.neo4j.gds.similarity.knn.SimilarityFunction;
import org.neo4j.gds.similarity.knn.metrics.SimilarityComputer;
import org.neo4j.gds.termination.TerminationFlag;

import java.util.Optional;
import java.util.concurrent.ExecutorService;

/**
* Filtered KNN is the same as ordinary KNN, _but_ we allow users to regulate final output in two ways.
*
* Consider each result item to be a relationship from one node to another, with a score.
*
* Firstly, we enable source node filtering. This limits the result to only contain relationships where the source node matches the filter.
* Secondly, we enable target node filtering. This limits the result to only contain relationships where the target node matches the filter.
*
* In both cases the source or target node set can be actual specified nodes, or it could be all nodes with a label.
*/
@SuppressWarnings("ClassWithOnlyPrivateConstructors")
Expand All @@ -56,15 +54,15 @@ public class FilteredKnn extends Algorithm<FilteredKnnResult> {
private final TargetNodeFiltering targetNodeFiltering;
private final NodeFilter sourceNodeFilter;

public static FilteredKnn createWithoutSeeding(Graph graph, FilteredKnnBaseConfig config, KnnContext context) {
return create(graph, config, context, Optional.empty());
public static FilteredKnn createWithoutSeeding(Graph graph, FilteredKnnBaseConfig config, KnnContext context, TerminationFlag terminationFlag) {
return create(graph, config, context, Optional.empty(), terminationFlag);
}

// a bit speculative, but we imagine this being used as entrypoint for seeding
public static FilteredKnn createWithDefaultSeeding(Graph graph, FilteredKnnBaseConfig config, KnnContext context) {
public static FilteredKnn createWithDefaultSeeding(Graph graph, FilteredKnnBaseConfig config, KnnContext context, TerminationFlag terminationFlag) {
var similarityFunction = new SimilarityFunction(SimilarityComputer.ofProperties(graph, config.nodeProperties()));

return create(graph, config, context, Optional.of(similarityFunction));
return create(graph, config, context, Optional.of(similarityFunction), terminationFlag);
}

/**
Expand All @@ -74,7 +72,13 @@ public static FilteredKnn createWithDefaultSeeding(Graph graph, FilteredKnnBaseC
*
* @param optionalSimilarityFunction An actual similarity function if you want seeding, empty otherwise
*/
static FilteredKnn create(Graph graph, FilteredKnnBaseConfig config, KnnContext context, Optional<SimilarityFunction> optionalSimilarityFunction) {
static FilteredKnn create(
Graph graph,
FilteredKnnBaseConfig config,
KnnContext context,
Optional<SimilarityFunction> optionalSimilarityFunction,
TerminationFlag terminationFlag
) {
var targetNodeFilter = config.targetNodeFilter().toNodeFilter(graph);
var sourceNodeFilter = config.sourceNodeFilter().toNodeFilter(graph);

Expand Down Expand Up @@ -107,22 +111,25 @@ static FilteredKnn create(Graph graph, FilteredKnnBaseConfig config, KnnContext
config.initialSampler(),
similarityFunction,
new KnnNeighborFilterFactory(graph.nodeCount()),
targetNodeFiltering
targetNodeFiltering,
terminationFlag
);

return new FilteredKnn(context.progressTracker(), knn, targetNodeFiltering, sourceNodeFilter);
return new FilteredKnn(context.progressTracker(), knn, targetNodeFiltering, sourceNodeFilter, terminationFlag);
}

private FilteredKnn(
ProgressTracker progressTracker,
Knn delegate,
TargetNodeFiltering targetNodeFiltering,
NodeFilter sourceNodeFilter
NodeFilter sourceNodeFilter,
TerminationFlag terminationFlag
) {
super(progressTracker);
this.delegate = delegate;
this.targetNodeFiltering = targetNodeFiltering;
this.sourceNodeFilter = sourceNodeFilter;
this.terminationFlag = terminationFlag;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.neo4j.gds.similarity.knn.ImmutableKnnContext;
import org.neo4j.gds.similarity.knn.KnnContext;
import org.neo4j.gds.similarity.knn.KnnFactory;
import org.neo4j.gds.termination.TerminationFlag;

public class FilteredKnnFactory<CONFIG extends FilteredKnnBaseConfig> extends GraphAlgorithmFactory<FilteredKnn, CONFIG> {
private static final String FILTERED_KNN_TASK_NAME = "Filtered KNN";
Expand All @@ -37,7 +38,20 @@ public class FilteredKnnFactory<CONFIG extends FilteredKnnBaseConfig> extends Gr
private final TriFunction<Graph, CONFIG, KnnContext, FilteredKnn> seededFilteredKnnSupplier;

public FilteredKnnFactory() {
this(FilteredKnn::createWithoutSeeding, FilteredKnn::createWithDefaultSeeding);
this(
(graph, config, knnContext) -> FilteredKnn.createWithoutSeeding(
graph,
config,
knnContext,
TerminationFlag.RUNNING_TRUE
),
(graph, config, knnContext) -> FilteredKnn.createWithDefaultSeeding(
graph,
config,
knnContext,
TerminationFlag.RUNNING_TRUE
)
);
}

FilteredKnnFactory(
Expand Down
12 changes: 9 additions & 3 deletions algo/src/main/java/org/neo4j/gds/similarity/knn/Knn.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.neo4j.gds.core.utils.partition.PartitionUtils;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.similarity.knn.metrics.SimilarityComputer;
import org.neo4j.gds.termination.TerminationFlag;

import java.util.Optional;
import java.util.SplittableRandom;
Expand All @@ -42,7 +43,8 @@ public static Knn create(
KnnParameters parameters,
SimilarityComputer similarityComputer,
NeighborFilterFactory neighborFilterFactory,
KnnContext context
KnnContext context,
TerminationFlag terminationFlag
) {
var similarityFunction = new SimilarityFunction(similarityComputer);
return new Knn(
Expand All @@ -60,7 +62,8 @@ public static Knn create(
parameters.samplerType(),
similarityFunction,
neighborFilterFactory,
NeighbourConsumers.no_op
NeighbourConsumers.no_op,
terminationFlag
);
}

Expand Down Expand Up @@ -92,7 +95,8 @@ public Knn(
KnnSampler.SamplerType initialSamplerType,
SimilarityFunction similarityFunction,
NeighborFilterFactory neighborFilterFactory,
NeighbourConsumers neighborConsumers
NeighbourConsumers neighborConsumers,
TerminationFlag terminationFlag
) {
super(progressTracker);
this.graph = graph;
Expand Down Expand Up @@ -136,6 +140,8 @@ public Knn(
splittableRandom,
progressTracker
);

this.terminationFlag = terminationFlag;
}

public ExecutorService executorService() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
import org.neo4j.gds.similarity.knn.metrics.SimilarityComputer;
import org.neo4j.gds.termination.TerminationFlag;

import java.util.List;

Expand All @@ -54,7 +55,8 @@ public Knn build(
.builder()
.progressTracker(progressTracker)
.executor(DefaultPool.INSTANCE)
.build()
.build(),
TerminationFlag.RUNNING_TRUE
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.neo4j.gds.similarity.SimilarityGraphResult;
import org.neo4j.gds.similarity.SimilarityResult;
import org.neo4j.gds.similarity.filtering.NodeFilter;
import org.neo4j.gds.termination.TerminationFlag;
import org.neo4j.gds.wcc.Wcc;
import org.neo4j.gds.wcc.WccAlgorithmFactory;
import org.neo4j.gds.wcc.WccParameters;
Expand Down Expand Up @@ -74,6 +75,10 @@ public class NodeSimilarity extends Algorithm<NodeSimilarityResult> {
private Function<Long, LongStream> sourceNodesStream;
private BiFunction<Long, Long, LongStream> targetNodesStream;

/**
* @deprecated Don't use this, use the one that injects termination flag directly
*/
@Deprecated
public NodeSimilarity(
Graph graph,
NodeSimilarityParameters parameters,
Expand All @@ -88,10 +93,35 @@ public NodeSimilarity(
executorService,
progressTracker,
NodeFilter.ALLOW_EVERYTHING,
NodeFilter.ALLOW_EVERYTHING
NodeFilter.ALLOW_EVERYTHING,
TerminationFlag.RUNNING_TRUE
);
}

public NodeSimilarity(
Graph graph,
NodeSimilarityParameters parameters,
Concurrency concurrency,
ExecutorService executorService,
ProgressTracker progressTracker,
TerminationFlag terminationFlag
) {
this(
graph,
parameters,
concurrency,
executorService,
progressTracker,
NodeFilter.ALLOW_EVERYTHING,
NodeFilter.ALLOW_EVERYTHING,
terminationFlag
);
}

/**
* @deprecated Don't use this, use the one that injects termination flag directly
*/
@Deprecated
public NodeSimilarity(
Graph graph,
NodeSimilarityParameters parameters,
Expand All @@ -100,6 +130,28 @@ public NodeSimilarity(
ProgressTracker progressTracker,
NodeFilter sourceNodeFilter,
NodeFilter targetNodeFilter
) {
this(
graph,
parameters,
concurrency,
executorService,
progressTracker,
sourceNodeFilter,
targetNodeFilter,
TerminationFlag.RUNNING_TRUE
);
}

public NodeSimilarity(
Graph graph,
NodeSimilarityParameters parameters,
Concurrency concurrency,
ExecutorService executorService,
ProgressTracker progressTracker,
NodeFilter sourceNodeFilter,
NodeFilter targetNodeFilter,
TerminationFlag terminationFlag
) {
super(progressTracker);
this.graph = graph;
Expand All @@ -113,6 +165,7 @@ public NodeSimilarity(
this.sourceNodes = new BitSet(graph.nodeCount());
this.targetNodes = new BitSet(graph.nodeCount());
this.weighted = this.parameters.hasRelationshipWeightProperty();
this.terminationFlag = terminationFlag;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.neo4j.gds.similarity.filtering.NodeFilterSpecFactory;
import org.neo4j.gds.similarity.knn.KnnContext;
import org.neo4j.gds.similarity.knn.KnnNodePropertySpec;
import org.neo4j.gds.termination.TerminationFlag;

import java.util.List;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -65,7 +66,7 @@ void shouldIdMapTheSourceNodeFilter() {
.sourceNodeFilter(NodeFilterSpecFactory.create(lowestOriginalId))
.build();

var knn = FilteredKnn.createWithoutSeeding(graph, config, KnnContext.empty());
var knn = FilteredKnn.createWithoutSeeding(graph, config, KnnContext.empty(), TerminationFlag.RUNNING_TRUE);

var result = knn.compute();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.neo4j.gds.similarity.knn.ImmutableKnnContext;
import org.neo4j.gds.similarity.knn.KnnContext;
import org.neo4j.gds.similarity.knn.KnnNodePropertySpec;
import org.neo4j.gds.termination.TerminationFlag;

import java.util.Comparator;
import java.util.List;
Expand Down Expand Up @@ -83,7 +84,7 @@ void shouldRunJustLikeKnnWhenYouDoNotSpecifySourceNodeFilterOrTargetNodeFilter()
.build();
var knnContext = ImmutableKnnContext.builder().build();

var knn = FilteredKnn.createWithoutSeeding(graph, knnConfig, knnContext);
var knn = FilteredKnn.createWithoutSeeding(graph, knnConfig, knnContext, TerminationFlag.RUNNING_TRUE);
var result = knn.compute();

assertThat(result).isNotNull();
Expand Down Expand Up @@ -138,7 +139,7 @@ void shouldOnlyProduceResultsForFilteredSourceNode() {
.sourceNodeFilter(graph.toOriginalNodeId(filteredSourceNode))
.build();
var knnContext = KnnContext.empty();
var knn = FilteredKnn.createWithoutSeeding(graph, config, knnContext);
var knn = FilteredKnn.createWithoutSeeding(graph, config, knnContext, TerminationFlag.RUNNING_TRUE);
var result = knn.compute();

assertThat(result.similarityResultStream()
Expand All @@ -160,7 +161,7 @@ void shouldOnlyProduceResultsForMultipleFilteredSourceNode() {
.sourceNodeFilter(filteredNodes.stream().map(graph::toOriginalNodeId).collect(Collectors.toList()))
.build();
var knnContext = KnnContext.empty();
var knn = FilteredKnn.createWithoutSeeding(graph, config, knnContext);
var knn = FilteredKnn.createWithoutSeeding(graph, config, knnContext, TerminationFlag.RUNNING_TRUE);
var result = knn.compute();

assertThat(result.similarityResultStream()
Expand Down Expand Up @@ -194,7 +195,7 @@ void shouldOnlyProduceResultsForFilteredTargetNode() {
.targetNodeFilter(graph.toOriginalNodeId(targetNode))
.build();
var knnContext = KnnContext.empty();
var knn = FilteredKnn.createWithoutSeeding(graph, config, knnContext);
var knn = FilteredKnn.createWithoutSeeding(graph, config, knnContext, TerminationFlag.RUNNING_TRUE);
var result = knn.compute();

assertThat(result.similarityResultStream()
Expand All @@ -216,7 +217,7 @@ void shouldOnlyProduceResultsForFilteredTargetNodes() {
.targetNodeFilter(targetNodes.stream().map(graph::toOriginalNodeId).collect(Collectors.toList()))
.build();
var knnContext = KnnContext.empty();
var knn = FilteredKnn.createWithoutSeeding(graph, config, knnContext);
var knn = FilteredKnn.createWithoutSeeding(graph, config, knnContext, TerminationFlag.RUNNING_TRUE);
var result = knn.compute();

assertThat(result.similarityResultStream()
Expand Down Expand Up @@ -244,7 +245,7 @@ void shouldIgnoreDuplicates() {
.topK(42)
.build();
var knnContext = KnnContext.empty();
var knn = FilteredKnn.createWithoutSeeding(graph, config, knnContext);
var knn = FilteredKnn.createWithoutSeeding(graph, config, knnContext, TerminationFlag.RUNNING_TRUE);
var result = knn.compute();

/*
Expand Down Expand Up @@ -306,7 +307,7 @@ void shouldSeedResultSet() {
.concurrency(1)
.build();
var knnContext = KnnContext.empty();
var knn = FilteredKnn.createWithDefaultSeeding(graph, config, knnContext);
var knn = FilteredKnn.createWithDefaultSeeding(graph, config, knnContext, TerminationFlag.RUNNING_TRUE);
var result = knn.compute();

/*
Expand Down
Loading

0 comments on commit 2e5367f

Please sign in to comment.