From 66607811048950dc687b4437250d45d7a3708054 Mon Sep 17 00:00:00 2001 From: Lasse Westh-Nielsen Date: Thu, 23 May 2024 11:18:14 +0200 Subject: [PATCH] introducing algorithm machinery to manage running algorithms and prodding progress tracker the right way --- .../org/neo4j/gds/paths/traverse/BFS.java | 17 ++- .../paths/traverse/BfsAlgorithmFactory.java | 4 +- .../paths/traverse/BFSComplexTreeTest.java | 4 +- .../paths/traverse/BFSOnBiggerGraphTest.java | 4 +- .../org/neo4j/gds/paths/traverse/BFSTest.java | 16 ++- .../paths/traverse/BFSTridentGraphTest.java | 4 +- .../centrality/CentralityAlgorithms.java | 9 +- .../machinery/AlgorithmMachinery.java | 54 +++++++++ .../machinery/AlgorithmMachineryTest.java | 104 ++++++++++++++++++ .../machinery/FailingAlgorithm.java | 36 ++++++ .../machinery/RegurgitatingAlgorithm.java | 36 ++++++ .../pathfinding/PathFindingAlgorithms.java | 57 ++++++---- .../traverse/BreadthFirstSearch.java | 9 +- .../traverse/DepthFirstSearch.java | 4 +- .../similarity/SimilarityAlgorithms.java | 32 ++++-- 15 files changed, 341 insertions(+), 49 deletions(-) create mode 100644 applications/algorithms/machinery/src/main/java/org/neo4j/gds/applications/algorithms/machinery/AlgorithmMachinery.java create mode 100644 applications/algorithms/machinery/src/test/java/org/neo4j/gds/applications/algorithms/machinery/AlgorithmMachineryTest.java create mode 100644 applications/algorithms/machinery/src/test/java/org/neo4j/gds/applications/algorithms/machinery/FailingAlgorithm.java create mode 100644 applications/algorithms/machinery/src/test/java/org/neo4j/gds/applications/algorithms/machinery/RegurgitatingAlgorithm.java diff --git a/algo/src/main/java/org/neo4j/gds/paths/traverse/BFS.java b/algo/src/main/java/org/neo4j/gds/paths/traverse/BFS.java index be699e3060..845a2927c1 100644 --- a/algo/src/main/java/org/neo4j/gds/paths/traverse/BFS.java +++ b/algo/src/main/java/org/neo4j/gds/paths/traverse/BFS.java @@ -31,6 +31,7 @@ import org.neo4j.gds.collections.ha.HugeLongArray; import org.neo4j.gds.core.utils.paged.ParalleLongPageCreator; import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker; +import org.neo4j.gds.termination.TerminationFlag; import java.util.ArrayList; import java.util.Collection; @@ -93,7 +94,8 @@ public static BFS create( Aggregator aggregatorFunction, Concurrency concurrency, ProgressTracker progressTracker, - long maximumDepth + long maximumDepth, + TerminationFlag terminationFlag ) { return create( graph, @@ -103,7 +105,8 @@ public static BFS create( concurrency, progressTracker, DEFAULT_DELTA, - maximumDepth + maximumDepth, + terminationFlag ); } @@ -115,7 +118,8 @@ static BFS create( Concurrency concurrency, ProgressTracker progressTracker, int delta, - long maximumDepth + long maximumDepth, + TerminationFlag terminationFlag ) { var nodeCount = graph.nodeCount(); @@ -135,7 +139,8 @@ static BFS create( concurrency, progressTracker, delta, - maximumDepth + maximumDepth, + terminationFlag ); } @@ -150,7 +155,8 @@ private BFS( Concurrency concurrency, ProgressTracker progressTracker, int delta, - long maximumDepth + long maximumDepth, + TerminationFlag terminationFlag ) { super(progressTracker); this.graph = graph; @@ -163,6 +169,7 @@ private BFS( this.traversedNodes = traversedNodes; this.weights = weights; this.visited = visited; + this.terminationFlag = terminationFlag; } @Override diff --git a/algo/src/main/java/org/neo4j/gds/paths/traverse/BfsAlgorithmFactory.java b/algo/src/main/java/org/neo4j/gds/paths/traverse/BfsAlgorithmFactory.java index aef504c483..9d2f7781a3 100644 --- a/algo/src/main/java/org/neo4j/gds/paths/traverse/BfsAlgorithmFactory.java +++ b/algo/src/main/java/org/neo4j/gds/paths/traverse/BfsAlgorithmFactory.java @@ -23,6 +23,7 @@ import org.neo4j.gds.api.Graph; import org.neo4j.gds.mem.MemoryEstimation; import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker; +import org.neo4j.gds.termination.TerminationFlag; import java.util.List; import java.util.stream.Collectors; @@ -59,7 +60,8 @@ public BFS build(Graph graph, CONFIG configuration, ProgressTracker progressTrac aggregatorFunction, configuration.concurrency(), progressTracker, - configuration.maxDepth() + configuration.maxDepth(), + TerminationFlag.RUNNING_TRUE ); } diff --git a/algo/src/test/java/org/neo4j/gds/paths/traverse/BFSComplexTreeTest.java b/algo/src/test/java/org/neo4j/gds/paths/traverse/BFSComplexTreeTest.java index 921ffd6115..28b6561e16 100644 --- a/algo/src/test/java/org/neo4j/gds/paths/traverse/BFSComplexTreeTest.java +++ b/algo/src/test/java/org/neo4j/gds/paths/traverse/BFSComplexTreeTest.java @@ -28,6 +28,7 @@ import org.neo4j.gds.extension.GdlGraph; import org.neo4j.gds.extension.Inject; import org.neo4j.gds.extension.TestGraph; +import org.neo4j.gds.termination.TerminationFlag; import java.util.List; import java.util.stream.Stream; @@ -113,7 +114,8 @@ void testBfsToTargetOut(int concurrency, int delta) { new Concurrency(concurrency), ProgressTracker.NULL_TRACKER, delta, - BFS.ALL_DEPTHS_ALLOWED + BFS.ALL_DEPTHS_ALLOWED, + TerminationFlag.RUNNING_TRUE ).compute().toArray(); assertThat(nodes) diff --git a/algo/src/test/java/org/neo4j/gds/paths/traverse/BFSOnBiggerGraphTest.java b/algo/src/test/java/org/neo4j/gds/paths/traverse/BFSOnBiggerGraphTest.java index 3bf52f0d3d..3cdd590cf6 100644 --- a/algo/src/test/java/org/neo4j/gds/paths/traverse/BFSOnBiggerGraphTest.java +++ b/algo/src/test/java/org/neo4j/gds/paths/traverse/BFSOnBiggerGraphTest.java @@ -28,6 +28,7 @@ import org.neo4j.gds.extension.GdlGraph; import org.neo4j.gds.extension.Inject; import org.neo4j.gds.extension.TestGraph; +import org.neo4j.gds.termination.TerminationFlag; import java.util.stream.Stream; @@ -112,7 +113,8 @@ void testBfsToTargetOut(int concurrency, int delta) { new Concurrency(concurrency), ProgressTracker.NULL_TRACKER, delta, - BFS.ALL_DEPTHS_ALLOWED + BFS.ALL_DEPTHS_ALLOWED, + TerminationFlag.RUNNING_TRUE ).compute().toArray(); assertThat(nodes) diff --git a/algo/src/test/java/org/neo4j/gds/paths/traverse/BFSTest.java b/algo/src/test/java/org/neo4j/gds/paths/traverse/BFSTest.java index e43d5f71ff..310271ffe8 100644 --- a/algo/src/test/java/org/neo4j/gds/paths/traverse/BFSTest.java +++ b/algo/src/test/java/org/neo4j/gds/paths/traverse/BFSTest.java @@ -33,6 +33,7 @@ import org.neo4j.gds.extension.Inject; import org.neo4j.gds.extension.TestGraph; import org.neo4j.gds.paths.traverse.ExitPredicate.Result; +import org.neo4j.gds.termination.TerminationFlag; import java.util.stream.Stream; @@ -107,7 +108,8 @@ void testBfsToTargetOut(int concurrency) { (s, t, w) -> 1., new Concurrency(concurrency), ProgressTracker.NULL_TRACKER, - BFS.ALL_DEPTHS_ALLOWED + BFS.ALL_DEPTHS_ALLOWED, + TerminationFlag.RUNNING_TRUE ).compute().toArray(); // algorithms return mapped ids @@ -133,7 +135,8 @@ void testBfsToTargetIn(int concurrency) { Aggregator.NO_AGGREGATION, new Concurrency(concurrency), ProgressTracker.NULL_TRACKER, - BFS.ALL_DEPTHS_ALLOWED + BFS.ALL_DEPTHS_ALLOWED, + TerminationFlag.RUNNING_TRUE ).compute().toArray(); assertEquals(7, nodes.length); } @@ -156,7 +159,8 @@ void testBfsMaxDepthOut(int concurrency) { (s, t, w) -> w + 1., new Concurrency(concurrency), ProgressTracker.NULL_TRACKER, - maxHops - 1 + maxHops - 1, + TerminationFlag.RUNNING_TRUE ).compute().toArray(); assertThat(nodes).isEqualTo( @@ -172,7 +176,8 @@ void testBfsOnLoopGraph(int concurrency) { Aggregator.NO_AGGREGATION, new Concurrency(concurrency), ProgressTracker.NULL_TRACKER, - BFS.ALL_DEPTHS_ALLOWED + BFS.ALL_DEPTHS_ALLOWED, + TerminationFlag.RUNNING_TRUE ).compute(); } @@ -189,7 +194,8 @@ void shouldLogProgress(int concurrency) { Aggregator.NO_AGGREGATION, new Concurrency(concurrency), progressTracker, - BFS.ALL_DEPTHS_ALLOWED + BFS.ALL_DEPTHS_ALLOWED, + TerminationFlag.RUNNING_TRUE ).compute(); var messagesInOrder = testLog.getMessages(INFO); diff --git a/algo/src/test/java/org/neo4j/gds/paths/traverse/BFSTridentGraphTest.java b/algo/src/test/java/org/neo4j/gds/paths/traverse/BFSTridentGraphTest.java index dfebc37e09..28d1507091 100644 --- a/algo/src/test/java/org/neo4j/gds/paths/traverse/BFSTridentGraphTest.java +++ b/algo/src/test/java/org/neo4j/gds/paths/traverse/BFSTridentGraphTest.java @@ -29,6 +29,7 @@ import org.neo4j.gds.extension.GdlGraph; import org.neo4j.gds.extension.Inject; import org.neo4j.gds.extension.TestGraph; +import org.neo4j.gds.termination.TerminationFlag; import java.util.Arrays; import java.util.List; @@ -96,7 +97,8 @@ void testBfsToTargetOut(int concurrency, int delta) { new Concurrency(concurrency), ProgressTracker.NULL_TRACKER, delta, - BFS.ALL_DEPTHS_ALLOWED + BFS.ALL_DEPTHS_ALLOWED, + TerminationFlag.RUNNING_TRUE ).compute().toArray(); assertThat(nodes) diff --git a/applications/algorithms/centrality/src/main/java/org/neo4j/gds/applications/algorithms/centrality/CentralityAlgorithms.java b/applications/algorithms/centrality/src/main/java/org/neo4j/gds/applications/algorithms/centrality/CentralityAlgorithms.java index 845398ea0e..ac7b765e9a 100644 --- a/applications/algorithms/centrality/src/main/java/org/neo4j/gds/applications/algorithms/centrality/CentralityAlgorithms.java +++ b/applications/algorithms/centrality/src/main/java/org/neo4j/gds/applications/algorithms/centrality/CentralityAlgorithms.java @@ -20,6 +20,7 @@ package org.neo4j.gds.applications.algorithms.centrality; import org.neo4j.gds.api.Graph; +import org.neo4j.gds.applications.algorithms.machinery.AlgorithmMachinery; import org.neo4j.gds.applications.algorithms.machinery.ProgressTrackerCreator; import org.neo4j.gds.applications.algorithms.metadata.LabelForProgressTracking; import org.neo4j.gds.betweenness.BetweennessCentrality; @@ -41,6 +42,8 @@ import org.neo4j.gds.termination.TerminationFlag; public class CentralityAlgorithms { + private final AlgorithmMachinery algorithmMachinery = new AlgorithmMachinery(); + private final ProgressTrackerCreator progressTrackerCreator; private final TerminationFlag terminationFlag; @@ -79,7 +82,7 @@ BetwennessCentralityResult betweennessCentrality(Graph graph, BetweennessCentral terminationFlag ); - return algorithm.compute(); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true); } ClosenessCentralityResult closenessCentrality(Graph graph, ClosenessCentralityBaseConfig configuration) { @@ -103,7 +106,7 @@ ClosenessCentralityResult closenessCentrality(Graph graph, ClosenessCentralityBa progressTracker ); - return algorithm.compute(); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true); } DegreeCentralityResult degreeCentrality(Graph graph, DegreeCentralityConfig configuration) { @@ -122,6 +125,6 @@ DegreeCentralityResult degreeCentrality(Graph graph, DegreeCentralityConfig conf progressTracker ); - return algorithm.compute(); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true); } } diff --git a/applications/algorithms/machinery/src/main/java/org/neo4j/gds/applications/algorithms/machinery/AlgorithmMachinery.java b/applications/algorithms/machinery/src/main/java/org/neo4j/gds/applications/algorithms/machinery/AlgorithmMachinery.java new file mode 100644 index 0000000000..f23ba6f17c --- /dev/null +++ b/applications/algorithms/machinery/src/main/java/org/neo4j/gds/applications/algorithms/machinery/AlgorithmMachinery.java @@ -0,0 +1,54 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.gds.applications.algorithms.machinery; + +import org.neo4j.gds.Algorithm; +import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker; + +/** + * I wish this did not exist quite like this; it is where we encapsulate running an algorithm, + * managing termination, and handling (progress tracker) resources. + * Somehow I wish that was encapsulated more naturally, but as you can hear from this use of language, + * the design has not crystallized yet. + * At least nothing here is tied to termination flag. + */ +public class AlgorithmMachinery { + /** + * Runs algorithm. + * Optionally releases progress tracker. + * Exceptionally marks progress tracker state as failed. + * + * @return algorithm result, or an error in the form of an exception + */ + public RESULT runAlgorithmsAndManageProgressTracker( + Algorithm algorithm, + ProgressTracker progressTracker, + boolean shouldReleaseProgressTracker + ) { + try { + return algorithm.compute(); + } catch (Exception e) { + progressTracker.endSubTaskWithFailure(); + throw e; + } finally { + if (shouldReleaseProgressTracker) progressTracker.release(); + } + } +} diff --git a/applications/algorithms/machinery/src/test/java/org/neo4j/gds/applications/algorithms/machinery/AlgorithmMachineryTest.java b/applications/algorithms/machinery/src/test/java/org/neo4j/gds/applications/algorithms/machinery/AlgorithmMachineryTest.java new file mode 100644 index 0000000000..2fd92342e1 --- /dev/null +++ b/applications/algorithms/machinery/src/test/java/org/neo4j/gds/applications/algorithms/machinery/AlgorithmMachineryTest.java @@ -0,0 +1,104 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.gds.applications.algorithms.machinery; + +import org.junit.jupiter.api.Test; +import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.fail; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; + +class AlgorithmMachineryTest { + @Test + void shouldRunAlgorithm() { + var algorithmMachinery = new AlgorithmMachinery(); + + var progressTracker = mock(ProgressTracker.class); + var result = algorithmMachinery.runAlgorithmsAndManageProgressTracker( + new RegurgitatingAlgorithm("Hello, world!"), + progressTracker, + false + ); + + assertThat(result).isEqualTo("Hello, world!"); + + verifyNoInteractions(progressTracker); + } + + @Test + void shouldReleaseProgressTrackerWhenAsked() { + var algorithmMachinery = new AlgorithmMachinery(); + + var progressTracker = mock(ProgressTracker.class); + var result = algorithmMachinery.runAlgorithmsAndManageProgressTracker( + new RegurgitatingAlgorithm("Dodgers win world series!"), + progressTracker, + true + ); + + assertThat(result).isEqualTo("Dodgers win world series!"); + + verify(progressTracker).release(); + } + + @Test + void shouldMarkProgressTracker() { + var algorithmMachinery = new AlgorithmMachinery(); + + var progressTracker = mock(ProgressTracker.class); + var exception = new RuntimeException("Whoops!"); + try { + algorithmMachinery.runAlgorithmsAndManageProgressTracker( + new FailingAlgorithm(exception), + progressTracker, + false + ); + fail(); + } catch (Exception e) { + assertThat(e).hasMessage("Whoops!"); + } + + verify(progressTracker).endSubTaskWithFailure(); + } + + @Test + void shouldMarkProgressTrackerAndReleaseIt() { + var algorithmMachinery = new AlgorithmMachinery(); + + var progressTracker = mock(ProgressTracker.class); + var exception = new RuntimeException("Yeah, no..."); + try { + algorithmMachinery.runAlgorithmsAndManageProgressTracker( + new FailingAlgorithm(exception), + progressTracker, + true + ); + fail(); + } catch (Exception e) { + assertThat(e).hasMessage("Yeah, no..."); + } + + verify(progressTracker).endSubTaskWithFailure(); + verify(progressTracker).release(); + } +} diff --git a/applications/algorithms/machinery/src/test/java/org/neo4j/gds/applications/algorithms/machinery/FailingAlgorithm.java b/applications/algorithms/machinery/src/test/java/org/neo4j/gds/applications/algorithms/machinery/FailingAlgorithm.java new file mode 100644 index 0000000000..33930a0b20 --- /dev/null +++ b/applications/algorithms/machinery/src/test/java/org/neo4j/gds/applications/algorithms/machinery/FailingAlgorithm.java @@ -0,0 +1,36 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.gds.applications.algorithms.machinery; + +import org.neo4j.gds.Algorithm; + +class FailingAlgorithm extends Algorithm { + private final RuntimeException exception; + + FailingAlgorithm(RuntimeException exception) { + super(null); + this.exception = exception; + } + + @Override + public Void compute() { + throw exception; + } +} diff --git a/applications/algorithms/machinery/src/test/java/org/neo4j/gds/applications/algorithms/machinery/RegurgitatingAlgorithm.java b/applications/algorithms/machinery/src/test/java/org/neo4j/gds/applications/algorithms/machinery/RegurgitatingAlgorithm.java new file mode 100644 index 0000000000..fbe28b690d --- /dev/null +++ b/applications/algorithms/machinery/src/test/java/org/neo4j/gds/applications/algorithms/machinery/RegurgitatingAlgorithm.java @@ -0,0 +1,36 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Neo4j is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.gds.applications.algorithms.machinery; + +import org.neo4j.gds.Algorithm; + +class RegurgitatingAlgorithm extends Algorithm { + private final String message; + + RegurgitatingAlgorithm(String message) { + super(null); + this.message = message; + } + + @Override + public String compute() { + return message; + } +} diff --git a/applications/algorithms/path-finding/src/main/java/org/neo4j/gds/applications/algorithms/pathfinding/PathFindingAlgorithms.java b/applications/algorithms/path-finding/src/main/java/org/neo4j/gds/applications/algorithms/pathfinding/PathFindingAlgorithms.java index 8a10947c2d..f2b6808e92 100644 --- a/applications/algorithms/path-finding/src/main/java/org/neo4j/gds/applications/algorithms/pathfinding/PathFindingAlgorithms.java +++ b/applications/algorithms/path-finding/src/main/java/org/neo4j/gds/applications/algorithms/pathfinding/PathFindingAlgorithms.java @@ -25,6 +25,7 @@ import org.neo4j.gds.allshortestpaths.MSBFSAllShortestPaths; import org.neo4j.gds.allshortestpaths.WeightedAllShortestPaths; import org.neo4j.gds.api.Graph; +import org.neo4j.gds.applications.algorithms.machinery.AlgorithmMachinery; import org.neo4j.gds.applications.algorithms.machinery.ProgressTrackerCreator; import org.neo4j.gds.applications.algorithms.machinery.RequestScopedDependencies; import org.neo4j.gds.applications.algorithms.metadata.LabelForProgressTracking; @@ -82,7 +83,7 @@ * Associated mode-specific validation is also done in layers above. */ public class PathFindingAlgorithms { - // global scoped dependencies + private final AlgorithmMachinery algorithmMachinery = new AlgorithmMachinery(); // request scoped parameters private final RequestScopedDependencies requestScopedDependencies; @@ -121,9 +122,15 @@ BellmanFordResult bellmanFord(Graph graph, BellmanFordBaseConfig configuration) configuration.concurrency() ); - return algorithm.compute(); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, false); } + /** + * Here is an example of how resource management and structure collide. + * Progress tracker is constructed here for BreadthFirstSearch, then inside it is delegated to BFS. + * Ergo we apply the progress tracker resource machinery inside. + * But it is not great innit. + */ HugeLongArray breadthFirstSearch(Graph graph, BfsBaseConfig configuration) { var progressTracker = createProgressTracker(configuration, Tasks.leaf(LabelForProgressTracking.BFS.value)); @@ -132,7 +139,8 @@ HugeLongArray breadthFirstSearch(Graph graph, BfsBaseConfig configuration) { return algorithm.compute( graph, configuration, - progressTracker + progressTracker, + requestScopedDependencies.getTerminationFlag() ); } @@ -146,9 +154,15 @@ PathFindingResult deltaStepping(Graph graph, AllShortestPathsDeltaBaseConfig con ); var progressTracker = createProgressTracker(configuration, iterativeTask); var algorithm = DeltaStepping.of(graph, configuration, DefaultPool.INSTANCE, progressTracker); - return algorithm.compute(); + + return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, false); } + /** + * Moar resource shenanigans + * + * @see #breadthFirstSearch(org.neo4j.gds.api.Graph, org.neo4j.gds.paths.traverse.BfsBaseConfig) + */ HugeLongArray depthFirstSearch(Graph graph, DfsBaseConfig configuration) { var progressTracker = createProgressTracker(configuration, Tasks.leaf(LabelForProgressTracking.DFS.value)); @@ -184,7 +198,7 @@ SpanningTree kSpanningTree(Graph graph, KSpanningTreeBaseConfig configuration) { requestScopedDependencies.getTerminationFlag() ); - return algorithm.compute(); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true); } PathFindingResult longestPath(Graph graph, AlgoBaseConfig configuration) { @@ -200,7 +214,7 @@ PathFindingResult longestPath(Graph graph, AlgoBaseConfig configuration) { requestScopedDependencies.getTerminationFlag() ); - return algorithm.compute(); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, false); } Stream randomWalk(Graph graph, RandomWalkBaseConfig configuration) { @@ -224,7 +238,7 @@ Stream randomWalk(Graph graph, RandomWalkBaseConfig configuration) { requestScopedDependencies.getTerminationFlag() ); - return algorithm.compute(); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, false); } PathFindingResult singlePairShortestPathAStar(Graph graph, ShortestPathAStarBaseConfig configuration) { @@ -240,7 +254,7 @@ PathFindingResult singlePairShortestPathAStar(Graph graph, ShortestPathAStarBase requestScopedDependencies.getTerminationFlag() ); - return algorithm.compute(); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, false); } /** @@ -255,7 +269,7 @@ PathFindingResult singlePairShortestPathDijkstra(Graph graph, DijkstraSourceTarg Tasks.leaf(LabelForProgressTracking.Dijkstra.value, graph.relationshipCount()) ); - var dijkstra = Dijkstra.sourceTarget( + var algorithm = Dijkstra.sourceTarget( graph, configuration.sourceNode(), configuration.targetsList(), @@ -265,7 +279,7 @@ PathFindingResult singlePairShortestPathDijkstra(Graph graph, DijkstraSourceTarg requestScopedDependencies.getTerminationFlag() ); - return dijkstra.compute(); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, false); } PathFindingResult singlePairShortestPathYens(Graph graph, ShortestPathYensBaseConfig configuration) { @@ -278,7 +292,7 @@ PathFindingResult singlePairShortestPathYens(Graph graph, ShortestPathYensBaseCo yensTask ); - var yens = Yens.sourceTarget( + var algorithm = Yens.sourceTarget( graph, configuration, configuration.concurrency(), @@ -286,7 +300,7 @@ PathFindingResult singlePairShortestPathYens(Graph graph, ShortestPathYensBaseCo requestScopedDependencies.getTerminationFlag() ); - return yens.compute(); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, false); } PathFindingResult singleSourceShortestPathDijkstra(Graph graph, DijkstraBaseConfig configuration) { @@ -295,7 +309,7 @@ PathFindingResult singleSourceShortestPathDijkstra(Graph graph, DijkstraBaseConf Tasks.leaf(LabelForProgressTracking.SingleSourceDijkstra.value, graph.relationshipCount()) ); - var dijkstra = Dijkstra.singleSource( + var algorithm = Dijkstra.singleSource( graph, configuration.sourceNode(), false, @@ -304,7 +318,7 @@ PathFindingResult singleSourceShortestPathDijkstra(Graph graph, DijkstraBaseConf requestScopedDependencies.getTerminationFlag() ); - return dijkstra.compute(); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, false); } SpanningTree spanningTree(Graph graph, SpanningTreeBaseConfig configuration) { @@ -314,9 +328,12 @@ SpanningTree spanningTree(Graph graph, SpanningTreeBaseConfig configuration) { } var parameters = configuration.toParameters(); - var progressTracker = createProgressTracker(configuration, Tasks.leaf(LabelForProgressTracking.SpanningTree.value)); + var progressTracker = createProgressTracker( + configuration, + Tasks.leaf(LabelForProgressTracking.SpanningTree.value) + ); - var prim = new Prim( + var algorithm = new Prim( graph, parameters.objective(), graph.toMappedNodeId(parameters.sourceNode()), @@ -324,7 +341,7 @@ SpanningTree spanningTree(Graph graph, SpanningTreeBaseConfig configuration) { requestScopedDependencies.getTerminationFlag() ); - return prim.compute(); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true); } SteinerTreeResult steinerTree(Graph graph, SteinerTreeBaseConfig configuration) { @@ -345,7 +362,7 @@ SteinerTreeResult steinerTree(Graph graph, SteinerTreeBaseConfig configuration) Tasks.task(LabelForProgressTracking.SteinerTree.value, subtasks) ); - var steiner = new ShortestPathsSteinerAlgorithm( + var algorithm = new ShortestPathsSteinerAlgorithm( graph, mappedSourceNodeId, mappedTargetNodeIds, @@ -357,7 +374,7 @@ SteinerTreeResult steinerTree(Graph graph, SteinerTreeBaseConfig configuration) requestScopedDependencies.getTerminationFlag() ); - return steiner.compute(); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true); } public TopologicalSortResult topologicalSort(Graph graph, TopologicalSortBaseConfig configuration) { @@ -377,7 +394,7 @@ public TopologicalSortResult topologicalSort(Graph graph, TopologicalSortBaseCon requestScopedDependencies.getTerminationFlag() ); - return algorithm.compute(); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true); } private MSBFSASPAlgorithm selectAlgorithm(Graph graph, AllShortestPathsConfig configuration) { diff --git a/applications/algorithms/path-finding/src/main/java/org/neo4j/gds/applications/algorithms/pathfinding/traverse/BreadthFirstSearch.java b/applications/algorithms/path-finding/src/main/java/org/neo4j/gds/applications/algorithms/pathfinding/traverse/BreadthFirstSearch.java index 8318f11b1c..e106739346 100644 --- a/applications/algorithms/path-finding/src/main/java/org/neo4j/gds/applications/algorithms/pathfinding/traverse/BreadthFirstSearch.java +++ b/applications/algorithms/path-finding/src/main/java/org/neo4j/gds/applications/algorithms/pathfinding/traverse/BreadthFirstSearch.java @@ -20,6 +20,7 @@ package org.neo4j.gds.applications.algorithms.pathfinding.traverse; import org.neo4j.gds.api.Graph; +import org.neo4j.gds.applications.algorithms.machinery.AlgorithmMachinery; import org.neo4j.gds.collections.ha.HugeLongArray; import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker; import org.neo4j.gds.paths.traverse.Aggregator; @@ -28,12 +29,13 @@ import org.neo4j.gds.paths.traverse.ExitPredicate; import org.neo4j.gds.paths.traverse.OneHopAggregator; import org.neo4j.gds.paths.traverse.TargetExitPredicate; +import org.neo4j.gds.termination.TerminationFlag; import java.util.List; import java.util.stream.Collectors; public class BreadthFirstSearch { - public HugeLongArray compute(Graph graph, BfsBaseConfig configuration, ProgressTracker progressTracker) { + public HugeLongArray compute(Graph graph, BfsBaseConfig configuration, ProgressTracker progressTracker, TerminationFlag terminationFlag) { ExitPredicate exitFunction; Aggregator aggregatorFunction; // target node given; terminate if target is reached @@ -62,9 +64,10 @@ public HugeLongArray compute(Graph graph, BfsBaseConfig configuration, ProgressT aggregatorFunction, configuration.concurrency(), progressTracker, - configuration.maxDepth() + configuration.maxDepth(), + terminationFlag ); - return bfs.compute(); + return new AlgorithmMachinery().runAlgorithmsAndManageProgressTracker(bfs, progressTracker, true); } } diff --git a/applications/algorithms/path-finding/src/main/java/org/neo4j/gds/applications/algorithms/pathfinding/traverse/DepthFirstSearch.java b/applications/algorithms/path-finding/src/main/java/org/neo4j/gds/applications/algorithms/pathfinding/traverse/DepthFirstSearch.java index 295bcb1cd5..728a78f568 100644 --- a/applications/algorithms/path-finding/src/main/java/org/neo4j/gds/applications/algorithms/pathfinding/traverse/DepthFirstSearch.java +++ b/applications/algorithms/path-finding/src/main/java/org/neo4j/gds/applications/algorithms/pathfinding/traverse/DepthFirstSearch.java @@ -20,6 +20,7 @@ package org.neo4j.gds.applications.algorithms.pathfinding.traverse; import org.neo4j.gds.api.Graph; +import org.neo4j.gds.applications.algorithms.machinery.AlgorithmMachinery; import org.neo4j.gds.collections.ha.HugeLongArray; import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker; import org.neo4j.gds.paths.traverse.Aggregator; @@ -63,6 +64,7 @@ public HugeLongArray compute(Graph graph, DfsBaseConfig configuration, ProgressT maxDepth, progressTracker ); - return dfs.compute(); + + return new AlgorithmMachinery().runAlgorithmsAndManageProgressTracker(dfs, progressTracker, true); } } diff --git a/applications/algorithms/similarity/src/main/java/org/neo4j/gds/applications/algorithms/similarity/SimilarityAlgorithms.java b/applications/algorithms/similarity/src/main/java/org/neo4j/gds/applications/algorithms/similarity/SimilarityAlgorithms.java index a2e40da257..ad67bd29dc 100644 --- a/applications/algorithms/similarity/src/main/java/org/neo4j/gds/applications/algorithms/similarity/SimilarityAlgorithms.java +++ b/applications/algorithms/similarity/src/main/java/org/neo4j/gds/applications/algorithms/similarity/SimilarityAlgorithms.java @@ -20,6 +20,7 @@ package org.neo4j.gds.applications.algorithms.similarity; import org.neo4j.gds.api.Graph; +import org.neo4j.gds.applications.algorithms.machinery.AlgorithmMachinery; import org.neo4j.gds.applications.algorithms.machinery.ProgressTrackerCreator; import org.neo4j.gds.applications.algorithms.machinery.RequestScopedDependencies; import org.neo4j.gds.applications.algorithms.metadata.LabelForProgressTracking; @@ -50,10 +51,15 @@ import static org.neo4j.gds.applications.algorithms.metadata.LabelForProgressTracking.KNN; public class SimilarityAlgorithms { + private final AlgorithmMachinery algorithmMachinery = new AlgorithmMachinery(); + private final ProgressTrackerCreator progressTrackerCreator; private final RequestScopedDependencies requestScopedDependencies; - public SimilarityAlgorithms(ProgressTrackerCreator progressTrackerCreator, RequestScopedDependencies requestScopedDependencies) { + public SimilarityAlgorithms( + ProgressTrackerCreator progressTrackerCreator, + RequestScopedDependencies requestScopedDependencies + ) { this.progressTrackerCreator = progressTrackerCreator; this.requestScopedDependencies = requestScopedDependencies; } @@ -70,9 +76,9 @@ FilteredKnnResult filteredKnn(Graph graph, FilteredKnnBaseConfig configuration) .executor(DefaultPool.INSTANCE) .build(); - var filteredKnn = selectAlgorithmConfiguration(graph, configuration, knnContext); + var algorithm = selectAlgorithmConfiguration(graph, configuration, knnContext); - return filteredKnn.compute(); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true); } NodeSimilarityResult filteredNodeSimilarity(Graph graph, FilteredNodeSimilarityBaseConfig configuration) { @@ -98,7 +104,7 @@ NodeSimilarityResult filteredNodeSimilarity(Graph graph, FilteredNodeSimilarityB requestScopedDependencies.getTerminationFlag() ); - return algorithm.compute(); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true); } KnnResult knn(Graph graph, KnnBaseConfig configuration) { @@ -136,7 +142,7 @@ KnnResult knn(Graph graph, KnnBaseConfig configuration) { requestScopedDependencies.getTerminationFlag() ); - return algorithm.compute(); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true); } NodeSimilarityResult nodeSimilarity(Graph graph, NodeSimilarityBaseConfig configuration) { @@ -165,7 +171,7 @@ NodeSimilarityResult nodeSimilarity(Graph graph, NodeSimilarityBaseConfig config requestScopedDependencies.getTerminationFlag() ); - return algorithm.compute(); + return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true); } private Task filteredNodeSimilarityProgressTask(Graph graph, boolean runWcc) { @@ -185,9 +191,19 @@ private FilteredKnn selectAlgorithmConfiguration( KnnContext knnContext ) { if (configuration.seedTargetNodes()) { - return FilteredKnn.createWithDefaultSeeding(graph, configuration, knnContext, requestScopedDependencies.getTerminationFlag()); + return FilteredKnn.createWithDefaultSeeding( + graph, + configuration, + knnContext, + requestScopedDependencies.getTerminationFlag() + ); } - return FilteredKnn.createWithoutSeeding(graph, configuration, knnContext, requestScopedDependencies.getTerminationFlag()); + return FilteredKnn.createWithoutSeeding( + graph, + configuration, + knnContext, + requestScopedDependencies.getTerminationFlag() + ); } }