From 67e9d85124df48b68fa385f55f891e2ad7b94b7f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Petri=20H=C3=A4kkinen?= Date: Fri, 17 Jan 2025 18:45:03 +0200 Subject: [PATCH] Add 2-component vector constructor (#1569) Implement RFC: 2-component vector constructor. This includes 2-component overload for `vector.create` and associated fastcall function, and its type definition. These features are controlled by a new feature flag `LuauVector2Constructor`. Additionally constant folding now supports two components when `LuauVector2Constants` feature flag is set. Note: this work does not include changes to CodeGen. Thus calls to `vector.create` with only two arguments are not natively compiled currently. This is left for future work. --- Analysis/src/EmbeddedBuiltinDefinitions.cpp | 75 +++++++++++++++++++-- Compiler/src/BuiltinFolding.cpp | 9 ++- VM/src/lbuiltins.cpp | 62 +++++++++++++---- VM/src/lveclib.cpp | 9 ++- bench/micro_tests/test_vector_lib.lua | 14 ++++ tests/Compiler.test.cpp | 26 +++++-- tests/Conformance.test.cpp | 3 + tests/Fixture.cpp | 3 + tests/NonStrictTypeChecker.test.cpp | 4 +- tests/conformance/vector_library.lua | 7 ++ 10 files changed, 182 insertions(+), 30 deletions(-) create mode 100644 bench/micro_tests/test_vector_lib.lua diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index caff137d5..e794588f1 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -3,6 +3,7 @@ LUAU_FASTFLAGVARIABLE(LuauVectorDefinitionsExtra) LUAU_FASTFLAG(LuauBufferBitMethods) +LUAU_FASTFLAG(LuauVector2Constructor) namespace Luau { @@ -265,7 +266,7 @@ declare buffer: { )BUILTIN_SRC"; -static const std::string kBuiltinDefinitionVectorSrc_DEPRECATED = R"BUILTIN_SRC( +static const std::string kBuiltinDefinitionVectorSrc_NoExtra_NoVector2Ctor_DEPRECATED = R"BUILTIN_SRC( -- TODO: this will be replaced with a built-in primitive type declare class vector end @@ -291,7 +292,33 @@ declare vector: { )BUILTIN_SRC"; -static const std::string kBuiltinDefinitionVectorSrc = R"BUILTIN_SRC( +static const std::string kBuiltinDefinitionVectorSrc_NoExtra_DEPRECATED = R"BUILTIN_SRC( + +-- TODO: this will be replaced with a built-in primitive type +declare class vector end + +declare vector: { + create: @checked (x: number, y: number, z: number?) -> vector, + magnitude: @checked (vec: vector) -> number, + normalize: @checked (vec: vector) -> vector, + cross: @checked (vec1: vector, vec2: vector) -> vector, + dot: @checked (vec1: vector, vec2: vector) -> number, + angle: @checked (vec1: vector, vec2: vector, axis: vector?) -> number, + floor: @checked (vec: vector) -> vector, + ceil: @checked (vec: vector) -> vector, + abs: @checked (vec: vector) -> vector, + sign: @checked (vec: vector) -> vector, + clamp: @checked (vec: vector, min: vector, max: vector) -> vector, + max: @checked (vector, ...vector) -> vector, + min: @checked (vector, ...vector) -> vector, + + zero: vector, + one: vector, +} + +)BUILTIN_SRC"; + +static const std::string kBuiltinDefinitionVectorSrc_NoVector2Ctor_DEPRECATED = R"BUILTIN_SRC( -- While vector would have been better represented as a built-in primitive type, type solver class handling covers most of the properties declare class vector @@ -321,6 +348,36 @@ declare vector: { )BUILTIN_SRC"; +static const std::string kBuiltinDefinitionVectorSrc = R"BUILTIN_SRC( + +-- While vector would have been better represented as a built-in primitive type, type solver class handling covers most of the properties +declare class vector + x: number + y: number + z: number +end + +declare vector: { + create: @checked (x: number, y: number, z: number?) -> vector, + magnitude: @checked (vec: vector) -> number, + normalize: @checked (vec: vector) -> vector, + cross: @checked (vec1: vector, vec2: vector) -> vector, + dot: @checked (vec1: vector, vec2: vector) -> number, + angle: @checked (vec1: vector, vec2: vector, axis: vector?) -> number, + floor: @checked (vec: vector) -> vector, + ceil: @checked (vec: vector) -> vector, + abs: @checked (vec: vector) -> vector, + sign: @checked (vec: vector) -> vector, + clamp: @checked (vec: vector, min: vector, max: vector) -> vector, + max: @checked (vector, ...vector) -> vector, + min: @checked (vector, ...vector) -> vector, + + zero: vector, + one: vector, +} + +)BUILTIN_SRC"; + std::string getBuiltinDefinitionSource() { std::string result = kBuiltinDefinitionLuaSrcChecked; @@ -328,9 +385,19 @@ std::string getBuiltinDefinitionSource() result += FFlag::LuauBufferBitMethods ? kBuiltinDefinitionBufferSrc : kBuiltinDefinitionBufferSrc_DEPRECATED; if (FFlag::LuauVectorDefinitionsExtra) - result += kBuiltinDefinitionVectorSrc; + { + if (FFlag::LuauVector2Constructor) + result += kBuiltinDefinitionVectorSrc; + else + result += kBuiltinDefinitionVectorSrc_NoVector2Ctor_DEPRECATED; + } else - result += kBuiltinDefinitionVectorSrc_DEPRECATED; + { + if (FFlag::LuauVector2Constructor) + result += kBuiltinDefinitionVectorSrc_NoExtra_DEPRECATED; + else + result += kBuiltinDefinitionVectorSrc_NoExtra_NoVector2Ctor_DEPRECATED; + } return result; } diff --git a/Compiler/src/BuiltinFolding.cpp b/Compiler/src/BuiltinFolding.cpp index 916021a66..d6aeb3ddc 100644 --- a/Compiler/src/BuiltinFolding.cpp +++ b/Compiler/src/BuiltinFolding.cpp @@ -5,6 +5,7 @@ #include +LUAU_FASTFLAGVARIABLE(LuauVector2Constants) LUAU_FASTFLAG(LuauCompileMathLerp) namespace Luau @@ -473,11 +474,13 @@ Constant foldBuiltin(int bfid, const Constant* args, size_t count) break; case LBF_VECTOR: - if (count >= 3 && args[0].type == Constant::Type_Number && args[1].type == Constant::Type_Number && args[2].type == Constant::Type_Number) + if (count >= 2 && args[0].type == Constant::Type_Number && args[1].type == Constant::Type_Number) { - if (count == 3) + if (count == 2 && FFlag::LuauVector2Constants) + return cvector(args[0].valueNumber, args[1].valueNumber, 0.0, 0.0); + else if (count == 3 && args[2].type == Constant::Type_Number) return cvector(args[0].valueNumber, args[1].valueNumber, args[2].valueNumber, 0.0); - else if (count == 4 && args[3].type == Constant::Type_Number) + else if (count == 4 && args[2].type == Constant::Type_Number && args[3].type == Constant::Type_Number) return cvector(args[0].valueNumber, args[1].valueNumber, args[2].valueNumber, args[3].valueNumber); } break; diff --git a/VM/src/lbuiltins.cpp b/VM/src/lbuiltins.cpp index 6d71836e7..92234a0f7 100644 --- a/VM/src/lbuiltins.cpp +++ b/VM/src/lbuiltins.cpp @@ -25,6 +25,8 @@ #endif #endif +LUAU_FASTFLAG(LuauVector2Constructor) + // luauF functions implement FASTCALL instruction that performs a direct execution of some builtin functions from the VM // The rule of thumb is that FASTCALL functions can not call user code, yield, fail, or reallocate stack. // If types of the arguments mismatch, luauF_* needs to return -1 and the execution will fall back to the usual call path @@ -1055,26 +1057,60 @@ static int luauF_tunpack(lua_State* L, StkId res, TValue* arg0, int nresults, St static int luauF_vector(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) { - if (nparams >= 3 && nresults <= 1 && ttisnumber(arg0) && ttisnumber(args) && ttisnumber(args + 1)) + if (FFlag::LuauVector2Constructor) { - double x = nvalue(arg0); - double y = nvalue(args); - double z = nvalue(args + 1); + if (nparams >= 2 && nresults <= 1 && ttisnumber(arg0) && ttisnumber(args)) + { + float x = (float)nvalue(arg0); + float y = (float)nvalue(args); + float z = 0.0f; + + if (nparams >= 3) + { + if (!ttisnumber(args + 1)) + return -1; + z = (float)nvalue(args + 1); + } #if LUA_VECTOR_SIZE == 4 - double w = 0.0; - if (nparams >= 4) - { - if (!ttisnumber(args + 2)) - return -1; - w = nvalue(args + 2); + float w = 0.0f; + if (nparams >= 4) + { + if (!ttisnumber(args + 2)) + return -1; + w = (float)nvalue(args + 2); + } + setvvalue(res, x, y, z, w); +#else + setvvalue(res, x, y, z, 0.0f); +#endif + + return 1; } - setvvalue(res, float(x), float(y), float(z), float(w)); + } + else + { + if (nparams >= 3 && nresults <= 1 && ttisnumber(arg0) && ttisnumber(args) && ttisnumber(args + 1)) + { + double x = nvalue(arg0); + double y = nvalue(args); + double z = nvalue(args + 1); + +#if LUA_VECTOR_SIZE == 4 + double w = 0.0; + if (nparams >= 4) + { + if (!ttisnumber(args + 2)) + return -1; + w = nvalue(args + 2); + } + setvvalue(res, float(x), float(y), float(z), float(w)); #else - setvvalue(res, float(x), float(y), float(z), 0.0f); + setvvalue(res, float(x), float(y), float(z), 0.0f); #endif - return 1; + return 1; + } } return -1; diff --git a/VM/src/lveclib.cpp b/VM/src/lveclib.cpp index 2a4e58c60..ff1fd2691 100644 --- a/VM/src/lveclib.cpp +++ b/VM/src/lveclib.cpp @@ -7,16 +7,19 @@ #include LUAU_FASTFLAGVARIABLE(LuauVectorMetatable) +LUAU_FASTFLAGVARIABLE(LuauVector2Constructor) static int vector_create(lua_State* L) { + // checking argument count to avoid accepting 'nil' as a valid value + int count = lua_gettop(L); + double x = luaL_checknumber(L, 1); double y = luaL_checknumber(L, 2); - double z = luaL_checknumber(L, 3); + double z = FFlag::LuauVector2Constructor ? (count >= 3 ? luaL_checknumber(L, 3) : 0.0) : luaL_checknumber(L, 3); #if LUA_VECTOR_SIZE == 4 - // checking argument count to avoid accepting 'nil' as a valid value - double w = lua_gettop(L) >= 4 ? luaL_checknumber(L, 4) : 0.0; + double w = count >= 4 ? luaL_checknumber(L, 4) : 0.0; lua_pushvector(L, float(x), float(y), float(z), float(w)); #else diff --git a/bench/micro_tests/test_vector_lib.lua b/bench/micro_tests/test_vector_lib.lua new file mode 100644 index 000000000..59bddc045 --- /dev/null +++ b/bench/micro_tests/test_vector_lib.lua @@ -0,0 +1,14 @@ +local function prequire(name) local success, result = pcall(require, name); return success and result end +local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") + +bench.runCode(function() + for i=1,1000000 do + vector.create(i, 2, 3) + vector.create(i, 2, 3) + vector.create(i, 2, 3) + vector.create(i, 2, 3) + vector.create(i, 2, 3) + end +end, "vector: create") + +-- TODO: add more tests \ No newline at end of file diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index e98926ac1..8256b24ac 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -26,6 +26,7 @@ LUAU_FASTINT(LuauRecursionLimit) LUAU_FASTFLAG(LuauCompileOptimizeRevArith) LUAU_FASTFLAG(LuauCompileLibraryConstants) LUAU_FASTFLAG(LuauVectorFolding) +LUAU_FASTFLAG(LuauVector2Constants) LUAU_FASTFLAG(LuauCompileDisabledBuiltins) using namespace Luau; @@ -5098,36 +5099,49 @@ L0: RETURN R3 -1 )"); } -TEST_CASE("VectorLiterals") +TEST_CASE("VectorConstants") { - CHECK_EQ("\n" + compileFunction("return Vector3.new(1, 2, 3)", 0, 2, 0, /*enableVectors*/ true), R"( + ScopedFastFlag luauVector2Constants{FFlag::LuauVector2Constants, true}; + + CHECK_EQ("\n" + compileFunction("return vector.create(1, 2)", 0, 2, 0), R"( +LOADK R0 K0 [1, 2, 0] +RETURN R0 1 +)"); + + CHECK_EQ("\n" + compileFunction("return vector.create(1, 2, 3)", 0, 2, 0), R"( LOADK R0 K0 [1, 2, 3] RETURN R0 1 )"); - CHECK_EQ("\n" + compileFunction("print(Vector3.new(1, 2, 3))", 0, 2, 0, /*enableVectors*/ true), R"( + CHECK_EQ("\n" + compileFunction("print(vector.create(1, 2, 3))", 0, 2, 0), R"( GETIMPORT R0 1 [print] LOADK R1 K2 [1, 2, 3] CALL R0 1 0 RETURN R0 0 )"); - CHECK_EQ("\n" + compileFunction("print(Vector3.new(1, 2, 3, 4))", 0, 2, 0, /*enableVectors*/ true), R"( + CHECK_EQ("\n" + compileFunction("print(vector.create(1, 2, 3, 4))", 0, 2, 0), R"( GETIMPORT R0 1 [print] LOADK R1 K2 [1, 2, 3, 4] CALL R0 1 0 RETURN R0 0 )"); - CHECK_EQ("\n" + compileFunction("return Vector3.new(0, 0, 0), Vector3.new(-0, 0, 0)", 0, 2, 0, /*enableVectors*/ true), R"( + CHECK_EQ("\n" + compileFunction("return vector.create(0, 0, 0), vector.create(-0, 0, 0)", 0, 2, 0), R"( LOADK R0 K0 [0, 0, 0] LOADK R1 K1 [-0, 0, 0] RETURN R0 2 )"); - CHECK_EQ("\n" + compileFunction("return type(Vector3.new(0, 0, 0))", 0, 2, 0, /*enableVectors*/ true), R"( + CHECK_EQ("\n" + compileFunction("return type(vector.create(0, 0, 0))", 0, 2, 0), R"( LOADK R0 K0 ['vector'] RETURN R0 1 +)"); + + // test legacy constructor + CHECK_EQ("\n" + compileFunction("return Vector3.new(1, 2, 3)", 0, 2, 0, /*enableVectors*/ true), R"( +LOADK R0 K0 [1, 2, 3] +RETURN R0 1 )"); } diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index e68ce2c79..9653397d4 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -39,6 +39,7 @@ LUAU_DYNAMIC_FASTFLAG(LuauDebugInfoInvArgLeftovers) LUAU_FASTFLAG(LuauVectorLibNativeCodegen) LUAU_FASTFLAG(LuauVectorLibNativeDot) LUAU_FASTFLAG(LuauVectorMetatable) +LUAU_FASTFLAG(LuauVector2Constructor) LUAU_FASTFLAG(LuauBufferBitMethods) LUAU_FASTFLAG(LuauCodeGenLimitLiveSlotReuse) @@ -896,6 +897,7 @@ TEST_CASE("VectorLibrary") ScopedFastFlag luauVectorLibNativeCodegen{FFlag::LuauVectorLibNativeCodegen, true}; ScopedFastFlag luauVectorLibNativeDot{FFlag::LuauVectorLibNativeDot, true}; ScopedFastFlag luauVectorMetatable{FFlag::LuauVectorMetatable, true}; + ScopedFastFlag luauVector2Constructor{FFlag::LuauVector2Constructor, true}; lua_CompileOptions copts = defaultOptions(); @@ -986,6 +988,7 @@ static void populateRTTI(lua_State* L, Luau::TypeId type) TEST_CASE("Types") { + ScopedFastFlag luauVector2Constructor{FFlag::LuauVector2Constructor, true}; ScopedFastFlag luauMathLerp{FFlag::LuauMathLerp, false}; // waiting for math.lerp to be added to embedded type definitions runConformance( diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index 5a2f9319f..e22e52b00 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -25,6 +25,7 @@ static const char* mainModuleName = "MainModule"; LUAU_FASTFLAG(LuauSolverV2); +LUAU_FASTFLAG(LuauVector2Constructor) LUAU_FASTFLAG(DebugLuauLogSolverToJsonFile) LUAU_FASTFLAGVARIABLE(DebugLuauForceAllNewSolverTests); @@ -580,6 +581,8 @@ LoadDefinitionFileResult Fixture::loadDefinition(const std::string& source, bool BuiltinsFixture::BuiltinsFixture(bool prepareAutocomplete) : Fixture(prepareAutocomplete) { + ScopedFastFlag luauVector2Constructor{FFlag::LuauVector2Constructor, true}; + Luau::unfreeze(frontend.globals.globalTypes); Luau::unfreeze(frontend.globalsForAutocomplete.globalTypes); diff --git a/tests/NonStrictTypeChecker.test.cpp b/tests/NonStrictTypeChecker.test.cpp index 8d13ebde9..f613e7502 100644 --- a/tests/NonStrictTypeChecker.test.cpp +++ b/tests/NonStrictTypeChecker.test.cpp @@ -15,6 +15,7 @@ #include LUAU_FASTFLAG(LuauCountSelfCallsNonstrict) +LUAU_FASTFLAG(LuauVector2Constructor) using namespace Luau; @@ -581,7 +582,8 @@ buffer.readi8(b, 0) TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "nonstrict_method_calls") { - ScopedFastFlag sff{FFlag::LuauCountSelfCallsNonstrict, true}; + ScopedFastFlag luauCountSelfCallsNonstrict{FFlag::LuauCountSelfCallsNonstrict, true}; + ScopedFastFlag luauVector2Constructor{FFlag::LuauVector2Constructor, true}; Luau::unfreeze(frontend.globals.globalTypes); Luau::unfreeze(frontend.globalsForAutocomplete.globalTypes); diff --git a/tests/conformance/vector_library.lua b/tests/conformance/vector_library.lua index 3f30d900d..dd5f2d1b9 100644 --- a/tests/conformance/vector_library.lua +++ b/tests/conformance/vector_library.lua @@ -11,8 +11,15 @@ function ecall(fn, ...) end -- make sure we cover both builtin and C impl +assert(vector.create(1, 2) == vector.create("1", "2")) assert(vector.create(1, 2, 4) == vector.create("1", "2", "4")) +-- 'create' +local v12 = vector.create(1, 2) +local v123 = vector.create(1, 2, 3) +assert(v12.x == 1 and v12.y == 2 and v12.z == 0) +assert(v123.x == 1 and v123.y == 2 and v123.z == 3) + -- testing 'dot' with error handling and different call kinds to mostly check details in the codegen assert(vector.dot(vector.create(1, 2, 4), vector.create(5, 6, 7)) == 45) assert(ecall(function() vector.dot(vector.create(1, 2, 4)) end) == "missing argument #2 to 'dot' (vector expected)")