From 48cbb11c111f9033a036c22a846fb471176f20ae Mon Sep 17 00:00:00 2001 From: Michael-F-Bryan Date: Mon, 21 Aug 2023 22:58:46 +0800 Subject: [PATCH] Stubbed out caching --- src/lib.rs | 1 - src/module_cache.rs | 177 ---------------------------- src/tasks/pool.rs | 12 +- src/tasks/{worker2.js => worker.js} | 0 src/tasks/worker.rs | 40 +++++-- 5 files changed, 40 insertions(+), 190 deletions(-) delete mode 100644 src/module_cache.rs rename src/tasks/{worker2.js => worker.js} (100%) diff --git a/src/lib.rs b/src/lib.rs index 828adf81..e781b96c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,7 +3,6 @@ wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser); mod facade; mod instance; -mod module_cache; mod net; mod run; mod runtime; diff --git a/src/module_cache.rs b/src/module_cache.rs deleted file mode 100644 index 7625668f..00000000 --- a/src/module_cache.rs +++ /dev/null @@ -1,177 +0,0 @@ -use std::{cell::RefCell, collections::HashMap}; - -use base64::{engine::general_purpose::STANDARD, Engine as _}; -use bytes::Bytes; -use wasm_bindgen::{JsCast, JsValue}; -use wasmer::{Engine, Module}; -use wasmer_wasix::runtime::module_cache::{CacheError, ModuleHash}; - -std::thread_local! { - static CACHED_MODULES: RefCell> - = RefCell::new(HashMap::new()); -} - -/// A cache that will automatically share cached modules with other web -/// workers. -#[derive(Debug, Default)] -pub(crate) struct ModuleCache {} - -impl ModuleCache { - fn cache_in_main(&self, key: ModuleHash, module: &Module, deterministic_id: &str) {} - - pub fn export() -> JsValue { - CACHED_MODULES.with(|m| { - // Annotation is here to prevent spurious IDE warnings. - #[allow(unused_unsafe)] - unsafe { - let entries = js_sys::Array::new_with_length(m.borrow().len() as u32); - - for (i, ((key, deterministic_id), module)) in m.borrow().iter().enumerate() { - let entry = js_sys::Object::new(); - - js_sys::Reflect::set( - &entry, - &"key".into(), - &JsValue::from(STANDARD.encode(key.as_bytes())), - ) - .unwrap(); - - js_sys::Reflect::set( - &entry, - &"deterministic_id".into(), - &JsValue::from(deterministic_id.clone()), - ) - .unwrap(); - - js_sys::Reflect::set(&entry, &"module".into(), &JsValue::from(module.clone())) - .unwrap(); - - let module_bytes = Box::new(module.serialize().unwrap()); - let module_bytes = Box::into_raw(module_bytes); - js_sys::Reflect::set( - &entry, - &"module_bytes".into(), - &JsValue::from(module_bytes as u32), - ) - .unwrap(); - - entries.set(i as u32, JsValue::from(entry)); - } - - JsValue::from(entries) - } - }) - } - - pub fn import(cache: JsValue) { - CACHED_MODULES.with(|m| { - // Annotation is here to prevent spurious IDE warnings. - #[allow(unused_unsafe)] - unsafe { - let entries = cache.dyn_into::().unwrap(); - - for i in 0..entries.length() { - let entry = entries.get(i); - - let key = js_sys::Reflect::get(&entry, &"key".into()).unwrap(); - let key = JsValue::as_string(&key).unwrap(); - let key = STANDARD.decode(key).unwrap(); - let key: [u8; 32] = key.try_into().unwrap(); - let key = ModuleHash::from_bytes(key); - - let deterministic_id = - js_sys::Reflect::get(&entry, &"deterministic_id".into()).unwrap(); - let deterministic_id = JsValue::as_string(&deterministic_id).unwrap(); - - let module_bytes = - js_sys::Reflect::get(&entry, &"module_bytes".into()).unwrap(); - let module_bytes: u32 = module_bytes.as_f64().unwrap() as u32; - let module_bytes = module_bytes as *mut Bytes; - let module_bytes = unsafe { Box::from_raw(module_bytes) }; - - let module = js_sys::Reflect::get(&entry, &"module".into()).unwrap(); - let module = module.dyn_into::().unwrap(); - let module: Module = (module, *module_bytes).into(); - - let key = (key, deterministic_id); - m.borrow_mut().insert(key, module.clone()); - } - } - }); - } - - pub fn lookup(&self, key: ModuleHash, deterministic_id: &str) -> Option { - let key = (key, deterministic_id.to_string()); - CACHED_MODULES.with(|m| m.borrow().get(&key).cloned()) - } - - /// Add an item to the cache, returning whether that item already exists. - pub fn insert(&self, key: ModuleHash, module: &Module, deterministic_id: &str) -> bool { - let key = (key, deterministic_id.to_string()); - let previous_value = CACHED_MODULES.with(|m| m.borrow_mut().insert(key, module.clone())); - previous_value.is_none() - } -} - -#[async_trait::async_trait] -impl wasmer_wasix::runtime::module_cache::ModuleCache for ModuleCache { - async fn load(&self, key: ModuleHash, engine: &Engine) -> Result { - match self.lookup(key, engine.deterministic_id()) { - Some(m) => { - tracing::debug!("Cache hit!"); - Ok(m) - } - None => Err(CacheError::NotFound), - } - } - - async fn save( - &self, - key: ModuleHash, - engine: &Engine, - module: &Module, - ) -> Result<(), CacheError> { - let already_exists = self.insert(key, module, engine.deterministic_id()); - - // We also send the module to the main thread via a postMessage - // which they relays it to all the web works - if !already_exists { - self.cache_in_main(key, module, engine.deterministic_id()); - } - - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use wasmer_wasix::runtime::module_cache::ModuleCache as _; - - const ADD_WAT: &[u8] = br#"( - module - (func - (export "add") - (param $x i64) - (param $y i64) - (result i64) - (i64.add (local.get $x) (local.get $y))) - )"#; - - #[wasm_bindgen_test::wasm_bindgen_test] - async fn round_trip_via_cache() { - let engine = Engine::default(); - let module = Module::new(&engine, ADD_WAT).unwrap(); - let cache = ModuleCache::default(); - let key = ModuleHash::from_bytes([0; 32]); - - cache.save(key, &engine, &module).await.unwrap(); - let round_tripped = cache.load(key, &engine).await.unwrap(); - - let exports: Vec<_> = round_tripped - .exports() - .map(|export| export.name().to_string()) - .collect(); - assert_eq!(exports, ["add"]); - } -} diff --git a/src/tasks/pool.rs b/src/tasks/pool.rs index 55519229..73d6e246 100644 --- a/src/tasks/pool.rs +++ b/src/tasks/pool.rs @@ -8,6 +8,7 @@ use std::{ use anyhow::{Context, Error}; use futures::{future::LocalBoxFuture, Future}; use tokio::sync::mpsc::{self, UnboundedSender}; +use wasm_bindgen::{JsCast, JsValue}; use wasmer_wasix::{ runtime::{resolver::WebcHash, task_manager::TaskWasm}, WasiThreadError, @@ -164,7 +165,7 @@ impl Scheduler { move_worker(worker_id, &mut self.busy, &mut self.idle) } Message::CacheModule { hash, module } => { - let module = js_sys::WebAssembly::Module::from(module); + let module: js_sys::WebAssembly::Module = JsValue::from(module).unchecked_into(); self.cached_modules.insert(hash, module.clone()); for worker in self.idle.iter().chain(self.busy.iter()) { @@ -242,8 +243,13 @@ impl Scheduler { self.next_id += 1; let handle = WorkerHandle::spawn(id, self.mailbox.clone())?; - for (hash, module) in &self.cached_modules { - todo!(); + // Prime the worker's module cache + for (&hash, module) in &self.cached_modules { + let msg = PostMessagePayload::CacheModule { + hash, + module: module.clone(), + }; + handle.send(msg)?; } Ok(handle) diff --git a/src/tasks/worker2.js b/src/tasks/worker.js similarity index 100% rename from src/tasks/worker2.js rename to src/tasks/worker.js diff --git a/src/tasks/worker.rs b/src/tasks/worker.rs index 248394f8..71c83c0e 100644 --- a/src/tasks/worker.rs +++ b/src/tasks/worker.rs @@ -1,4 +1,4 @@ -use std::pin::Pin; +use std::{mem::ManuallyDrop, pin::Pin}; use anyhow::{Context, Error}; use futures::Future; @@ -9,6 +9,7 @@ use wasm_bindgen::{ prelude::{wasm_bindgen, Closure}, JsCast, JsValue, }; +use wasmer_wasix::runtime::resolver::WebcHash; use crate::tasks::pool::{Message, PostMessagePayload}; @@ -118,7 +119,7 @@ static WORKER_URL: Lazy = Lazy::new(|| { tracing::debug!(import_url = IMPORT_META_URL.as_str()); - let script = include_str!("worker2.js").replace("$IMPORT_META_URL", &IMPORT_META_URL); + let script = include_str!("worker.js").replace("$IMPORT_META_URL", &IMPORT_META_URL); let blob = web_sys::Blob::new_with_u8_array_sequence_and_options( Array::from_iter([Uint8Array::from(script.as_bytes())]).as_ref(), @@ -136,30 +137,45 @@ static WORKER_URL: Lazy = Lazy::new(|| { #[derive(serde::Serialize, serde::Deserialize)] #[serde(tag = "type", rename_all = "kebab-case")] pub(crate) enum PostMessagePayloadRepr { - SpawnAsync { ptr: usize }, - SpawnBlocking { ptr: usize }, + SpawnAsync { + ptr: usize, + }, + SpawnBlocking { + ptr: usize, + }, + #[serde(skip)] + CacheModule { + hash: WebcHash, + module: js_sys::WebAssembly::Module, + }, } impl PostMessagePayloadRepr { pub(crate) unsafe fn reconstitute(self) -> PostMessagePayload { - match self { + let this = ManuallyDrop::new(self); + + match &*this { PostMessagePayloadRepr::SpawnAsync { ptr } => { let boxed = Box::from_raw( - ptr as *mut Box< + *ptr as *mut Box< dyn FnOnce() -> Pin + 'static>> + Send + 'static, >, ); - std::mem::forget(self); PostMessagePayload::SpawnAsync(*boxed) } PostMessagePayloadRepr::SpawnBlocking { ptr } => { - let boxed = Box::from_raw(ptr as *mut Box); - std::mem::forget(self); + let boxed = Box::from_raw(*ptr as *mut Box); PostMessagePayload::SpawnBlocking(*boxed) } + PostMessagePayloadRepr::CacheModule { hash, ref module } => { + PostMessagePayload::CacheModule { + hash: std::ptr::read(hash), + module: std::ptr::read(module), + } + } } } } @@ -182,6 +198,9 @@ impl From for PostMessagePayloadRepr { ptr: Box::into_raw(boxed) as usize, } } + PostMessagePayload::CacheModule { hash, module } => { + PostMessagePayloadRepr::CacheModule { hash, module } + } } } } @@ -291,6 +310,9 @@ pub async fn __worker_handle_message(msg: JsValue) -> Result<(), crate::utils::E match msg.reconstitute() { PostMessagePayload::SpawnAsync(thunk) => thunk().await, PostMessagePayload::SpawnBlocking(thunk) => thunk(), + PostMessagePayload::CacheModule { hash, .. } => { + tracing::warn!(%hash, "XXX Caching module"); + } } }