Skip to content

Commit

Permalink
Add 2-component vector constructor (#1569)
Browse files Browse the repository at this point in the history
Implement RFC: 2-component vector constructor. This includes 2-component
overload for `vector.create` and associated fastcall function, and its
type definition. These features are controlled by a new feature flag
`LuauVector2Constructor`. Additionally constant folding now supports two
components when `LuauVector2Constants` feature flag is set.

Note: this work does not include changes to CodeGen. Thus calls to
`vector.create` with only two arguments are not natively compiled
currently. This is left for future work.
  • Loading branch information
petrihakkinen authored Jan 17, 2025
1 parent 24cacc9 commit 67e9d85
Show file tree
Hide file tree
Showing 10 changed files with 182 additions and 30 deletions.
75 changes: 71 additions & 4 deletions Analysis/src/EmbeddedBuiltinDefinitions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

LUAU_FASTFLAGVARIABLE(LuauVectorDefinitionsExtra)
LUAU_FASTFLAG(LuauBufferBitMethods)
LUAU_FASTFLAG(LuauVector2Constructor)

namespace Luau
{
Expand Down Expand Up @@ -265,7 +266,7 @@ declare buffer: {
)BUILTIN_SRC";

static const std::string kBuiltinDefinitionVectorSrc_DEPRECATED = R"BUILTIN_SRC(
static const std::string kBuiltinDefinitionVectorSrc_NoExtra_NoVector2Ctor_DEPRECATED = R"BUILTIN_SRC(
-- TODO: this will be replaced with a built-in primitive type
declare class vector end
Expand All @@ -291,7 +292,33 @@ declare vector: {
)BUILTIN_SRC";

static const std::string kBuiltinDefinitionVectorSrc = R"BUILTIN_SRC(
static const std::string kBuiltinDefinitionVectorSrc_NoExtra_DEPRECATED = R"BUILTIN_SRC(
-- TODO: this will be replaced with a built-in primitive type
declare class vector end
declare vector: {
create: @checked (x: number, y: number, z: number?) -> vector,
magnitude: @checked (vec: vector) -> number,
normalize: @checked (vec: vector) -> vector,
cross: @checked (vec1: vector, vec2: vector) -> vector,
dot: @checked (vec1: vector, vec2: vector) -> number,
angle: @checked (vec1: vector, vec2: vector, axis: vector?) -> number,
floor: @checked (vec: vector) -> vector,
ceil: @checked (vec: vector) -> vector,
abs: @checked (vec: vector) -> vector,
sign: @checked (vec: vector) -> vector,
clamp: @checked (vec: vector, min: vector, max: vector) -> vector,
max: @checked (vector, ...vector) -> vector,
min: @checked (vector, ...vector) -> vector,
zero: vector,
one: vector,
}
)BUILTIN_SRC";

static const std::string kBuiltinDefinitionVectorSrc_NoVector2Ctor_DEPRECATED = R"BUILTIN_SRC(
-- While vector would have been better represented as a built-in primitive type, type solver class handling covers most of the properties
declare class vector
Expand Down Expand Up @@ -321,16 +348,56 @@ declare vector: {
)BUILTIN_SRC";

static const std::string kBuiltinDefinitionVectorSrc = R"BUILTIN_SRC(
-- While vector would have been better represented as a built-in primitive type, type solver class handling covers most of the properties
declare class vector
x: number
y: number
z: number
end
declare vector: {
create: @checked (x: number, y: number, z: number?) -> vector,
magnitude: @checked (vec: vector) -> number,
normalize: @checked (vec: vector) -> vector,
cross: @checked (vec1: vector, vec2: vector) -> vector,
dot: @checked (vec1: vector, vec2: vector) -> number,
angle: @checked (vec1: vector, vec2: vector, axis: vector?) -> number,
floor: @checked (vec: vector) -> vector,
ceil: @checked (vec: vector) -> vector,
abs: @checked (vec: vector) -> vector,
sign: @checked (vec: vector) -> vector,
clamp: @checked (vec: vector, min: vector, max: vector) -> vector,
max: @checked (vector, ...vector) -> vector,
min: @checked (vector, ...vector) -> vector,
zero: vector,
one: vector,
}
)BUILTIN_SRC";

