Skip to content

Commit

Permalink
Implement WASM imported function return value
Browse files Browse the repository at this point in the history
  • Loading branch information
TooTallNate committed Oct 8, 2023
1 parent b433831 commit a4fb327
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 18 deletions.
5 changes: 5 additions & 0 deletions .changeset/wicked-keys-brush.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'nxjs-runtime': patch
---

Implement WASM imported function return value
Binary file added apps/tests/romfs/compute.wasm
Binary file not shown.
49 changes: 49 additions & 0 deletions apps/tests/src/wasm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,55 @@ test('grow.wasm', async () => {
// TODO: `mem.grow(1)` test, once that function is implemented
});

test('compute.wasm', async () => {
let aVal = -1;
let bVal = -1;
const { module, instance } = await WebAssembly.instantiateStreaming(
fetch('compute.wasm'),
{
env: {
compute(a: number, b: number) {
aVal = a;
bVal = b;
return Math.round(a * b);
},
},
}
);
assert.equal(WebAssembly.Module.imports(module), [
{ module: 'env', name: 'compute', kind: 'function' },
]);

assert.equal(WebAssembly.Module.exports(module), [
{ name: 'invoke', kind: 'function' },
{ name: 'val', kind: 'global' },
]);

const invoke = instance.exports.invoke as Function;
assert.type(invoke, 'function');

const val = instance.exports.val as WebAssembly.Global<'i32'>;
assert.instance(val, WebAssembly.Global);

assert.equal(val.value, 0, 'Global value starts at `0`');
invoke(1.23, 3.45);
assert.equal(
aVal,
2.46,
'`compute()` should have been invoked with values doubled - a'
);
assert.equal(
bVal,
6.9,
'`compute()` should have been invoked with values doubled - b'
);
assert.equal(
val.value,
17,
'Global value should be result of `a * b` rounded to nearest'
);
});

test('Imported function throws an Error is propagated', async () => {
const e = new Error('will be thrown');
const { instance } = await WebAssembly.instantiateStreaming(
Expand Down
63 changes: 45 additions & 18 deletions source/wasm.c
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,36 @@
static M3Result nx_wasm_js_error = "JS error was thrown";

// https://webassembly.github.io/spec/js-api/index.html#towebassemblyvalue
// static void nx__wasm_towebassemblyvalue(JSContext *ctx, M3ValueType type, const void *stack) {}
static int nx__wasm_towebassemblyvalue(JSContext *ctx, JSValueConst val, M3ValueType type, void *stack)
{
int r = 0;
switch (type)
{
case c_m3Type_i32:
{
r = JS_ToInt32(ctx, (int32_t *)stack, val);
break;
};
case c_m3Type_i64:
{
r = JS_ToInt64(ctx, (int64_t *)stack, val);
break;
};
case c_m3Type_f32:
case c_m3Type_f64:
{
r = JS_ToFloat64(ctx, (double *)stack, val);
break;
};
case c_m3Type_none:
case c_m3Type_unknown:
{
/* shrug */
break;
}
}
return r;
}

// https://webassembly.github.io/spec/js-api/index.html#tojsvalue
static JSValue nx__wasm_tojsvalue(JSContext *ctx, M3ValueType type, const void *stack)
Expand Down Expand Up @@ -308,9 +337,8 @@ m3ApiRawFunction(nx_wasm_imported_func)
IM3FuncType funcType = func->funcType;
nx_wasm_imported_func_t *js = _ctx->userdata;

for (int i = 0; i < funcType->numRets; i++)
{
}
uint64_t *retValAddr = _sp;
_sp += funcType->numRets;

// Map the WASM arguments to JS values
JSValue args[funcType->numArgs];
Expand All @@ -329,12 +357,16 @@ m3ApiRawFunction(nx_wasm_imported_func)
return nx_wasm_js_error;
}

// TODO: map JS return value back to WASM return value(s)
// m3ApiMultiValueReturnType(int32_t, one);
// m3ApiGetArg(int32_t, param);
// m3ApiGetArg(int64_t, param)
// m3ApiGetArg(float, param)
// m3ApiMultiValueReturn(one, 1);
// Map the JS return value to WASM
if (funcType->numRets > 0)
{
if (nx__wasm_towebassemblyvalue(js->ctx, ret_val, funcType->types[0], retValAddr))
{
JS_FreeValue(js->ctx, ret_val);
return nx_wasm_js_error;
}
// TODO: handle multi-return values when JS returns an Array?
}

JS_FreeValue(js->ctx, ret_val);
m3ApiSuccess();
Expand Down Expand Up @@ -465,8 +497,6 @@ static JSValue nx_wasm_new_instance(JSContext *ctx, JSValueConst this_val, int a
return JS_EXCEPTION;
}

// TODO: validate "kind === 'function'"

JSValue v = JS_GetPropertyStr(ctx, matching_import, "val");
if (JS_IsFunction(ctx, v))
{
Expand Down Expand Up @@ -495,6 +525,8 @@ static JSValue nx_wasm_new_instance(JSContext *ctx, JSValueConst this_val, int a
{
JS_FreeValue(ctx, v);
JS_FreeValue(ctx, matching_import);
JS_FreeValue(ctx, js->func);
js_free(ctx, js);
return nx_throw_wasm_error(ctx, "LinkError", r);
}
}
Expand Down Expand Up @@ -527,10 +559,7 @@ static JSValue nx_wasm_new_instance(JSContext *ctx, JSValueConst this_val, int a
return JS_EXCEPTION;
}

// TODO: validate "kind === 'global'"

JSValue v = JS_GetPropertyStr(ctx, matching_import, "val");

nx_wasm_global_t *nx_g = nx_wasm_global_get(ctx, v);
nx_g->global = g;

Expand Down Expand Up @@ -735,10 +764,9 @@ static JSValue nx_wasm_call_func(JSContext *ctx, JSValueConst this_val, int argc
JS_FreeCString(ctx, fname);
return nx_throw_wasm_error(ctx, "RuntimeError", r);
}

JS_FreeCString(ctx, fname);

int nargs = argc - 2;
int nargs = m3_GetArgCount(func);
if (nargs == 0)
{
r = m3_Call(func, 0, NULL);
Expand Down Expand Up @@ -773,7 +801,6 @@ static JSValue nx_wasm_call_func(JSContext *ctx, JSValueConst this_val, int argc
}

int ret_count = m3_GetRetCount(func);

if (ret_count == 0)
{
return JS_UNDEFINED;
Expand Down

0 comments on commit a4fb327

Please sign in to comment.