diff --git a/agb/examples/output.rs b/agb/examples/output.rs index 8aabc69d2..110e3df6d 100644 --- a/agb/examples/output.rs +++ b/agb/examples/output.rs @@ -8,7 +8,7 @@ static COUNT: Static = Static::new(0); #[agb::entry] fn main(_gba: agb::Gba) -> ! { let _a = unsafe { - agb::interrupt::add_interrupt_handler(agb::interrupt::Interrupt::VBlank, |_| { + agb::interrupt::add_interrupt_handler(agb::interrupt::Interrupt::VBlank, || { let cur_count = COUNT.read(); agb::println!("Hello, world, frame = {}", cur_count); COUNT.write(cur_count + 1); diff --git a/agb/examples/wave.rs b/agb/examples/wave.rs index 1e9cf2236..01804bda8 100644 --- a/agb/examples/wave.rs +++ b/agb/examples/wave.rs @@ -11,7 +11,7 @@ use agb::{ fixnum::FixedNum, interrupt::{free, Interrupt}, }; -use bare_metal::{CriticalSection, Mutex}; +use bare_metal::Mutex; struct BackCosines { cosines: [u16; 32], @@ -36,11 +36,13 @@ fn main(mut gba: agb::Gba) -> ! { example_logo::display_logo(&mut background, &mut vram); let _a = unsafe { - agb::interrupt::add_interrupt_handler(Interrupt::HBlank, |key: CriticalSection| { - let mut back = BACK.borrow(key).borrow_mut(); - let deflection = back.cosines[back.row % 32]; - ((0x0400_0010) as *mut u16).write_volatile(deflection); - back.row += 1; + agb::interrupt::add_interrupt_handler(Interrupt::HBlank, || { + free(|key| { + let mut back = BACK.borrow(key).borrow_mut(); + let deflection = back.cosines[back.row % 32]; + ((0x0400_0010) as *mut u16).write_volatile(deflection); + back.row += 1; + }); }) }; diff --git a/agb/src/interrupt.rs b/agb/src/interrupt.rs index c30c4b459..6c52cdd6e 100644 --- a/agb/src/interrupt.rs +++ b/agb/src/interrupt.rs @@ -1,6 +1,6 @@ -use core::{cell::Cell, marker::PhantomPinned, pin::Pin}; +use core::cell::Cell; -use alloc::boxed::Box; +use alloc::{rc::Rc, vec::Vec}; use bare_metal::CriticalSection; use crate::{display::DISPLAY_STATUS, memory_mapped::MemoryMapped, sync::Static}; @@ -24,16 +24,14 @@ pub enum Interrupt { } impl Interrupt { - fn enable(self) { - let _interrupt_token = temporary_interrupt_disable(); + fn enable(self, _cs: CriticalSection) { self.other_things_to_enable_interrupt(); let interrupt = self as usize; let enabled = ENABLED_INTERRUPTS.get() | (1 << (interrupt as u16)); ENABLED_INTERRUPTS.set(enabled); } - fn disable(self) { - let _interrupt_token = temporary_interrupt_disable(); + fn disable(self, _cs: CriticalSection) { self.other_things_to_disable_interrupt(); let interrupt = self as usize; let enabled = ENABLED_INTERRUPTS.get() & !(1 << (interrupt as u16)); @@ -66,77 +64,39 @@ impl Interrupt { } const ENABLED_INTERRUPTS: MemoryMapped = unsafe { MemoryMapped::new(0x04000200) }; -const INTERRUPTS_ENABLED: MemoryMapped = unsafe { MemoryMapped::new(0x04000208) }; +const INTERRUPTS_ENABLED: MemoryMapped = unsafe { MemoryMapped::new(0x04000208) }; -struct Disable { - pre: u16, -} - -impl Drop for Disable { - fn drop(&mut self) { - INTERRUPTS_ENABLED.set(self.pre); - } -} - -fn temporary_interrupt_disable() -> Disable { - let d = Disable { - pre: INTERRUPTS_ENABLED.get(), - }; - disable_interrupts(); - d -} - -fn disable_interrupts() { - INTERRUPTS_ENABLED.set(0); +extern "C" { + static mut __INTERRUPT_NEST: u32; } struct InterruptRoot { - next: Cell<*const InterruptInner>, - count: Cell, - interrupt: Interrupt, + interrupts: Vec>, } impl InterruptRoot { - const fn new(interrupt: Interrupt) -> Self { + const fn new() -> Self { InterruptRoot { - next: Cell::new(core::ptr::null()), - count: Cell::new(0), - interrupt, + interrupts: Vec::new(), } } - - fn reduce(&self) { - let new_count = self.count.get() - 1; - if new_count == 0 { - self.interrupt.disable(); - } - self.count.set(new_count); - } - - fn add(&self) { - let count = self.count.get(); - if count == 0 { - self.interrupt.enable(); - } - self.count.set(count + 1); - } } static mut INTERRUPT_TABLE: [InterruptRoot; 14] = [ - InterruptRoot::new(Interrupt::VBlank), - InterruptRoot::new(Interrupt::HBlank), - InterruptRoot::new(Interrupt::VCounter), - InterruptRoot::new(Interrupt::Timer0), - InterruptRoot::new(Interrupt::Timer1), - InterruptRoot::new(Interrupt::Timer2), - InterruptRoot::new(Interrupt::Timer3), - InterruptRoot::new(Interrupt::Serial), - InterruptRoot::new(Interrupt::Dma0), - InterruptRoot::new(Interrupt::Dma1), - InterruptRoot::new(Interrupt::Dma2), - InterruptRoot::new(Interrupt::Dma3), - InterruptRoot::new(Interrupt::Keypad), - InterruptRoot::new(Interrupt::Gamepak), + InterruptRoot::new(), + InterruptRoot::new(), + InterruptRoot::new(), + InterruptRoot::new(), + InterruptRoot::new(), + InterruptRoot::new(), + InterruptRoot::new(), + InterruptRoot::new(), + InterruptRoot::new(), + InterruptRoot::new(), + InterruptRoot::new(), + InterruptRoot::new(), + InterruptRoot::new(), + InterruptRoot::new(), ]; #[no_mangle] @@ -150,81 +110,39 @@ extern "C" fn __RUST_INTERRUPT_HANDLER(interrupt: u16) -> u16 { interrupt } -struct InterruptInner { - next: Cell<*const InterruptInner>, - root: *const InterruptRoot, - closure: *const dyn Fn(CriticalSection), - _pin: PhantomPinned, -} - -unsafe fn create_interrupt_inner( - c: impl Fn(CriticalSection), - root: *const InterruptRoot, -) -> Pin> { - let c = Box::new(c); - let c: &dyn Fn(CriticalSection) = Box::leak(c); - let c: &dyn Fn(CriticalSection) = core::mem::transmute(c); - Box::pin(InterruptInner { - next: Cell::new(core::ptr::null()), - root, - closure: c, - _pin: PhantomPinned, - }) +pub struct InterruptHandler { + kind: Interrupt, + closure: Rc, } -impl Drop for InterruptInner { +impl Drop for InterruptHandler { fn drop(&mut self) { - inner_drop(unsafe { Pin::new_unchecked(self) }); - #[allow(clippy::needless_pass_by_value)] // needed for safety reasons - fn inner_drop(this: Pin<&mut InterruptInner>) { - // drop the closure allocation safely - let _closure_box = - unsafe { Box::from_raw(this.closure as *mut dyn Fn(&CriticalSection)) }; - - // perform the rest of the drop sequence - let root = unsafe { &*this.root }; - root.reduce(); - let mut c = root.next.get(); - let own_pointer = &*this as *const _; - if c == own_pointer { - unsafe { &*this.root }.next.set(this.next.get()); - return; - } - loop { - let p = unsafe { &*c }.next.get(); - if p == own_pointer { - unsafe { &*c }.next.set(this.next.get()); - return; - } - c = p; + free(|cs| { + let root = unsafe { interrupt_to_root(self.kind) }; + root.interrupts.retain(|x| { + !core::ptr::eq::(&**x, &*self.closure) + }); + if root.interrupts.is_empty() { + self.kind.disable(cs); } - } + }); } } -pub struct InterruptHandler { - _inner: Pin>, -} - impl InterruptRoot { fn trigger_interrupts(&self) { - let mut c = self.next.get(); - while !c.is_null() { - let closure_ptr = unsafe { &*c }.closure; - let closure_ref = unsafe { &*closure_ptr }; - closure_ref(unsafe { CriticalSection::new() }); - c = unsafe { &*c }.next.get(); + for interrupt in self.interrupts.iter() { + (interrupt)(); } } } -fn interrupt_to_root(interrupt: Interrupt) -> &'static InterruptRoot { - unsafe { &INTERRUPT_TABLE[interrupt as usize] } +unsafe fn interrupt_to_root(interrupt: Interrupt) -> &'static mut InterruptRoot { + unsafe { &mut INTERRUPT_TABLE[interrupt as usize] } } #[must_use] -/// Adds an interrupt handler as long as the returned value is alive. The -/// closure takes a [`CriticalSection`] which can be used for mutexes. +/// Adds an interrupt handler as long as the returned value is alive. /// /// # Safety /// * You *must not* allocate in an interrupt. @@ -234,19 +152,16 @@ fn interrupt_to_root(interrupt: Interrupt) -> &'static InterruptRoot { /// * The closure must be static because forgetting the interrupt handler would /// cause a use after free. /// -/// [`CriticalSection`]: bare_metal::CriticalSection -/// /// # Examples /// /// ```rust,no_run /// # #![no_std] /// # #![no_main] /// # fn foo() { -/// use bare_metal::CriticalSection; /// use agb::interrupt::{add_interrupt_handler, Interrupt}; /// // Safety: doesn't allocate /// let _a = unsafe { -/// add_interrupt_handler(Interrupt::VBlank, |_: CriticalSection| { +/// add_interrupt_handler(Interrupt::VBlank, || { /// agb::println!("Woah there! There's been a vblank!"); /// }) /// }; @@ -254,33 +169,28 @@ fn interrupt_to_root(interrupt: Interrupt) -> &'static InterruptRoot { /// ``` pub unsafe fn add_interrupt_handler( interrupt: Interrupt, - handler: impl Fn(CriticalSection) + Send + Sync + 'static, + handler: impl Fn() + Send + Sync + 'static, ) -> InterruptHandler { - fn do_with_inner(interrupt: Interrupt, inner: Pin>) -> InterruptHandler { - free(|_| { - let root = interrupt_to_root(interrupt); - root.add(); - let mut c = root.next.get(); - if c.is_null() { - root.next.set((&*inner) as *const _); - return; - } - loop { - let p = unsafe { &*c }.next.get(); - if p.is_null() { - unsafe { &*c }.next.set((&*inner) as *const _); - return; - } - - c = p; - } - }); + fn inner( + interrupt: Interrupt, + handle: Rc, + cs: CriticalSection, + ) -> InterruptHandler { + let interrupts = unsafe { interrupt_to_root(interrupt) }; + + if interrupts.interrupts.is_empty() { + interrupt.enable(cs); + } - InterruptHandler { _inner: inner } + interrupts.interrupts.push(handle.clone()); + + InterruptHandler { + kind: interrupt, + closure: handle, + } } - let root = interrupt_to_root(interrupt) as *const _; - let inner = unsafe { create_interrupt_inner(handler, root) }; - do_with_inner(interrupt, inner) + + free(|cs| inner(interrupt, Rc::new(handler), cs)) } /// How you can access mutexes outside of interrupts by being given a @@ -293,7 +203,7 @@ where { let enabled = INTERRUPTS_ENABLED.get(); - disable_interrupts(); + INTERRUPTS_ENABLED.set(0); // prevents the contents of the function from being reordered before IME is disabled. crate::sync::memory_write_hint(&mut f); @@ -323,7 +233,7 @@ impl VBlank { if !HAS_CREATED_INTERRUPT.read() { // safety: we don't allocate in the interrupt let handler = unsafe { - add_interrupt_handler(Interrupt::VBlank, |_| { + add_interrupt_handler(Interrupt::VBlank, || { NUM_VBLANKS.write(NUM_VBLANKS.read() + 1); }) }; @@ -365,7 +275,7 @@ pub fn profiler(timer: &mut crate::timer::Timer, period: u16) -> InterruptHandle timer.set_enabled(true); unsafe { - add_interrupt_handler(timer.interrupt(), |_key: CriticalSection| { + add_interrupt_handler(timer.interrupt(), || { crate::println!("{:#010x}", crate::program_counter_before_interrupt()); }) } diff --git a/agb/src/sound/mixer/sw_mixer.rs b/agb/src/sound/mixer/sw_mixer.rs index e1d6c743b..3ed3b71b1 100644 --- a/agb/src/sound/mixer/sw_mixer.rs +++ b/agb/src/sound/mixer/sw_mixer.rs @@ -165,8 +165,10 @@ impl Mixer<'_> { let buffer_pointer_for_interrupt_handler: &MixerBuffer = unsafe { core::mem::transmute(buffer_pointer_for_interrupt_handler) }; let interrupt_handler = unsafe { - add_interrupt_handler(interrupt_timer.interrupt(), |cs| { - buffer_pointer_for_interrupt_handler.swap(cs); + add_interrupt_handler(interrupt_timer.interrupt(), || { + free(|cs| { + buffer_pointer_for_interrupt_handler.swap(cs); + }); }) }; diff --git a/agb/src/sync/statics.rs b/agb/src/sync/statics.rs index e7828c6cc..f0f28630e 100644 --- a/agb/src/sync/statics.rs +++ b/agb/src/sync/statics.rs @@ -262,7 +262,6 @@ unsafe impl Sync for Static {} #[cfg(test)] mod test { - use crate::interrupt::Interrupt; use crate::sync::Static; use crate::timer::Divider; use crate::Gba; @@ -282,7 +281,7 @@ mod test { timer.set_enabled(true); let _int = unsafe { - crate::interrupt::add_interrupt_handler(Interrupt::Timer2, |_| { + crate::interrupt::add_interrupt_handler(timer.interrupt(), || { VALUE.write(SENTINEL); }) };