std::string getBuiltinDefinitionSource()
{
std::string result = kBuiltinDefinitionLuaSrcChecked;

result += FFlag::LuauBufferBitMethods ? kBuiltinDefinitionBufferSrc : kBuiltinDefinitionBufferSrc_DEPRECATED;

if (FFlag::LuauVectorDefinitionsExtra)
result += kBuiltinDefinitionVectorSrc;
{
if (FFlag::LuauVector2Constructor)
result += kBuiltinDefinitionVectorSrc;
else
result += kBuiltinDefinitionVectorSrc_NoVector2Ctor_DEPRECATED;
}
else
result += kBuiltinDefinitionVectorSrc_DEPRECATED;
{
if (FFlag::LuauVector2Constructor)
result += kBuiltinDefinitionVectorSrc_NoExtra_DEPRECATED;
else
result += kBuiltinDefinitionVectorSrc_NoExtra_NoVector2Ctor_DEPRECATED;
}

return result;
}
Expand Down
9 changes: 6 additions & 3 deletions Compiler/src/BuiltinFolding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include <math.h>

LUAU_FASTFLAGVARIABLE(LuauVector2Constants)
LUAU_FASTFLAG(LuauCompileMathLerp)

namespace Luau
Expand Down Expand Up @@ -473,11 +474,13 @@ Constant foldBuiltin(int bfid, const Constant* args, size_t count)
break;

case LBF_VECTOR:
if (count >= 3 && args[0].type == Constant::Type_Number && args[1].type == Constant::Type_Number && args[2].type == Constant::Type_Number)
if (count >= 2 && args[0].type == Constant::Type_Number && args[1].type == Constant::Type_Number)
{
if (count == 3)
if (count == 2 && FFlag::LuauVector2Constants)
return cvector(args[0].valueNumber, args[1].valueNumber, 0.0, 0.0);
else if (count == 3 && args[2].type == Constant::Type_Number)
return cvector(args[0].valueNumber, args[1].valueNumber, args[2].valueNumber, 0.0);
else if (count == 4 && args[3].type == Constant::Type_Number)
else if (count == 4 && args[2].type == Constant::Type_Number && args[3].type == Constant::Type_Number)
return cvector(args[0].valueNumber, args[1].valueNumber, args[2].valueNumber, args[3].valueNumber);
}
break;
Expand Down
62 changes: 49 additions & 13 deletions VM/src/lbuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
#endif
#endif

LUAU_FASTFLAG(LuauVector2Constructor)

