diff --git a/tokio/src/runtime/tests/loom_current_thread.rs b/tokio/src/runtime/tests/loom_current_thread.rs index edda6e49954..fd0a44314f8 100644 --- a/tokio/src/runtime/tests/loom_current_thread.rs +++ b/tokio/src/runtime/tests/loom_current_thread.rs @@ -1,6 +1,6 @@ mod yield_now; -use crate::loom::sync::atomic::AtomicUsize; +use crate::loom::sync::atomic::{AtomicUsize, Ordering}; use crate::loom::sync::Arc; use crate::loom::thread; use crate::runtime::{Builder, Runtime}; @@ -9,7 +9,7 @@ use crate::task; use std::future::Future; use std::pin::Pin; use std::sync::atomic::Ordering::{Acquire, Release}; -use std::task::{Context, Poll}; +use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker}; fn assert_at_most_num_polls(rt: Arc, at_most_polls: usize) { let (tx, rx) = oneshot::channel(); @@ -106,6 +106,60 @@ fn assert_no_unnecessary_polls() { }); } +#[test] +fn drop_jh_during_schedule() { + unsafe fn waker_clone(ptr: *const ()) -> RawWaker { + let atomic = unsafe { &*(ptr as *const AtomicUsize) }; + atomic.fetch_add(1, Ordering::Relaxed); + RawWaker::new(ptr, &VTABLE) + } + unsafe fn waker_drop(ptr: *const ()) { + let atomic = unsafe { &*(ptr as *const AtomicUsize) }; + atomic.fetch_sub(1, Ordering::Relaxed); + } + unsafe fn waker_nop(_ptr: *const ()) {} + + static VTABLE: RawWakerVTable = + RawWakerVTable::new(waker_clone, waker_drop, waker_nop, waker_drop); + + loom::model(|| { + let rt = Builder::new_current_thread().build().unwrap(); + + let mut jh = rt.spawn(async {}); + // Using AbortHandle to increment task refcount. This ensures that the waker is not + // destroyed due to the refcount hitting zero. + let task_refcnt = jh.abort_handle(); + + let waker_refcnt = AtomicUsize::new(1); + { + // Set up the join waker. + use std::future::Future; + use std::pin::Pin; + + // SAFETY: Before `waker_refcnt` goes out of scope, this test asserts that the refcnt + // has dropped to zero. + let join_waker = unsafe { + Waker::from_raw(RawWaker::new( + (&waker_refcnt) as *const AtomicUsize as *const (), + &VTABLE, + )) + }; + + assert!(Pin::new(&mut jh) + .poll(&mut Context::from_waker(&join_waker)) + .is_pending()); + } + assert_eq!(waker_refcnt.load(Ordering::Relaxed), 1); + + let bg_thread = loom::thread::spawn(move || drop(jh)); + rt.block_on(crate::task::yield_now()); + bg_thread.join().unwrap(); + + assert_eq!(waker_refcnt.load(Ordering::Relaxed), 0); + drop(task_refcnt); + }); +} + struct BlockedFuture { rx: Receiver<()>, num_polls: Arc, diff --git a/tokio/src/runtime/tests/loom_multi_thread.rs b/tokio/src/runtime/tests/loom_multi_thread.rs index e2706e65c65..ddd14b7fb3f 100644 --- a/tokio/src/runtime/tests/loom_multi_thread.rs +++ b/tokio/src/runtime/tests/loom_multi_thread.rs @@ -10,7 +10,6 @@ mod yield_now; /// In order to speed up the C use crate::runtime::tests::loom_oneshot as oneshot; use crate::runtime::{self, Runtime}; -use crate::sync::mpsc::channel; use crate::{spawn, task}; use tokio_test::assert_ok; @@ -460,32 +459,3 @@ impl Future for Track { }) } } - -#[test] -fn drop_tasks_with_reference_cycle() { - loom::model(|| { - let pool = mk_pool(2); - - pool.block_on(async move { - let (tx, mut rx) = channel(1); - - let (a_closer, mut wait_for_close_a) = channel::<()>(1); - let (b_closer, mut wait_for_close_b) = channel::<()>(1); - - let a = spawn(async move { - let b = rx.recv().await.unwrap(); - - futures::future::select(std::pin::pin!(b), std::pin::pin!(a_closer.send(()))).await; - }); - - let b = spawn(async move { - let _ = a.await; - let _ = b_closer.send(()).await; - }); - - tx.send(b).await.unwrap(); - - futures::future::join(wait_for_close_a.recv(), wait_for_close_b.recv()).await; - }); - }); -} diff --git a/tokio/src/runtime/tests/task.rs b/tokio/src/runtime/tests/task.rs index ea48b8e5199..66d4b8c2773 100644 --- a/tokio/src/runtime/tests/task.rs +++ b/tokio/src/runtime/tests/task.rs @@ -1,7 +1,9 @@ use crate::runtime::task::{ self, unowned, Id, JoinHandle, OwnedTasks, Schedule, Task, TaskHarnessScheduleHooks, }; -use crate::runtime::tests::NoopSchedule; +use crate::runtime::{self, tests::NoopSchedule}; +use crate::spawn; +use crate::sync::{mpsc, Barrier}; use std::collections::VecDeque; use std::future::Future; @@ -45,6 +47,41 @@ impl Drop for AssertDrop { } } +#[test] +fn drop_tasks_with_reference_cycle() { + let rt = runtime::Builder::new_current_thread().build().unwrap(); + + rt.block_on(async { + let (tx, mut rx) = mpsc::channel(1); + + let barrier = Arc::new(Barrier::new(3)); + let barrier_a = barrier.clone(); + let barrier_b = barrier.clone(); + + let a = spawn(async move { + let b = rx.recv().await.unwrap(); + + // Poll the JoinHandle once. This registers the waker. + // The other task cannot have finished at this point due to the barrier below. + futures::future::select(b, std::future::ready(())).await; + + barrier_a.wait().await; + }); + + let b = spawn(async move { + // Poll the JoinHandle once. This registers the waker. + // The other task cannot have finished at this point due to the barrier below. + futures::future::select(a, std::future::ready(())).await; + + barrier_b.wait().await; + }); + + tx.send(b).await.unwrap(); + + barrier.wait().await; + }); +} + // A Notified does not shut down on drop, but it is dropped once the ref-count // hits zero. #[test]