diff --git a/src/function.rs b/src/function.rs index 95cb43a7..fdf4c1a7 100644 --- a/src/function.rs +++ b/src/function.rs @@ -5,7 +5,7 @@ use std::{mem, ptr, slice}; use crate::error::{Error, Result}; use crate::state::Lua; use crate::table::Table; -use crate::types::{Callback, MaybeSend, ValueRef}; +use crate::types::{Callback, LuaType, MaybeSend, ValueRef}; use crate::util::{ assert_stack, check_stack, linenumber_to_usize, pop_error, ptr_to_lossy_str, ptr_to_str, StackGuard, }; @@ -588,6 +588,10 @@ impl IntoLua for WrappedAsyncFunction { } } +impl LuaType for Function { + const TYPE_ID: c_int = ffi::LUA_TFUNCTION; +} + #[cfg(test)] mod assertions { use super::*; diff --git a/src/state.rs b/src/state.rs index e10174b0..f695296b 100644 --- a/src/state.rs +++ b/src/state.rs @@ -18,7 +18,7 @@ use crate::string::String; use crate::table::Table; use crate::thread::Thread; use crate::types::{ - AppDataRef, AppDataRefMut, ArcReentrantMutexGuard, Integer, MaybeSend, Number, ReentrantMutex, + AppDataRef, AppDataRefMut, ArcReentrantMutexGuard, Integer, LuaType, MaybeSend, Number, ReentrantMutex, ReentrantMutexGuard, RegistryKey, VmState, XRc, XWeak, }; use crate::userdata::{AnyUserData, UserData, UserDataProxy, UserDataRegistry, UserDataStorage}; @@ -1337,24 +1337,66 @@ impl Lua { unsafe { self.lock().make_userdata(UserDataStorage::new(ud)) } } - /// Sets the metatable for a Luau builtin vector type. - #[cfg(any(feature = "luau", doc))] - #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] - pub fn set_vector_metatable(&self, metatable: Option) { + /// Sets the metatable for a Lua builtin type. + /// + /// The metatable will be shared by all values of the given type. + /// + /// # Examples + /// + /// Change metatable for Lua boolean type: + /// + /// ``` + /// # use mlua::{Lua, Result, Function}; + /// # fn main() -> Result<()> { + /// # let lua = Lua::new(); + /// let mt = lua.create_table()?; + /// mt.set("__tostring", lua.create_function(|_, b: bool| Ok(if b { 2 } else { 0 }))?)?; + /// lua.set_type_metatable::(Some(mt)); + /// lua.load("assert(tostring(true) == '2')").exec()?; + /// # Ok(()) + /// # } + /// ``` + #[allow(private_bounds)] + pub fn set_type_metatable(&self, metatable: Option
) { let lua = self.lock(); let state = lua.state(); unsafe { let _sg = StackGuard::new(state); assert_stack(state, 2); - #[cfg(not(feature = "luau-vector4"))] - ffi::lua_pushvector(state, 0., 0., 0.); - #[cfg(feature = "luau-vector4")] - ffi::lua_pushvector(state, 0., 0., 0., 0.); + match T::TYPE_ID { + ffi::LUA_TBOOLEAN => { + ffi::lua_pushboolean(state, 0); + } + ffi::LUA_TLIGHTUSERDATA => { + ffi::lua_pushlightuserdata(state, ptr::null_mut()); + } + ffi::LUA_TNUMBER => { + ffi::lua_pushnumber(state, 0.); + } + #[cfg(feature = "luau")] + ffi::LUA_TVECTOR => { + #[cfg(not(feature = "luau-vector4"))] + ffi::lua_pushvector(state, 0., 0., 0.); + #[cfg(feature = "luau-vector4")] + ffi::lua_pushvector(state, 0., 0., 0., 0.); + } + ffi::LUA_TSTRING => { + ffi::lua_pushstring(state, b"\0" as *const u8 as *const _); + } + ffi::LUA_TFUNCTION => match self.load("function() end").eval::() { + Ok(func) => lua.push_ref(&func.0), + Err(_) => return, + }, + ffi::LUA_TTHREAD => { + ffi::lua_newthread(state); + } + _ => {} + } match metatable { Some(metatable) => lua.push_ref(&metatable.0), None => ffi::lua_pushnil(state), - }; + } ffi::lua_setmetatable(state, -2); } } diff --git a/src/string.rs b/src/string.rs index 39c1ddc7..5057b0be 100644 --- a/src/string.rs +++ b/src/string.rs @@ -1,7 +1,7 @@ use std::borrow::Borrow; use std::hash::{Hash, Hasher}; use std::ops::Deref; -use std::os::raw::c_void; +use std::os::raw::{c_int, c_void}; use std::string::String as StdString; use std::{cmp, fmt, slice, str}; @@ -13,7 +13,7 @@ use { use crate::error::{Error, Result}; use crate::state::Lua; -use crate::types::ValueRef; +use crate::types::{LuaType, ValueRef}; /// Handle to an internal Lua string. /// @@ -327,6 +327,10 @@ impl<'a> IntoIterator for BorrowedBytes<'a> { } } +impl LuaType for String { + const TYPE_ID: c_int = ffi::LUA_TSTRING; +} + #[cfg(test)] mod assertions { use super::*; diff --git a/src/table.rs b/src/table.rs index 3bf638c4..b231f0eb 100644 --- a/src/table.rs +++ b/src/table.rs @@ -1,7 +1,7 @@ use std::collections::HashSet; use std::fmt; use std::marker::PhantomData; -use std::os::raw::c_void; +use std::os::raw::{c_int, c_void}; use std::string::String as StdString; #[cfg(feature = "serialize")] @@ -15,7 +15,7 @@ use crate::error::{Error, Result}; use crate::function::Function; use crate::state::{LuaGuard, RawLua}; use crate::traits::ObjectLike; -use crate::types::{Integer, ValueRef}; +use crate::types::{Integer, LuaType, ValueRef}; use crate::util::{assert_stack, check_stack, StackGuard}; use crate::value::{FromLua, FromLuaMulti, IntoLua, IntoLuaMulti, Nil, Value}; @@ -961,6 +961,10 @@ impl Serialize for Table { } } +impl LuaType for Table { + const TYPE_ID: c_int = ffi::LUA_TTABLE; +} + #[cfg(feature = "serialize")] impl<'a> SerializableTable<'a> { #[inline] diff --git a/src/thread.rs b/src/thread.rs index b4b56ee2..cbbd61fe 100644 --- a/src/thread.rs +++ b/src/thread.rs @@ -4,7 +4,7 @@ use crate::error::{Error, Result}; #[allow(unused)] use crate::state::Lua; use crate::state::RawLua; -use crate::types::{ValueRef, VmState}; +use crate::types::{LuaType, ValueRef, VmState}; use crate::util::{check_stack, error_traceback_thread, pop_error, StackGuard}; use crate::value::{FromLuaMulti, IntoLuaMulti}; @@ -372,6 +372,10 @@ impl PartialEq for Thread { } } +impl LuaType for Thread { + const TYPE_ID: c_int = ffi::LUA_TTHREAD; +} + #[cfg(feature = "async")] impl AsyncThread { #[inline] diff --git a/src/types.rs b/src/types.rs index ef90e805..be5b53b3 100644 --- a/src/types.rs +++ b/src/types.rs @@ -27,7 +27,7 @@ pub type Integer = ffi::lua_Integer; /// Type of Lua floating point numbers. pub type Number = ffi::lua_Number; -// Represents different subtypes wrapped to AnyUserData +// Represents different subtypes wrapped in AnyUserData #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub(crate) enum SubtypeId { None, @@ -115,6 +115,22 @@ impl MaybeSend for T {} pub(crate) struct DestructedUserdata; +pub(crate) trait LuaType { + const TYPE_ID: c_int; +} + +impl LuaType for bool { + const TYPE_ID: c_int = ffi::LUA_TBOOLEAN; +} + +impl LuaType for Number { + const TYPE_ID: c_int = ffi::LUA_TNUMBER; +} + +impl LuaType for LightUserData { + const TYPE_ID: c_int = ffi::LUA_TLIGHTUSERDATA; +} + mod app_data; mod registry_key; mod sync; diff --git a/src/types/vector.rs b/src/types/vector.rs index d4ac5c61..f65ba863 100644 --- a/src/types/vector.rs +++ b/src/types/vector.rs @@ -3,6 +3,8 @@ use std::fmt; #[cfg(all(any(feature = "luau", doc), feature = "serialize"))] use serde::ser::{Serialize, SerializeTupleStruct, Serializer}; +use super::LuaType; + /// A Luau vector type. /// /// By default vectors are 3-dimensional, but can be 4-dimensional @@ -84,3 +86,12 @@ impl PartialEq<[f32; Self::SIZE]> for Vector { self.0 == *other } } + +impl LuaType for Vector { + #[cfg(feature = "luau")] + const TYPE_ID: i32 = ffi::LUA_TVECTOR; + + // This is a dummy value, as `Vector` is supported only by Luau + #[cfg(not(feature = "luau"))] + const TYPE_ID: i32 = ffi::LUA_TNONE; +} diff --git a/tests/luau.rs b/tests/luau.rs index afcdcbdf..3d7268b8 100644 --- a/tests/luau.rs +++ b/tests/luau.rs @@ -194,7 +194,7 @@ fn test_vector_metatable() -> Result<()> { ) .eval::
()?; vector_mt.set_metatable(Some(vector_mt.clone())); - lua.set_vector_metatable(Some(vector_mt.clone())); + lua.set_type_metatable::(Some(vector_mt.clone())); lua.globals().set("Vector3", vector_mt)?; let compiler = Compiler::new().set_vector_lib("Vector3").set_vector_ctor("new"); diff --git a/tests/types.rs b/tests/types.rs index ffe39607..830851ae 100644 --- a/tests/types.rs +++ b/tests/types.rs @@ -1,6 +1,6 @@ use std::os::raw::c_void; -use mlua::{Function, LightUserData, Lua, Result}; +use mlua::{Function, LightUserData, Lua, Number, Result, String as LuaString, Thread}; #[test] fn test_lightuserdata() -> Result<()> { @@ -24,3 +24,114 @@ fn test_lightuserdata() -> Result<()> { Ok(()) } + +#[test] +fn test_boolean_type_metatable() -> Result<()> { + let lua = Lua::new(); + + let mt = lua.create_table()?; + mt.set("__add", Function::wrap(|_, (a, b): (bool, bool)| Ok(a || b)))?; + lua.set_type_metatable::(Some(mt)); + + lua.load(r#"assert(true + true == true)"#).exec().unwrap(); + lua.load(r#"assert(true + false == true)"#).exec().unwrap(); + lua.load(r#"assert(false + true == true)"#).exec().unwrap(); + lua.load(r#"assert(false + false == false)"#).exec().unwrap(); + + Ok(()) +} + +#[test] +fn test_lightuserdata_type_metatable() -> Result<()> { + let lua = Lua::new(); + + let mt = lua.create_table()?; + mt.set( + "__add", + Function::wrap(|_, (a, b): (LightUserData, LightUserData)| { + Ok(LightUserData((a.0 as usize + b.0 as usize) as *mut c_void)) + }), + )?; + lua.set_type_metatable::(Some(mt)); + + let res = lua + .load( + r#" + local a, b = ... + return a + b + "#, + ) + .call::(( + LightUserData(42 as *mut c_void), + LightUserData(100 as *mut c_void), + )) + .unwrap(); + assert_eq!(res, LightUserData(142 as *mut c_void)); + + Ok(()) +} + +#[test] +fn test_number_type_metatable() -> Result<()> { + let lua = Lua::new(); + + let mt = lua.create_table()?; + mt.set("__call", Function::wrap(|_, (n1, n2): (f64, f64)| Ok(n1 * n2)))?; + lua.set_type_metatable::(Some(mt)); + lua.load(r#"assert((1.5)(3.0) == 4.5)"#).exec().unwrap(); + lua.load(r#"assert((5)(5) == 25)"#).exec().unwrap(); + + Ok(()) +} + +#[test] +fn test_string_type_metatable() -> Result<()> { + let lua = Lua::new(); + + let mt = lua.create_table()?; + mt.set( + "__add", + Function::wrap(|_, (a, b): (LuaString, LuaString)| Ok(format!("{}{}", a.to_str()?, b.to_str()?))), + )?; + lua.set_type_metatable::(Some(mt)); + + lua.load(r#"assert(("foo" + "bar") == "foobar")"#).exec().unwrap(); + + Ok(()) +} + +#[test] +fn test_function_type_metatable() -> Result<()> { + let lua = Lua::new(); + + let mt = lua.create_table()?; + mt.set( + "__index", + Function::wrap(|_, (_, key): (Function, String)| Ok(format!("function.{key}"))), + )?; + lua.set_type_metatable::(Some(mt)); + + lua.load(r#"assert((function() end).foo == "function.foo")"#) + .exec() + .unwrap(); + + Ok(()) +} + +#[test] +fn test_thread_type_metatable() -> Result<()> { + let lua = Lua::new(); + + let mt = lua.create_table()?; + mt.set( + "__index", + Function::wrap(|_, (_, key): (Thread, String)| Ok(format!("thread.{key}"))), + )?; + lua.set_type_metatable::(Some(mt)); + + lua.load(r#"assert((coroutine.create(function() end)).foo == "thread.foo")"#) + .exec() + .unwrap(); + + Ok(()) +}