Skip to content

Commit

Permalink
Support yielding from hooks for Lua 5.3+
Browse files Browse the repository at this point in the history
  • Loading branch information
khvzak committed Sep 22, 2024
1 parent fce8538 commit fc1570d
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 31 deletions.
10 changes: 4 additions & 6 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ pub use crate::string::{BorrowedBytes, BorrowedStr, String};
pub use crate::table::{Table, TablePairs, TableSequence};
pub use crate::thread::{Thread, ThreadStatus};
pub use crate::traits::ObjectLike;
pub use crate::types::{AppDataRef, AppDataRefMut, Integer, LightUserData, MaybeSend, Number, RegistryKey};
pub use crate::types::{
AppDataRef, AppDataRefMut, Integer, LightUserData, MaybeSend, Number, RegistryKey, VmState,
};
pub use crate::userdata::{
AnyUserData, MetaMethod, UserData, UserDataFields, UserDataMetatable, UserDataMethods, UserDataRef,
UserDataRefMut, UserDataRegistry,
Expand All @@ -128,11 +130,7 @@ pub use crate::hook::HookTriggers;

#[cfg(any(feature = "luau", doc))]
#[cfg_attr(docsrs, doc(cfg(feature = "luau")))]
pub use crate::{
chunk::Compiler,
function::CoverageInfo,
types::{Vector, VmState},
};
pub use crate::{chunk::Compiler, function::CoverageInfo, types::Vector};

#[cfg(feature = "async")]
pub use crate::thread::AsyncThread;
Expand Down
4 changes: 2 additions & 2 deletions src/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pub use crate::{
ThreadStatus as LuaThreadStatus, UserData as LuaUserData, UserDataFields as LuaUserDataFields,
UserDataMetatable as LuaUserDataMetatable, UserDataMethods as LuaUserDataMethods,
UserDataRef as LuaUserDataRef, UserDataRefMut as LuaUserDataRefMut,
UserDataRegistry as LuaUserDataRegistry, Value as LuaValue,
UserDataRegistry as LuaUserDataRegistry, Value as LuaValue, VmState as LuaVmState,
};

#[cfg(not(feature = "luau"))]
Expand All @@ -21,7 +21,7 @@ pub use crate::HookTriggers as LuaHookTriggers;

#[cfg(feature = "luau")]
#[doc(no_inline)]
pub use crate::{CoverageInfo as LuaCoverageInfo, Vector as LuaVector, VmState as LuaVmState};
pub use crate::{CoverageInfo as LuaCoverageInfo, Vector as LuaVector};

#[cfg(feature = "async")]
#[doc(no_inline)]
Expand Down
10 changes: 5 additions & 5 deletions src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use crate::table::Table;
use crate::thread::Thread;
use crate::types::{
AppDataRef, AppDataRefMut, ArcReentrantMutexGuard, Integer, MaybeSend, Number, ReentrantMutex,
ReentrantMutexGuard, RegistryKey, XRc, XWeak,
ReentrantMutexGuard, RegistryKey, VmState, XRc, XWeak,
};
use crate::userdata::{AnyUserData, UserData, UserDataProxy, UserDataRegistry, UserDataStorage};
use crate::util::{
Expand All @@ -31,7 +31,7 @@ use crate::value::{FromLua, FromLuaMulti, IntoLua, IntoLuaMulti, MultiValue, Nil
use crate::hook::HookTriggers;

#[cfg(any(feature = "luau", doc))]
use crate::{chunk::Compiler, types::VmState};
use crate::chunk::Compiler;

#[cfg(feature = "async")]
use {
Expand Down Expand Up @@ -499,12 +499,12 @@ impl Lua {
/// Shows each line number of code being executed by the Lua interpreter.
///
/// ```
/// # use mlua::{Lua, HookTriggers, Result};
/// # use mlua::{Lua, HookTriggers, Result, VmState};
/// # fn main() -> Result<()> {
/// let lua = Lua::new();
/// lua.set_hook(HookTriggers::EVERY_LINE, |_lua, debug| {
/// println!("line {}", debug.curr_line());
/// Ok(())
/// Ok(VmState::Continue)
/// });
///
/// lua.load(r#"
Expand All @@ -521,7 +521,7 @@ impl Lua {
#[cfg_attr(docsrs, doc(cfg(not(feature = "luau"))))]
pub fn set_hook<F>(&self, triggers: HookTriggers, callback: F)
where
F: Fn(&Lua, Debug) -> Result<()> + MaybeSend + 'static,
F: Fn(&Lua, Debug) -> Result<VmState> + MaybeSend + 'static,
{
let lua = self.lock();
unsafe { lua.set_thread_hook(lua.state(), triggers, callback) };
Expand Down
27 changes: 22 additions & 5 deletions src/state/raw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use crate::table::Table;
use crate::thread::Thread;
use crate::types::{
AppDataRef, AppDataRefMut, Callback, CallbackUpvalue, DestructedUserdata, Integer, LightUserData,
MaybeSend, ReentrantMutex, RegistryKey, SubtypeId, ValueRef, XRc,
MaybeSend, ReentrantMutex, RegistryKey, SubtypeId, ValueRef, VmState, XRc,
};
use crate::userdata::{AnyUserData, MetaMethod, UserData, UserDataRegistry, UserDataStorage};
use crate::util::{
Expand Down Expand Up @@ -356,7 +356,7 @@ impl RawLua {
triggers: HookTriggers,
callback: F,
) where
F: Fn(&Lua, Debug) -> Result<()> + MaybeSend + 'static,
F: Fn(&Lua, Debug) -> Result<VmState> + MaybeSend + 'static,
{
unsafe extern "C-unwind" fn hook_proc(state: *mut ffi::lua_State, ar: *mut ffi::lua_Debug) {
let extra = ExtraData::get(state);
Expand All @@ -365,17 +365,34 @@ impl RawLua {
ffi::lua_sethook(state, None, 0, 0);
return;
}
callback_error_ext(state, extra, move |extra, _| {
let result = callback_error_ext(state, extra, move |extra, _| {
let hook_cb = (*extra).hook_callback.clone();
let hook_cb = mlua_expect!(hook_cb, "no hook callback set in hook_proc");
if std::rc::Rc::strong_count(&hook_cb) > 2 {
return Ok(()); // Don't allow recursion
return Ok(VmState::Continue); // Don't allow recursion
}
let rawlua = (*extra).raw_lua();
let _guard = StateGuard::new(rawlua, state);
let debug = Debug::new(rawlua, ar);
hook_cb((*extra).lua(), debug)
})
});
match result {
VmState::Continue => {}
VmState::Yield => {
// Only count and line events can yield
if (*ar).event == ffi::LUA_HOOKCOUNT || (*ar).event == ffi::LUA_HOOKLINE {
#[cfg(any(feature = "lua54", feature = "lua53"))]
if ffi::lua_isyieldable(state) != 0 {
ffi::lua_yield(state, 0);
}
#[cfg(any(feature = "lua52", feature = "lua51", feature = "luajit"))]
{
ffi::lua_pushliteral(state, "attempt to yield from a hook");
ffi::lua_error(state);
}
}
}
}
}

(*self.extra.get()).hook_callback = Some(std::rc::Rc::new(callback));
Expand Down
4 changes: 2 additions & 2 deletions src/thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::error::{Error, Result};
#[allow(unused)]
use crate::state::Lua;
use crate::state::RawLua;
use crate::types::ValueRef;
use crate::types::{ValueRef, VmState};
use crate::util::{check_stack, error_traceback_thread, pop_error, StackGuard};
use crate::value::{FromLuaMulti, IntoLuaMulti};

Expand Down Expand Up @@ -194,7 +194,7 @@ impl Thread {
#[cfg_attr(docsrs, doc(cfg(not(feature = "luau"))))]
pub fn set_hook<F>(&self, triggers: HookTriggers, callback: F)
where
F: Fn(&Lua, Debug) -> Result<()> + MaybeSend + 'static,
F: Fn(&Lua, Debug) -> Result<VmState> + MaybeSend + 'static,
{
let lua = self.0.lua.lock();
unsafe {
Expand Down
9 changes: 5 additions & 4 deletions src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,19 @@ pub(crate) type AsyncCallbackUpvalue = Upvalue<AsyncCallback>;
pub(crate) type AsyncPollUpvalue = Upvalue<BoxFuture<'static, Result<c_int>>>;

/// Type to set next Luau VM action after executing interrupt function.
#[cfg(any(feature = "luau", doc))]
#[cfg_attr(docsrs, doc(cfg(feature = "luau")))]
pub enum VmState {
Continue,
/// Yield the current thread.
///
/// Supported by Lua 5.3+ and Luau.
Yield,
}

#[cfg(all(feature = "send", not(feature = "luau")))]
pub(crate) type HookCallback = Rc<dyn Fn(&Lua, Debug) -> Result<()> + Send>;
pub(crate) type HookCallback = Rc<dyn Fn(&Lua, Debug) -> Result<VmState> + Send>;

#[cfg(all(not(feature = "send"), not(feature = "luau")))]
pub(crate) type HookCallback = Rc<dyn Fn(&Lua, Debug) -> Result<()>>;
pub(crate) type HookCallback = Rc<dyn Fn(&Lua, Debug) -> Result<VmState>>;

#[cfg(all(feature = "send", feature = "luau"))]
pub(crate) type InterruptCallback = Rc<dyn Fn(&Lua) -> Result<VmState> + Send>;
Expand Down
50 changes: 43 additions & 7 deletions tests/hooks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::ops::Deref;
use std::sync::atomic::{AtomicI64, Ordering};
use std::sync::{Arc, Mutex};

use mlua::{DebugEvent, Error, HookTriggers, Lua, Result, Value};
use mlua::{DebugEvent, Error, HookTriggers, Lua, Result, ThreadStatus, Value, VmState};

#[test]
fn test_hook_triggers() {
Expand All @@ -26,7 +26,7 @@ fn test_line_counts() -> Result<()> {
lua.set_hook(HookTriggers::EVERY_LINE, move |_lua, debug| {
assert_eq!(debug.event(), DebugEvent::Line);
hook_output.lock().unwrap().push(debug.curr_line());
Ok(())
Ok(VmState::Continue)
});
lua.load(
r#"
Expand Down Expand Up @@ -61,7 +61,7 @@ fn test_function_calls() -> Result<()> {
let source = debug.source();
let name = names.name.map(|s| s.into_owned());
hook_output.lock().unwrap().push((name, source.what));
Ok(())
Ok(VmState::Continue)
});

lua.load(
Expand Down Expand Up @@ -120,7 +120,7 @@ fn test_limit_execution_instructions() -> Result<()> {
if max_instructions.fetch_sub(30, Ordering::Relaxed) <= 30 {
Err(Error::runtime("time's up"))
} else {
Ok(())
Ok(VmState::Continue)
}
},
);
Expand Down Expand Up @@ -191,10 +191,10 @@ fn test_hook_swap_within_hook() -> Result<()> {
TL_LUA.with(|tl| {
tl.borrow().as_ref().unwrap().remove_hook();
});
Ok(())
Ok(VmState::Continue)
})
});
Ok(())
Ok(VmState::Continue)
})
});

Expand Down Expand Up @@ -234,7 +234,7 @@ fn test_hook_threads() -> Result<()> {
co.set_hook(HookTriggers::EVERY_LINE, move |_lua, debug| {
assert_eq!(debug.event(), DebugEvent::Line);
hook_output.lock().unwrap().push(debug.curr_line());
Ok(())
Ok(VmState::Continue)
});

co.resume::<()>(())?;
Expand All @@ -249,3 +249,39 @@ fn test_hook_threads() -> Result<()> {

Ok(())
}

#[test]
fn test_hook_yield() -> Result<()> {
let lua = Lua::new();

let func = lua
.load(
r#"
local x = 2 + 3
local y = x * 63
local z = string.len(x..", "..y)
"#,
)
.into_function()?;
let co = lua.create_thread(func)?;

co.set_hook(HookTriggers::EVERY_LINE, move |_lua, _debug| Ok(VmState::Yield));

#[cfg(any(feature = "lua54", feature = "lua53"))]
{
assert!(co.resume::<()>(()).is_ok());
assert!(co.resume::<()>(()).is_ok());
assert!(co.resume::<()>(()).is_ok());
assert!(co.resume::<()>(()).is_ok());
assert!(co.status() == ThreadStatus::Finished);
}
#[cfg(any(feature = "lua51", feature = "lua52", feature = "luajit"))]
{
assert!(
matches!(co.resume::<()>(()), Err(Error::RuntimeError(err)) if err.contains("attempt to yield from a hook"))
);
assert!(co.status() == ThreadStatus::Error);
}

Ok(())
}

0 comments on commit fc1570d

Please sign in to comment.