From fc1570d2d7e104d9992498dd925a707265be2bea Mon Sep 17 00:00:00 2001 From: Alex Orlenko Date: Sun, 22 Sep 2024 19:04:32 +0100 Subject: [PATCH] Support yielding from hooks for Lua 5.3+ --- src/lib.rs | 10 ++++------ src/prelude.rs | 4 ++-- src/state.rs | 10 +++++----- src/state/raw.rs | 27 +++++++++++++++++++++----- src/thread.rs | 4 ++-- src/types.rs | 9 +++++---- tests/hooks.rs | 50 +++++++++++++++++++++++++++++++++++++++++------- 7 files changed, 83 insertions(+), 31 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index a02a4a9a..c402e11c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -116,7 +116,9 @@ pub use crate::string::{BorrowedBytes, BorrowedStr, String}; pub use crate::table::{Table, TablePairs, TableSequence}; pub use crate::thread::{Thread, ThreadStatus}; pub use crate::traits::ObjectLike; -pub use crate::types::{AppDataRef, AppDataRefMut, Integer, LightUserData, MaybeSend, Number, RegistryKey}; +pub use crate::types::{ + AppDataRef, AppDataRefMut, Integer, LightUserData, MaybeSend, Number, RegistryKey, VmState, +}; pub use crate::userdata::{ AnyUserData, MetaMethod, UserData, UserDataFields, UserDataMetatable, UserDataMethods, UserDataRef, UserDataRefMut, UserDataRegistry, @@ -128,11 +130,7 @@ pub use crate::hook::HookTriggers; #[cfg(any(feature = "luau", doc))] #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] -pub use crate::{ - chunk::Compiler, - function::CoverageInfo, - types::{Vector, VmState}, -}; +pub use crate::{chunk::Compiler, function::CoverageInfo, types::Vector}; #[cfg(feature = "async")] pub use crate::thread::AsyncThread; diff --git a/src/prelude.rs b/src/prelude.rs index faf8f7c1..329f2ee7 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -12,7 +12,7 @@ pub use crate::{ ThreadStatus as LuaThreadStatus, UserData as LuaUserData, UserDataFields as LuaUserDataFields, UserDataMetatable as LuaUserDataMetatable, UserDataMethods as LuaUserDataMethods, UserDataRef as LuaUserDataRef, UserDataRefMut as LuaUserDataRefMut, - UserDataRegistry as LuaUserDataRegistry, Value as LuaValue, + UserDataRegistry as LuaUserDataRegistry, Value as LuaValue, VmState as LuaVmState, }; #[cfg(not(feature = "luau"))] @@ -21,7 +21,7 @@ pub use crate::HookTriggers as LuaHookTriggers; #[cfg(feature = "luau")] #[doc(no_inline)] -pub use crate::{CoverageInfo as LuaCoverageInfo, Vector as LuaVector, VmState as LuaVmState}; +pub use crate::{CoverageInfo as LuaCoverageInfo, Vector as LuaVector}; #[cfg(feature = "async")] #[doc(no_inline)] diff --git a/src/state.rs b/src/state.rs index ae05c72e..e10174b0 100644 --- a/src/state.rs +++ b/src/state.rs @@ -19,7 +19,7 @@ use crate::table::Table; use crate::thread::Thread; use crate::types::{ AppDataRef, AppDataRefMut, ArcReentrantMutexGuard, Integer, MaybeSend, Number, ReentrantMutex, - ReentrantMutexGuard, RegistryKey, XRc, XWeak, + ReentrantMutexGuard, RegistryKey, VmState, XRc, XWeak, }; use crate::userdata::{AnyUserData, UserData, UserDataProxy, UserDataRegistry, UserDataStorage}; use crate::util::{ @@ -31,7 +31,7 @@ use crate::value::{FromLua, FromLuaMulti, IntoLua, IntoLuaMulti, MultiValue, Nil use crate::hook::HookTriggers; #[cfg(any(feature = "luau", doc))] -use crate::{chunk::Compiler, types::VmState}; +use crate::chunk::Compiler; #[cfg(feature = "async")] use { @@ -499,12 +499,12 @@ impl Lua { /// Shows each line number of code being executed by the Lua interpreter. /// /// ``` - /// # use mlua::{Lua, HookTriggers, Result}; + /// # use mlua::{Lua, HookTriggers, Result, VmState}; /// # fn main() -> Result<()> { /// let lua = Lua::new(); /// lua.set_hook(HookTriggers::EVERY_LINE, |_lua, debug| { /// println!("line {}", debug.curr_line()); - /// Ok(()) + /// Ok(VmState::Continue) /// }); /// /// lua.load(r#" @@ -521,7 +521,7 @@ impl Lua { #[cfg_attr(docsrs, doc(cfg(not(feature = "luau"))))] pub fn set_hook(&self, triggers: HookTriggers, callback: F) where - F: Fn(&Lua, Debug) -> Result<()> + MaybeSend + 'static, + F: Fn(&Lua, Debug) -> Result + MaybeSend + 'static, { let lua = self.lock(); unsafe { lua.set_thread_hook(lua.state(), triggers, callback) }; diff --git a/src/state/raw.rs b/src/state/raw.rs index 4fbae99d..81d283e6 100644 --- a/src/state/raw.rs +++ b/src/state/raw.rs @@ -18,7 +18,7 @@ use crate::table::Table; use crate::thread::Thread; use crate::types::{ AppDataRef, AppDataRefMut, Callback, CallbackUpvalue, DestructedUserdata, Integer, LightUserData, - MaybeSend, ReentrantMutex, RegistryKey, SubtypeId, ValueRef, XRc, + MaybeSend, ReentrantMutex, RegistryKey, SubtypeId, ValueRef, VmState, XRc, }; use crate::userdata::{AnyUserData, MetaMethod, UserData, UserDataRegistry, UserDataStorage}; use crate::util::{ @@ -356,7 +356,7 @@ impl RawLua { triggers: HookTriggers, callback: F, ) where - F: Fn(&Lua, Debug) -> Result<()> + MaybeSend + 'static, + F: Fn(&Lua, Debug) -> Result + MaybeSend + 'static, { unsafe extern "C-unwind" fn hook_proc(state: *mut ffi::lua_State, ar: *mut ffi::lua_Debug) { let extra = ExtraData::get(state); @@ -365,17 +365,34 @@ impl RawLua { ffi::lua_sethook(state, None, 0, 0); return; } - callback_error_ext(state, extra, move |extra, _| { + let result = callback_error_ext(state, extra, move |extra, _| { let hook_cb = (*extra).hook_callback.clone(); let hook_cb = mlua_expect!(hook_cb, "no hook callback set in hook_proc"); if std::rc::Rc::strong_count(&hook_cb) > 2 { - return Ok(()); // Don't allow recursion + return Ok(VmState::Continue); // Don't allow recursion } let rawlua = (*extra).raw_lua(); let _guard = StateGuard::new(rawlua, state); let debug = Debug::new(rawlua, ar); hook_cb((*extra).lua(), debug) - }) + }); + match result { + VmState::Continue => {} + VmState::Yield => { + // Only count and line events can yield + if (*ar).event == ffi::LUA_HOOKCOUNT || (*ar).event == ffi::LUA_HOOKLINE { + #[cfg(any(feature = "lua54", feature = "lua53"))] + if ffi::lua_isyieldable(state) != 0 { + ffi::lua_yield(state, 0); + } + #[cfg(any(feature = "lua52", feature = "lua51", feature = "luajit"))] + { + ffi::lua_pushliteral(state, "attempt to yield from a hook"); + ffi::lua_error(state); + } + } + } + } } (*self.extra.get()).hook_callback = Some(std::rc::Rc::new(callback)); diff --git a/src/thread.rs b/src/thread.rs index c1c0a042..b4b56ee2 100644 --- a/src/thread.rs +++ b/src/thread.rs @@ -4,7 +4,7 @@ use crate::error::{Error, Result}; #[allow(unused)] use crate::state::Lua; use crate::state::RawLua; -use crate::types::ValueRef; +use crate::types::{ValueRef, VmState}; use crate::util::{check_stack, error_traceback_thread, pop_error, StackGuard}; use crate::value::{FromLuaMulti, IntoLuaMulti}; @@ -194,7 +194,7 @@ impl Thread { #[cfg_attr(docsrs, doc(cfg(not(feature = "luau"))))] pub fn set_hook(&self, triggers: HookTriggers, callback: F) where - F: Fn(&Lua, Debug) -> Result<()> + MaybeSend + 'static, + F: Fn(&Lua, Debug) -> Result + MaybeSend + 'static, { let lua = self.0.lua.lock(); unsafe { diff --git a/src/types.rs b/src/types.rs index a62e56de..6095c828 100644 --- a/src/types.rs +++ b/src/types.rs @@ -76,18 +76,19 @@ pub(crate) type AsyncCallbackUpvalue = Upvalue; pub(crate) type AsyncPollUpvalue = Upvalue>>; /// Type to set next Luau VM action after executing interrupt function. -#[cfg(any(feature = "luau", doc))] -#[cfg_attr(docsrs, doc(cfg(feature = "luau")))] pub enum VmState { Continue, + /// Yield the current thread. + /// + /// Supported by Lua 5.3+ and Luau. Yield, } #[cfg(all(feature = "send", not(feature = "luau")))] -pub(crate) type HookCallback = Rc Result<()> + Send>; +pub(crate) type HookCallback = Rc Result + Send>; #[cfg(all(not(feature = "send"), not(feature = "luau")))] -pub(crate) type HookCallback = Rc Result<()>>; +pub(crate) type HookCallback = Rc Result>; #[cfg(all(feature = "send", feature = "luau"))] pub(crate) type InterruptCallback = Rc Result + Send>; diff --git a/tests/hooks.rs b/tests/hooks.rs index a7b3ff47..7231bdfd 100644 --- a/tests/hooks.rs +++ b/tests/hooks.rs @@ -4,7 +4,7 @@ use std::ops::Deref; use std::sync::atomic::{AtomicI64, Ordering}; use std::sync::{Arc, Mutex}; -use mlua::{DebugEvent, Error, HookTriggers, Lua, Result, Value}; +use mlua::{DebugEvent, Error, HookTriggers, Lua, Result, ThreadStatus, Value, VmState}; #[test] fn test_hook_triggers() { @@ -26,7 +26,7 @@ fn test_line_counts() -> Result<()> { lua.set_hook(HookTriggers::EVERY_LINE, move |_lua, debug| { assert_eq!(debug.event(), DebugEvent::Line); hook_output.lock().unwrap().push(debug.curr_line()); - Ok(()) + Ok(VmState::Continue) }); lua.load( r#" @@ -61,7 +61,7 @@ fn test_function_calls() -> Result<()> { let source = debug.source(); let name = names.name.map(|s| s.into_owned()); hook_output.lock().unwrap().push((name, source.what)); - Ok(()) + Ok(VmState::Continue) }); lua.load( @@ -120,7 +120,7 @@ fn test_limit_execution_instructions() -> Result<()> { if max_instructions.fetch_sub(30, Ordering::Relaxed) <= 30 { Err(Error::runtime("time's up")) } else { - Ok(()) + Ok(VmState::Continue) } }, ); @@ -191,10 +191,10 @@ fn test_hook_swap_within_hook() -> Result<()> { TL_LUA.with(|tl| { tl.borrow().as_ref().unwrap().remove_hook(); }); - Ok(()) + Ok(VmState::Continue) }) }); - Ok(()) + Ok(VmState::Continue) }) }); @@ -234,7 +234,7 @@ fn test_hook_threads() -> Result<()> { co.set_hook(HookTriggers::EVERY_LINE, move |_lua, debug| { assert_eq!(debug.event(), DebugEvent::Line); hook_output.lock().unwrap().push(debug.curr_line()); - Ok(()) + Ok(VmState::Continue) }); co.resume::<()>(())?; @@ -249,3 +249,39 @@ fn test_hook_threads() -> Result<()> { Ok(()) } + +#[test] +fn test_hook_yield() -> Result<()> { + let lua = Lua::new(); + + let func = lua + .load( + r#" + local x = 2 + 3 + local y = x * 63 + local z = string.len(x..", "..y) + "#, + ) + .into_function()?; + let co = lua.create_thread(func)?; + + co.set_hook(HookTriggers::EVERY_LINE, move |_lua, _debug| Ok(VmState::Yield)); + + #[cfg(any(feature = "lua54", feature = "lua53"))] + { + assert!(co.resume::<()>(()).is_ok()); + assert!(co.resume::<()>(()).is_ok()); + assert!(co.resume::<()>(()).is_ok()); + assert!(co.resume::<()>(()).is_ok()); + assert!(co.status() == ThreadStatus::Finished); + } + #[cfg(any(feature = "lua51", feature = "lua52", feature = "luajit"))] + { + assert!( + matches!(co.resume::<()>(()), Err(Error::RuntimeError(err)) if err.contains("attempt to yield from a hook")) + ); + assert!(co.status() == ThreadStatus::Error); + } + + Ok(()) +}