Skip to content

Commit

Permalink
fix(runtime): incorrect min_timeout (#246)
Browse files Browse the repository at this point in the history
* fix(runtime): incorrect min_timeout

* refactor(runtime): use Reverse in BinaryHeap
  • Loading branch information
Mivik authored May 1, 2024
1 parent 925460c commit c0164a9
Showing 1 changed file with 19 additions and 6 deletions.
25 changes: 19 additions & 6 deletions compio-runtime/src/runtime/time.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::{
cmp::Reverse,
collections::BinaryHeap,
future::Future,
pin::Pin,
Expand Down Expand Up @@ -39,7 +40,7 @@ impl Ord for TimerEntry {
pub struct TimerRuntime {
time: Instant,
tasks: Slab<FutureState>,
wheel: BinaryHeap<TimerEntry>,
wheel: BinaryHeap<Reverse<TimerEntry>>,
}

impl TimerRuntime {
Expand All @@ -66,7 +67,7 @@ impl TimerRuntime {
let key = self.tasks.insert(FutureState::Active(None));
delay += elapsed;
let entry = TimerEntry { key, delay };
self.wheel.push(entry);
self.wheel.push(Reverse(entry));
Some(key)
}

Expand All @@ -83,8 +84,8 @@ impl TimerRuntime {
pub fn min_timeout(&self) -> Option<Duration> {
let elapsed = self.time.elapsed();
self.wheel.peek().map(|entry| {
if entry.delay > elapsed {
entry.delay - elapsed
if entry.0.delay > elapsed {
entry.0.delay - elapsed
} else {
Duration::ZERO
}
Expand All @@ -94,8 +95,8 @@ impl TimerRuntime {
pub fn wake(&mut self) {
let elapsed = self.time.elapsed();
while let Some(entry) = self.wheel.pop() {
if entry.delay <= elapsed {
if let Some(state) = self.tasks.get_mut(entry.key) {
if entry.0.delay <= elapsed {
if let Some(state) = self.tasks.get_mut(entry.0.key) {
let old_state = std::mem::replace(state, FutureState::Completed);
if let FutureState::Active(Some(waker)) = old_state {
waker.wake();
Expand Down Expand Up @@ -132,3 +133,15 @@ impl Drop for TimerFuture {
Runtime::current().inner().cancel_timer(self.key);
}
}

#[test]
fn timer_min_timeout() {
let mut runtime = TimerRuntime::new();
assert_eq!(runtime.min_timeout(), None);

runtime.insert(Duration::from_secs(1));
runtime.insert(Duration::from_secs(10));
let min_timeout = runtime.min_timeout().unwrap().as_secs_f32();

assert!(min_timeout < 1.);
}

0 comments on commit c0164a9

Please sign in to comment.