diff --git a/core/src/main/java/org/neo4j/gds/core/concurrency/SyncBarrier.java b/core/src/main/java/org/neo4j/gds/core/concurrency/SyncBarrier.java index 7946226552..c1408d9a7f 100644 --- a/core/src/main/java/org/neo4j/gds/core/concurrency/SyncBarrier.java +++ b/core/src/main/java/org/neo4j/gds/core/concurrency/SyncBarrier.java @@ -23,11 +23,13 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.locks.ReentrantLock; public class SyncBarrier { private final AtomicInteger workerCount; private final AtomicBoolean isSyncing; + private final ReentrantLock lock; private final BackoffIdleStrategy idleStrategy; private final Runnable rejectAction; @@ -42,24 +44,41 @@ public static SyncBarrier create(Runnable rejectAction) { private SyncBarrier(Runnable rejectAction) { this.workerCount = new AtomicInteger(0); this.isSyncing = new AtomicBoolean(false); + this.lock = new ReentrantLock(true); this.idleStrategy = new BackoffIdleStrategy(); this.rejectAction = rejectAction; } public void startWorker() { - if (isSyncing.get()) { - this.rejectAction.run(); + try { + // Checking the sync flag and increment the worker count must be atomic. + // Otherwise, we could run into the situation where thread A passes + // the sync check, is paused and thread B is executing the sync() method. + // If thread A is resumed after sync() is complete, it will violate the + // sync boundary. + this.lock.lock(); + if (this.isSyncing.get()) { + this.rejectAction.run(); + } + this.workerCount.incrementAndGet(); + } finally { + this.lock.unlock(); } - workerCount.incrementAndGet(); } public void stopWorker() { - workerCount.decrementAndGet(); + this.workerCount.decrementAndGet(); } public void sync() { - this.isSyncing.set(true); - // wait for import processes to finish + try { + this.lock.lock(); + this.isSyncing.set(true); + } finally { + this.lock.unlock(); + } + + // Wait for all workers to finish. while (workerCount.get() > 0) { idleStrategy.idle(); }