Skip to content

Commit

Permalink
Fix UB in TokenCell
Browse files Browse the repository at this point in the history
  • Loading branch information
ryoqun committed Mar 18, 2024
1 parent e7d2035 commit 9e07dcd
Showing 1 changed file with 102 additions and 94 deletions.
196 changes: 102 additions & 94 deletions unified-scheduler-logic/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,8 @@ mod utils {
/// [`TokenCell`]s. And `&mut Token` is needed to access one of them as if the one is of
/// [`Token`]'s `*_mut()` getters. Thus, the Rust aliasing rule for `UnsafeCell` can
/// transitively be proven to be satisfied simply based on the usual borrow checking of the
/// `&mut` reference of [`Token`] itself via [`::borrow_mut()`](TokenCell::borrow_mut).
/// `&mut` reference of [`Token`] itself via
/// [`::with_borrow_mut()`](TokenCell::with_borrow_mut).
///
/// By extension, it's allowed to create _multiple_ tokens in a _single_ process as long as no
/// instance of [`TokenCell`] is shared by multiple instances of [`Token`].
Expand Down Expand Up @@ -213,15 +214,22 @@ mod utils {
Self(UnsafeCell::new(value))
}

/// Returns a mutable reference with its lifetime bound to the mutable reference of the
/// given token.
/// Acquires a mutable reference inside a given closure, while borrowing the mutable
/// reference of the given token.
///
/// In this way, any additional reborrow can never happen at the same time across all
/// instances of [`TokenCell<V>`] conceptually owned by the instance of [`Token<V>`] (a
/// particular thread), unless previous borrow is released. After the release, the used
/// singleton token should be free to be reused for reborrows.
pub(super) fn borrow_mut<'t>(&self, _token: &'t mut Token<V>) -> &'t mut V {
unsafe { &mut *self.0.get() }
///
/// Note that lifetime of the acquired reference is still restricted to 'self, not
/// 'token, in order to avoid use-after-free undefined behaviors.
pub(super) fn with_borrow_mut<R>(
&self,
_token: &mut Token<V>,
f: impl FnOnce(&mut V) -> R,
) -> R {
f(unsafe { &mut *self.0.get() })
}
}

Expand Down Expand Up @@ -332,23 +340,19 @@ impl TaskInner {
&self.lock_attempts
}

fn blocked_usage_count_mut<'t>(
&self,
token: &'t mut BlockedUsageCountToken,
) -> &'t mut ShortCounter {
self.blocked_usage_count.borrow_mut(token)
}

fn set_blocked_usage_count(&self, token: &mut BlockedUsageCountToken, count: ShortCounter) {
*self.blocked_usage_count_mut(token) = count;
self.blocked_usage_count
.with_borrow_mut(token, |usage_count| {
*usage_count = count;
})
}

#[must_use]
fn try_unblock(self: Task, token: &mut BlockedUsageCountToken) -> Option<Task> {
self.blocked_usage_count_mut(token)
.decrement_self()
.is_zero()
.then_some(self)
let did_unblock = self
.blocked_usage_count
.with_borrow_mut(token, |usage_count| usage_count.decrement_self().is_zero());
did_unblock.then_some(self)
}
}

Expand All @@ -369,11 +373,12 @@ impl LockAttempt {
}
}

fn usage_queue_mut<'t>(
fn with_usage_queue_mut<R>(
&self,
usage_queue_token: &'t mut UsageQueueToken,
) -> &'t mut UsageQueueInner {
self.usage_queue.0.borrow_mut(usage_queue_token)
usage_queue_token: &mut UsageQueueToken,
f: impl FnOnce(&mut UsageQueueInner) -> R,
) -> R {
self.usage_queue.0.with_borrow_mut(usage_queue_token, f)
}
}

