From 980ffde2445a077252c781ca1d479ff709e38bb4 Mon Sep 17 00:00:00 2001 From: 007gzs <007gzs@gmail.com> Date: Wed, 7 Aug 2024 20:06:11 +0800 Subject: [PATCH] Optimize WASM Rust SDK's body caching logic. (#1181) --- plugins/wasm-rust/src/plugin_wrapper.rs | 91 ++++++++++++++----------- 1 file changed, 50 insertions(+), 41 deletions(-) diff --git a/plugins/wasm-rust/src/plugin_wrapper.rs b/plugins/wasm-rust/src/plugin_wrapper.rs index 8291d22e0a..8dcf03b81e 100644 --- a/plugins/wasm-rust/src/plugin_wrapper.rs +++ b/plugins/wasm-rust/src/plugin_wrapper.rs @@ -22,10 +22,10 @@ pub trait RootContextWrapper: RootContext where PluginConfig: Default + DeserializeOwned + 'static + Clone, { - // fn create_http_context(&self, _context_id: u32) -> Option> { - fn create_http_context_use_wrapper(&self, _context_id: u32) -> Option> { + // fn create_http_context(&self, context_id: u32) -> Option> { + fn create_http_context_use_wrapper(&self, context_id: u32) -> Option> { // trait 继承没法重写 RootContext 的 create_http_context,先写个函数让上层调下吧 - match self.create_http_context_wrapper(_context_id) { + match self.create_http_context_wrapper(context_id) { Some(http_context) => Some(Box::new(PluginHttpWrapper::new( self.rule_matcher(), http_context, @@ -58,11 +58,17 @@ pub trait HttpContextWrapper: HttpContext { fn on_http_response_body_ok(&mut self, _res_body: &Bytes) -> Action { Action::Continue } + fn replace_http_request_body(&mut self, body: &[u8]) { + self.set_http_request_body(0, i32::MAX as usize, body) + } + fn replace_http_response_body(&mut self, body: &[u8]) { + self.set_http_response_body(0, i32::MAX as usize, body) + } } pub struct PluginHttpWrapper { req_headers: MultiMap, - req_body: Bytes, - res_body: Bytes, + req_body_len: usize, + res_body_len: usize, config: Option, rule_matcher: SharedRuleMatcher, http_content: Box>, @@ -74,8 +80,8 @@ impl PluginHttpWrapper { ) -> Self { PluginHttpWrapper { req_headers: MultiMap::new(), - req_body: Bytes::new(), - res_body: Bytes::new(), + req_body_len: 0, + res_body_len: 0, config: None, rule_matcher: rule_matcher.clone(), http_content, @@ -123,12 +129,9 @@ impl HttpContext for PluginHttpWrapper where PluginConfig: Default + DeserializeOwned + Clone, { - fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action { + fn on_http_request_headers(&mut self, num_headers: usize, end_of_stream: bool) -> Action { let binding = self.rule_matcher.borrow(); - self.config = match binding.get_match_config() { - None => None, - Some(config) => Some(config.1.clone()), - }; + self.config = binding.get_match_config().map(|config| config.1.clone()); for (k, v) in self.get_http_request_headers() { self.req_headers.insert(k, v); } @@ -137,7 +140,7 @@ where } let ret = self .http_content - .on_http_request_headers(_num_headers, _end_of_stream); + .on_http_request_headers(num_headers, end_of_stream); if ret != Action::Continue { return ret; } @@ -145,53 +148,59 @@ where .on_http_request_headers_ok(&self.req_headers) } - fn on_http_request_body(&mut self, _body_size: usize, _end_of_stream: bool) -> Action { - let mut ret = self + fn on_http_request_body(&mut self, body_size: usize, end_of_stream: bool) -> Action { + self.req_body_len += body_size; + if !end_of_stream { + return Action::Pause; + } + let ret = self .http_content - .on_http_request_body(_body_size, _end_of_stream); - if !self.http_content.cache_request_body() { + .on_http_request_body(self.req_body_len, end_of_stream); + if ret != Action::Continue || !self.http_content.cache_request_body() { return ret; } - if _body_size > 0 { - if let Some(body) = self.get_http_request_body(0, _body_size) { - self.req_body.extend(body) + let mut req_body = Bytes::new(); + if self.req_body_len > 0 { + if let Some(body) = self.get_http_request_body(0, self.req_body_len) { + req_body = body; } } - if _end_of_stream && ret == Action::Continue { - ret = self.http_content.on_http_request_body_ok(&self.req_body); - } - ret + self.http_content.on_http_request_body_ok(&req_body) } - fn on_http_request_trailers(&mut self, _num_trailers: usize) -> Action { - self.http_content.on_http_request_trailers(_num_trailers) + fn on_http_request_trailers(&mut self, num_trailers: usize) -> Action { + self.http_content.on_http_request_trailers(num_trailers) } - fn on_http_response_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action { + fn on_http_response_headers(&mut self, num_headers: usize, end_of_stream: bool) -> Action { self.http_content - .on_http_response_headers(_num_headers, _end_of_stream) + .on_http_response_headers(num_headers, end_of_stream) } - fn on_http_response_body(&mut self, _body_size: usize, _end_of_stream: bool) -> Action { - let mut ret = self + fn on_http_response_body(&mut self, body_size: usize, end_of_stream: bool) -> Action { + self.res_body_len += body_size; + + if !end_of_stream { + return Action::Pause; + } + let ret = self .http_content - .on_http_response_body(_body_size, _end_of_stream); - if !self.http_content.cache_response_body() { + .on_http_response_body(self.res_body_len, end_of_stream); + if ret != Action::Continue || !self.http_content.cache_response_body() { return ret; } - if _body_size > 0 { - if let Some(body) = self.get_http_response_body(0, _body_size) { - self.res_body.extend(body); + + let mut res_body = Bytes::new(); + if self.res_body_len > 0 { + if let Some(body) = self.get_http_response_body(0, self.res_body_len) { + res_body = body; } } - if _end_of_stream && ret == Action::Continue { - ret = self.http_content.on_http_response_body_ok(&self.res_body); - } - ret + self.http_content.on_http_response_body_ok(&res_body) } - fn on_http_response_trailers(&mut self, _num_trailers: usize) -> Action { - self.http_content.on_http_response_trailers(_num_trailers) + fn on_http_response_trailers(&mut self, num_trailers: usize) -> Action { + self.http_content.on_http_response_trailers(num_trailers) } fn on_log(&mut self) {