From 73d18a773cfc5a0ef34439e5fd753eda19824de7 Mon Sep 17 00:00:00 2001 From: Brian Shi Date: Mon, 25 Sep 2023 16:33:23 +0100 Subject: [PATCH] Fix topKComputer concurrency Co-authored-by: Olga Razvenskaia Co-authored-by: Martin Junghanns --- .../src/main/java/org/neo4j/gds/ml/kge/TopKMapComputer.java | 3 ++- .../test/java/org/neo4j/gds/ml/kge/TopKMapComputerTest.java | 6 +++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/kge/TopKMapComputer.java b/proc/machine-learning/src/main/java/org/neo4j/gds/ml/kge/TopKMapComputer.java index 275afa9fb7..edbc9fdbdf 100644 --- a/proc/machine-learning/src/main/java/org/neo4j/gds/ml/kge/TopKMapComputer.java +++ b/proc/machine-learning/src/main/java/org/neo4j/gds/ml/kge/TopKMapComputer.java @@ -100,9 +100,9 @@ public KGEPredictResult compute() { concurrency, terminationFlag, stream -> { - LongLongPredicate isCandidateLinkPredicate = isCandidateLink(concurrentGraph.get()); stream.forEach(node1 -> { terminationFlag.assertRunning(); + LongLongPredicate isCandidateLinkPredicate = isCandidateLink(concurrentGraph.get()); LinkScorer linkScorer = threadLocalScorer.get(); linkScorer.init(node1); @@ -138,6 +138,7 @@ private long estimateWorkload() { } private LongLongPredicate isCandidateLink(Graph graph) { + //exists O(n) return (s, t) -> s != t && !graph.exists(s, t); } } diff --git a/proc/machine-learning/src/test/java/org/neo4j/gds/ml/kge/TopKMapComputerTest.java b/proc/machine-learning/src/test/java/org/neo4j/gds/ml/kge/TopKMapComputerTest.java index 6d908cecd1..1683e5a306 100644 --- a/proc/machine-learning/src/test/java/org/neo4j/gds/ml/kge/TopKMapComputerTest.java +++ b/proc/machine-learning/src/test/java/org/neo4j/gds/ml/kge/TopKMapComputerTest.java @@ -73,7 +73,7 @@ void shouldComputeTopKMapTransE() { var sourceNodes = create(0, 1, 2); var targetNodes = create(3, 4, 5); var topK = 1; - var concurrency = 1; + var concurrency = 4; var computer = new TopKMapComputer( graph, @@ -111,7 +111,7 @@ void shouldComputeTopKMapDistMult() { var sourceNodes = create(0, 1, 2); var targetNodes = create(3, 4, 5); var topK = 1; - var concurrency = 1; + var concurrency = 4; var computer = new TopKMapComputer( graph, @@ -153,7 +153,7 @@ void shouldComputeOverCorrectFiltering() { var sourceNodes = create(0, 1, 2); var targetNodes = create(0, 1, 2, 3); var topK = 10; - var concurrency = 1; + var concurrency = 4; var computer = new TopKMapComputer( graph,