diff --git a/python/tvm/contrib/emcc.py b/python/tvm/contrib/emcc.py index fac204321586..07ff29205e10 100644 --- a/python/tvm/contrib/emcc.py +++ b/python/tvm/contrib/emcc.py @@ -42,7 +42,14 @@ def create_tvmjs_wasm(output, objects, options=None, cc="emcc"): cmd += ["-O3"] cmd += ["-std=c++17"] cmd += ["--no-entry"] - cmd += ["-fwasm-exceptions"] + # NOTE: asynctify conflicts with wasm-exception + # so we temp disable exception handling for now + # + # We also expect user to explicitly pass in + # -s ASYNCIFY=1 as it can increase wasm size by 2xq + # + # cmd += ["-s", "ASYNCIFY=1"] + # cmd += ["-fwasm-exceptions"] cmd += ["-s", "WASM_BIGINT=1"] cmd += ["-s", "ERROR_ON_UNDEFINED_SYMBOLS=0"] cmd += ["-s", "STANDALONE_WASM=1"] diff --git a/src/runtime/c_runtime_api.cc b/src/runtime/c_runtime_api.cc index 799ef116ce8c..ea22b89dd771 100644 --- a/src/runtime/c_runtime_api.cc +++ b/src/runtime/c_runtime_api.cc @@ -569,7 +569,6 @@ int TVMByteArrayFree(TVMByteArray* arr) { int TVMFuncCall(TVMFunctionHandle func, TVMValue* args, int* arg_type_codes, int num_args, TVMValue* ret_val, int* ret_type_code) { API_BEGIN(); - TVMRetValue rv; (static_cast(func)) ->CallPacked(TVMArgs(args, arg_type_codes, num_args), &rv); diff --git a/web/Makefile b/web/Makefile index bd5e6cbf2bd9..317438842b23 100644 --- a/web/Makefile +++ b/web/Makefile @@ -27,10 +27,11 @@ all: dist/wasm/tvmjs_runtime.wasm dist/wasm/tvmjs_runtime.wasi.js src/tvmjs_runt EMCC = emcc -EMCC_CFLAGS = $(INCLUDE_FLAGS) -O3 -std=c++17 -Wno-ignored-attributes -fwasm-exceptions +EMCC_CFLAGS = $(INCLUDE_FLAGS) -O3 -std=c++17 -Wno-ignored-attributes EMCC_LDFLAGS = --no-entry -s WASM_BIGINT=1 -s ALLOW_MEMORY_GROWTH=1 -s STANDALONE_WASM=1\ - -s ERROR_ON_UNDEFINED_SYMBOLS=0 --pre-js emcc/preload.js + -s ERROR_ON_UNDEFINED_SYMBOLS=0 --pre-js emcc/preload.js\ + -s ASYNCIFY=1 dist/wasm/%.bc: emcc/%.cc @mkdir -p $(@D) diff --git a/web/apps/node/example.js b/web/apps/node/example.js index d17ec072fa21..580bbf57ab80 100644 --- a/web/apps/node/example.js +++ b/web/apps/node/example.js @@ -21,7 +21,7 @@ */ const path = require("path"); const fs = require("fs"); -const tvmjs = require("../../lib"); +const tvmjs = require("../../dist/tvmjs.bundle"); const wasmPath = tvmjs.wasmPath(); const wasmSource = fs.readFileSync(path.join(wasmPath, "tvmjs_runtime.wasm")); diff --git a/web/emcc/decorate_as_wasi.py b/web/emcc/decorate_as_wasi.py index bce0dbb80e9f..6d6b0a7b82dc 100644 --- a/web/emcc/decorate_as_wasi.py +++ b/web/emcc/decorate_as_wasi.py @@ -20,6 +20,7 @@ template_head = """ function EmccWASI() { +var asyncifyStubs = {}; """ template_tail = """ diff --git a/web/emcc/wasm_runtime.cc b/web/emcc/wasm_runtime.cc index be9704eaef99..8543361340e7 100644 --- a/web/emcc/wasm_runtime.cc +++ b/web/emcc/wasm_runtime.cc @@ -100,6 +100,11 @@ TVM_REGISTER_GLOBAL("testing.echo").set_body([](TVMArgs args, TVMRetValue* ret) *ret = args[0]; }); +TVM_REGISTER_GLOBAL("testing.call").set_body([](TVMArgs args, TVMRetValue* ret) { + (args[0].operator PackedFunc()) + .CallPacked(TVMArgs(args.values + 1, args.type_codes + 1, args.num_args - 1), ret); +}); + TVM_REGISTER_GLOBAL("testing.ret_string").set_body([](TVMArgs args, TVMRetValue* ret) { *ret = args[0].operator String(); }); diff --git a/web/emcc/webgpu_runtime.cc b/web/emcc/webgpu_runtime.cc index ce2a7cadb68e..1d7dbe0787b2 100644 --- a/web/emcc/webgpu_runtime.cc +++ b/web/emcc/webgpu_runtime.cc @@ -112,7 +112,11 @@ class WebGPUDeviceAPI : public DeviceAPI { LOG(FATAL) << "Not implemented"; } - void StreamSync(Device dev, TVMStreamHandle stream) final { LOG(FATAL) << "Not implemented"; } + void StreamSync(Device dev, TVMStreamHandle stream) final { + static const PackedFunc* func = runtime::Registry::Get("__asyncify.WebGPUWaitForTasks"); + ICHECK(func != nullptr) << "Stream sync inside c++ only supported in asyncify mode"; + (*func)(); + } void SetStream(Device dev, TVMStreamHandle stream) final { LOG(FATAL) << "Not implemented"; } diff --git a/web/src/artifact_cache.ts b/web/src/artifact_cache.ts index 394cda83bc43..ffb5011324f5 100644 --- a/web/src/artifact_cache.ts +++ b/web/src/artifact_cache.ts @@ -1,19 +1,37 @@ /* - Common Interface for the artifact cache -*/ + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/** + * Common Interface for the artifact cache + */ export interface ArtifactCacheTemplate { - /** - * fetch key url from cache - */ - fetchWithCache(url: string); + /** + * fetch key url from cache + */ + fetchWithCache(url: string); - /** - * check if cache has all keys in Cache - */ - hasAllKeys(keys: string[]); + /** + * check if cache has all keys in Cache + */ + hasAllKeys(keys: string[]); - /** - * Delete url in cache if url exists - */ - deleteInCache(url: string); + /** + * Delete url in cache if url exists + */ + deleteInCache(url: string); } diff --git a/web/src/asyncify.ts b/web/src/asyncify.ts new file mode 100644 index 000000000000..703dbbf80a10 --- /dev/null +++ b/web/src/asyncify.ts @@ -0,0 +1,227 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +// Helper tools to enable asynctify handling +// Thie following code is used to support wrapping of +// functins that can have async await calls in the backend runtime +// reference +// - https://kripken.github.io/blog/wasm/2019/07/16/asyncify.html +// - https://github.com/GoogleChromeLabs/asyncify +import { assert, isPromise } from "./support"; + +/** + * enums to check the current state of asynctify + */ +const enum AsyncifyStateKind { + None = 0, + Unwinding = 1, + Rewinding = 2 +} + +/** The start location of asynctify stack data */ +const ASYNCIFY_DATA_ADDR = 16; +/** The data start of stack rewind/unwind */ +const ASYNCIFY_DATA_START = ASYNCIFY_DATA_ADDR + 8; +/** The data end of stack rewind/unwind */ +const ASYNCIFY_DATA_END = 1024; + +/** Hold asynctify handler instance that runtime can use */ +export class AsyncifyHandler { + /** exports from wasm */ + private exports: Record; + /** current state kind */ + private state: AsyncifyStateKind = AsyncifyStateKind.None; + /** The stored value before unwind */ + private storedPromiseBeforeUnwind : Promise = null; + // NOTE: asynctify do not work with exceptions + // this implementation here is mainly for possible future compact + /** The stored value that is resolved */ + private storedValueBeforeRewind: any = null; + /** The stored exception */ + private storedExceptionBeforeRewind: any = null; + + constructor(exports: Record, memory: WebAssembly.Memory) { + this.exports = exports; + this.initMemory(memory); + } + + // NOTE: wrapImport and wrapExport are closely related to each other + // We mark the logical jump pt in comments to increase the readability + /** + * Whether the wasm enables asynctify + * @returns Whether the wasm enables asynctify + */ + enabled(): boolean { + return this.exports.asyncify_stop_rewind !== undefined; + } + + /** + * Get the current asynctify state + * + * @returns The current asynctify state + */ + getState(): AsyncifyStateKind { + return this.state; + } + + /** + * Wrap a function that can be used as import of the wasm asynctify layer + * + * @param func The input import function + * @returns The wrapped function that can be registered to the system + */ + wrapImport(func: (...args: Array) => any): (...args: Array) => any { + return (...args: any) => { + // this is being called second time + // where we are rewinding the stack + if (this.getState() == AsyncifyStateKind.Rewinding) { + // JUMP-PT-REWIND: rewind will jump to this pt + // while rewinding the stack + this.stopRewind(); + // the value has been resolved + if (this.storedValueBeforeRewind !== null) { + assert(this.storedExceptionBeforeRewind === null); + const result = this.storedValueBeforeRewind; + this.storedValueBeforeRewind = null; + return result; + } else { + assert(this.storedValueBeforeRewind === null); + const error = this.storedExceptionBeforeRewind; + this.storedExceptionBeforeRewind = null; + throw error; + } + } + // this function is being called for the first time + assert(this.getState() == AsyncifyStateKind.None); + + // call the function + const value = func(...args); + // if the value is promise + // we need to unwind the stack + // so the caller will be able to evaluate the promise + if (isPromise(value)) { + // The next code step is JUMP-PT-UNWIND in wrapExport + // The value will be passed to that pt through storedPromiseBeforeUnwind + // getState() == Unwinding and we will enter the while loop in wrapExport + this.startUnwind(); + assert(this.storedPromiseBeforeUnwind == null); + this.storedPromiseBeforeUnwind = value; + return undefined; + } else { + // The next code step is JUMP-PT-UNWIND in wrapExport + // normal value, we don't have to do anything + // getState() == None and we will exit while loop there + return value; + } + }; + } + + /** + * Warp an exported asynctify function so it can return promise + * + * @param func The input function + * @returns The wrapped async function + */ + wrapExport(func: (...args: Array) => any): (...args: Array) => Promise { + return async (...args: Array) => { + assert(this.getState() == AsyncifyStateKind.None); + + // call the original function + let result = func(...args); + + // JUMP-PT-UNWIND + // after calling the function + // the caller may hit a unwinding point depending on + // the if (isPromise(value)) condition in wrapImport + while (this.getState() == AsyncifyStateKind.Unwinding) { + this.stopUnwind(); + // try to resolve the promise that the internal requested + // we then store it into the temp value in storedValueBeforeRewind + // which then get passed onto the function(see wrapImport) + // that can return the value + const storedPromiseBeforeUnwind = this.storedPromiseBeforeUnwind; + this.storedPromiseBeforeUnwind = null; + assert(this.storedExceptionBeforeRewind === null); + assert(this.storedValueBeforeRewind == null); + + try { + this.storedValueBeforeRewind = await storedPromiseBeforeUnwind; + } catch (error) { + // the store exception + this.storedExceptionBeforeRewind = error; + } + assert(!isPromise(this.storedValueBeforeRewind)); + // because we called asynctify_stop_unwind,the state is now none + assert(this.getState() == AsyncifyStateKind.None); + + // re-enter the function, jump to JUMP-PT-REWIND in wrapImport + // the value will be passed to that point via storedValueBeforeRewind + // + // NOTE: we guarantee that if exception is throw the asynctify state + // will already be at None, this is because we will goto JUMP-PT-REWIND + // which will call aynctify_stop_rewind + this.startRewind(); + result = func(...args); + } + return result; + }; + } + + private startRewind() : void { + if (this.exports.asyncify_start_rewind === undefined) { + throw Error("Asynctify is not enabled, please compile with -s ASYNCIFY=1 in emcc"); + } + this.exports.asyncify_start_rewind(ASYNCIFY_DATA_ADDR); + this.state = AsyncifyStateKind.Rewinding; + } + + private stopRewind() : void { + if (this.exports.asyncify_stop_rewind === undefined) { + throw Error("Asynctify is not enabled, please compile with -s ASYNCIFY=1 in emcc"); + } + this.exports.asyncify_stop_rewind(); + this.state = AsyncifyStateKind.None; + } + + private startUnwind() : void { + if (this.exports.asyncify_start_unwind === undefined) { + throw Error("Asynctify is not enabled, please compile with -s ASYNCIFY=1 in emcc"); + } + this.exports.asyncify_start_unwind(ASYNCIFY_DATA_ADDR); + this.state = AsyncifyStateKind.Unwinding; + } + + private stopUnwind() : void { + if (this.exports.asyncify_stop_unwind === undefined) { + throw Error("Asynctify is not enabled, please compile with -s ASYNCIFY=1 in emcc"); + } + this.exports.asyncify_stop_unwind(); + this.state = AsyncifyStateKind.None; + } + /** + * Initialize the wasm memory to setup necessary meta-data + * for asynctify handling + * @param memory The memory ti + */ + private initMemory(memory: WebAssembly.Memory): void { + // Set the meta-data at address ASYNCTIFY_DATA_ADDR + new Int32Array(memory.buffer, ASYNCIFY_DATA_ADDR, 2).set( + [ASYNCIFY_DATA_START, ASYNCIFY_DATA_END] + ); + } +} diff --git a/web/src/runtime.ts b/web/src/runtime.ts index 6ef225526324..8df48c43a5f9 100644 --- a/web/src/runtime.ts +++ b/web/src/runtime.ts @@ -25,6 +25,7 @@ import { Disposable } from "./types"; import { Memory, CachedCallStack } from "./memory"; import { assert, StringToUint8Array } from "./support"; import { Environment } from "./environment"; +import { AsyncifyHandler } from "./asyncify"; import { FunctionInfo, WebGPUContext } from "./webgpu"; import { ArtifactCacheTemplate } from "./artifact_cache"; @@ -32,11 +33,18 @@ import * as compact from "./compact"; import * as ctypes from "./ctypes"; /** - * Type for PackedFunc inthe TVMRuntime. + * Type for PackedFunc in the TVMRuntime. */ export type PackedFunc = ((...args: any) => any) & Disposable & { _tvmPackedCell: PackedFuncCell }; +/** + * Type for AyncPackedFunc in TVMRuntime + * possibly may contain stack unwinding through Asynctify + */ +export type AsyncPackedFunc = ((...args: any) => Promise) & + Disposable & { _tvmPackedCell: PackedFuncCell }; + /** * @internal * FFI Library wrapper, maintains most runtime states. @@ -79,7 +87,6 @@ class FFILibrary implements Disposable { if (code != 0) { const msgPtr = (this.exports .TVMGetLastError as ctypes.FTVMGetLastError)(); - console.log("Here"); throw new Error("TVMError: " + this.memory.loadCString(msgPtr)); } } @@ -1057,6 +1064,7 @@ export class Instance implements Disposable { private env: Environment; private objFactory: Map; private ctx: RuntimeContext; + private asyncifyHandler: AsyncifyHandler; private initProgressCallback: Array = []; /** @@ -1099,6 +1107,7 @@ export class Instance implements Disposable { this.lib = new FFILibrary(wasmInstance, env.imports); this.memory = this.lib.memory; this.exports = this.lib.exports; + this.asyncifyHandler = new AsyncifyHandler(this.exports, this.memory.memory); this.objFactory = new Map(); this.ctx = new RuntimeContext( (name: string) => { @@ -1140,6 +1149,14 @@ export class Instance implements Disposable { return results; } + /** + * Check whether we enabled asyncify mode + * @returns The asynctify mode toggle + */ + asyncifyEnabled(): boolean { + return this.asyncifyHandler.enabled(); + } + dispose(): void { // order matters // ctx release goes back into lib. @@ -1922,13 +1939,55 @@ export class Instance implements Disposable { } this.objFactory.set(typeIndex, func); } + + /** + * Wrap a function obtained from tvm runtime as AsyncPackedFunc + * through the asyncify mechanism + * + * You only need to call it if the function may contain callback into async + * JS function via asynctify. A common one can be GPU synchronize. + * + * It is always safe to wrap any function as Asynctify, however you do need + * to make sure you use await when calling the funciton. + * + * @param func The PackedFunc. + * @returns The wrapped AsyncPackedFunc + */ + wrapAsyncifyPackedFunc(func: PackedFunc): AsyncPackedFunc { + const asyncFunc = this.asyncifyHandler.wrapExport(func) as AsyncPackedFunc; + asyncFunc.dispose = func.dispose; + asyncFunc._tvmPackedCell = func._tvmPackedCell; + return asyncFunc; + } + + /** + * Register async function as asynctify callable in global environment. + * + * @param name The name of the function. + * @param func function to be registered. + * @param override Whether overwrite function in existing registry. + * + * @note This function is handled via asynctify mechanism + * The wasm needs to be compiled with Asynctify + */ + registerAsyncifyFunc( + name: string, + func: (...args: Array) => Promise, + override = false + ): void { + const asyncWrapped = this.asyncifyHandler.wrapImport(func); + this.registerFunc(name, asyncWrapped, override); + } + /** * Register an asyncfunction to be global function in the server. + * * @param name The name of the function. * @param func function to be registered. * @param override Whether overwrite function in existing registry. * - * @note The async function will only be used for serving remote calls in the rpc. + * @note The async function will only be used for serving remote calls in the rpc + * These functions contains explicit continuation */ registerAsyncServerFunc( name: string, @@ -2036,6 +2095,11 @@ export class Instance implements Disposable { this.registerAsyncServerFunc("wasm.WebGPUWaitForTasks", async () => { await webGPUContext.sync(); }); + if (this.asyncifyHandler.enabled()) { + this.registerAsyncifyFunc("__asyncify.WebGPUWaitForTasks", async () => { + await webGPUContext.sync(); + }); + } this.lib.webGPUContext = webGPUContext; } @@ -2281,7 +2345,6 @@ export class Instance implements Disposable { // normal return path // recycle all js object value in function unless we want to retain them. this.ctx.endScope(); - if (rv !== undefined && rv !== null) { const stack = lib.getOrAllocCallStack(); const valueOffset = stack.allocRawBytes(SizeOf.TVMValue); @@ -2320,8 +2383,10 @@ export class Instance implements Disposable { const rvaluePtr = stack.ptrFromOffset(rvalueOffset); const rcodePtr = stack.ptrFromOffset(rcodeOffset); - // commit to wasm memory, till rvalueOffset (the return value don't need to be committed) - stack.commitToWasmMemory(rvalueOffset); + // pre-store the rcode to be null, in case caller unwind + // and not have chance to reset this rcode. + stack.storeI32(rcodeOffset, ArgTypeCode.Null); + stack.commitToWasmMemory(); this.lib.checkCall( (this.exports.TVMFuncCall as ctypes.FTVMFuncCall)( diff --git a/web/src/support.ts b/web/src/support.ts index 18748c2c85ba..b03fa363cdce 100644 --- a/web/src/support.ts +++ b/web/src/support.ts @@ -17,6 +17,18 @@ * under the License. */ + +/** + * Check if value is a promise type + * + * @param value The input value + * @returns Whether value is promise + */ +export function isPromise(value: any): boolean { + return value !== undefined && ( + typeof value == "object" || typeof value == "function" + ) && typeof value.then == "function"; +} /** * Convert string to Uint8array. * @param str The string. diff --git a/web/tests/node/test_packed_func.js b/web/tests/node/test_packed_func.js index f5c0ac6c2fad..e1d070f0e473 100644 --- a/web/tests/node/test_packed_func.js +++ b/web/tests/node/test_packed_func.js @@ -22,6 +22,9 @@ const fs = require("fs"); const assert = require("assert"); const tvmjs = require("../../dist/tvmjs.bundle") +// for now skip exception testing +// as it may not be compatible with asyncify +const exceptionEnabled = false; const wasmPath = tvmjs.wasmPath(); const wasmSource = fs.readFileSync(path.join(wasmPath, "tvmjs_runtime.wasm")); @@ -127,6 +130,8 @@ test("RegisterGlobal", () => { }); test("ExceptionPassing", () => { + if (!exceptionEnabled) return; + tvm.beginScope(); tvm.registerFunc("throw_error", function (msg) { throw Error(msg); @@ -141,6 +146,31 @@ test("ExceptionPassing", () => { tvm.endScope(); }); + +test("AsyncifyFunc", async () => { + if (!tvm.asyncifyEnabled()) { + console.log("Skip asyncify tests as it is not enabled.."); + return; + } + tvm.beginScope(); + tvm.registerAsyncifyFunc("async_sleep_echo", async function (x) { + await new Promise(resolve => setTimeout(resolve, 10)); + return x; + }); + let fecho = tvm.wrapAsyncifyPackedFunc( + tvm.getGlobalFunc("async_sleep_echo") + ); + let fcall = tvm.wrapAsyncifyPackedFunc( + tvm.getGlobalFunc("testing.call") + ); + assert((await fecho(1)) == 1); + assert((await fecho(2)) == 2); + assert((await fcall(fecho, 2) == 2)); + tvm.endScope(); + assert(fecho._tvmPackedCell.getHandle(false) == 0); + assert(fcall._tvmPackedCell.getHandle(false) == 0); +}); + test("NDArrayCbArg", () => { tvm.beginScope(); let use_count = tvm.getGlobalFunc("testing.object_use_count");