diff --git a/src/function.rs b/src/function.rs
index 95cb43a7..fdf4c1a7 100644
--- a/src/function.rs
+++ b/src/function.rs
@@ -5,7 +5,7 @@ use std::{mem, ptr, slice};
use crate::error::{Error, Result};
use crate::state::Lua;
use crate::table::Table;
-use crate::types::{Callback, MaybeSend, ValueRef};
+use crate::types::{Callback, LuaType, MaybeSend, ValueRef};
use crate::util::{
assert_stack, check_stack, linenumber_to_usize, pop_error, ptr_to_lossy_str, ptr_to_str, StackGuard,
};
@@ -588,6 +588,10 @@ impl IntoLua for WrappedAsyncFunction {
}
}
+impl LuaType for Function {
+ const TYPE_ID: c_int = ffi::LUA_TFUNCTION;
+}
+
#[cfg(test)]
mod assertions {
use super::*;
diff --git a/src/state.rs b/src/state.rs
index e10174b0..f695296b 100644
--- a/src/state.rs
+++ b/src/state.rs
@@ -18,7 +18,7 @@ use crate::string::String;
use crate::table::Table;
use crate::thread::Thread;
use crate::types::{
- AppDataRef, AppDataRefMut, ArcReentrantMutexGuard, Integer, MaybeSend, Number, ReentrantMutex,
+ AppDataRef, AppDataRefMut, ArcReentrantMutexGuard, Integer, LuaType, MaybeSend, Number, ReentrantMutex,
ReentrantMutexGuard, RegistryKey, VmState, XRc, XWeak,
};
use crate::userdata::{AnyUserData, UserData, UserDataProxy, UserDataRegistry, UserDataStorage};
@@ -1337,24 +1337,66 @@ impl Lua {
unsafe { self.lock().make_userdata(UserDataStorage::new(ud)) }
}
- /// Sets the metatable for a Luau builtin vector type.
- #[cfg(any(feature = "luau", doc))]
- #[cfg_attr(docsrs, doc(cfg(feature = "luau")))]
- pub fn set_vector_metatable(&self, metatable: Option
) {
+ /// Sets the metatable for a Lua builtin type.
+ ///
+ /// The metatable will be shared by all values of the given type.
+ ///
+ /// # Examples
+ ///
+ /// Change metatable for Lua boolean type:
+ ///
+ /// ```
+ /// # use mlua::{Lua, Result, Function};
+ /// # fn main() -> Result<()> {
+ /// # let lua = Lua::new();
+ /// let mt = lua.create_table()?;
+ /// mt.set("__tostring", lua.create_function(|_, b: bool| Ok(if b { 2 } else { 0 }))?)?;
+ /// lua.set_type_metatable::(Some(mt));
+ /// lua.load("assert(tostring(true) == '2')").exec()?;
+ /// # Ok(())
+ /// # }
+ /// ```
+ #[allow(private_bounds)]
+ pub fn set_type_metatable(&self, metatable: Option) {
let lua = self.lock();
let state = lua.state();
unsafe {
let _sg = StackGuard::new(state);
assert_stack(state, 2);
- #[cfg(not(feature = "luau-vector4"))]
- ffi::lua_pushvector(state, 0., 0., 0.);
- #[cfg(feature = "luau-vector4")]
- ffi::lua_pushvector(state, 0., 0., 0., 0.);
+ match T::TYPE_ID {
+ ffi::LUA_TBOOLEAN => {
+ ffi::lua_pushboolean(state, 0);
+ }
+ ffi::LUA_TLIGHTUSERDATA => {
+ ffi::lua_pushlightuserdata(state, ptr::null_mut());
+ }
+ ffi::LUA_TNUMBER => {
+ ffi::lua_pushnumber(state, 0.);
+ }
+ #[cfg(feature = "luau")]
+ ffi::LUA_TVECTOR => {
+ #[cfg(not(feature = "luau-vector4"))]
+ ffi::lua_pushvector(state, 0., 0., 0.);
+ #[cfg(feature = "luau-vector4")]
+ ffi::lua_pushvector(state, 0., 0., 0., 0.);
+ }
+ ffi::LUA_TSTRING => {
+ ffi::lua_pushstring(state, b"\0" as *const u8 as *const _);
+ }
+ ffi::LUA_TFUNCTION => match self.load("function() end").eval::() {
+ Ok(func) => lua.push_ref(&func.0),
+ Err(_) => return,
+ },
+ ffi::LUA_TTHREAD => {
+ ffi::lua_newthread(state);
+ }
+ _ => {}
+ }
match metatable {
Some(metatable) => lua.push_ref(&metatable.0),
None => ffi::lua_pushnil(state),
- };
+ }
ffi::lua_setmetatable(state, -2);
}
}
diff --git a/src/string.rs b/src/string.rs
index 39c1ddc7..5057b0be 100644
--- a/src/string.rs
+++ b/src/string.rs
@@ -1,7 +1,7 @@
use std::borrow::Borrow;
use std::hash::{Hash, Hasher};
use std::ops::Deref;
-use std::os::raw::c_void;
+use std::os::raw::{c_int, c_void};
use std::string::String as StdString;
use std::{cmp, fmt, slice, str};
@@ -13,7 +13,7 @@ use {
use crate::error::{Error, Result};
use crate::state::Lua;
-use crate::types::ValueRef;
+use crate::types::{LuaType, ValueRef};
/// Handle to an internal Lua string.
///
@@ -327,6 +327,10 @@ impl<'a> IntoIterator for BorrowedBytes<'a> {
}
}
+impl LuaType for String {
+ const TYPE_ID: c_int = ffi::LUA_TSTRING;
+}
+
#[cfg(test)]
mod assertions {
use super::*;
diff --git a/src/table.rs b/src/table.rs
index 3bf638c4..b231f0eb 100644
--- a/src/table.rs
+++ b/src/table.rs
@@ -1,7 +1,7 @@
use std::collections::HashSet;
use std::fmt;
use std::marker::PhantomData;
-use std::os::raw::c_void;
+use std::os::raw::{c_int, c_void};
use std::string::String as StdString;
#[cfg(feature = "serialize")]
@@ -15,7 +15,7 @@ use crate::error::{Error, Result};
use crate::function::Function;
use crate::state::{LuaGuard, RawLua};
use crate::traits::ObjectLike;
-use crate::types::{Integer, ValueRef};
+use crate::types::{Integer, LuaType, ValueRef};
use crate::util::{assert_stack, check_stack, StackGuard};
use crate::value::{FromLua, FromLuaMulti, IntoLua, IntoLuaMulti, Nil, Value};
@@ -961,6 +961,10 @@ impl Serialize for Table {
}
}
+impl LuaType for Table {
+ const TYPE_ID: c_int = ffi::LUA_TTABLE;
+}
+
#[cfg(feature = "serialize")]
impl<'a> SerializableTable<'a> {
#[inline]
diff --git a/src/thread.rs b/src/thread.rs
index b4b56ee2..cbbd61fe 100644
--- a/src/thread.rs
+++ b/src/thread.rs
@@ -4,7 +4,7 @@ use crate::error::{Error, Result};
#[allow(unused)]
use crate::state::Lua;
use crate::state::RawLua;
-use crate::types::{ValueRef, VmState};
+use crate::types::{LuaType, ValueRef, VmState};
use crate::util::{check_stack, error_traceback_thread, pop_error, StackGuard};
use crate::value::{FromLuaMulti, IntoLuaMulti};
@@ -372,6 +372,10 @@ impl PartialEq for Thread {
}
}
+impl LuaType for Thread {
+ const TYPE_ID: c_int = ffi::LUA_TTHREAD;
+}
+
#[cfg(feature = "async")]
impl AsyncThread {
#[inline]
diff --git a/src/types.rs b/src/types.rs
index ef90e805..be5b53b3 100644
--- a/src/types.rs
+++ b/src/types.rs
@@ -27,7 +27,7 @@ pub type Integer = ffi::lua_Integer;
/// Type of Lua floating point numbers.
pub type Number = ffi::lua_Number;
-// Represents different subtypes wrapped to AnyUserData
+// Represents different subtypes wrapped in AnyUserData
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub(crate) enum SubtypeId {
None,
@@ -115,6 +115,22 @@ impl MaybeSend for T {}
pub(crate) struct DestructedUserdata;
+pub(crate) trait LuaType {
+ const TYPE_ID: c_int;
+}
+
+impl LuaType for bool {
+ const TYPE_ID: c_int = ffi::LUA_TBOOLEAN;
+}
+
+impl LuaType for Number {
+ const TYPE_ID: c_int = ffi::LUA_TNUMBER;
+}
+
+impl LuaType for LightUserData {
+ const TYPE_ID: c_int = ffi::LUA_TLIGHTUSERDATA;
+}
+
mod app_data;
mod registry_key;
mod sync;
diff --git a/src/types/vector.rs b/src/types/vector.rs
index d4ac5c61..f65ba863 100644
--- a/src/types/vector.rs
+++ b/src/types/vector.rs
@@ -3,6 +3,8 @@ use std::fmt;
#[cfg(all(any(feature = "luau", doc), feature = "serialize"))]
use serde::ser::{Serialize, SerializeTupleStruct, Serializer};
+use super::LuaType;
+
/// A Luau vector type.
///
/// By default vectors are 3-dimensional, but can be 4-dimensional
@@ -84,3 +86,12 @@ impl PartialEq<[f32; Self::SIZE]> for Vector {
self.0 == *other
}
}
+
+impl LuaType for Vector {
+ #[cfg(feature = "luau")]
+ const TYPE_ID: i32 = ffi::LUA_TVECTOR;
+
+ // This is a dummy value, as `Vector` is supported only by Luau
+ #[cfg(not(feature = "luau"))]
+ const TYPE_ID: i32 = ffi::LUA_TNONE;
+}
diff --git a/tests/luau.rs b/tests/luau.rs
index afcdcbdf..3d7268b8 100644
--- a/tests/luau.rs
+++ b/tests/luau.rs
@@ -194,7 +194,7 @@ fn test_vector_metatable() -> Result<()> {
)
.eval::()?;
vector_mt.set_metatable(Some(vector_mt.clone()));
- lua.set_vector_metatable(Some(vector_mt.clone()));
+ lua.set_type_metatable::(Some(vector_mt.clone()));
lua.globals().set("Vector3", vector_mt)?;
let compiler = Compiler::new().set_vector_lib("Vector3").set_vector_ctor("new");
diff --git a/tests/types.rs b/tests/types.rs
index ffe39607..830851ae 100644
--- a/tests/types.rs
+++ b/tests/types.rs
@@ -1,6 +1,6 @@
use std::os::raw::c_void;
-use mlua::{Function, LightUserData, Lua, Result};
+use mlua::{Function, LightUserData, Lua, Number, Result, String as LuaString, Thread};
#[test]
fn test_lightuserdata() -> Result<()> {
@@ -24,3 +24,114 @@ fn test_lightuserdata() -> Result<()> {
Ok(())
}
+
+#[test]
+fn test_boolean_type_metatable() -> Result<()> {
+ let lua = Lua::new();
+
+ let mt = lua.create_table()?;
+ mt.set("__add", Function::wrap(|_, (a, b): (bool, bool)| Ok(a || b)))?;
+ lua.set_type_metatable::(Some(mt));
+
+ lua.load(r#"assert(true + true == true)"#).exec().unwrap();
+ lua.load(r#"assert(true + false == true)"#).exec().unwrap();
+ lua.load(r#"assert(false + true == true)"#).exec().unwrap();
+ lua.load(r#"assert(false + false == false)"#).exec().unwrap();
+
+ Ok(())
+}
+
+#[test]
+fn test_lightuserdata_type_metatable() -> Result<()> {
+ let lua = Lua::new();
+
+ let mt = lua.create_table()?;
+ mt.set(
+ "__add",
+ Function::wrap(|_, (a, b): (LightUserData, LightUserData)| {
+ Ok(LightUserData((a.0 as usize + b.0 as usize) as *mut c_void))
+ }),
+ )?;
+ lua.set_type_metatable::(Some(mt));
+
+ let res = lua
+ .load(
+ r#"
+ local a, b = ...
+ return a + b
+ "#,
+ )
+ .call::((
+ LightUserData(42 as *mut c_void),
+ LightUserData(100 as *mut c_void),
+ ))
+ .unwrap();
+ assert_eq!(res, LightUserData(142 as *mut c_void));
+
+ Ok(())
+}
+
+#[test]
+fn test_number_type_metatable() -> Result<()> {
+ let lua = Lua::new();
+
+ let mt = lua.create_table()?;
+ mt.set("__call", Function::wrap(|_, (n1, n2): (f64, f64)| Ok(n1 * n2)))?;
+ lua.set_type_metatable::(Some(mt));
+ lua.load(r#"assert((1.5)(3.0) == 4.5)"#).exec().unwrap();
+ lua.load(r#"assert((5)(5) == 25)"#).exec().unwrap();
+
+ Ok(())
+}
+
+#[test]
+fn test_string_type_metatable() -> Result<()> {
+ let lua = Lua::new();
+
+ let mt = lua.create_table()?;
+ mt.set(
+ "__add",
+ Function::wrap(|_, (a, b): (LuaString, LuaString)| Ok(format!("{}{}", a.to_str()?, b.to_str()?))),
+ )?;
+ lua.set_type_metatable::(Some(mt));
+
+ lua.load(r#"assert(("foo" + "bar") == "foobar")"#).exec().unwrap();
+
+ Ok(())
+}
+
+#[test]
+fn test_function_type_metatable() -> Result<()> {
+ let lua = Lua::new();
+
+ let mt = lua.create_table()?;
+ mt.set(
+ "__index",
+ Function::wrap(|_, (_, key): (Function, String)| Ok(format!("function.{key}"))),
+ )?;
+ lua.set_type_metatable::(Some(mt));
+
+ lua.load(r#"assert((function() end).foo == "function.foo")"#)
+ .exec()
+ .unwrap();
+
+ Ok(())
+}
+
+#[test]
+fn test_thread_type_metatable() -> Result<()> {
+ let lua = Lua::new();
+
+ let mt = lua.create_table()?;
+ mt.set(
+ "__index",
+ Function::wrap(|_, (_, key): (Thread, String)| Ok(format!("thread.{key}"))),
+ )?;
+ lua.set_type_metatable::(Some(mt));
+
+ lua.load(r#"assert((coroutine.create(function() end)).foo == "thread.foo")"#)
+ .exec()
+ .unwrap();
+
+ Ok(())
+}