Expand Down Expand Up @@ -594,23 +599,24 @@ impl SchedulingStateMachine {
let mut blocked_usage_count = ShortCounter::zero();

for attempt in task.lock_attempts() {
let usage_queue = attempt.usage_queue_mut(&mut self.usage_queue_token);
let lock_result = if usage_queue.has_no_blocked_usage() {
Self::try_lock_usage_queue(usage_queue, attempt.requested_usage)
} else {
LockResult::Err(())
};
match lock_result {
LockResult::Ok(Usage::Unused) => unreachable!(),
LockResult::Ok(new_usage) => {
usage_queue.current_usage = new_usage;
}
LockResult::Err(()) => {
blocked_usage_count.increment_self();
let usage_from_task = (attempt.requested_usage, task.clone());
usage_queue.push_blocked_usage_from_task(usage_from_task);
attempt.with_usage_queue_mut(&mut self.usage_queue_token, |usage_queue| {
let lock_result = if usage_queue.has_no_blocked_usage() {
Self::try_lock_usage_queue(usage_queue, attempt.requested_usage)
} else {
LockResult::Err(())
};
match lock_result {
LockResult::Ok(Usage::Unused) => unreachable!(),
LockResult::Ok(new_usage) => {
usage_queue.current_usage = new_usage;
}
LockResult::Err(()) => {
blocked_usage_count.increment_self();
let usage_from_task = (attempt.requested_usage, task.clone());
usage_queue.push_blocked_usage_from_task(usage_from_task);
}
}
}
});
}

// no blocked usage count means success
Expand All @@ -624,30 +630,33 @@ impl SchedulingStateMachine {

fn unlock_for_task(&mut self, task: &Task) {
for attempt in task.lock_attempts() {
let usage_queue = attempt.usage_queue_mut(&mut self.usage_queue_token);
let mut unblocked_task_from_queue = Self::unlock_usage_queue(usage_queue, attempt);

while let Some((requested_usage, task_with_unblocked_queue)) = unblocked_task_from_queue
{
if let Some(task) = task_with_unblocked_queue.try_unblock(&mut self.count_token) {
self.unblocked_task_queue.push_back(task);
}
attempt.with_usage_queue_mut(&mut self.usage_queue_token, |usage_queue| {
let mut unblocked_task_from_queue = Self::unlock_usage_queue(usage_queue, attempt);

while let Some((requested_usage, task_with_unblocked_queue)) =
unblocked_task_from_queue
{
if let Some(task) = task_with_unblocked_queue.try_unblock(&mut self.count_token)
{
self.unblocked_task_queue.push_back(task);
}

match Self::try_lock_usage_queue(usage_queue, requested_usage) {
LockResult::Ok(Usage::Unused) => unreachable!(),
LockResult::Ok(new_usage) => {
usage_queue.current_usage = new_usage;
// Try to further schedule blocked task for parallelism in the case of
// readonly usages
unblocked_task_from_queue = if matches!(new_usage, Usage::Readonly(_)) {
usage_queue.pop_unblocked_readonly_usage_from_task()
} else {
None
};
match Self::try_lock_usage_queue(usage_queue, requested_usage) {
LockResult::Ok(Usage::Unused) => unreachable!(),
LockResult::Ok(new_usage) => {
usage_queue.current_usage = new_usage;
// Try to further schedule blocked task for parallelism in the case of
// readonly usages
unblocked_task_from_queue = if matches!(new_usage, Usage::Readonly(_)) {
usage_queue.pop_unblocked_readonly_usage_from_task()
} else {
None
};
}
LockResult::Err(_) => panic!("should never fail in this context"),
}
LockResult::Err(_) => panic!("should never fail in this context"),
}
}
});
}
}

Expand Down Expand Up @@ -1215,24 +1224,20 @@ mod tests {
assert_matches!(state_machine.schedule_task(task2.clone()), None);
let usage_queues = usage_queues.borrow_mut();
let usage_queue = usage_queues.get(&conflicting_address).unwrap();
assert_matches!(
usage_queue
.0
.borrow_mut(&mut state_machine.usage_queue_token)
.current_usage,
Usage::Writable
);
usage_queue
.0
.with_borrow_mut(&mut state_machine.usage_queue_token, |usage_queue| {
assert_matches!(usage_queue.current_usage, Usage::Writable);
});
// task2's fee payer should have been locked already even if task2 is blocked still via the
// above the schedule_task(task2) call
let fee_payer = task2.transaction().message().fee_payer();
let usage_queue = usage_queues.get(fee_payer).unwrap();
assert_matches!(
usage_queue
.0
.borrow_mut(&mut state_machine.usage_queue_token)
.current_usage,
Usage::Writable
);
usage_queue
.0
.with_borrow_mut(&mut state_machine.usage_queue_token, |usage_queue| {
assert_matches!(usage_queue.current_usage, Usage::Writable);
});
state_machine.deschedule_task(&task1);
assert_matches!(
state_machine
Expand All @@ -1251,12 +1256,15 @@ mod tests {
SchedulingStateMachine::exclusively_initialize_current_thread_for_scheduling()
};
let usage_queue = UsageQueue::default();
let _ = SchedulingStateMachine::unlock_usage_queue(
usage_queue
.0
.borrow_mut(&mut state_machine.usage_queue_token),
&LockAttempt::new(usage_queue, RequestedUsage::Writable),
);
let usage_queue_for_lock_attempt = UsageQueue::default();
usage_queue
.0
.with_borrow_mut(&mut state_machine.usage_queue_token, |usage_queue| {
let _ = SchedulingStateMachine::unlock_usage_queue(
usage_queue,
&LockAttempt::new(usage_queue_for_lock_attempt, RequestedUsage::Writable),
);
});
}

#[test]
Expand All @@ -1268,14 +1276,14 @@ mod tests {
let usage_queue = UsageQueue::default();
usage_queue
.0
.borrow_mut(&mut state_machine.usage_queue_token)
.current_usage = Usage::Writable;
let _ = SchedulingStateMachine::unlock_usage_queue(
usage_queue
.0
.borrow_mut(&mut state_machine.usage_queue_token),
&LockAttempt::new(usage_queue, RequestedUsage::Readonly),
);
.with_borrow_mut(&mut state_machine.usage_queue_token, |usage_queue| {
usage_queue.current_usage = Usage::Writable;
let usage_queue_for_lock_attempt = UsageQueue::default();
let _ = SchedulingStateMachine::unlock_usage_queue(
usage_queue,
&LockAttempt::new(usage_queue_for_lock_attempt, RequestedUsage::Readonly),
);
});
}

#[test]
Expand All @@ -1287,13 +1295,13 @@ mod tests {
let usage_queue = UsageQueue::default();
usage_queue
.0
.borrow_mut(&mut state_machine.usage_queue_token)
.current_usage = Usage::Readonly(ShortCounter::one());
let _ = SchedulingStateMachine::unlock_usage_queue(
usage_queue
.0
.borrow_mut(&mut state_machine.usage_queue_token),
&LockAttempt::new(usage_queue, RequestedUsage::Writable),
);
.with_borrow_mut(&mut state_machine.usage_queue_token, |usage_queue| {
usage_queue.current_usage = Usage::Readonly(ShortCounter::one());
let usage_queue_for_lock_attempt = UsageQueue::default();
let _ = SchedulingStateMachine::unlock_usage_queue(
usage_queue,
&LockAttempt::new(usage_queue_for_lock_attempt, RequestedUsage::Writable),
);
});
}
}

0 comments on commit 9e07dcd

Please sign in to comment.