Skip to content

Commit

Permalink
Termination flag for sampling algorithms - abstract NodeSampler (#9724)
Browse files Browse the repository at this point in the history
Support termination flag in sampling algorithms
  • Loading branch information
orazve authored Oct 11, 2024
1 parent f4f3d01 commit 2c19e8c
Show file tree
Hide file tree
Showing 17 changed files with 131 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -963,6 +963,7 @@ public RandomWalkSamplingResult sampleRandomWalkWithRestarts(
DatabaseId databaseId,
TaskRegistryFactory taskRegistryFactory,
UserLogRegistryFactory userLogRegistryFactory,
TerminationFlag terminationFlag,
String graphName,
String originGraphName,
Map<String, Object> configuration
Expand All @@ -972,6 +973,7 @@ public RandomWalkSamplingResult sampleRandomWalkWithRestarts(
databaseId,
taskRegistryFactory,
userLogRegistryFactory,
terminationFlag,
graphName,
originGraphName,
configuration,
Expand All @@ -985,6 +987,7 @@ public RandomWalkSamplingResult sampleCommonNeighbourAwareRandomWalk(
DatabaseId databaseId,
TaskRegistryFactory taskRegistryFactory,
UserLogRegistryFactory userLogRegistryFactory,
TerminationFlag terminationFlag,
String graphNameAsString,
String originGraphName,
Map<String, Object> configuration
Expand All @@ -994,6 +997,7 @@ public RandomWalkSamplingResult sampleCommonNeighbourAwareRandomWalk(
databaseId,
taskRegistryFactory,
userLogRegistryFactory,
terminationFlag,
graphNameAsString,
originGraphName,
configuration,
Expand Down Expand Up @@ -1104,6 +1108,7 @@ private RandomWalkSamplingResult sampleRandomWalk(
DatabaseId databaseId,
TaskRegistryFactory taskRegistryFactory,
UserLogRegistryFactory userLogRegistryFactory,
TerminationFlag terminationFlag,
String graphNameAsString,
String originGraphNameAsString,
Map<String, Object> configuration,
Expand All @@ -1125,6 +1130,7 @@ private RandomWalkSamplingResult sampleRandomWalk(
userLogRegistryFactory,
graphStore,
graphProjectConfig,
terminationFlag,
originGraphName,
graphName,
configuration,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ RandomWalkSamplingResult sampleRandomWalkWithRestarts(
DatabaseId databaseId,
TaskRegistryFactory taskRegistryFactory,
UserLogRegistryFactory userLogRegistryFactory,
TerminationFlag terminationFlag,
String graphName,
String originGraphName,
Map<String, Object> configuration
Expand All @@ -265,6 +266,7 @@ RandomWalkSamplingResult sampleCommonNeighbourAwareRandomWalk(
DatabaseId databaseId,
TaskRegistryFactory taskRegistryFactory,
UserLogRegistryFactory userLogRegistryFactory,
TerminationFlag terminationFlag,
String graphName,
String originGraphName,
Map<String, Object> configuration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.neo4j.gds.graphsampling.GraphSampleConstructor;
import org.neo4j.gds.graphsampling.RandomWalkSamplerType;
import org.neo4j.gds.logging.Log;
import org.neo4j.gds.termination.TerminationFlag;

import java.util.Map;

Expand All @@ -50,6 +51,7 @@ RandomWalkSamplingResult sample(
UserLogRegistryFactory userLogRegistryFactory,
GraphStore graphStore,
GraphProjectConfig graphProjectConfig,
TerminationFlag terminationFlag,
GraphName originGraphName,
GraphName graphName,
Map<String, Object> configuration,
Expand All @@ -73,7 +75,8 @@ RandomWalkSamplingResult sample(
samplerConfig,
graphStore,
samplerAlgorithm,
progressTracker
progressTracker,
terminationFlag
);
var sampledGraphStore = graphSampleConstructor.compute();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.neo4j.gds.extension.IdFunction;
import org.neo4j.gds.extension.Inject;
import org.neo4j.gds.logging.Log;
import org.neo4j.gds.termination.TerminationFlag;

import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -118,6 +119,7 @@ void shouldSampleRWR(Map<String, Object> mapConfiguration, long expectedNodeCoun
EmptyUserLogRegistryFactory.INSTANCE,
graphStore,
GraphProjectConfig.emptyWithName("user", "graph"),
TerminationFlag.RUNNING_TRUE,
GraphName.parse("graph"),
GraphName.parse("sample"),
mapConfiguration,
Expand Down Expand Up @@ -157,6 +159,7 @@ void shouldSampleCNARW(Map<String, Object> mapConfiguration, long expectedNodeCo
EmptyUserLogRegistryFactory.INSTANCE,
graphStore,
GraphProjectConfig.emptyWithName("user", "graph"),
TerminationFlag.RUNNING_TRUE,
GraphName.parse("graph"),
GraphName.parse("sample"),
mapConfiguration,
Expand Down Expand Up @@ -196,6 +199,7 @@ void shouldUseSingleStartNodeRWR(double samplingRatio, long expectedStartNodeCou
EmptyUserLogRegistryFactory.INSTANCE,
graphStore,
GraphProjectConfig.emptyWithName("user", "graph"),
TerminationFlag.RUNNING_TRUE,
GraphName.parse("graph"),
GraphName.parse("sample"),
Map.of(
Expand Down Expand Up @@ -239,6 +243,7 @@ void shouldUseSingleStartNodeCNARW(double samplingRatio, long expectedStartNodeC
EmptyUserLogRegistryFactory.INSTANCE,
graphStore,
GraphProjectConfig.emptyWithName("user", "graph"),
TerminationFlag.RUNNING_TRUE,
GraphName.parse("graph"),
GraphName.parse("sample"),
Map.of(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
import org.neo4j.gds.termination.TerminationFlag;

import java.util.List;
import java.util.Map;
Expand All @@ -57,18 +58,21 @@ public class GraphSampleConstructor {
private final GraphStore inputGraphStore;
private final NodesSampler nodesSampler;
private final ProgressTracker progressTracker;
private final TerminationFlag terminationFlag;

public GraphSampleConstructor(
GraphSampleAlgoConfig config,
GraphStore inputGraphStore,
NodesSampler nodesSampler,
ProgressTracker progressTracker
ProgressTracker progressTracker,
TerminationFlag terminationFlag
) {
this.config = config;
this.concurrency = config.concurrency();
this.inputGraphStore = inputGraphStore;
this.nodesSampler = nodesSampler;
this.progressTracker = progressTracker;
this.terminationFlag = terminationFlag;
}

public GraphStore compute() {
Expand All @@ -79,6 +83,8 @@ public GraphStore compute() {
config.internalRelationshipTypes(inputGraphStore),
config.relationshipWeightProperty()
);
nodesSampler.setTerminationFlag(terminationFlag);

var sampledNodesBitSet = nodesSampler.compute(inputGraph, progressTracker);

progressTracker.beginSubTask("Construct graph");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,25 @@
import org.neo4j.gds.core.utils.paged.HugeAtomicBitSet;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.termination.TerminationFlag;

public interface NodesSampler {
HugeAtomicBitSet compute(Graph inputGraph, ProgressTracker progressTracker);
public abstract class NodesSampler {
protected abstract HugeAtomicBitSet compute(
Graph inputGraph,
ProgressTracker progressTracker
);

Task progressTask(GraphStore graphStore);
protected abstract Task progressTask(GraphStore graphStore);

String progressTaskName();
protected abstract String progressTaskName();

protected volatile TerminationFlag terminationFlag = TerminationFlag.RUNNING_TRUE;

public void setTerminationFlag(TerminationFlag terminationFlag) {
this.terminationFlag = terminationFlag;
}

public TerminationFlag getTerminationFlag() {
return terminationFlag;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
*/
package org.neo4j.gds.graphsampling;

public interface RandomWalkBasedNodesSampler extends NodesSampler {
public abstract class RandomWalkBasedNodesSampler extends NodesSampler {

long startNodesCount();
public abstract long startNodesCount();

}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
import org.neo4j.gds.core.utils.partition.PartitionUtils;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.termination.TerminationFlag;

import java.util.Comparator;
import java.util.Optional;
Expand All @@ -45,7 +46,9 @@ interface Result {
LongLongHashMap histogram();
}

public static Result compute(Graph inputGraph, Concurrency concurrency, ProgressTracker progressTracker) {
public static Result compute(Graph inputGraph, Concurrency concurrency, ProgressTracker progressTracker,
TerminationFlag terminationFlag
) {
progressTracker.beginSubTask("Count node labels");
progressTracker.setSteps(inputGraph.nodeCount());

Expand All @@ -61,6 +64,7 @@ public static Result compute(Graph inputGraph, Concurrency concurrency, Progress
concurrency,
inputGraph.nodeCount(),
partition -> (Runnable) () -> {
terminationFlag.assertRunning();
var labelCount = new LongLongHashMap();
partition.consume(nodeId -> {
labelCount.addTo(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.neo4j.gds.core.concurrency.Concurrency;
import org.neo4j.gds.core.utils.paged.HugeAtomicBitSet;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.termination.TerminationFlag;

import java.util.Arrays;

Expand All @@ -40,14 +41,16 @@ public interface SeenNodes {
long totalExpectedNodes();

static SeenNodes create(
Graph inputGraph, ProgressTracker progressTracker, boolean nodeLabelStratification,
Graph inputGraph, ProgressTracker progressTracker, TerminationFlag terminationFlag,
boolean nodeLabelStratification,
Concurrency concurrency, double samplingRatio
) {
if (nodeLabelStratification) {
var nodeLabelHistogram = NodeLabelHistogram.compute(
inputGraph,
concurrency,
progressTracker
progressTracker,
terminationFlag
);

return new SeenNodes.SeenNodesByLabelSet(inputGraph, nodeLabelHistogram, samplingRatio);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.graphsampling.samplers.SeenNodes;
import org.neo4j.gds.graphsampling.samplers.rw.rwr.RandomWalkWithRestarts;
import org.neo4j.gds.termination.TerminationFlag;

import java.util.Optional;
import java.util.SplittableRandom;
Expand All @@ -41,6 +42,7 @@ public class Walker implements Runnable {
protected final Graph inputGraph;
private final double restartProbability;
protected final ProgressTracker progressTracker;
private final TerminationFlag terminationFlag;

private final LongSet startNodesUsed;

Expand All @@ -55,6 +57,7 @@ public Walker(
Graph inputGraph,
double restartProbability,
ProgressTracker progressTracker,
TerminationFlag terminationFlag,
NextNodeStrategy nextNodeStrategy
) {
this.seenNodes = seenNodes;
Expand All @@ -65,6 +68,7 @@ public Walker(
this.inputGraph = inputGraph;
this.restartProbability = restartProbability;
this.progressTracker = progressTracker;
this.terminationFlag = terminationFlag;
this.startNodesUsed = new LongHashSet();
this.nextNodeStrategy = nextNodeStrategy;
}
Expand All @@ -78,7 +82,7 @@ public void run() {
int nodesConsidered = 1;
int walksLeft = (int) Math.round(walkQualities.nodeQuality(currentStartNodePosition) * RandomWalkWithRestarts.MAX_WALKS_PER_START);

while (!seenNodes.hasSeenEnough()) {
while (!seenNodes.hasSeenEnough() && terminationFlag.running()) {
if (seenNodes.addNode(currentNode)) {
addedNodes++;
}
Expand Down Expand Up @@ -118,6 +122,7 @@ public void run() {
nodesConsidered++;
}
}
terminationFlag.assertRunning();
}

private double computeDegree(long currentNode) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.neo4j.gds.graphsampling.samplers.SeenNodes;
import org.neo4j.gds.graphsampling.samplers.rw.cnarw.CNARWNodeSamplingStrategySupplier;
import org.neo4j.gds.graphsampling.samplers.rw.rwr.RWRNodeSamplingStrategySupplier;
import org.neo4j.gds.termination.TerminationFlag;

import java.util.Optional;
import java.util.SplittableRandom;
Expand All @@ -47,7 +48,8 @@ public Runnable getWalker(
SplittableRandom split,
Graph concurrentCopy,
RandomWalkWithRestartsConfig config,
ProgressTracker progressTracker
ProgressTracker progressTracker,
TerminationFlag terminationFlag
) {
return new Walker(
seenNodes,
Expand All @@ -58,6 +60,7 @@ public Runnable getWalker(
concurrentCopy,
config.restartProbability(),
progressTracker,
terminationFlag,
nodeSamplingStrategySupplier.apply(concurrentCopy, split, totalWeights)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@
import org.neo4j.gds.core.concurrency.Concurrency;
import org.neo4j.gds.core.concurrency.ParallelUtil;
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
import org.neo4j.gds.mem.MemoryEstimation;
import org.neo4j.gds.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.paged.HugeAtomicBitSet;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
Expand All @@ -41,13 +39,15 @@
import org.neo4j.gds.graphsampling.samplers.rw.WalkQualities;
import org.neo4j.gds.graphsampling.samplers.rw.Walker;
import org.neo4j.gds.graphsampling.samplers.rw.WalkerProducer;
import org.neo4j.gds.mem.MemoryEstimation;
import org.neo4j.gds.mem.MemoryEstimations;

import java.util.Optional;
import java.util.SplittableRandom;

import static org.neo4j.gds.graphsampling.samplers.rw.RandomWalkCompanion.initializeTotalWeights;

public class CommonNeighbourAwareRandomWalk implements RandomWalkBasedNodesSampler {
public class CommonNeighbourAwareRandomWalk extends RandomWalkBasedNodesSampler {
private LongHashSet startNodesUsed;

private static final double QUALITY_THRESHOLD_BASE = 0.05;
Expand All @@ -71,6 +71,7 @@ public HugeAtomicBitSet compute(Graph inputGraph, ProgressTracker progressTracke
var seenNodes = SeenNodes.create(
inputGraph,
progressTracker,
terminationFlag,
config.nodeLabelStratification(),
concurrency,
config.samplingRatio()
Expand All @@ -93,7 +94,8 @@ public HugeAtomicBitSet compute(Graph inputGraph, ProgressTracker progressTracke
rng.split(),
inputGraph.concurrentCopy(),
config,
progressTracker
progressTracker,
terminationFlag
)
);

Expand Down
Loading

0 comments on commit 2c19e8c

Please sign in to comment.