From 67aa074c1dac08cf431388d0c856130834af0d60 Mon Sep 17 00:00:00 2001 From: Lasse Westh-Nielsen Date: Mon, 25 Nov 2024 13:32:13 +0100 Subject: [PATCH] kill Yens and A* factories --- algo/build.gradle | 1 + .../java/org/neo4j/gds/paths/astar/AStar.java | 12 ---- .../neo4j/gds/paths/astar/AStarFactory.java | 55 --------------- .../neo4j/gds/paths/dijkstra/Dijkstra.java | 21 ------ .../gds/paths/dijkstra/DijkstraFactory.java | 13 ++-- .../org/neo4j/gds/paths/yens/YensFactory.java | 64 ----------------- .../org/neo4j/gds/paths/astar/AStarTest.java | 68 +++++++++---------- .../paths/bellmanford/BellmanFordTest.java | 3 +- .../gds/paths/delta/DeltaSteppingTest.java | 3 +- .../gds/paths/dijkstra/DijkstraTest.java | 9 ++- .../gds/paths/yens/YensParallelEdgesTest.java | 18 +++-- .../org/neo4j/gds/paths/yens/YensTest.java | 61 ++++++++--------- .../pathfinding/PathFindingAlgorithms.java | 12 +++- .../progress/tasks/TaskProgressLogger.java | 2 +- 14 files changed, 99 insertions(+), 243 deletions(-) delete mode 100644 algo/src/main/java/org/neo4j/gds/paths/astar/AStarFactory.java delete mode 100644 algo/src/main/java/org/neo4j/gds/paths/yens/YensFactory.java diff --git a/algo/build.gradle b/algo/build.gradle index fa07fc4646..b61f348995 100644 --- a/algo/build.gradle +++ b/algo/build.gradle @@ -80,4 +80,5 @@ dependencies { testImplementation project(':centrality-algorithms') testImplementation project(':node-embedding-algorithms') + testImplementation project(':path-finding-algorithms') } diff --git a/algo/src/main/java/org/neo4j/gds/paths/astar/AStar.java b/algo/src/main/java/org/neo4j/gds/paths/astar/AStar.java index 1461d6db06..ee36d80e1a 100644 --- a/algo/src/main/java/org/neo4j/gds/paths/astar/AStar.java +++ b/algo/src/main/java/org/neo4j/gds/paths/astar/AStar.java @@ -44,18 +44,6 @@ private AStar(Dijkstra dijkstra, TerminationFlag terminationFlag) { this.terminationFlag = terminationFlag; } - /** - * @deprecated Use the one with termination flag - */ - @Deprecated - public static AStar sourceTarget( - Graph graph, - ShortestPathAStarBaseConfig config, - ProgressTracker progressTracker - ) { - return sourceTarget(graph, config, progressTracker, TerminationFlag.RUNNING_TRUE); - } - public static AStar sourceTarget( Graph graph, ShortestPathAStarBaseConfig config, diff --git a/algo/src/main/java/org/neo4j/gds/paths/astar/AStarFactory.java b/algo/src/main/java/org/neo4j/gds/paths/astar/AStarFactory.java deleted file mode 100644 index 174031ca25..0000000000 --- a/algo/src/main/java/org/neo4j/gds/paths/astar/AStarFactory.java +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Copyright (c) "Neo4j" - * Neo4j Sweden AB [http://neo4j.com] - * - * This file is part of Neo4j. - * - * Neo4j is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - */ -package org.neo4j.gds.paths.astar; - -import org.neo4j.gds.GraphAlgorithmFactory; -import org.neo4j.gds.api.Graph; -import org.neo4j.gds.mem.MemoryEstimation; -import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker; -import org.neo4j.gds.core.utils.progress.tasks.Task; -import org.neo4j.gds.paths.astar.config.ShortestPathAStarBaseConfig; -import org.neo4j.gds.paths.dijkstra.DijkstraFactory; - -public class AStarFactory extends GraphAlgorithmFactory { - - @Override - public MemoryEstimation memoryEstimation(CONFIG configuration) { - return new AStarMemoryEstimateDefinition().memoryEstimation(); - } - - @Override - public Task progressTask(Graph graph, CONFIG config) { - return DijkstraFactory.dijkstraProgressTask(taskName(), graph); - } - - @Override - public String taskName() { - return "AStar"; - } - - @Override - public AStar build( - Graph graph, - CONFIG configuration, - ProgressTracker progressTracker - ) { - return AStar.sourceTarget(graph, configuration, progressTracker); - } -} diff --git a/algo/src/main/java/org/neo4j/gds/paths/dijkstra/Dijkstra.java b/algo/src/main/java/org/neo4j/gds/paths/dijkstra/Dijkstra.java index d9c26564a4..bf612de0c0 100644 --- a/algo/src/main/java/org/neo4j/gds/paths/dijkstra/Dijkstra.java +++ b/algo/src/main/java/org/neo4j/gds/paths/dijkstra/Dijkstra.java @@ -94,27 +94,6 @@ public static Dijkstra sourceTarget( ); } - /** - * @deprecated Use the other one with termination flag - */ - @Deprecated - public static Dijkstra singleSource( - Graph graph, - long originalNodeId, - boolean trackRelationships, - Optional heuristicFunction, - ProgressTracker progressTracker - ) { - return singleSource( - graph, - originalNodeId, - trackRelationships, - heuristicFunction, - progressTracker, - TerminationFlag.RUNNING_TRUE - ); - } - /** * Configure Dijkstra to compute all single-source shortest path. */ diff --git a/algo/src/main/java/org/neo4j/gds/paths/dijkstra/DijkstraFactory.java b/algo/src/main/java/org/neo4j/gds/paths/dijkstra/DijkstraFactory.java index 39132ebb7e..97598e50ae 100644 --- a/algo/src/main/java/org/neo4j/gds/paths/dijkstra/DijkstraFactory.java +++ b/algo/src/main/java/org/neo4j/gds/paths/dijkstra/DijkstraFactory.java @@ -19,7 +19,6 @@ */ package org.neo4j.gds.paths.dijkstra; -import org.jetbrains.annotations.NotNull; import org.neo4j.gds.GraphAlgorithmFactory; import org.neo4j.gds.api.Graph; import org.neo4j.gds.mem.MemoryEstimation; @@ -27,6 +26,7 @@ import org.neo4j.gds.core.utils.progress.tasks.Task; import org.neo4j.gds.core.utils.progress.tasks.Tasks; import org.neo4j.gds.paths.dijkstra.config.DijkstraBaseConfig; +import org.neo4j.gds.termination.TerminationFlag; import java.util.Optional; @@ -44,13 +44,7 @@ public String taskName() { @Override public Task progressTask(Graph graph, CONFIG config) { - return dijkstraProgressTask(taskName(), graph); - } - - - @NotNull - public static Task dijkstraProgressTask(String taskName, Graph graph) { - return Tasks.leaf(taskName, graph.relationshipCount()); + return Tasks.leaf(taskName(), graph.relationshipCount()); } public static class AllShortestPathsDijkstraFactory extends DijkstraFactory { @@ -65,7 +59,8 @@ public Dijkstra build( configuration.sourceNode(), false, Optional.empty(), - progressTracker + progressTracker, + TerminationFlag.RUNNING_TRUE ); } } diff --git a/algo/src/main/java/org/neo4j/gds/paths/yens/YensFactory.java b/algo/src/main/java/org/neo4j/gds/paths/yens/YensFactory.java deleted file mode 100644 index d6c9bdab14..0000000000 --- a/algo/src/main/java/org/neo4j/gds/paths/yens/YensFactory.java +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Copyright (c) "Neo4j" - * Neo4j Sweden AB [http://neo4j.com] - * - * This file is part of Neo4j. - * - * Neo4j is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - */ -package org.neo4j.gds.paths.yens; - -import org.neo4j.gds.GraphAlgorithmFactory; -import org.neo4j.gds.api.Graph; -import org.neo4j.gds.mem.MemoryEstimation; -import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker; -import org.neo4j.gds.core.utils.progress.tasks.Task; -import org.neo4j.gds.core.utils.progress.tasks.Tasks; -import org.neo4j.gds.paths.dijkstra.DijkstraFactory; -import org.neo4j.gds.paths.yens.config.ShortestPathYensBaseConfig; -import org.neo4j.gds.termination.TerminationFlag; - -public class YensFactory extends GraphAlgorithmFactory { - - @Override - public MemoryEstimation memoryEstimation(ShortestPathYensBaseConfig configuration) { - return new YensMemoryEstimateDefinition(configuration.k()).memoryEstimation(); - } - - @Override - public Task progressTask(Graph graph, CONFIG config) { - var initTask = DijkstraFactory.dijkstraProgressTask("Dijkstra", graph); - return Tasks.task(taskName(), initTask, Tasks.leaf("Path growing", config.k() - 1)); - } - - @Override - public String taskName() { - return "Yens"; - } - - @Override - public Yens build( - Graph graph, - CONFIG configuration, - ProgressTracker progressTracker - ) { - return Yens.sourceTarget( - graph, - configuration, - configuration.concurrency(), - progressTracker, - TerminationFlag.RUNNING_TRUE - ); - } -} diff --git a/algo/src/test/java/org/neo4j/gds/paths/astar/AStarTest.java b/algo/src/test/java/org/neo4j/gds/paths/astar/AStarTest.java index 7f16b9cef0..4485e6c239 100644 --- a/algo/src/test/java/org/neo4j/gds/paths/astar/AStarTest.java +++ b/algo/src/test/java/org/neo4j/gds/paths/astar/AStarTest.java @@ -22,24 +22,24 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.CsvSource; -import org.neo4j.gds.TestProgressTracker; +import org.neo4j.gds.applications.algorithms.machinery.ProgressTrackerCreator; +import org.neo4j.gds.applications.algorithms.machinery.RequestScopedDependencies; +import org.neo4j.gds.applications.algorithms.pathfinding.PathFindingAlgorithms; import org.neo4j.gds.compat.TestLog; -import org.neo4j.gds.core.concurrency.Concurrency; import org.neo4j.gds.core.utils.progress.EmptyTaskRegistryFactory; import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker; +import org.neo4j.gds.core.utils.warnings.EmptyUserLogRegistryFactory; import org.neo4j.gds.extension.GdlExtension; import org.neo4j.gds.extension.GdlGraph; import org.neo4j.gds.extension.Inject; import org.neo4j.gds.extension.TestGraph; import org.neo4j.gds.logging.GdsTestLog; import org.neo4j.gds.paths.astar.config.ShortestPathAStarStreamConfigImpl; +import org.neo4j.gds.termination.TerminationFlag; -import java.util.List; -import java.util.Set; -import java.util.concurrent.atomic.AtomicLong; - +import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.neo4j.gds.assertj.Extractors.removingThreadId; import static org.neo4j.gds.paths.PathTestUtil.expected; @GdlExtension @@ -114,7 +114,7 @@ void sourceTarget() { .build(); var path = AStar - .sourceTarget(graph, config, ProgressTracker.NULL_TRACKER) + .sourceTarget(graph, config, ProgressTracker.NULL_TRACKER, TerminationFlag.RUNNING_TRUE) .compute() .findFirst() .get(); @@ -124,38 +124,36 @@ void sourceTarget() { @Test void shouldLogProgress() { + var log = new GdsTestLog(); + var requestScopedDependencies = RequestScopedDependencies.builder() + .with(EmptyTaskRegistryFactory.INSTANCE) + .with(TerminationFlag.RUNNING_TRUE) + .with(EmptyUserLogRegistryFactory.INSTANCE) + .build(); + var progressTrackerCreator = new ProgressTrackerCreator(log, requestScopedDependencies); + var pathFindingAlgorithms = new PathFindingAlgorithms(requestScopedDependencies, progressTrackerCreator); var config = defaultSourceTargetConfigBuilder() .sourceNode(graph.toOriginalNodeId("nA")) .targetNode(graph.toOriginalNodeId("nX")) .build(); - - var progressTask = new AStarFactory<>().progressTask(graph, config); - var log = new GdsTestLog(); - var progressTracker = new TestProgressTracker(progressTask, log, new Concurrency(1), EmptyTaskRegistryFactory.INSTANCE); - - AStar.sourceTarget(graph, config, progressTracker) - .compute() - .pathSet(); - - List progresses = progressTracker.getProgresses(); - assertEquals(1, progresses.size()); - assertEquals(9, progresses.get(0).get()); - - assertTrue(log.containsMessage(TestLog.INFO, "AStar :: Start")); - assertTrue(log.containsMessage(TestLog.INFO, "AStar 5%")); - assertTrue(log.containsMessage(TestLog.INFO, "AStar 17%")); - assertTrue(log.containsMessage(TestLog.INFO, "AStar 23%")); - assertTrue(log.containsMessage(TestLog.INFO, "AStar 29%")); - assertTrue(log.containsMessage(TestLog.INFO, "AStar 35%")); - assertTrue(log.containsMessage(TestLog.INFO, "AStar 41%")); - assertTrue(log.containsMessage(TestLog.INFO, "AStar 47%")); - assertTrue(log.containsMessage(TestLog.INFO, "AStar 52%")); - assertTrue(log.containsMessage(TestLog.INFO, "AStar :: Finished")); - - // no duplicate entries in progress logger - var logMessages = log.getMessages(TestLog.INFO); - assertEquals(Set.copyOf(logMessages).size(), logMessages.size()); + pathFindingAlgorithms.singlePairShortestPathAStar(graph, config).pathSet(); + + assertThat(log.getMessages(TestLog.INFO)) + .extracting(removingThreadId()) + .contains( + "AStar :: Start", + "AStar 5%", + "AStar 17%", + "AStar 23%", + "AStar 29%", + "AStar 35%", + "AStar 41%", + "AStar 47%", + "AStar 52%", + "AStar 100%", + "AStar :: Finished" + ); } // Validated against https://www.vcalc.com/wiki/vCalc/Haversine+-+Distance diff --git a/algo/src/test/java/org/neo4j/gds/paths/bellmanford/BellmanFordTest.java b/algo/src/test/java/org/neo4j/gds/paths/bellmanford/BellmanFordTest.java index 93cc4e1c9c..4dedd95750 100644 --- a/algo/src/test/java/org/neo4j/gds/paths/bellmanford/BellmanFordTest.java +++ b/algo/src/test/java/org/neo4j/gds/paths/bellmanford/BellmanFordTest.java @@ -37,6 +37,7 @@ import org.neo4j.gds.logging.GdsTestLog; import org.neo4j.gds.paths.delta.config.AllShortestPathsDeltaStreamConfigImpl; import org.neo4j.gds.paths.dijkstra.Dijkstra; +import org.neo4j.gds.termination.TerminationFlag; import java.util.Optional; @@ -287,7 +288,7 @@ void shouldGiveSameResultsAsDijkstra() { .shortestPaths(); var dijkstraAlgo = Dijkstra - .singleSource(newGraph, config.sourceNode(), true, Optional.empty(), ProgressTracker.NULL_TRACKER) + .singleSource(newGraph, config.sourceNode(), true, Optional.empty(), ProgressTracker.NULL_TRACKER, TerminationFlag.RUNNING_TRUE) .compute(); double[] bellman = new double[nodeCount]; diff --git a/algo/src/test/java/org/neo4j/gds/paths/delta/DeltaSteppingTest.java b/algo/src/test/java/org/neo4j/gds/paths/delta/DeltaSteppingTest.java index 6146be1284..ec668e5566 100644 --- a/algo/src/test/java/org/neo4j/gds/paths/delta/DeltaSteppingTest.java +++ b/algo/src/test/java/org/neo4j/gds/paths/delta/DeltaSteppingTest.java @@ -46,6 +46,7 @@ import org.neo4j.gds.logging.GdsTestLog; import org.neo4j.gds.paths.delta.config.AllShortestPathsDeltaStreamConfigImpl; import org.neo4j.gds.paths.dijkstra.Dijkstra; +import org.neo4j.gds.termination.TerminationFlag; import java.util.List; import java.util.Optional; @@ -378,7 +379,7 @@ void shouldGiveSameResultsAsDijkstra() { ).compute(); var dijkstraAlgo = Dijkstra - .singleSource(newGraph, config.sourceNode(), true, Optional.empty(), ProgressTracker.NULL_TRACKER) + .singleSource(newGraph, config.sourceNode(), true, Optional.empty(), ProgressTracker.NULL_TRACKER, TerminationFlag.RUNNING_TRUE) .compute(); double[] delta = new double[nodeCount]; diff --git a/algo/src/test/java/org/neo4j/gds/paths/dijkstra/DijkstraTest.java b/algo/src/test/java/org/neo4j/gds/paths/dijkstra/DijkstraTest.java index 1f3548eb39..2a6b0b3754 100644 --- a/algo/src/test/java/org/neo4j/gds/paths/dijkstra/DijkstraTest.java +++ b/algo/src/test/java/org/neo4j/gds/paths/dijkstra/DijkstraTest.java @@ -236,7 +236,8 @@ void singleSource() { config.sourceNode(), false, Optional.empty(), - ProgressTracker.NULL_TRACKER + ProgressTracker.NULL_TRACKER, + TerminationFlag.RUNNING_TRUE ).compute() .pathSet(); @@ -265,7 +266,8 @@ void singleSourceFromDisconnectedNode() { config.sourceNode(), false, Optional.empty(), - ProgressTracker.NULL_TRACKER + ProgressTracker.NULL_TRACKER, + TerminationFlag.RUNNING_TRUE ) .compute() .pathSet(); @@ -380,7 +382,8 @@ void singleSource() { config.sourceNode(), false, Optional.empty(), - ProgressTracker.NULL_TRACKER + ProgressTracker.NULL_TRACKER, + TerminationFlag.RUNNING_TRUE ) .compute() .pathSet(); diff --git a/algo/src/test/java/org/neo4j/gds/paths/yens/YensParallelEdgesTest.java b/algo/src/test/java/org/neo4j/gds/paths/yens/YensParallelEdgesTest.java index e766940dc7..5a94e74bb3 100644 --- a/algo/src/test/java/org/neo4j/gds/paths/yens/YensParallelEdgesTest.java +++ b/algo/src/test/java/org/neo4j/gds/paths/yens/YensParallelEdgesTest.java @@ -21,24 +21,24 @@ import org.junit.jupiter.api.Test; import org.neo4j.gds.api.Graph; +import org.neo4j.gds.applications.algorithms.machinery.RequestScopedDependencies; +import org.neo4j.gds.applications.algorithms.pathfinding.PathFindingAlgorithms; import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker; import org.neo4j.gds.extension.GdlExtension; import org.neo4j.gds.extension.GdlGraph; import org.neo4j.gds.extension.IdFunction; import org.neo4j.gds.extension.Inject; import org.neo4j.gds.paths.yens.config.ShortestPathYensStreamConfigImpl; +import org.neo4j.gds.termination.TerminationFlag; import static org.assertj.core.api.Assertions.assertThat; @GdlExtension class YensParallelEdgesTest { - static ShortestPathYensStreamConfigImpl.Builder defaultSourceTargetConfigBuilder() { - return ShortestPathYensStreamConfigImpl.builder() - .concurrency(1); + return ShortestPathYensStreamConfigImpl.builder().concurrency(1); } - @GdlGraph private static final String DB_CYPHER = "CREATE" + @@ -64,16 +64,20 @@ static ShortestPathYensStreamConfigImpl.Builder defaultSourceTargetConfigBuilder @Inject private IdFunction idFunction; - @Test void shouldWorkWithParallelEdges() { + var requestScopedDependencies = RequestScopedDependencies.builder() + .with(TerminationFlag.RUNNING_TRUE) + .build(); + var pathFindingAlgorithms = new PathFindingAlgorithms(requestScopedDependencies, null); + var config = defaultSourceTargetConfigBuilder() .sourceNode(idFunction.of("a")) .targetNode(idFunction.of("d")) .k(9) .build(); - var yens = new YensFactory<>().build(graph, config, ProgressTracker.NULL_TRACKER); - var result = yens.compute(); + var result = pathFindingAlgorithms.singlePairShortestPathYens(graph, config, ProgressTracker.NULL_TRACKER); + var associatedCosts = result.pathSet().stream().mapToInt(path -> (int) path.totalCost()).toArray(); assertThat(associatedCosts.length).isEqualTo(9); assertThat(associatedCosts).doesNotContain(200); //paths are 1 + (1..3)+(1..3) diff --git a/algo/src/test/java/org/neo4j/gds/paths/yens/YensTest.java b/algo/src/test/java/org/neo4j/gds/paths/yens/YensTest.java index 1a82118691..9f17153c32 100644 --- a/algo/src/test/java/org/neo4j/gds/paths/yens/YensTest.java +++ b/algo/src/test/java/org/neo4j/gds/paths/yens/YensTest.java @@ -26,12 +26,14 @@ import org.junit.jupiter.api.TestInstance; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; -import org.neo4j.gds.TestProgressTracker; +import org.neo4j.gds.applications.algorithms.machinery.ProgressTrackerCreator; +import org.neo4j.gds.applications.algorithms.machinery.RequestScopedDependencies; +import org.neo4j.gds.applications.algorithms.pathfinding.PathFindingAlgorithms; import org.neo4j.gds.compat.TestLog; import org.neo4j.gds.core.Aggregation; -import org.neo4j.gds.core.concurrency.Concurrency; import org.neo4j.gds.core.utils.progress.EmptyTaskRegistryFactory; import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker; +import org.neo4j.gds.core.utils.warnings.EmptyUserLogRegistryFactory; import org.neo4j.gds.extension.GdlExtension; import org.neo4j.gds.extension.GdlGraph; import org.neo4j.gds.extension.IdFunction; @@ -58,13 +60,10 @@ @GdlExtension class YensTest { - static ShortestPathYensStreamConfigImpl.Builder defaultSourceTargetConfigBuilder(int concurrency) { - return ShortestPathYensStreamConfigImpl.builder() - .concurrency(concurrency); + return ShortestPathYensStreamConfigImpl.builder().concurrency(concurrency); } - // https://en.wikipedia.org/wiki/Yen%27s_algorithm#/media/File:Yen's_K-Shortest_Path_Algorithm,_K=3,_A_to_F.gif @GdlGraph(aggregation = Aggregation.SINGLE) private static final String DB_CYPHER = @@ -144,26 +143,26 @@ static Stream> pathInput() { @ParameterizedTest @MethodSource("pathInput") void compute(Collection expectedPaths) { - assertResult(graph, expectedPaths, false, 4); + assertResult(graph, expectedPaths, false); } @Test void shouldLogProgress() { - int k = 3; + var log = new GdsTestLog(); + var requestScopedDependencies = RequestScopedDependencies.builder() + .with(EmptyTaskRegistryFactory.INSTANCE) + .with(TerminationFlag.RUNNING_TRUE) + .with(EmptyUserLogRegistryFactory.INSTANCE) + .build(); + var progressTrackerCreator = new ProgressTrackerCreator(log, requestScopedDependencies); + var pathFindingAlgorithms = new PathFindingAlgorithms(requestScopedDependencies, progressTrackerCreator); var config = defaultSourceTargetConfigBuilder(1) .sourceNode(graph.toOriginalNodeId("c")) .targetNode(graph.toOriginalNodeId("h")) - .k(k) + .k(3) .build(); - - var progressTask = new YensFactory<>().progressTask(graph, config); - var log = new GdsTestLog(); - var progressTracker = new TestProgressTracker(progressTask, log, new Concurrency(1), EmptyTaskRegistryFactory.INSTANCE); - - Yens.sourceTarget(graph, config, config.concurrency(), progressTracker, TerminationFlag.RUNNING_TRUE) - .compute() - .pathSet(); + pathFindingAlgorithms.singlePairShortestPathYens(graph, config).pathSet(); assertThat(log.getMessages(TestLog.INFO)) .extracting(removingThreadId()) @@ -187,21 +186,21 @@ void shouldLogProgress() { @Test void shouldLogProgressIfNothingToDo() { - int k = 3; + var log = new GdsTestLog(); + var requestScopedDependencies = RequestScopedDependencies.builder() + .with(EmptyTaskRegistryFactory.INSTANCE) + .with(TerminationFlag.RUNNING_TRUE) + .with(EmptyUserLogRegistryFactory.INSTANCE) + .build(); + var progressTrackerCreator = new ProgressTrackerCreator(log, requestScopedDependencies); + var pathFindingAlgorithms = new PathFindingAlgorithms(requestScopedDependencies, progressTrackerCreator); var config = defaultSourceTargetConfigBuilder(1) .sourceNode(graph.toOriginalNodeId("z")) .targetNode(graph.toOriginalNodeId("h")) - .k(k) + .k(3) .build(); - - var progressTask = new YensFactory<>().progressTask(graph, config); - var log = new GdsTestLog(); - var progressTracker = new TestProgressTracker(progressTask, log, new Concurrency(1), EmptyTaskRegistryFactory.INSTANCE); - - Yens.sourceTarget(graph, config, config.concurrency(), progressTracker, TerminationFlag.RUNNING_TRUE) - .compute() - .pathSet(); + pathFindingAlgorithms.singlePairShortestPathYens(graph, config).pathSet(); assertThat(log.getMessages(TestLog.INFO)) .extracting(removingThreadId()) @@ -215,12 +214,10 @@ void shouldLogProgressIfNothingToDo() { ); } - private static void assertResult( TestGraph graph, Collection expectedPaths, - boolean trackRelationships, - int concurrency + boolean trackRelationships ) { var expectedPathResults = expectedPathResults(graph::toMappedNodeId, expectedPaths, trackRelationships); @@ -235,7 +232,7 @@ private static void assertResult( throw new IllegalArgumentException("All expected paths must have the same source and target nodes."); } - var config = defaultSourceTargetConfigBuilder(concurrency) + var config = defaultSourceTargetConfigBuilder(4) .sourceNode(graph.toOriginalNodeId(firstResult.sourceNode())) .targetNode(graph.toOriginalNodeId(firstResult.targetNode())) .k(expectedPathResults.size()) @@ -359,7 +356,7 @@ Stream> pathInput() { @ParameterizedTest @MethodSource("pathInput") void compute(Collection expectedPaths) { - assertResult(graph, expectedPaths, true, 4); + assertResult(graph, expectedPaths, true); } } } diff --git a/applications/algorithms/path-finding/src/main/java/org/neo4j/gds/applications/algorithms/pathfinding/PathFindingAlgorithms.java b/applications/algorithms/path-finding/src/main/java/org/neo4j/gds/applications/algorithms/pathfinding/PathFindingAlgorithms.java index 8fc41a2d0a..36fe494725 100644 --- a/applications/algorithms/path-finding/src/main/java/org/neo4j/gds/applications/algorithms/pathfinding/PathFindingAlgorithms.java +++ b/applications/algorithms/path-finding/src/main/java/org/neo4j/gds/applications/algorithms/pathfinding/PathFindingAlgorithms.java @@ -289,7 +289,7 @@ HugeAtomicLongArray randomWalkCountingNodeVisits(Graph graph, RandomWalkBaseConf return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true); } - PathFindingResult singlePairShortestPathAStar(Graph graph, ShortestPathAStarBaseConfig configuration) { + public PathFindingResult singlePairShortestPathAStar(Graph graph, ShortestPathAStarBaseConfig configuration) { var progressTracker = createProgressTracker( configuration, Tasks.leaf(AlgorithmLabel.AStar.asString(), graph.relationshipCount()) @@ -330,7 +330,7 @@ PathFindingResult singlePairShortestPathDijkstra(Graph graph, DijkstraSourceTarg return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, false); } - PathFindingResult singlePairShortestPathYens(Graph graph, ShortestPathYensBaseConfig configuration) { + public PathFindingResult singlePairShortestPathYens(Graph graph, ShortestPathYensBaseConfig configuration) { var initialTask = Tasks.leaf(AlgorithmLabel.Dijkstra.asString(), graph.relationshipCount()); var pathGrowingTask = Tasks.leaf("Path growing", configuration.k() - 1); var yensTask = Tasks.task(AlgorithmLabel.Yens.asString(), initialTask, pathGrowingTask); @@ -340,6 +340,14 @@ PathFindingResult singlePairShortestPathYens(Graph graph, ShortestPathYensBaseCo yensTask ); + return singlePairShortestPathYens(graph, configuration, progressTracker); + } + + public PathFindingResult singlePairShortestPathYens( + Graph graph, + ShortestPathYensBaseConfig configuration, + ProgressTracker progressTracker + ) { var algorithm = Yens.sourceTarget( graph, configuration, diff --git a/progress-tracking/src/main/java/org/neo4j/gds/core/utils/progress/tasks/TaskProgressLogger.java b/progress-tracking/src/main/java/org/neo4j/gds/core/utils/progress/tasks/TaskProgressLogger.java index 532bc6edc8..f72797285b 100644 --- a/progress-tracking/src/main/java/org/neo4j/gds/core/utils/progress/tasks/TaskProgressLogger.java +++ b/progress-tracking/src/main/java/org/neo4j/gds/core/utils/progress/tasks/TaskProgressLogger.java @@ -27,7 +27,7 @@ import static org.neo4j.gds.utils.StringFormatting.formatWithLocale; -class TaskProgressLogger extends BatchingProgressLogger { +public class TaskProgressLogger extends BatchingProgressLogger { private final Task baseTask; private final TaskVisitor loggingLeafTaskVisitor;