Skip to content

Commit

Permalink
Fix weighted bug progress logging
Browse files Browse the repository at this point in the history
Co-authored-by: Veselin Nikolov <[email protected]>
  • Loading branch information
IoannisPanagiotas and vnickolov committed Jan 26, 2024
1 parent 5f3705d commit de3c4b9
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 14 deletions.
40 changes: 26 additions & 14 deletions algo/src/main/java/org/neo4j/gds/traversal/RandomWalk.java
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@ public final class RandomWalk extends Algorithm<Stream<long[]>> {

private final int concurrency;
private final ExecutorService executorService;
private final RandomWalkTaskSupplier taskSupplier;
private final Graph graph;
private final long randomSeed;
private final WalkParameters walkParameters;
private final List<Long> sourceNodes;
private final ExternalTerminationFlag externalTerminationFlag;
private final BlockingQueue<long[]> walks;

Expand Down Expand Up @@ -99,15 +102,33 @@ private RandomWalk(
this.executorService = executorService;
this.walks = new ArrayBlockingQueue<>(walkBufferSize);
this.externalTerminationFlag = new ExternalTerminationFlag(this);
long randomSeed = maybeRandomSeed.orElseGet(() -> new Random().nextLong());
this.graph = graph;
this.walkParameters = walkParameters;
this.sourceNodes = sourceNodes;
this.randomSeed = maybeRandomSeed.orElseGet(() -> new Random().nextLong());
}

@Override
public Stream<long[]> compute() {
progressTracker.beginSubTask("RandomWalk");
var taskSupplier = createRandomWalkTaskSupplier();

startWalkers(
taskSupplier,
() -> progressTracker.endSubTask("RandomWalk")
);
return streamWalks(walks);
}

RandomWalkTaskSupplier createRandomWalkTaskSupplier() {
var nextNodeSupplier = RandomWalkCompanion.nextNodeSupplier(graph, sourceNodes);
RandomWalkSampler.CumulativeWeightSupplier cumulativeWeightSupplier = RandomWalkCompanion.cumulativeWeights(
graph,
concurrency,
executorService,
progressTracker
);
var nextNodeSupplier = RandomWalkCompanion.nextNodeSupplier(graph, sourceNodes);
this.taskSupplier = new RandomWalkTaskSupplier(
return new RandomWalkTaskSupplier(
graph::concurrentCopy,
nextNodeSupplier,
cumulativeWeightSupplier,
Expand All @@ -119,16 +140,7 @@ private RandomWalk(
);
}

@Override
public Stream<long[]> compute() {
progressTracker.beginSubTask("RandomWalk");
startWalkers(
() -> progressTracker.endSubTask("RandomWalk")
);
return streamWalks(walks);
}

private void startWalkers(Runnable whenCompleteAction) {
private void startWalkers(RandomWalkTaskSupplier taskSupplier, Runnable whenCompleteAction) {
var tasks = IntStream
.range(0, this.concurrency)
.mapToObj(i -> taskSupplier.get())
Expand Down
79 changes: 79 additions & 0 deletions algo/src/test/java/org/neo4j/gds/traversal/RandomWalkTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,31 @@ class ProgressTracking {
@Inject
TestGraph graph;

@GdlGraph(graphNamePrefix = "weighted")
public static final String WEIGHTED_GDL =
"CREATE " +
" (a:Node)" +
", (b:Node)" +
", (c:Node)" +
", (d:Node)" +
", (e:Node)" +
", (f:Node)" +
", (a)-[:REL{w:10}]->(b)" +
", (a)-[:REL{w:10}]->(c)" +
", (a)-[:REL{w:5}]->(d)" +
", (b)-[:REL{w:5}]->(a)" +
", (b)-[:REL{w:5}]->(e)" +
", (c)-[:REL{w:5}]->(a)" +
", (c)-[:REL{w:5}]->(d)" +
", (c)-[:REL{w:5}]->(e)" +
", (d)-[:REL{w:5}]->(a)" +
", (d)-[:REL{w:5}]->(c)" +
", (d)-[:REL{w:5}]->(e)" +
", (e)-[:REL{w:5}]->(a)";

@Inject
private TestGraph weightedGraph;

@Test
void progressLogging() throws InterruptedException {

Expand Down Expand Up @@ -517,6 +542,60 @@ void progressLogging() throws InterruptedException {
);
}

@Test
void shouldLogProgressOnWeightedGraph() throws InterruptedException {

var config = RandomWalkStreamConfigImpl.builder()
.walkLength(10)
.concurrency(4)
.walksPerNode(1000)
.walkBufferSize(1000)
.returnFactor(0.1)
.inOutFactor(100000)
.randomSeed(87L)
.build();

var fact = new RandomWalkAlgorithmFactory<RandomWalkStreamConfig>();
var log = Neo4jProxy.testLog();
var taskStore = new PerDatabaseTaskStore();

var pt = new TestProgressTracker(
fact.progressTask(weightedGraph, config),
log,
config.concurrency(),
TaskRegistryFactory.local("rw", taskStore)
);

RandomWalk randomWalk = fact.build(weightedGraph, config, pt);

assertThatNoException().isThrownBy(() -> {
var randomWalksStream = randomWalk.compute();
// Make sure to consume the stream...
assertThat(randomWalksStream).hasSize(5000);
});

awaitEmptyTaskStore(taskStore);

assertThat(log.getMessages(TestLog.INFO))
.extracting(removingThreadId())
.extracting(replaceTimings())
.containsExactly(
"RandomWalk :: Start",
"RandomWalk :: DegreeCentrality :: Start",
"RandomWalk :: DegreeCentrality 100%",
"RandomWalk :: DegreeCentrality :: Finished",
"RandomWalk :: create walks :: Start",
"RandomWalk :: create walks 16%",
"RandomWalk :: create walks 33%",
"RandomWalk :: create walks 50%",
"RandomWalk :: create walks 66%",
"RandomWalk :: create walks 83%",
"RandomWalk :: create walks 100%",
"RandomWalk :: create walks :: Finished",
"RandomWalk :: Finished"
);
}

@Test
void shouldLeaveNoTasksBehind() {
var config = ImmutableRandomWalkStreamConfig.builder().build();
Expand Down

0 comments on commit de3c4b9

Please sign in to comment.