// luauF functions implement FASTCALL instruction that performs a direct execution of some builtin functions from the VM
// The rule of thumb is that FASTCALL functions can not call user code, yield, fail, or reallocate stack.
// If types of the arguments mismatch, luauF_* needs to return -1 and the execution will fall back to the usual call path
Expand Down Expand Up @@ -1055,26 +1057,60 @@ static int luauF_tunpack(lua_State* L, StkId res, TValue* arg0, int nresults, St

static int luauF_vector(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams)
{
if (nparams >= 3 && nresults <= 1 && ttisnumber(arg0) && ttisnumber(args) && ttisnumber(args + 1))
if (FFlag::LuauVector2Constructor)
{
double x = nvalue(arg0);
double y = nvalue(args);
double z = nvalue(args + 1);
if (nparams >= 2 && nresults <= 1 && ttisnumber(arg0) && ttisnumber(args))
{
float x = (float)nvalue(arg0);
float y = (float)nvalue(args);
float z = 0.0f;

if (nparams >= 3)
{
if (!ttisnumber(args + 1))
return -1;
z = (float)nvalue(args + 1);
}

#if LUA_VECTOR_SIZE == 4
double w = 0.0;
if (nparams >= 4)
{
if (!ttisnumber(args + 2))
return -1;
w = nvalue(args + 2);
float w = 0.0f;
if (nparams >= 4)
{
if (!ttisnumber(args + 2))
return -1;
w = (float)nvalue(args + 2);
}
setvvalue(res, x, y, z, w);
#else
setvvalue(res, x, y, z, 0.0f);
#endif

return 1;
}
setvvalue(res, float(x), float(y), float(z), float(w));
}
else
{
if (nparams >= 3 && nresults <= 1 && ttisnumber(arg0) && ttisnumber(args) && ttisnumber(args + 1))
{
double x = nvalue(arg0);
double y = nvalue(args);
double z = nvalue(args + 1);

#if LUA_VECTOR_SIZE == 4
double w = 0.0;
if (nparams >= 4)
{
if (!ttisnumber(args + 2))
return -1;
w = nvalue(args + 2);
}
setvvalue(res, float(x), float(y), float(z), float(w));
#else
setvvalue(res, float(x), float(y), float(z), 0.0f);
setvvalue(res, float(x), float(y), float(z), 0.0f);
#endif

return 1;
return 1;
}
}

return -1;
Expand Down
9 changes: 6 additions & 3 deletions VM/src/lveclib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,19 @@
#include <math.h>

LUAU_FASTFLAGVARIABLE(LuauVectorMetatable)
LUAU_FASTFLAGVARIABLE(LuauVector2Constructor)

static int vector_create(lua_State* L)
{
// checking argument count to avoid accepting 'nil' as a valid value
int count = lua_gettop(L);

double x = luaL_checknumber(L, 1);
double y = luaL_checknumber(L, 2);
double z = luaL_checknumber(L, 3);
double z = FFlag::LuauVector2Constructor ? (count >= 3 ? luaL_checknumber(L, 3) : 0.0) : luaL_checknumber(L, 3);

#if LUA_VECTOR_SIZE == 4
// checking argument count to avoid accepting 'nil' as a valid value
double w = lua_gettop(L) >= 4 ? luaL_checknumber(L, 4) : 0.0;
double w = count >= 4 ? luaL_checknumber(L, 4) : 0.0;

lua_pushvector(L, float(x), float(y), float(z), float(w));
#else
Expand Down
14 changes: 14 additions & 0 deletions bench/micro_tests/test_vector_lib.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
local function prequire(name) local success, result = pcall(require, name); return success and result end
local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support")

bench.runCode(function()
for i=1,1000000 do
vector.create(i, 2, 3)
vector.create(i, 2, 3)
vector.create(i, 2, 3)
vector.create(i, 2, 3)
vector.create(i, 2, 3)
end
end, "vector: create")

-- TODO: add more tests
26 changes: 20 additions & 6 deletions tests/Compiler.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ LUAU_FASTINT(LuauRecursionLimit)
LUAU_FASTFLAG(LuauCompileOptimizeRevArith)
LUAU_FASTFLAG(LuauCompileLibraryConstants)
LUAU_FASTFLAG(LuauVectorFolding)
LUAU_FASTFLAG(LuauVector2Constants)
LUAU_FASTFLAG(LuauCompileDisabledBuiltins)

using namespace Luau;
Expand Down Expand Up @@ -5098,36 +5099,49 @@ L0: RETURN R3 -1
)");
}

TEST_CASE("VectorLiterals")
TEST_CASE("VectorConstants")
{
CHECK_EQ("\n" + compileFunction("return Vector3.new(1, 2, 3)", 0, 2, 0, /*enableVectors*/ true), R"(
ScopedFastFlag luauVector2Constants{FFlag::LuauVector2Constants, true};

CHECK_EQ("\n" + compileFunction("return vector.create(1, 2)", 0, 2, 0), R"(
LOADK R0 K0 [1, 2, 0]
RETURN R0 1
)");

CHECK_EQ("\n" + compileFunction("return vector.create(1, 2, 3)", 0, 2, 0), R"(
LOADK R0 K0 [1, 2, 3]
RETURN R0 1
)");

CHECK_EQ("\n" + compileFunction("print(Vector3.new(1, 2, 3))", 0, 2, 0, /*enableVectors*/ true), R"(
CHECK_EQ("\n" + compileFunction("print(vector.create(1, 2, 3))", 0, 2, 0), R"(
GETIMPORT R0 1 [print]
LOADK R1 K2 [1, 2, 3]
CALL R0 1 0
RETURN R0 0
)");

CHECK_EQ("\n" + compileFunction("print(Vector3.new(1, 2, 3, 4))", 0, 2, 0, /*enableVectors*/ true), R"(
CHECK_EQ("\n" + compileFunction("print(vector.create(1, 2, 3, 4))", 0, 2, 0), R"(
GETIMPORT R0 1 [print]
LOADK R1 K2 [1, 2, 3, 4]
CALL R0 1 0
RETURN R0 0
)");

CHECK_EQ("\n" + compileFunction("return Vector3.new(0, 0, 0), Vector3.new(-0, 0, 0)", 0, 2, 0, /*enableVectors*/ true), R"(
CHECK_EQ("\n" + compileFunction("return vector.create(0, 0, 0), vector.create(-0, 0, 0)", 0, 2, 0), R"(
LOADK R0 K0 [0, 0, 0]
LOADK R1 K1 [-0, 0, 0]
RETURN R0 2
)");

CHECK_EQ("\n" + compileFunction("return type(Vector3.new(0, 0, 0))", 0, 2, 0, /*enableVectors*/ true), R"(
CHECK_EQ("\n" + compileFunction("return type(vector.create(0, 0, 0))", 0, 2, 0), R"(
LOADK R0 K0 ['vector']
RETURN R0 1
)");

// test legacy constructor
CHECK_EQ("\n" + compileFunction("return Vector3.new(1, 2, 3)", 0, 2, 0, /*enableVectors*/ true), R"(
LOADK R0 K0 [1, 2, 3]
RETURN R0 1
)");
}

Expand Down
3 changes: 3 additions & 0 deletions tests/Conformance.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ LUAU_DYNAMIC_FASTFLAG(LuauDebugInfoInvArgLeftovers)
LUAU_FASTFLAG(LuauVectorLibNativeCodegen)
LUAU_FASTFLAG(LuauVectorLibNativeDot)
LUAU_FASTFLAG(LuauVectorMetatable)
LUAU_FASTFLAG(LuauVector2Constructor)
LUAU_FASTFLAG(LuauBufferBitMethods)
LUAU_FASTFLAG(LuauCodeGenLimitLiveSlotReuse)

Expand Down Expand Up @@ -896,6 +897,7 @@ TEST_CASE("VectorLibrary")
ScopedFastFlag luauVectorLibNativeCodegen{FFlag::LuauVectorLibNativeCodegen, true};
ScopedFastFlag luauVectorLibNativeDot{FFlag::LuauVectorLibNativeDot, true};
ScopedFastFlag luauVectorMetatable{FFlag::LuauVectorMetatable, true};
ScopedFastFlag luauVector2Constructor{FFlag::LuauVector2Constructor, true};

lua_CompileOptions copts = defaultOptions();

Expand Down Expand Up @@ -986,6 +988,7 @@ static void populateRTTI(lua_State* L, Luau::TypeId type)

TEST_CASE("Types")
{
ScopedFastFlag luauVector2Constructor{FFlag::LuauVector2Constructor, true};
ScopedFastFlag luauMathLerp{FFlag::LuauMathLerp, false}; // waiting for math.lerp to be added to embedded type definitions

runConformance(
Expand Down
3 changes: 3 additions & 0 deletions tests/Fixture.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
static const char* mainModuleName = "MainModule";

LUAU_FASTFLAG(LuauSolverV2);
LUAU_FASTFLAG(LuauVector2Constructor)
LUAU_FASTFLAG(DebugLuauLogSolverToJsonFile)

LUAU_FASTFLAGVARIABLE(DebugLuauForceAllNewSolverTests);
Expand Down Expand Up @@ -580,6 +581,8 @@ LoadDefinitionFileResult Fixture::loadDefinition(const std::string& source, bool
BuiltinsFixture::BuiltinsFixture(bool prepareAutocomplete)
: Fixture(prepareAutocomplete)
{
ScopedFastFlag luauVector2Constructor{FFlag::LuauVector2Constructor, true};

Luau::unfreeze(frontend.globals.globalTypes);
Luau::unfreeze(frontend.globalsForAutocomplete.globalTypes);

Expand Down
4 changes: 3 additions & 1 deletion tests/NonStrictTypeChecker.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <iostream>

LUAU_FASTFLAG(LuauCountSelfCallsNonstrict)
LUAU_FASTFLAG(LuauVector2Constructor)

using namespace Luau;

Expand Down Expand Up @@ -581,7 +582,8 @@ buffer.readi8(b, 0)

TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "nonstrict_method_calls")
{
ScopedFastFlag sff{FFlag::LuauCountSelfCallsNonstrict, true};
ScopedFastFlag luauCountSelfCallsNonstrict{FFlag::LuauCountSelfCallsNonstrict, true};
ScopedFastFlag luauVector2Constructor{FFlag::LuauVector2Constructor, true};

Luau::unfreeze(frontend.globals.globalTypes);
Luau::unfreeze(frontend.globalsForAutocomplete.globalTypes);
Expand Down
7 changes: 7 additions & 0 deletions tests/conformance/vector_library.lua
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,15 @@ function ecall(fn, ...)
end

-- make sure we cover both builtin and C impl
assert(vector.create(1, 2) == vector.create("1", "2"))
assert(vector.create(1, 2, 4) == vector.create("1", "2", "4"))

-- 'create'
local v12 = vector.create(1, 2)
local v123 = vector.create(1, 2, 3)
assert(v12.x == 1 and v12.y == 2 and v12.z == 0)
assert(v123.x == 1 and v123.y == 2 and v123.z == 3)

-- testing 'dot' with error handling and different call kinds to mostly check details in the codegen
assert(vector.dot(vector.create(1, 2, 4), vector.create(5, 6, 7)) == 45)
assert(ecall(function() vector.dot(vector.create(1, 2, 4)) end) == "missing argument #2 to 'dot' (vector expected)")
Expand Down

0 comments on commit 67e9d85

Please sign in to comment.