Skip to content

Commit

Permalink
migrate estimation cli {(k)spanning, steiner} to application layer
Browse files Browse the repository at this point in the history
  • Loading branch information
lassewesth committed Nov 26, 2024
1 parent e054743 commit 0a9aa96
Show file tree
Hide file tree
Showing 21 changed files with 98 additions and 984 deletions.

This file was deleted.

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
import org.neo4j.gds.Orientation;
import org.neo4j.gds.TestProgressTracker;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.applications.algorithms.machinery.ProgressTrackerCreator;
import org.neo4j.gds.applications.algorithms.machinery.RequestScopedDependencies;
import org.neo4j.gds.applications.algorithms.pathfinding.PathFindingAlgorithms;
import org.neo4j.gds.compat.TestLog;
import org.neo4j.gds.core.concurrency.Concurrency;
import org.neo4j.gds.core.utils.progress.EmptyTaskRegistryFactory;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.warnings.EmptyUserLogRegistryFactory;
import org.neo4j.gds.extension.GdlExtension;
import org.neo4j.gds.extension.GdlGraph;
import org.neo4j.gds.extension.IdFunction;
Expand Down Expand Up @@ -268,34 +270,36 @@ void shouldWorkForComponentSmallerThanK() {

@Test
void shouldLogProgress() {
var config = KSpanningTreeBaseConfigImpl.builder().sourceNode(idFunction.of("a")).k(2).build();
var factory = new KSpanningTreeAlgorithmFactory<>();
var log = new GdsTestLog();
var progressTracker = new TestProgressTracker(
factory.progressTask(graph, config),
log,
new Concurrency(1),
EmptyTaskRegistryFactory.INSTANCE
);
factory.build(graph, config, progressTracker).compute();
var requestScopedDependencies = RequestScopedDependencies.builder()
.with(EmptyTaskRegistryFactory.INSTANCE)
.with(TerminationFlag.RUNNING_TRUE)
.with(EmptyUserLogRegistryFactory.INSTANCE)
.build();
var progressTrackerCreator = new ProgressTrackerCreator(log, requestScopedDependencies);
var pathFindingAlgorithms = new PathFindingAlgorithms(requestScopedDependencies, progressTrackerCreator);

var config = KSpanningTreeBaseConfigImpl.builder().sourceNode(idFunction.of("a")).k(2).build();
pathFindingAlgorithms.kSpanningTree(graph, config);

assertThat(log.getMessages(TestLog.INFO))
.extracting(removingThreadId())
.extracting(replaceTimings())
.containsExactly(
"KSpanningTree :: Start",
"KSpanningTree :: SpanningTree :: Start",
"KSpanningTree :: SpanningTree 30%",
"KSpanningTree :: SpanningTree 50%",
"KSpanningTree :: SpanningTree 80%",
"KSpanningTree :: SpanningTree 100%",
"KSpanningTree :: SpanningTree :: Finished",
"KSpanningTree :: Remove relationships :: Start",
"KSpanningTree :: Remove relationships 20%",
"KSpanningTree :: Remove relationships 40%",
"KSpanningTree :: Remove relationships 60%",
"KSpanningTree :: Remove relationships 100%",
"KSpanningTree :: Remove relationships :: Finished",
"KSpanningTree :: Finished"
"K Spanning Tree :: Start",
"K Spanning Tree :: SpanningTree :: Start",
"K Spanning Tree :: SpanningTree 30%",
"K Spanning Tree :: SpanningTree 50%",
"K Spanning Tree :: SpanningTree 80%",
"K Spanning Tree :: SpanningTree 100%",
"K Spanning Tree :: SpanningTree :: Finished",
"K Spanning Tree :: Remove relationships :: Start",
"K Spanning Tree :: Remove relationships 20%",
"K Spanning Tree :: Remove relationships 40%",
"K Spanning Tree :: Remove relationships 60%",
"K Spanning Tree :: Remove relationships 100%",
"K Spanning Tree :: Remove relationships :: Finished",
"K Spanning Tree :: Finished"
);
}

Expand Down
26 changes: 15 additions & 11 deletions algo/src/test/java/org/neo4j/gds/spanningtree/PrimTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,13 @@
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.neo4j.gds.Orientation;
import org.neo4j.gds.TestProgressTracker;
import org.neo4j.gds.applications.algorithms.machinery.ProgressTrackerCreator;
import org.neo4j.gds.applications.algorithms.machinery.RequestScopedDependencies;
import org.neo4j.gds.applications.algorithms.pathfinding.PathFindingAlgorithms;
import org.neo4j.gds.compat.TestLog;
import org.neo4j.gds.core.concurrency.Concurrency;
import org.neo4j.gds.core.utils.progress.EmptyTaskRegistryFactory;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.warnings.EmptyUserLogRegistryFactory;
import org.neo4j.gds.extension.GdlExtension;
import org.neo4j.gds.extension.GdlGraph;
import org.neo4j.gds.extension.IdFunction;
Expand Down Expand Up @@ -150,16 +152,18 @@ void testMinimum(String nodeId, String parentA, String parentB, String parentC,

@Test
void shouldLogProgress() {
var parameters = new SpanningTreeParameters(PrimOperators.MIN_OPERATOR, graph.toOriginalNodeId("a"));
var factory = new SpanningTreeAlgorithmFactory<>();
var log = new GdsTestLog();
var progressTracker = new TestProgressTracker(
factory.progressTask(graph),
log,
new Concurrency(1),
EmptyTaskRegistryFactory.INSTANCE
);
factory.build(graph, parameters, progressTracker).compute();
var requestScopedDependencies = RequestScopedDependencies.builder()
.with(EmptyTaskRegistryFactory.INSTANCE)
.with(TerminationFlag.RUNNING_TRUE)
.with(EmptyUserLogRegistryFactory.INSTANCE)
.build();
var progressTrackerCreator = new ProgressTrackerCreator(log, requestScopedDependencies);
var pathFindingAlgorithms = new PathFindingAlgorithms(requestScopedDependencies, progressTrackerCreator);

var config = SpanningTreeBaseConfigImpl.builder().sourceNode(graph.toOriginalNodeId("a")).build();
pathFindingAlgorithms.spanningTree(graph, config);

assertThat(log.getMessages(TestLog.INFO))
.extracting(removingThreadId())
.extracting(replaceTimings())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
package org.neo4j.gds.spanningtree;

import org.junit.jupiter.api.Test;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.applications.algorithms.pathfinding.PathFindingAlgorithms;
import org.neo4j.gds.gdl.GdlFactory;

import static org.assertj.core.api.Assertions.assertThatThrownBy;
Expand All @@ -29,9 +29,9 @@ class SpanningTreeAlgorithmFactoryTest {
@Test
void shouldThrowIfNotUndirected() {
var graph = GdlFactory.of("(a)-[:foo{cost:1.0}]->(b)").build().getUnion();
var parameters = new SpanningTreeParameters(PrimOperators.MIN_OPERATOR, 0);
var spanningTreeAlgorithmFactory = new SpanningTreeAlgorithmFactory<>();
assertThatThrownBy(() -> spanningTreeAlgorithmFactory.build(graph, parameters, ProgressTracker.NULL_TRACKER))
var pathFindingAlgorithms = new PathFindingAlgorithms(null, null);

assertThatThrownBy(() -> pathFindingAlgorithms.spanningTree(graph, null))
.hasMessageContaining("undirected");
}
}
Loading

0 comments on commit 0a9aa96

Please sign in to comment.