Skip to content

Commit

Permalink
Allow to temporarily set the current registry even if it is not assoc…
Browse files Browse the repository at this point in the history
…iated with a worker thread
  • Loading branch information
adamreichold committed May 12, 2024
1 parent b3bd4bc commit 301b603
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 12 deletions.
71 changes: 67 additions & 4 deletions rayon-core/src/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ static THE_REGISTRY_SET: Once = Once::new();
/// Starts the worker threads (if that has not already happened). If
/// initialization has not already occurred, use the default
/// configuration.
pub(super) fn global_registry() -> &'static Arc<Registry> {
fn global_registry() -> &'static Arc<Registry> {
set_global_registry(default_global_registry)
.or_else(|err| unsafe { THE_REGISTRY.as_ref().ok_or(err) })
.expect("The global thread pool has not been initialized.")
Expand Down Expand Up @@ -217,6 +217,36 @@ fn default_global_registry() -> Result<Arc<Registry>, ThreadPoolBuildError> {
result
}

// This is used to temporarily overwrite the current registry.
//
// This either null, a pointer to the global registry if it was
// ever used to access the global registry or a pointer to a
// registry which is temporarily made current because the current
// thread is not a worker thread but is running a scope associated
// to a specific thread pool.
thread_local! {
static CURRENT_REGISTRY: Cell<*const Arc<Registry>> = const { Cell::new(ptr::null()) };
}

#[cold]
fn set_current_registry_to_global_registry() -> *const Arc<Registry> {
let global = global_registry();

CURRENT_REGISTRY.with(|current_registry| current_registry.set(global));

global
}

pub(super) fn current_registry() -> *const Arc<Registry> {
let mut current = CURRENT_REGISTRY.with(Cell::get);

if current.is_null() {
current = set_current_registry_to_global_registry();
}

current
}

struct Terminator<'a>(&'a Arc<Registry>);

impl<'a> Drop for Terminator<'a> {
Expand Down Expand Up @@ -315,22 +345,55 @@ impl Registry {
unsafe {
let worker_thread = WorkerThread::current();
let registry = if worker_thread.is_null() {
global_registry()
&*current_registry()
} else {
&(*worker_thread).registry
};
Arc::clone(registry)
}
}

/// Optionally install a specific registry as the current one.
///
/// This is used when a thread which is not a worker executes
/// a scope which should use the specific thread pool instead of
/// the global one.
pub(super) fn with_current<F, R>(registry: Option<&Arc<Registry>>, f: F) -> R
where
F: FnOnce() -> R,
{
struct Guard {
current: *const Arc<Registry>,
}

impl Guard {
fn new(registry: &Arc<Registry>) -> Self {
let current =
CURRENT_REGISTRY.with(|current_registry| current_registry.replace(registry));

Self { current }
}
}

impl Drop for Guard {
fn drop(&mut self) {
CURRENT_REGISTRY.with(|current_registry| current_registry.set(self.current));
}
}

let _guard = registry.map(Guard::new);

f()
}

/// Returns the number of threads in the current registry. This
/// is better than `Registry::current().num_threads()` because it
/// avoids incrementing the `Arc`.
pub(super) fn current_num_threads() -> usize {
unsafe {
let worker_thread = WorkerThread::current();
if worker_thread.is_null() {
global_registry().num_threads()
(*current_registry()).num_threads()
} else {
(*worker_thread).registry.num_threads()
}
Expand Down Expand Up @@ -946,7 +1009,7 @@ where
// invalidated until we return.
op(&*owner_thread, false)
} else {
global_registry().in_worker(op)
(*current_registry()).in_worker(op)
}
}
}
Expand Down
20 changes: 12 additions & 8 deletions rayon-core/src/scope/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
use crate::broadcast::BroadcastContext;
use crate::job::{ArcJob, HeapJob, JobFifo, JobRef};
use crate::latch::{CountLatch, Latch};
use crate::registry::{global_registry, in_worker, Registry, WorkerThread};
use crate::registry::{current_registry, in_worker, Registry, WorkerThread};
use crate::unwind;
use std::any::Any;
use std::fmt;
Expand Down Expand Up @@ -416,9 +416,11 @@ pub(crate) fn do_in_place_scope<'scope, OP, R>(registry: Option<&Arc<Registry>>,
where
OP: FnOnce(&Scope<'scope>) -> R,
{
let thread = unsafe { WorkerThread::current().as_ref() };
let scope = Scope::<'scope>::new(thread, registry);
scope.base.complete(thread, || op(&scope))
Registry::with_current(registry, || {
let thread = unsafe { WorkerThread::current().as_ref() };
let scope = Scope::<'scope>::new(thread, registry);
scope.base.complete(thread, || op(&scope))
})
}

/// Creates a "fork-join" scope `s` with FIFO order, and invokes the
Expand Down Expand Up @@ -453,9 +455,11 @@ pub(crate) fn do_in_place_scope_fifo<'scope, OP, R>(registry: Option<&Arc<Regist
where
OP: FnOnce(&ScopeFifo<'scope>) -> R,
{
let thread = unsafe { WorkerThread::current().as_ref() };
let scope = ScopeFifo::<'scope>::new(thread, registry);
scope.base.complete(thread, || op(&scope))
Registry::with_current(registry, || {
let thread = unsafe { WorkerThread::current().as_ref() };
let scope = ScopeFifo::<'scope>::new(thread, registry);
scope.base.complete(thread, || op(&scope))
})
}

impl<'scope> Scope<'scope> {
Expand Down Expand Up @@ -625,7 +629,7 @@ impl<'scope> ScopeBase<'scope> {
fn new(owner: Option<&WorkerThread>, registry: Option<&Arc<Registry>>) -> Self {
let registry = registry.unwrap_or_else(|| match owner {
Some(owner) => owner.registry(),
None => global_registry(),
None => unsafe { &*current_registry() },
});

ScopeBase {
Expand Down

0 comments on commit 301b603

Please sign in to comment.