From 1c32ff0c07a5527e85c8ff14ef8129ad8507d8d7 Mon Sep 17 00:00:00 2001 From: Kevin Hoffman Date: Tue, 26 Sep 2023 08:45:26 -0400 Subject: [PATCH 1/2] Initial implementation of automatic service discovery Signed-off-by: Kevin Hoffman --- nats/Cargo.lock | 125 ++++++++++---------------- nats/Cargo.toml | 15 ++-- nats/src/config.rs | 145 ++++++++++++++++++++++++++++++ nats/src/main.rs | 210 +++++++++++++++++++++---------------------- nats/src/services.rs | 30 +++++++ 5 files changed, 333 insertions(+), 192 deletions(-) create mode 100644 nats/src/config.rs create mode 100644 nats/src/services.rs diff --git a/nats/Cargo.lock b/nats/Cargo.lock index c66ed7ec..811f3b3f 100644 --- a/nats/Cargo.lock +++ b/nats/Cargo.lock @@ -129,9 +129,9 @@ checksum = "ca6c635b3aa665c649ad1415f1573c85957dfa47690ec27aebe7ec17efe3c643" [[package]] name = "async-nats" -version = "0.30.0" +version = "0.31.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94e3e851ddf3b62be8a8085e1e453968df9cdbf990a37bbb589b5b4f587c68d7" +checksum = "8257238e2a3629ee5618502a75d1b91f8017c24638c75349fc8d2d80cf1f7c4c" dependencies = [ "base64 0.21.2", "bytes", @@ -139,7 +139,7 @@ dependencies = [ "http", "itoa", "memchr", - "nkeys 0.3.0", + "nkeys", "nuid 0.3.2", "once_cell", "rand", @@ -675,19 +675,6 @@ dependencies = [ "cfg-if", ] -[[package]] -name = "env_logger" -version = "0.8.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a19187fea3ac7e84da7dacf48de0c45d63c6a76f9490dae389aead16c243fce3" -dependencies = [ - "atty", - "humantime", - "log", - "regex", - "termcolor", -] - [[package]] name = "env_logger" version = "0.10.0" @@ -1302,21 +1289,6 @@ version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5ce46fe64a9d73be07dcbe690a38ce1b293be448fd8ce1e6c1b8062c9f72c6a" -[[package]] -name = "nkeys" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e66a7cd1358277b2a6f77078e70aea7315ff2f20db969cc61153103ec162594" -dependencies = [ - "byteorder 1.4.3", - "data-encoding", - "ed25519-dalek", - "getrandom", - "log", - "rand", - "signatory", -] - [[package]] name = "nkeys" version = "0.3.0" @@ -1478,12 +1450,6 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" -[[package]] -name = "parity-wasm" -version = "0.42.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be5e13c266502aadf83426d87d81a0f5d1ef45b8027f5a471c360abfe4bfae92" - [[package]] name = "parking_lot" version = "0.12.1" @@ -2518,7 +2484,19 @@ dependencies = [ "serde", "serde_spanned", "toml_datetime", - "toml_edit", + "toml_edit 0.19.14", +] + +[[package]] +name = "toml" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c226a7bba6d859b63c92c4b4fe69c5b6b72d0cb897dbc8e6012298e6154cb56e" +dependencies = [ + "serde", + "serde_spanned", + "toml_datetime", + "toml_edit 0.20.0", ] [[package]] @@ -2543,6 +2521,19 @@ dependencies = [ "winnow", ] +[[package]] +name = "toml_edit" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ff63e60a958cefbb518ae1fd6566af80d9d4be430a33f3723dfc47d1d411d95" +dependencies = [ + "indexmap 2.0.0", + "serde", + "serde_spanned", + "toml_datetime", + "winnow", +] + [[package]] name = "tonic" version = "0.6.2" @@ -2819,38 +2810,17 @@ dependencies = [ [[package]] name = "wascap" -version = "0.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32d1cfad67501627ac9344cbd89be80d2d5ebc98ef3b86862041cb26a280081f" -dependencies = [ - "base64 0.13.1", - "data-encoding", - "env_logger 0.8.4", - "humantime", - "lazy_static", - "log", - "nkeys 0.2.0", - "nuid 0.3.2", - "parity-wasm", - "ring", - "serde", - "serde_derive", - "serde_json", -] - -[[package]] -name = "wascap" -version = "0.11.0" +version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c38c03de923f80027cb62281f520bd97245a38eeb85a53f435a110509d044551" +checksum = "69a594087cfb6d023c1ac1e3ca3390b3461f7be344e3e291c2b6140486ba2221" dependencies = [ "base64 0.13.1", "data-encoding", - "env_logger 0.10.0", + "env_logger", "humantime", "lazy_static", "log", - "nkeys 0.3.0", + "nkeys", "nuid 0.4.1", "ring", "serde", @@ -2972,9 +2942,9 @@ dependencies = [ [[package]] name = "wasmbus-rpc" -version = "0.14.0" +version = "0.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2da4e25698e9e15af56e9bb510d233495c4ec9c6bbbf13d046205190dd4ab935" +checksum = "d17417841089a43a945a31ddfcc0062704a97625fb7ecbfbe91df043d930f05a" dependencies = [ "async-nats", "async-trait", @@ -2987,7 +2957,7 @@ dependencies = [ "lazy_static", "minicbor 0.17.1", "minicbor-ser", - "nkeys 0.3.0", + "nkeys", "once_cell", "opentelemetry", "opentelemetry-otlp", @@ -3005,16 +2975,16 @@ dependencies = [ "tracing-opentelemetry", "tracing-subscriber", "uuid", - "wascap 0.11.0", + "wascap", "wasmbus-macros", "weld-codegen", ] [[package]] name = "wasmcloud-interface-messaging" -version = "0.10.0" +version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "11608f7ac598493d0d8ebbd089c93a7037e013dea108e282f47f1cf1096fa658" +checksum = "6fe861c87bc80f3670237aefa9530fb7b56ea7b98e37869886bdd5d35d3b6b3e" dependencies = [ "async-trait", "log", @@ -3027,9 +2997,9 @@ dependencies = [ [[package]] name = "wasmcloud-interface-testing" -version = "0.9.0" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f673d4742a3199016dd7941f776b1b1d8be42ca4809f7ef00e0a7bd18e93536f" +checksum = "c2eabf80d5df3891ad8dd7676decdc41366324abf0f44ef341ee6687c7369086" dependencies = [ "async-trait", "regex", @@ -3043,7 +3013,7 @@ dependencies = [ [[package]] name = "wasmcloud-provider-nats" -version = "0.17.3" +version = "0.18.0" dependencies = [ "anyhow", "async-nats", @@ -3054,6 +3024,7 @@ dependencies = [ "chrono", "crossbeam", "futures", + "lazy_static", "once_cell", "rmp-serde", "serde", @@ -3061,11 +3032,11 @@ dependencies = [ "serde_json", "thiserror", "tokio", - "toml 0.5.11", + "toml 0.8.0", "tracing", "tracing-futures", "tracing-subscriber", - "wascap 0.8.0", + "wascap", "wasmbus-rpc", "wasmcloud-interface-messaging", "wasmcloud-test-util", @@ -3073,16 +3044,16 @@ dependencies = [ [[package]] name = "wasmcloud-test-util" -version = "0.10.0" +version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90fdfd194000a791d207ee0932acba35dfca35b2bae584317b315a01550ae816" +checksum = "57068cab9cc6ae1d4984b8c46cd5ac073f17c3035ffa8a9ab419080efd581170" dependencies = [ "anyhow", "async-trait", "base64 0.21.2", "futures", "log", - "nkeys 0.3.0", + "nkeys", "regex", "serde", "serde_bytes", diff --git a/nats/Cargo.toml b/nats/Cargo.toml index 930a8516..4a8b18eb 100644 --- a/nats/Cargo.toml +++ b/nats/Cargo.toml @@ -1,10 +1,10 @@ [package] name = "wasmcloud-provider-nats" -version = "0.17.3" +version = "0.18.0" edition = "2021" [dependencies] -async-nats = "0.30" +async-nats = {version = "0.31", features =["service"]} async-trait = "0.1" atty = "0.2" base64 = "0.13" @@ -19,19 +19,20 @@ serde_json = "1.0" serde = {version = "1.0", features = ["derive"] } thiserror = "1.0" tokio = { version = "1", features = ["full"] } -toml = "0.5" +toml = "0.8.0" tracing = "0.1" tracing-futures = "0.2" tracing-subscriber = { version = "0.3", features = ["env-filter"] } -wascap = "0.8.0" +wascap = "0.11.1" anyhow = "1.0.69" +lazy_static = "1.4.0" -wasmbus-rpc = { version = "0.14", features = [ "otel" ] } -wasmcloud-interface-messaging = "0.10" +wasmbus-rpc = { version = "0.15", features = [ "otel" ] } +wasmcloud-interface-messaging = "0.11" # test dependencies [dev-dependencies] -wasmcloud-test-util = "0.10" +wasmcloud-test-util = "0.11" [[bin]] name = "nats_messaging" diff --git a/nats/src/config.rs b/nats/src/config.rs new file mode 100644 index 00000000..cb8d7875 --- /dev/null +++ b/nats/src/config.rs @@ -0,0 +1,145 @@ +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use wasmbus_rpc::error::{RpcError, RpcResult}; + +const DEFAULT_NATS_URI: &str = "0.0.0.0:4222"; +const ENV_NATS_SUBSCRIPTION: &str = "SUBSCRIPTION"; +const ENV_NATS_URI: &str = "URI"; +const ENV_NATS_CLIENT_JWT: &str = "CLIENT_JWT"; +const ENV_NATS_CLIENT_SEED: &str = "CLIENT_SEED"; +const ENV_SERVICE_NAME: &str = "SERVICE_NAME"; +const ENV_SERVICE_ENDPOINTS: &str = "SERVICE_ENDPOINTS"; +const ENV_SERVICE_DESCRIPTION: &str = "SERVICE_DESCRIPTION"; +const ENV_SERVICE_VERSION: &str = "SERVICE_VERSION"; + +/// Configuration for connecting a nats client. +/// More options are available if you use the json than variables in the values string map. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub(crate) struct ConnectionConfig { + /// list of topics to subscribe to + #[serde(default)] + pub subscriptions: Vec, + #[serde(default)] + pub cluster_uris: Vec, + #[serde(default)] + pub auth_jwt: Option, + #[serde(default)] + pub auth_seed: Option, + + #[serde(default)] + pub service_name: Option, + #[serde(default)] + pub service_description: Option, + #[serde(default)] + pub service_endpoints: Option>, + #[serde(default)] + pub service_version: Option, + + /// ping interval in seconds + #[serde(default)] + pub ping_interval_sec: Option, +} + +impl ConnectionConfig { + pub fn merge(&self, extra: &ConnectionConfig) -> ConnectionConfig { + let mut out = self.clone(); + if !extra.subscriptions.is_empty() { + out.subscriptions = extra.subscriptions.clone(); + } + // If the default configuration has a URL in it, and then the link definition + // also provides a URL, the assumption is to replace/override rather than combine + // the two into a potentially incompatible set of URIs + if !extra.cluster_uris.is_empty() { + out.cluster_uris = extra.cluster_uris.clone(); + } + if extra.auth_jwt.is_some() { + out.auth_jwt = extra.auth_jwt.clone() + } + if extra.auth_seed.is_some() { + out.auth_seed = extra.auth_seed.clone() + } + if extra.ping_interval_sec.is_some() { + out.ping_interval_sec = extra.ping_interval_sec.clone() + } + if extra.service_name.is_some() { + out.service_name = extra.service_name.clone(); + out.service_description = extra.service_description.clone(); + out.service_endpoints = extra.service_endpoints.clone(); + out.service_version = extra.service_version.clone(); + } + out + } +} + +impl Default for ConnectionConfig { + fn default() -> ConnectionConfig { + ConnectionConfig { + subscriptions: vec![], + cluster_uris: vec![DEFAULT_NATS_URI.to_string()], + auth_jwt: None, + auth_seed: None, + ping_interval_sec: None, + service_description: None, + service_endpoints: None, + service_name: None, + service_version: None, + } + } +} + +impl ConnectionConfig { + pub fn new_from(vs: &HashMap) -> RpcResult { + let mut values = HashMap::::new(); + for (k, v) in vs { + values.insert(k.to_ascii_uppercase(), v.to_string()); + } + + let mut config = if let Some(config_b64) = values.get("config_b64") { + let bytes = base64::decode(config_b64.as_bytes()).map_err(|e| { + RpcError::InvalidParameter(format!("invalid base64 encoding: {}", e)) + })?; + serde_json::from_slice::(&bytes) + .map_err(|e| RpcError::InvalidParameter(format!("corrupt config_b64: {}", e)))? + } else if let Some(config) = values.get("config_json") { + serde_json::from_str::(config) + .map_err(|e| RpcError::InvalidParameter(format!("corrupt config_json: {}", e)))? + } else { + ConnectionConfig::default() + }; + + if let Some(sub) = values.get(ENV_NATS_SUBSCRIPTION) { + config + .subscriptions + .extend(sub.split(',').map(|s| s.to_string())); + } + if let Some(url) = values.get(ENV_NATS_URI) { + config.cluster_uris = url.split(',').map(String::from).collect(); + } + if let Some(jwt) = values.get(ENV_NATS_CLIENT_JWT) { + config.auth_jwt = Some(jwt.clone()); + } + if let Some(seed) = values.get(ENV_NATS_CLIENT_SEED) { + config.auth_seed = Some(seed.clone()); + } + config.service_name = values.get(ENV_SERVICE_NAME).cloned(); + config.service_description = values.get(ENV_SERVICE_DESCRIPTION).cloned(); + config.service_version = values.get(ENV_SERVICE_VERSION).cloned(); + config.service_endpoints = values + .get(ENV_SERVICE_ENDPOINTS) + .map(|es| es.split(',').map(|s| s.to_string()).collect()); + + if config.auth_jwt.is_some() && config.auth_seed.is_none() { + return Err(RpcError::InvalidParameter( + "if you specify jwt, you must also specify a seed".to_string(), + )); + } + + if config.cluster_uris.is_empty() { + config.cluster_uris.push(DEFAULT_NATS_URI.to_string()); + } + + eprintln!("{config:?}"); + + Ok(config) + } +} diff --git a/nats/src/main.rs b/nats/src/main.rs index f3eeda9c..ff2e7ce2 100644 --- a/nats/src/main.rs +++ b/nats/src/main.rs @@ -1,10 +1,12 @@ //! Nats implementation for wasmcloud:messaging. //! -use std::{collections::HashMap, convert::Infallible, sync::Arc, time::Duration}; +use async_nats::service::{Service, ServiceExt}; +use bytes::Bytes; use futures::StreamExt; -use serde::{Deserialize, Serialize}; -use tokio::sync::{OwnedSemaphorePermit, RwLock, Semaphore}; +use services::is_request_waiting; +use std::{collections::HashMap, convert::Infallible, sync::Arc, time::Duration}; +use tokio::sync::{oneshot, OwnedSemaphorePermit, RwLock, Semaphore}; use tokio::task::JoinHandle; use tracing::{debug, error, instrument, warn}; use tracing_futures::Instrument; @@ -19,11 +21,10 @@ use wasmcloud_interface_messaging::{ ReplyMessage, RequestMessage, SubMessage, }; -const DEFAULT_NATS_URI: &str = "0.0.0.0:4222"; -const ENV_NATS_SUBSCRIPTION: &str = "SUBSCRIPTION"; -const ENV_NATS_URI: &str = "URI"; -const ENV_NATS_CLIENT_JWT: &str = "CLIENT_JWT"; -const ENV_NATS_CLIENT_SEED: &str = "CLIENT_SEED"; +mod config; +mod services; + +use config::*; fn main() -> Result<(), Box> { // handle lattice control messages and forward rpc to the provider dispatch @@ -54,103 +55,6 @@ fn generate_provider(host_data: HostData) -> NatsMessagingProvider { } } -/// Configuration for connecting a nats client. -/// More options are available if you use the json than variables in the values string map. -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -struct ConnectionConfig { - /// list of topics to subscribe to - #[serde(default)] - subscriptions: Vec, - #[serde(default)] - cluster_uris: Vec, - #[serde(default)] - auth_jwt: Option, - #[serde(default)] - auth_seed: Option, - - /// ping interval in seconds - #[serde(default)] - ping_interval_sec: Option, -} - -impl ConnectionConfig { - fn merge(&self, extra: &ConnectionConfig) -> ConnectionConfig { - let mut out = self.clone(); - if !extra.subscriptions.is_empty() { - out.subscriptions = extra.subscriptions.clone(); - } - // If the default configuration has a URL in it, and then the link definition - // also provides a URL, the assumption is to replace/override rather than combine - // the two into a potentially incompatible set of URIs - if !extra.cluster_uris.is_empty() { - out.cluster_uris = extra.cluster_uris.clone(); - } - if extra.auth_jwt.is_some() { - out.auth_jwt = extra.auth_jwt.clone() - } - if extra.auth_seed.is_some() { - out.auth_seed = extra.auth_seed.clone() - } - if extra.ping_interval_sec.is_some() { - out.ping_interval_sec = extra.ping_interval_sec.clone() - } - out - } -} - -impl Default for ConnectionConfig { - fn default() -> ConnectionConfig { - ConnectionConfig { - subscriptions: vec![], - cluster_uris: vec![DEFAULT_NATS_URI.to_string()], - auth_jwt: None, - auth_seed: None, - ping_interval_sec: None, - } - } -} - -impl ConnectionConfig { - fn new_from(values: &HashMap) -> RpcResult { - let mut config = if let Some(config_b64) = values.get("config_b64") { - let bytes = base64::decode(config_b64.as_bytes()).map_err(|e| { - RpcError::InvalidParameter(format!("invalid base64 encoding: {}", e)) - })?; - serde_json::from_slice::(&bytes) - .map_err(|e| RpcError::InvalidParameter(format!("corrupt config_b64: {}", e)))? - } else if let Some(config) = values.get("config_json") { - serde_json::from_str::(config) - .map_err(|e| RpcError::InvalidParameter(format!("corrupt config_json: {}", e)))? - } else { - ConnectionConfig::default() - }; - - if let Some(sub) = values.get(ENV_NATS_SUBSCRIPTION) { - config - .subscriptions - .extend(sub.split(',').map(|s| s.to_string())); - } - if let Some(url) = values.get(ENV_NATS_URI) { - config.cluster_uris = url.split(',').map(String::from).collect(); - } - if let Some(jwt) = values.get(ENV_NATS_CLIENT_JWT) { - config.auth_jwt = Some(jwt.clone()); - } - if let Some(seed) = values.get(ENV_NATS_CLIENT_SEED) { - config.auth_seed = Some(seed.clone()); - } - if config.auth_jwt.is_some() && config.auth_seed.is_none() { - return Err(RpcError::InvalidParameter( - "if you specify jwt, you must also specify a seed".to_string(), - )); - } - if config.cluster_uris.is_empty() { - config.cluster_uris.push(DEFAULT_NATS_URI.to_string()); - } - Ok(config) - } -} - /// NatsClientBundles hold a NATS client and information (subscriptions) /// related to it. /// @@ -187,8 +91,9 @@ impl NatsMessagingProvider { async fn connect( &self, cfg: ConnectionConfig, - ld: &LinkDefinition, + link_def: &LinkDefinition, ) -> Result { + eprintln!("CONNECTING"); let opts = match (cfg.auth_jwt, cfg.auth_seed) { (Some(jwt), Some(seed)) => { let key_pair = std::sync::Arc::new( @@ -217,8 +122,35 @@ impl NatsMessagingProvider { .await .map_err(|e| RpcError::ProviderInit(format!("NATS connection to {}: {}", url, e)))?; - // Connections let mut sub_handles = Vec::new(); + + // Every service subscribes on {service}.{endpoint} + if cfg.service_name.is_some() { + let service_name = cfg.service_name.unwrap_or("default".to_string()); + let mut svc = client + .service_builder() + .description(cfg.service_description.unwrap_or("Unknown".to_string())) + .start( + service_name.clone(), + cfg.service_version.unwrap_or("0.0.1".to_string()), + ) + .await + .map_err(|e| RpcError::ProviderInit(format!("service start failed: {}", e)))?; + eprintln!("MADE SERVICE CLIENT"); + if let Some(ref eps) = cfg.service_endpoints { + for ep in eps { + let subject = format!("{}.{}", service_name, ep); + sub_handles.push(( + subject.to_string(), + self.service_subscribe(&mut svc, link_def, subject, ep.to_string()) + .await?, + )); + } + } + } + + // Connections + for sub in cfg.subscriptions.iter().filter(|s| !s.is_empty()) { let (sub, queue) = match sub.split_once('|') { Some((sub, queue)) => (sub, Some(queue.to_string())), @@ -227,7 +159,8 @@ impl NatsMessagingProvider { sub_handles.push(( sub.to_string(), - self.subscribe(&client, ld, sub.to_string(), queue).await?, + self.subscribe(&client, link_def, sub.to_string(), queue) + .await?, )); } @@ -290,6 +223,61 @@ impl NatsMessagingProvider { Ok(join_handle) } + + async fn service_subscribe( + &self, + svc: &mut Service, + ld: &LinkDefinition, + subject: String, + endpoint: String, + ) -> RpcResult> { + eprintln!("STARTING SUBSCRIPTION FOR {endpoint}"); + + let mut endpoint = svc + .endpoint_builder() + .name(endpoint.to_string()) + .add(subject) + .await + .map_err(|e| RpcError::ProviderInit(format!("service start failed: {}", e)))?; + + let ld = ld.clone(); + let join_handle = tokio::spawn(async move { + let semaphore = Arc::new(Semaphore::new(75)); + while let Some(req) = endpoint.next().await { + let msg = req.message.clone(); + + let (tx, rx) = oneshot::channel::(); + services::add_request_waiter( + &req.message.reply.clone().unwrap_or("default".to_string()), + tx, + ) + .await; + + //Set up tracing context for the NATS message + let span = tracing::debug_span!("handle_service_request", actor_id = %ld.actor_id); + span.in_scope(|| { + wasmbus_rpc::otel::attach_span_context(&msg); + }); + + let permit = match semaphore.clone().acquire_owned().await { + Ok(p) => p, + Err(_) => { + warn!("Work pool has been closed, exiting queue subscribe"); + break; + } + }; + + tokio::spawn(dispatch_msg(ld.clone(), msg, permit).instrument(span)); + if let Ok(raw) = rx.await { + let _ = req.respond(Ok(raw)).await; + } else { + warn!("Sender for service request dropped without sending."); + } + } + }); + + Ok(join_handle) + } } #[instrument(level = "debug", skip_all, fields(actor_id = %link_def.actor_id, subject = %nats_msg.subject, reply_to = ?nats_msg.reply))] @@ -322,6 +310,7 @@ impl ProviderHandler for NatsMessagingProvider { #[instrument(level = "debug", skip(self, ld), fields(actor_id = %ld.actor_id))] async fn put_link(&self, ld: &LinkDefinition) -> RpcResult { // If the link definition values are empty, use the default connection configuration + eprintln!("CONFIG: {ld:?}"); let config = if ld.values.is_empty() { self.default_config.clone() } else { @@ -388,6 +377,11 @@ impl Messaging for NatsMessagingProvider { drop(_rd); let headers = OtelHeaderInjector::default_with_span().into(); + // Is this publish actually a reply to a request? + if is_request_waiting(&msg.subject).await { + let _ = services::dispatch_request_waiter(&msg.subject, msg.body.clone().into()).await; + return Ok(()); + } let res = match msg.reply_to.clone() { Some(reply_to) => if should_strip_headers(&msg.subject) { diff --git a/nats/src/services.rs b/nats/src/services.rs new file mode 100644 index 00000000..8f0aeb35 --- /dev/null +++ b/nats/src/services.rs @@ -0,0 +1,30 @@ +use std::{collections::HashMap, sync::Arc}; + +use bytes::Bytes; +use lazy_static::lazy_static; +use tokio::sync::{oneshot, RwLock}; +use tracing::warn; + +lazy_static! { + static ref REQUEST_WAITERS: Arc>>> = + Arc::new(RwLock::new(HashMap::new())); +} + +pub(crate) async fn add_request_waiter(subject: &str, sender: oneshot::Sender) { + let mut waiters = REQUEST_WAITERS.write().await; + waiters.insert(subject.to_string(), sender); +} + +pub(crate) async fn dispatch_request_waiter(subject: &str, bytes: Bytes) { + let mut waiters = REQUEST_WAITERS.write().await; + if let Some(sender) = waiters.remove(subject) { + if let Err(_) = sender.send(bytes) { + warn!("Receiver side of request waiter dropped"); + } + } +} + +pub(crate) async fn is_request_waiting(subject: &str) -> bool { + let waits = REQUEST_WAITERS.read().await; + waits.contains_key(subject) +} From 03a54632dc64e043b130dbfa0537b1938307dfce Mon Sep 17 00:00:00 2001 From: Kevin Hoffman Date: Tue, 26 Sep 2023 08:49:44 -0400 Subject: [PATCH 2/2] removing debug printlns Signed-off-by: Kevin Hoffman --- nats/src/main.rs | 4 ---- 1 file changed, 4 deletions(-) diff --git a/nats/src/main.rs b/nats/src/main.rs index ff2e7ce2..ec2df27c 100644 --- a/nats/src/main.rs +++ b/nats/src/main.rs @@ -93,7 +93,6 @@ impl NatsMessagingProvider { cfg: ConnectionConfig, link_def: &LinkDefinition, ) -> Result { - eprintln!("CONNECTING"); let opts = match (cfg.auth_jwt, cfg.auth_seed) { (Some(jwt), Some(seed)) => { let key_pair = std::sync::Arc::new( @@ -136,7 +135,6 @@ impl NatsMessagingProvider { ) .await .map_err(|e| RpcError::ProviderInit(format!("service start failed: {}", e)))?; - eprintln!("MADE SERVICE CLIENT"); if let Some(ref eps) = cfg.service_endpoints { for ep in eps { let subject = format!("{}.{}", service_name, ep); @@ -231,8 +229,6 @@ impl NatsMessagingProvider { subject: String, endpoint: String, ) -> RpcResult> { - eprintln!("STARTING SUBSCRIPTION FOR {endpoint}"); - let mut endpoint = svc .endpoint_builder() .name(endpoint.to_string())