diff --git a/src/client.rs b/src/client.rs index e9b5d98..2ce6598 100644 --- a/src/client.rs +++ b/src/client.rs @@ -30,14 +30,29 @@ async fn main() -> Result<(), Box> { since_time: Some(1234), sdk_key: "1234".into(), }); - let mut stream = client.stream_config_spec(request).await?.into_inner(); - while let Some(value) = stream.message().await? { - println!( - "STREAMING={:?}, CURRENT TIME={}", - value.last_updated, - Local::now() - ); - } + let response = client.stream_config_spec(request).await?; + println!("Metadata={:?}", response.metadata()); + let mut stream = response.into_inner(); + loop { + match stream.message().await { + Ok(Some(value)) => { + println!( + "STREAMING={:?}, CURRENT TIME={}", + value.last_updated, + Local::now() + ); + } + Ok(None) => { + println!("STREAMING DONE"); + break; + } + Err(e) => { + println!("CURRENT TIME={}", Local::now()); + println!("Error={:?}", e); + break; + } + } + } Ok(()) } diff --git a/src/datastore/caching/redis_cache.rs b/src/datastore/caching/redis_cache.rs index d8ae490..ed6de85 100644 --- a/src/datastore/caching/redis_cache.rs +++ b/src/datastore/caching/redis_cache.rs @@ -121,26 +121,41 @@ impl HttpDataProviderObserverTrait for RedisCache { path: &str, ) { if result == &DataProviderRequestResult::DataAvailable { - let connection = self.connection.get().await; + let connection: Result< + bb8::PooledConnection, + bb8::RunError, + > = self.connection.get().await; let redis_key = self.hash_key(key).await; match connection { Ok(mut conn) => { let mut pipe = redis::pipe(); pipe.atomic(); let should_update = match pipe + .ttl(REDIS_LEADER_KEY) .set_nx(REDIS_LEADER_KEY, self.uuid.clone()) .get(REDIS_LEADER_KEY) .hget(&redis_key, "lcut") - .query_async::)>( + .query_async::)>( &mut *conn, ) .await { Ok(query_result) => { - let is_leader = query_result.1 == self.uuid; - if self.check_lcut && query_result.2.is_some() { + let is_leader = query_result.2 == self.uuid; + + // Incase there was a crash without cleaning up the leader key + // validate on startup, and set expiry if needed. THis is best + // effort, so we don't check result + if query_result.0 == -1 && !is_leader { + pipe.expire::<&str>(REDIS_LEADER_KEY, self.leader_key_ttl) + .query_async::(&mut *conn) + .await + .ok(); + } + + if self.check_lcut && query_result.3.is_some() { let should_update = - query_result.2.expect("exists").parse().unwrap_or(0) < lcut; + query_result.3.expect("exists").parse().unwrap_or(0) < lcut; is_leader && should_update } else { is_leader diff --git a/src/datastore/data_providers/http_data_provider.rs b/src/datastore/data_providers/http_data_provider.rs index ba58a51..30b9cc2 100644 --- a/src/datastore/data_providers/http_data_provider.rs +++ b/src/datastore/data_providers/http_data_provider.rs @@ -85,8 +85,7 @@ impl DataProviderTrait for HttpDataProvider { Err(_err) => -2, }; if err_msg.is_empty() { - // TODO: This should be more robust - if body == "{\"has_updates\":false}" { + if !request_builder.is_an_update(&body, key).await { ProxyEventObserver::publish_event( ProxyEvent::new(ProxyEventType::HttpDataProviderNoData, key.to_string()) .with_path(request_builder.get_path()) diff --git a/src/datastore/data_providers/request_builder.rs b/src/datastore/data_providers/request_builder.rs index 941182b..80e7754 100644 --- a/src/datastore/data_providers/request_builder.rs +++ b/src/datastore/data_providers/request_builder.rs @@ -1,5 +1,6 @@ use async_trait::async_trait; -use std::{sync::Arc, time::Duration}; +use sha2::{Digest, Sha256}; +use std::{collections::HashMap, sync::Arc, time::Duration}; use tokio::{sync::RwLock, time::Instant}; use crate::{ @@ -18,6 +19,7 @@ pub trait RequestBuilderTrait: Send + Sync + 'static { lcut: u64, ) -> Result; fn get_path(&self) -> String; + async fn is_an_update(&self, body: &str, sdk_key: &str) -> bool; fn get_observers(&self) -> Arc; fn get_backup_cache(&self) -> Arc; fn get_sdk_key_store(&self) -> Arc; @@ -71,6 +73,11 @@ impl RequestBuilderTrait for DcsRequestBuilder { "/v1/download_config_specs".to_string() } + async fn is_an_update(&self, body: &str, _sdk_key: &str) -> bool { + // TODO: This should be more robust + !body.eq("{\"has_updates\":false}") + } + fn get_observers(&self) -> Arc { Arc::clone(&self.http_observers) } @@ -97,6 +104,7 @@ pub struct IdlistRequestBuilder { pub backup_cache: Arc, pub sdk_key_store: Arc, last_request: RwLock, + last_response_hash: RwLock>, } impl IdlistRequestBuilder { @@ -110,6 +118,7 @@ impl IdlistRequestBuilder { backup_cache, sdk_key_store, last_request: RwLock::new(Instant::now()), + last_response_hash: RwLock::new(HashMap::new()), } } } @@ -134,6 +143,21 @@ impl RequestBuilderTrait for IdlistRequestBuilder { "/v1/get_id_lists".to_string() } + async fn is_an_update(&self, body: &str, sdk_key: &str) -> bool { + let hash = format!("{:x}", Sha256::digest(body)); + let mut wlock = self.last_response_hash.write().await; + let mut is_an_update = true; + if let Some(old_hash) = wlock.get(sdk_key) { + is_an_update = hash != *old_hash; + } + + if is_an_update { + wlock.insert(sdk_key.to_string(), hash); + } + + is_an_update + } + fn get_observers(&self) -> Arc { Arc::clone(&self.http_observers) } diff --git a/src/server.rs b/src/server.rs index 898084c..10696ff 100644 --- a/src/server.rs +++ b/src/server.rs @@ -10,6 +10,7 @@ use datastore::{ data_providers::{background_data_provider, http_data_provider}, sdk_key_store, }; +use futures::join; use loggers::datadog_logger; use loggers::debug_logger; use observers::http_data_provider_observer::HttpDataProviderObserver; @@ -49,6 +50,8 @@ struct Cli { redis_leader_key_ttl: i64, #[clap(long, action)] force_gcp_profiling_enabled: bool, + #[clap(short, long, default_value = "500")] + grpc_max_concurrent_streams: u32, } #[derive(Deserialize, Debug)] @@ -258,32 +261,27 @@ async fn main() -> Result<(), Box> { match cli.mode { TransportMode::Grpc => { - servers::grpc_server::GrpcServer::start_server(config_spec_store, config_spec_observer) - .await? + servers::grpc_server::GrpcServer::start_server( + cli.grpc_max_concurrent_streams, + config_spec_store, + config_spec_observer, + ) + .await? } TransportMode::Http => { servers::http_server::HttpServer::start_server(config_spec_store, id_list_store).await? } TransportMode::GrpcAndHttp => { let grpc_server = servers::grpc_server::GrpcServer::start_server( + cli.grpc_max_concurrent_streams, config_spec_store.clone(), config_spec_observer.clone(), ); let http_server = servers::http_server::HttpServer::start_server(config_spec_store, id_list_store); - - tokio::select! { - res = grpc_server => { - if let Err(err) = res { - eprintln!("gRPC server failed: {}, terminating server...", err); - } - } - res = http_server => { - if let Err(err) = res { - eprintln!("HTTP server failed: {}, terminating server...", err); - } - } - } + join!(async { grpc_server.await.ok() }, async { + http_server.await.ok() + },); } } Ok(()) diff --git a/src/servers/grpc_server.rs b/src/servers/grpc_server.rs index 5e78096..93c22d7 100644 --- a/src/servers/grpc_server.rs +++ b/src/servers/grpc_server.rs @@ -1,4 +1,5 @@ use tokio::sync::{mpsc, RwLock}; + use tokio_stream::wrappers::ReceiverStream; use tonic::{transport::Server, Request, Response, Status}; @@ -15,6 +16,9 @@ use statsig_forward_proxy::statsig_forward_proxy_server::{ StatsigForwardProxy, StatsigForwardProxyServer, }; use statsig_forward_proxy::{ConfigSpecRequest, ConfigSpecResponse}; +use std::env; +use std::str::FromStr; +use tonic::metadata::MetadataValue; use std::collections::HashMap; use std::sync::Arc; @@ -33,11 +37,12 @@ impl StatsigForwardProxyServerImpl { fn new( config_spec_store: Arc, dcs_observer: Arc, + update_broadcast_cache: Arc>>>, ) -> Self { StatsigForwardProxyServerImpl { config_spec_store, dcs_observer, - update_broadcast_cache: Arc::new(RwLock::new(HashMap::new())), + update_broadcast_cache, } } } @@ -115,63 +120,41 @@ impl StatsigForwardProxy for StatsigForwardProxyServerImpl { // After initial response, then start listening for updates let ubc_ref = Arc::clone(&self.update_broadcast_cache); tokio::spawn(async move { - match tx.send(Ok(init_value)).await { - Ok(_) => { - ProxyEventObserver::publish_event( - ProxyEvent::new( - ProxyEventType::GrpcStreamingStreamedInitialized, - sdk_key.to_string(), - ) - .with_stat(EventStat { - operation_type: OperationType::IncrByValue, - value: 1, - }), - ) - .await; - } - Err(_e) => { - ProxyEventObserver::publish_event( - ProxyEvent::new( - ProxyEventType::GrpcStreamingStreamDisconnected, - sdk_key.to_string(), - ) - .with_stat(EventStat { - operation_type: OperationType::IncrByValue, - value: 1, - }), - ) - .await; - } - } + tx.send(Ok(init_value)).await.ok(); + ProxyEventObserver::publish_event( + ProxyEvent::new( + ProxyEventType::GrpcStreamingStreamedInitialized, + sdk_key.to_string(), + ) + .with_stat(EventStat { + operation_type: OperationType::IncrByValue, + value: 1, + }), + ) + .await; loop { match rc.recv().await { - Ok(csr) => match tx.send(Ok(csr)).await { - Ok(_) => { - ProxyEventObserver::publish_event( - ProxyEvent::new( - ProxyEventType::GrpcStreamingStreamedResponse, - sdk_key.to_string(), - ) - .with_stat(EventStat { - operation_type: OperationType::IncrByValue, - value: 1, - }), - ) - .await; - } - Err(_e) => { - ProxyEventObserver::publish_event( - ProxyEvent::new( - ProxyEventType::GrpcStreamingStreamDisconnected, - sdk_key.to_string(), + Ok(maybe_csr) => match maybe_csr { + Some(csr) => match tx.send(Ok(csr)).await { + Ok(_) => { + ProxyEventObserver::publish_event( + ProxyEvent::new( + ProxyEventType::GrpcStreamingStreamedResponse, + sdk_key.to_string(), + ) + .with_stat(EventStat { + operation_type: OperationType::IncrByValue, + value: 1, + }), ) - .with_stat(EventStat { - operation_type: OperationType::IncrByValue, - value: 1, - }), - ) - .await; + .await; + } + Err(_e) => { + break; + } + }, + None => { break; } }, @@ -189,26 +172,77 @@ impl StatsigForwardProxy for StatsigForwardProxyServerImpl { } } } + + ProxyEventObserver::publish_event( + ProxyEvent::new( + ProxyEventType::GrpcStreamingStreamDisconnected, + sdk_key.to_string(), + ) + .with_stat(EventStat { + operation_type: OperationType::IncrByValue, + value: 1, + }), + ) + .await; }); - Ok(Response::new(ReceiverStream::new(rx))) + let mut response = Response::new(ReceiverStream::new(rx)); + response.metadata_mut().insert( + "x-sfp-hostname", + MetadataValue::from_str(&env::var("HOSTNAME").unwrap_or("no_hostname_set".to_string())) + .unwrap(), + ); + Ok(response) } } pub struct GrpcServer {} impl GrpcServer { + async fn shutdown_signal( + update_broadcast_cache: Arc>>>, + ) { + let mut int_stream = + tokio::signal::unix::signal(tokio::signal::unix::SignalKind::interrupt()) + .expect("Failed to install SIGINT handler"); + let mut term_stream = + tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) + .expect("Failed to install SIGTERM handler"); + tokio::select! { + _ = int_stream.recv() => { + println!("Received SIGINT, terminating..."); + } + _ = term_stream.recv() => { + println!("Received SIGTERM, terminating..."); + } + }; + + let wlock = update_broadcast_cache.write().await; + + for (_, sc) in wlock.iter() { + sc.sender.write().await.send(None).ok(); + } + println!("All grpc streams terminated, shutting down..."); + } + pub async fn start_server( + max_concurrent_streams: u32, config_spec_store: Arc, shared_dcs_observer: Arc, ) -> Result<(), Box> { let addr = "0.0.0.0:50051".parse().unwrap(); - let greeter = StatsigForwardProxyServerImpl::new(config_spec_store, shared_dcs_observer); - println!("GrpcServer listening on {}", addr); + let update_broadcast_cache = Arc::new(RwLock::new(HashMap::new())); + let sfp_server = StatsigForwardProxyServerImpl::new( + config_spec_store, + shared_dcs_observer, + update_broadcast_cache.clone(), + ); + println!("GrpcServer listening on {}", addr); Server::builder() - .add_service(StatsigForwardProxyServer::new(greeter)) - .serve(addr) + .max_concurrent_streams(max_concurrent_streams) + .add_service(StatsigForwardProxyServer::new(sfp_server)) + .serve_with_shutdown(addr, GrpcServer::shutdown_signal(update_broadcast_cache)) .await?; Ok(()) diff --git a/src/servers/streaming_channel.rs b/src/servers/streaming_channel.rs index 6e1af92..8b5c04c 100644 --- a/src/servers/streaming_channel.rs +++ b/src/servers/streaming_channel.rs @@ -15,7 +15,7 @@ use crate::observers::{ProxyEvent, ProxyEventType}; pub struct StreamingChannel { key: String, last_updated: Arc>, - pub sender: Arc>>, + pub sender: Arc>>>, } impl StreamingChannel { @@ -59,10 +59,10 @@ impl HttpDataProviderObserverTrait for StreamingChannel { .sender .write() .await - .send(ConfigSpecResponse { + .send(Some(ConfigSpecResponse { spec: data.to_string(), last_updated: lcut, - }) + })) .is_err() { // TODO: Optimize code, no receivers are listening