Skip to content

Commit

Permalink
Run gargabe collection on main Lua instance drop
Browse files Browse the repository at this point in the history
This should help preventing leaking memory when capturing Lua in async block
and dropping future without finishing polling.
  • Loading branch information
khvzak committed Aug 29, 2024
1 parent ece66c4 commit 3774296
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 37 deletions.
36 changes: 27 additions & 9 deletions src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,13 @@ use util::{callback_error_ext, StateGuard};

/// Top level Lua struct which represents an instance of Lua VM.
#[derive(Clone)]
#[repr(transparent)]
pub struct Lua(XRc<ReentrantMutex<RawLua>>);
pub struct Lua {
pub(self) raw: XRc<ReentrantMutex<RawLua>>,
// Controls whether garbage collection should be run on drop
pub(self) collect_garbage: bool,
}

#[derive(Clone)]
#[repr(transparent)]
pub(crate) struct WeakLua(XWeak<ReentrantMutex<RawLua>>);

pub(crate) struct LuaGuard(ArcReentrantMutexGuard<RawLua>);
Expand Down Expand Up @@ -137,6 +139,14 @@ impl LuaOptions {
}
}

impl Drop for Lua {
fn drop(&mut self) {
if self.collect_garbage {
let _ = self.gc_collect();
}
}
}

