Skip to content

Commit

Permalink
Support setting metatable for Lua builtin types.
Browse files Browse the repository at this point in the history
Closes #445
  • Loading branch information
khvzak committed Sep 23, 2024
1 parent 16951e3 commit ca69be0
Show file tree
Hide file tree
Showing 9 changed files with 215 additions and 19 deletions.
6 changes: 5 additions & 1 deletion src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::{mem, ptr, slice};
use crate::error::{Error, Result};
use crate::state::Lua;
use crate::table::Table;
use crate::types::{Callback, MaybeSend, ValueRef};
use crate::types::{Callback, LuaType, MaybeSend, ValueRef};
use crate::util::{
assert_stack, check_stack, linenumber_to_usize, pop_error, ptr_to_lossy_str, ptr_to_str, StackGuard,
};
Expand Down Expand Up @@ -588,6 +588,10 @@ impl IntoLua for WrappedAsyncFunction {
}
}

impl LuaType for Function {
const TYPE_ID: c_int = ffi::LUA_TFUNCTION;
}

#[cfg(test)]
mod assertions {
use super::*;
Expand Down
62 changes: 52 additions & 10 deletions src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use crate::string::String;
use crate::table::Table;
use crate::thread::Thread;
use crate::types::{
AppDataRef, AppDataRefMut, ArcReentrantMutexGuard, Integer, MaybeSend, Number, ReentrantMutex,
AppDataRef, AppDataRefMut, ArcReentrantMutexGuard, Integer, LuaType, MaybeSend, Number, ReentrantMutex,
ReentrantMutexGuard, RegistryKey, VmState, XRc, XWeak,
};
use crate::userdata::{AnyUserData, UserData, UserDataProxy, UserDataRegistry, UserDataStorage};
Expand Down Expand Up @@ -1337,24 +1337,66 @@ impl Lua {
unsafe { self.lock().make_userdata(UserDataStorage::new(ud)) }
}

/// Sets the metatable for a Luau builtin vector type.
#[cfg(any(feature = "luau", doc))]
#[cfg_attr(docsrs, doc(cfg(feature = "luau")))]
pub fn set_vector_metatable(&self, metatable: Option<Table>) {
/// Sets the metatable for a Lua builtin type.
///
/// The metatable will be shared by all values of the given type.
///
/// # Examples
///
/// Change metatable for Lua boolean type:
///
/// ```
/// # use mlua::{Lua, Result, Function};
/// # fn main() -> Result<()> {
/// # let lua = Lua::new();
/// let mt = lua.create_table()?;
/// mt.set("__tostring", lua.create_function(|_, b: bool| Ok(if b { 2 } else { 0 }))?)?;
/// lua.set_type_metatable::<bool>(Some(mt));
/// lua.load("assert(tostring(true) == '2')").exec()?;
/// # Ok(())
/// # }
/// ```
#[allow(private_bounds)]
pub fn set_type_metatable<T: LuaType>(&self, metatable: Option<Table>) {
let lua = self.lock();
let state = lua.state();
unsafe {
let _sg = StackGuard::new(state);
assert_stack(state, 2);

#[cfg(not(feature = "luau-vector4"))]
ffi::lua_pushvector(state, 0., 0., 0.);
#[cfg(feature = "luau-vector4")]
ffi::lua_pushvector(state, 0., 0., 0., 0.);
match T::TYPE_ID {
ffi::LUA_TBOOLEAN => {
ffi::lua_pushboolean(state, 0);
}
ffi::LUA_TLIGHTUSERDATA => {
ffi::lua_pushlightuserdata(state, ptr::null_mut());
}
ffi::LUA_TNUMBER => {
ffi::lua_pushnumber(state, 0.);
}
#[cfg(feature = "luau")]
ffi::LUA_TVECTOR => {
#[cfg(not(feature = "luau-vector4"))]
ffi::lua_pushvector(state, 0., 0., 0.);
#[cfg(feature = "luau-vector4")]
ffi::lua_pushvector(state, 0., 0., 0., 0.);
}
ffi::LUA_TSTRING => {
ffi::lua_pushstring(state, b"\0" as *const u8 as *const _);
}
ffi::LUA_TFUNCTION => match self.load("function() end").eval::<Function>() {
Ok(func) => lua.push_ref(&func.0),
Err(_) => return,
},
ffi::LUA_TTHREAD => {
ffi::lua_newthread(state);
}
_ => {}
}
match metatable {
Some(metatable) => lua.push_ref(&metatable.0),
None => ffi::lua_pushnil(state),
};
}
ffi::lua_setmetatable(state, -2);
}
}
Expand Down
8 changes: 6 additions & 2 deletions src/string.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::borrow::Borrow;
use std::hash::{Hash, Hasher};
use std::ops::Deref;
use std::os::raw::c_void;
use std::os::raw::{c_int, c_void};
use std::string::String as StdString;
use std::{cmp, fmt, slice, str};

Expand All @@ -13,7 +13,7 @@ use {

use crate::error::{Error, Result};
use crate::state::Lua;
use crate::types::ValueRef;
use crate::types::{LuaType, ValueRef};

/// Handle to an internal Lua string.
///
Expand Down Expand Up @@ -327,6 +327,10 @@ impl<'a> IntoIterator for BorrowedBytes<'a> {
}
}

impl LuaType for String {
const TYPE_ID: c_int = ffi::LUA_TSTRING;
}

#[cfg(test)]
mod assertions {
use super::*;
Expand Down
8 changes: 6 additions & 2 deletions src/table.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::collections::HashSet;
use std::fmt;
use std::marker::PhantomData;
use std::os::raw::c_void;
use std::os::raw::{c_int, c_void};
use std::string::String as StdString;

#[cfg(feature = "serialize")]
Expand All @@ -15,7 +15,7 @@ use crate::error::{Error, Result};
use crate::function::Function;
use crate::state::{LuaGuard, RawLua};
use crate::traits::ObjectLike;
use crate::types::{Integer, ValueRef};
use crate::types::{Integer, LuaType, ValueRef};
use crate::util::{assert_stack, check_stack, StackGuard};
use crate::value::{FromLua, FromLuaMulti, IntoLua, IntoLuaMulti, Nil, Value};

Expand Down Expand Up @@ -961,6 +961,10 @@ impl Serialize for Table {
}
}

impl LuaType for Table {
const TYPE_ID: c_int = ffi::LUA_TTABLE;
}

#[cfg(feature = "serialize")]
impl<'a> SerializableTable<'a> {
#[inline]
Expand Down
6 changes: 5 additions & 1 deletion src/thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::error::{Error, Result};
#[allow(unused)]
use crate::state::Lua;
use crate::state::RawLua;
use crate::types::{ValueRef, VmState};
use crate::types::{LuaType, ValueRef, VmState};
use crate::util::{check_stack, error_traceback_thread, pop_error, StackGuard};
use crate::value::{FromLuaMulti, IntoLuaMulti};

Expand Down Expand Up @@ -372,6 +372,10 @@ impl PartialEq for Thread {
}
}

impl LuaType for Thread {
const TYPE_ID: c_int = ffi::LUA_TTHREAD;
}

#[cfg(feature = "async")]
impl<A, R> AsyncThread<A, R> {
#[inline]
Expand Down
18 changes: 17 additions & 1 deletion src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ pub type Integer = ffi::lua_Integer;
/// Type of Lua floating point numbers.
pub type Number = ffi::lua_Number;

// Represents different subtypes wrapped to AnyUserData
// Represents different subtypes wrapped in AnyUserData
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub(crate) enum SubtypeId {
None,
Expand Down Expand Up @@ -115,6 +115,22 @@ impl<T> MaybeSend for T {}

pub(crate) struct DestructedUserdata;

pub(crate) trait LuaType {
const TYPE_ID: c_int;
}

impl LuaType for bool {
const TYPE_ID: c_int = ffi::LUA_TBOOLEAN;
}

impl LuaType for Number {
const TYPE_ID: c_int = ffi::LUA_TNUMBER;
}

impl LuaType for LightUserData {
const TYPE_ID: c_int = ffi::LUA_TLIGHTUSERDATA;
}

mod app_data;
mod registry_key;
mod sync;
Expand Down
11 changes: 11 additions & 0 deletions src/types/vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ use std::fmt;
#[cfg(all(any(feature = "luau", doc), feature = "serialize"))]
use serde::ser::{Serialize, SerializeTupleStruct, Serializer};

use super::LuaType;

/// A Luau vector type.
///
/// By default vectors are 3-dimensional, but can be 4-dimensional
Expand Down Expand Up @@ -84,3 +86,12 @@ impl PartialEq<[f32; Self::SIZE]> for Vector {
self.0 == *other
}
}

impl LuaType for Vector {
#[cfg(feature = "luau")]
const TYPE_ID: i32 = ffi::LUA_TVECTOR;

// This is a dummy value, as `Vector` is supported only by Luau
#[cfg(not(feature = "luau"))]
const TYPE_ID: i32 = ffi::LUA_TNONE;
}
2 changes: 1 addition & 1 deletion tests/luau.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ fn test_vector_metatable() -> Result<()> {
)
.eval::<Table>()?;
vector_mt.set_metatable(Some(vector_mt.clone()));
lua.set_vector_metatable(Some(vector_mt.clone()));
lua.set_type_metatable::<Vector>(Some(vector_mt.clone()));
lua.globals().set("Vector3", vector_mt)?;

let compiler = Compiler::new().set_vector_lib("Vector3").set_vector_ctor("new");
Expand Down
113 changes: 112 additions & 1 deletion tests/types.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::os::raw::c_void;

use mlua::{Function, LightUserData, Lua, Result};
use mlua::{Function, LightUserData, Lua, Number, Result, String as LuaString, Thread};

#[test]
fn test_lightuserdata() -> Result<()> {
Expand All @@ -24,3 +24,114 @@ fn test_lightuserdata() -> Result<()> {

Ok(())
}

#[test]
fn test_boolean_type_metatable() -> Result<()> {
let lua = Lua::new();

let mt = lua.create_table()?;
mt.set("__add", Function::wrap(|_, (a, b): (bool, bool)| Ok(a || b)))?;
lua.set_type_metatable::<bool>(Some(mt));

lua.load(r#"assert(true + true == true)"#).exec().unwrap();
lua.load(r#"assert(true + false == true)"#).exec().unwrap();
lua.load(r#"assert(false + true == true)"#).exec().unwrap();
lua.load(r#"assert(false + false == false)"#).exec().unwrap();

Ok(())
}

#[test]
fn test_lightuserdata_type_metatable() -> Result<()> {
let lua = Lua::new();

let mt = lua.create_table()?;
mt.set(
"__add",
Function::wrap(|_, (a, b): (LightUserData, LightUserData)| {
Ok(LightUserData((a.0 as usize + b.0 as usize) as *mut c_void))
}),
)?;
lua.set_type_metatable::<LightUserData>(Some(mt));

let res = lua
.load(
r#"
local a, b = ...
return a + b
"#,
)
.call::<LightUserData>((
LightUserData(42 as *mut c_void),
LightUserData(100 as *mut c_void),
))
.unwrap();
assert_eq!(res, LightUserData(142 as *mut c_void));

Ok(())
}

#[test]
fn test_number_type_metatable() -> Result<()> {
let lua = Lua::new();

let mt = lua.create_table()?;
mt.set("__call", Function::wrap(|_, (n1, n2): (f64, f64)| Ok(n1 * n2)))?;
lua.set_type_metatable::<Number>(Some(mt));
lua.load(r#"assert((1.5)(3.0) == 4.5)"#).exec().unwrap();
lua.load(r#"assert((5)(5) == 25)"#).exec().unwrap();

Ok(())
}

#[test]
fn test_string_type_metatable() -> Result<()> {
let lua = Lua::new();

let mt = lua.create_table()?;
mt.set(
"__add",
Function::wrap(|_, (a, b): (LuaString, LuaString)| Ok(format!("{}{}", a.to_str()?, b.to_str()?))),
)?;
lua.set_type_metatable::<LuaString>(Some(mt));

lua.load(r#"assert(("foo" + "bar") == "foobar")"#).exec().unwrap();

Ok(())
}

#[test]
fn test_function_type_metatable() -> Result<()> {
let lua = Lua::new();

let mt = lua.create_table()?;
mt.set(
"__index",
Function::wrap(|_, (_, key): (Function, String)| Ok(format!("function.{key}"))),
)?;
lua.set_type_metatable::<Function>(Some(mt));

lua.load(r#"assert((function() end).foo == "function.foo")"#)
.exec()
.unwrap();

Ok(())
}

#[test]
fn test_thread_type_metatable() -> Result<()> {
let lua = Lua::new();

let mt = lua.create_table()?;
mt.set(
"__index",
Function::wrap(|_, (_, key): (Thread, String)| Ok(format!("thread.{key}"))),
)?;
lua.set_type_metatable::<Thread>(Some(mt));

lua.load(r#"assert((coroutine.create(function() end)).foo == "thread.foo")"#)
.exec()
.unwrap();

Ok(())
}

0 comments on commit ca69be0

Please sign in to comment.