diff --git a/rayon-core/src/registry.rs b/rayon-core/src/registry.rs index d30f815bd..0e86bd40f 100644 --- a/rayon-core/src/registry.rs +++ b/rayon-core/src/registry.rs @@ -217,6 +217,36 @@ fn default_global_registry() -> Result, ThreadPoolBuildError> { result } +// This is used to temporarily overwrite the current registry. +// +// This either null, a pointer to the global registry if it was +// ever used to access the global registry or a pointer to a +// registry which is temporarily made current because the current +// thread is not a worker thread but is running a scope associated +// to a specific thread pool. +thread_local! { + static CURRENT_REGISTRY: Cell<*const Arc> = const { Cell::new(ptr::null()) }; +} + +#[cold] +fn set_current_registry_to_global_registry() -> *const Arc { + let global = global_registry(); + + CURRENT_REGISTRY.with(|current_registry| current_registry.set(global)); + + global +} + +fn current_registry() -> *const Arc { + let mut current = CURRENT_REGISTRY.with(Cell::get); + + if current.is_null() { + current = set_current_registry_to_global_registry(); + } + + current +} + struct Terminator<'a>(&'a Arc); impl<'a> Drop for Terminator<'a> { @@ -315,7 +345,7 @@ impl Registry { unsafe { let worker_thread = WorkerThread::current(); let registry = if worker_thread.is_null() { - global_registry() + &*current_registry() } else { &(*worker_thread).registry }; @@ -323,6 +353,39 @@ impl Registry { } } + /// Optionally install a specific registry as the current one. + /// + /// This is used when a thread which is not a worker executes + /// a scope which should use the specific thread pool instead of + /// the global one. + pub(super) fn with_current(registry: Option<&Arc>, f: F) -> R + where + F: FnOnce() -> R, + { + struct Guard { + current: *const Arc, + } + + impl Guard { + fn new(registry: &Arc) -> Self { + let current = + CURRENT_REGISTRY.with(|current_registry| current_registry.replace(registry)); + + Self { current } + } + } + + impl Drop for Guard { + fn drop(&mut self) { + CURRENT_REGISTRY.with(|current_registry| current_registry.set(self.current)); + } + } + + let _guard = registry.map(Guard::new); + + f() + } + /// Returns the number of threads in the current registry. This /// is better than `Registry::current().num_threads()` because it /// avoids incrementing the `Arc`. diff --git a/rayon-core/src/scope/mod.rs b/rayon-core/src/scope/mod.rs index 1d8732fea..fa0b36d22 100644 --- a/rayon-core/src/scope/mod.rs +++ b/rayon-core/src/scope/mod.rs @@ -416,9 +416,11 @@ pub(crate) fn do_in_place_scope<'scope, OP, R>(registry: Option<&Arc>, where OP: FnOnce(&Scope<'scope>) -> R, { - let thread = unsafe { WorkerThread::current().as_ref() }; - let scope = Scope::<'scope>::new(thread, registry); - scope.base.complete(thread, || op(&scope)) + Registry::with_current(registry, || { + let thread = unsafe { WorkerThread::current().as_ref() }; + let scope = Scope::<'scope>::new(thread, registry); + scope.base.complete(thread, || op(&scope)) + }) } /// Creates a "fork-join" scope `s` with FIFO order, and invokes the @@ -453,9 +455,11 @@ pub(crate) fn do_in_place_scope_fifo<'scope, OP, R>(registry: Option<&Arc) -> R, { - let thread = unsafe { WorkerThread::current().as_ref() }; - let scope = ScopeFifo::<'scope>::new(thread, registry); - scope.base.complete(thread, || op(&scope)) + Registry::with_current(registry, || { + let thread = unsafe { WorkerThread::current().as_ref() }; + let scope = ScopeFifo::<'scope>::new(thread, registry); + scope.base.complete(thread, || op(&scope)) + }) } impl<'scope> Scope<'scope> {