impl fmt::Debug for Lua {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Lua({:p})", self.lock().state())
Expand Down Expand Up @@ -242,7 +252,10 @@ impl Lua {

/// Creates a new Lua state with required `libs` and `options`
unsafe fn inner_new(libs: StdLib, options: LuaOptions) -> Lua {
let lua = Lua(RawLua::new(libs, options));
let lua = Lua {
raw: RawLua::new(libs, options),
collect_garbage: true,
};

#[cfg(feature = "luau")]
mlua_expect!(lua.configure_luau(), "Error configuring Luau");
Expand All @@ -257,7 +270,10 @@ impl Lua {
#[allow(clippy::missing_safety_doc)]
#[inline]
pub unsafe fn init_from_ptr(state: *mut ffi::lua_State) -> Lua {
Lua(RawLua::init_from_ptr(state))
Lua {
raw: RawLua::init_from_ptr(state),
collect_garbage: true,
}
}

/// FIXME: Deprecated load_from_std_lib
Expand Down Expand Up @@ -1157,6 +1173,8 @@ impl Lua {
FR: Future<Output = Result<R>> + MaybeSend + 'static,
R: IntoLuaMulti,
{
// In future we should switch to async closures when they are stable to capture `&Lua`
// See https://rust-lang.github.io/rfcs/3668-async-closures.html
(self.lock()).create_async_callback(Box::new(move |rawlua, nargs| unsafe {
let args = match A::from_stack_args(nargs, 1, None, rawlua) {
Ok(args) => args,
Expand Down Expand Up @@ -1819,25 +1837,25 @@ impl Lua {

#[inline(always)]
pub(crate) fn lock(&self) -> ReentrantMutexGuard<RawLua> {
self.0.lock()
self.raw.lock()
}

#[inline(always)]
pub(crate) fn lock_arc(&self) -> LuaGuard {
LuaGuard(self.0.lock_arc())
LuaGuard(self.raw.lock_arc())
}

#[inline(always)]
pub(crate) fn weak(&self) -> WeakLua {
WeakLua(XRc::downgrade(&self.0))
WeakLua(XRc::downgrade(&self.raw))
}

/// Returns a handle to the unprotected Lua state without any synchronization.
///
/// This is useful where we know that the lock is already held by the caller.
#[inline(always)]
pub(crate) unsafe fn raw_lua(&self) -> &RawLua {
&*self.0.data_ptr()
&*self.raw.data_ptr()
}
}

Expand Down
27 changes: 14 additions & 13 deletions src/state/extra.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::any::TypeId;
use std::cell::UnsafeCell;
use std::mem::{self, MaybeUninit};
use std::mem::MaybeUninit;
use std::os::raw::{c_int, c_void};
use std::ptr;
use std::rc::Rc;
Expand All @@ -12,7 +12,7 @@ use rustc_hash::FxHashMap;
use crate::error::Result;
use crate::state::RawLua;
use crate::stdlib::StdLib;
use crate::types::{AppData, ReentrantMutex, XRc, XWeak};
use crate::types::{AppData, ReentrantMutex, XRc};
use crate::util::{get_internal_metatable, push_internal_userdata, TypeKey, WrappedFailure};

#[cfg(any(feature = "luau", doc))]
Expand All @@ -31,10 +31,8 @@ const REF_STACK_RESERVE: c_int = 1;

/// Data associated with the Lua state.
pub(crate) struct ExtraData {
// Same layout as `Lua`
pub(super) lua: MaybeUninit<XRc<ReentrantMutex<RawLua>>>,
// Same layout as `WeakLua`
pub(super) weak: MaybeUninit<XWeak<ReentrantMutex<RawLua>>>,
pub(super) lua: MaybeUninit<Lua>,
pub(super) weak: MaybeUninit<WeakLua>,

pub(super) registered_userdata: FxHashMap<TypeId, c_int>,
pub(super) registered_userdata_mt: FxHashMap<*const c_void, Option<TypeId>>,
Expand Down Expand Up @@ -185,12 +183,15 @@ impl ExtraData {
extra
}

pub(super) unsafe fn set_lua(&mut self, lua: &XRc<ReentrantMutex<RawLua>>) {
self.lua.write(XRc::clone(lua));
pub(super) unsafe fn set_lua(&mut self, raw: &XRc<ReentrantMutex<RawLua>>) {
self.lua.write(Lua {
raw: XRc::clone(raw),
collect_garbage: false,
});
if cfg!(not(feature = "module")) {
XRc::decrement_strong_count(XRc::as_ptr(lua));
XRc::decrement_strong_count(XRc::as_ptr(raw));
}
self.weak.write(XRc::downgrade(lua));
self.weak.write(WeakLua(XRc::downgrade(raw)));
}

pub(super) unsafe fn get(state: *mut ffi::lua_State) -> *mut Self {
Expand Down Expand Up @@ -228,16 +229,16 @@ impl ExtraData {

#[inline(always)]
pub(super) unsafe fn lua(&self) -> &Lua {
mem::transmute(self.lua.assume_init_ref())
self.lua.assume_init_ref()
}

#[inline(always)]
pub(super) unsafe fn raw_lua(&self) -> &RawLua {
&*self.lua.assume_init_ref().data_ptr()
&*self.lua.assume_init_ref().raw.data_ptr()
}

#[inline(always)]
pub(super) unsafe fn weak(&self) -> &WeakLua {
mem::transmute(self.weak.assume_init_ref())
self.weak.assume_init_ref()
}
}
4 changes: 2 additions & 2 deletions src/state/raw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,10 +219,10 @@ impl RawLua {
rawlua
}

pub(super) unsafe fn try_from_ptr(state: *mut ffi::lua_State) -> Option<XRc<ReentrantMutex<Self>>> {
unsafe fn try_from_ptr(state: *mut ffi::lua_State) -> Option<XRc<ReentrantMutex<Self>>> {
match ExtraData::get(state) {
extra if extra.is_null() => None,
extra => Some(XRc::clone(&(*extra).lua().0)),
extra => Some(XRc::clone(&(*extra).lua().raw)),
}
}

Expand Down
27 changes: 14 additions & 13 deletions tests/async.rs
Original file line number Diff line number Diff line change
Expand Up @@ -498,21 +498,22 @@ async fn test_async_thread_error() -> Result<()> {

#[tokio::test]
async fn test_async_terminate() -> Result<()> {
let lua = Lua::new();

let mutex = Arc::new(Mutex::new(0u32));
let mutex2 = mutex.clone();
let func = lua.create_async_function(move |_, ()| {
let mutex = mutex2.clone();
async move {
let _guard = mutex.lock().await;
sleep_ms(100).await;
Ok(())
}
})?;
{
let lua = Lua::new();
let mutex2 = mutex.clone();
let func = lua.create_async_function(move |lua, ()| {
let mutex = mutex2.clone();
async move {
let _guard = mutex.lock().await;
sleep_ms(100).await;
drop(lua); // Move Lua to the future to test drop
Ok(())
}
})?;

let _ = tokio::time::timeout(Duration::from_millis(30), func.call_async::<()>(())).await;
lua.gc_collect()?;
let _ = tokio::time::timeout(Duration::from_millis(30), func.call_async::<()>(())).await;
}
assert!(mutex.try_lock().is_ok());

Ok(())
Expand Down

0 comments on commit 3774296

Please sign in to comment.