From 95188b5b6b9e41aecbfa6e84fd3b34ae6327b3f0 Mon Sep 17 00:00:00 2001 From: Somtochi Onyekwere Date: Mon, 21 Oct 2024 23:56:44 +0100 Subject: [PATCH 1/4] Implemet cheaper update notifications in corrosion --- crates/corro-admin/src/lib.rs | 1 + crates/corro-agent/src/agent/setup.rs | 3 + crates/corro-agent/src/agent/util.rs | 40 +- crates/corro-agent/src/api/public/mod.rs | 2 + crates/corro-agent/src/api/public/pubsub.rs | 194 ++++++- crates/corro-agent/src/api/public/update.rs | 246 +++++++++ crates/corro-api-types/src/lib.rs | 1 + crates/corro-types/src/agent.rs | 9 + crates/corro-types/src/broadcast.rs | 22 +- crates/corro-types/src/lib.rs | 2 + crates/corro-types/src/pubsub.rs | 311 ++++------- crates/corro-types/src/updates.rs | 560 ++++++++++++++++++++ 12 files changed, 1151 insertions(+), 240 deletions(-) create mode 100644 crates/corro-agent/src/api/public/update.rs create mode 100644 crates/corro-types/src/updates.rs diff --git a/crates/corro-admin/src/lib.rs b/crates/corro-admin/src/lib.rs index 1552b849..25dcda24 100644 --- a/crates/corro-admin/src/lib.rs +++ b/crates/corro-admin/src/lib.rs @@ -11,6 +11,7 @@ use corro_types::{ broadcast::{FocaCmd, FocaInput, Timestamp}, sqlite::SqlitePoolError, sync::generate_sync, + updates::Handle, }; use futures::{SinkExt, TryStreamExt}; use rusqlite::{named_params, params, OptionalExtension}; diff --git a/crates/corro-agent/src/agent/setup.rs b/crates/corro-agent/src/agent/setup.rs index d67399cd..884721b7 100644 --- a/crates/corro-agent/src/agent/setup.rs +++ b/crates/corro-agent/src/agent/setup.rs @@ -41,6 +41,7 @@ use corro_types::{ schema::{init_schema, Schema}, sqlite::CrConn, }; +use corro_types::updates::UpdatesManager; /// Runtime state for the Corrosion agent pub struct AgentOptions { @@ -106,6 +107,7 @@ pub async fn setup(conf: Config, tripwire: Tripwire) -> eyre::Result<(Agent, Age let subs_manager = SubsManager::default(); + let updates_manager = UpdatesManager::default(); // Setup subscription handlers let subs_bcast_cache = setup_spawn_subscriptions( &subs_manager, @@ -210,6 +212,7 @@ pub async fn setup(conf: Config, tripwire: Tripwire) -> eyre::Result<(Agent, Age schema: RwLock::new(schema), cluster_id, subs_manager, + updates_manager, tripwire, }); diff --git a/crates/corro-agent/src/agent/util.rs b/crates/corro-agent/src/agent/util.rs index 0c1238f1..67a90b8a 100644 --- a/crates/corro-agent/src/agent/util.rs +++ b/crates/corro-agent/src/agent/util.rs @@ -25,6 +25,7 @@ use corro_types::{ channel::CorroReceiver, config::AuthzConfig, pubsub::SubsManager, + updates::{match_changes, match_changes_from_db_version}, }; use std::{ cmp, @@ -36,6 +37,7 @@ use std::{ time::{Duration, Instant}, }; +use crate::api::public::update::api_v1_updates; use axum::{ error_handling::HandleErrorLayer, extract::DefaultBodyLimit, @@ -216,6 +218,20 @@ pub async fn setup_http_api_handler( .layer(ConcurrencyLimitLayer::new(128)), ), ) + .route( + "/v1/updates/:table", + post(api_v1_updates).route_layer( + tower::ServiceBuilder::new() + .layer(HandleErrorLayer::new(|_error: BoxError| async { + Ok::<_, Infallible>(( + StatusCode::SERVICE_UNAVAILABLE, + "max concurrency limit reached".to_string(), + )) + })) + .layer(LoadShedLayer::new()) + .layer(ConcurrencyLimitLayer::new(128)), + ), + ) .route( "/v1/subscriptions/:id", get(api_v1_sub_by_id).route_layer( @@ -705,11 +721,16 @@ pub async fn process_fully_buffered_changes( if let Some(db_version) = db_version { let conn = agent.pool().read().await?; block_in_place(|| { - if let Err(e) = agent - .subs_manager() - .match_changes_from_db_version(&conn, db_version) + if let Err(e) = match_changes_from_db_version(agent.subs_manager(), &conn, db_version) { + error!(%db_version, "could not match changes for subs from db version: {e}"); + } + }); + + block_in_place(|| { + if let Err(e) = + match_changes_from_db_version(agent.updates_manager(), &conn, db_version) { - error!(%db_version, "could not match changes from db version: {e}"); + error!(%db_version, "could not match changes for updates from db version: {e}"); } }); } @@ -969,7 +990,6 @@ pub async fn process_multiple_changes( .snapshot() }; - snap.update_cleared_ts(&tx, ts) .map_err(|source| ChangeError::Rusqlite { source, @@ -988,12 +1008,11 @@ pub async fn process_multiple_changes( if let Some(ts) = last_cleared { let mut booked_writer = agent - .booked() - .blocking_write("process_multiple_changes(update_cleared_ts)"); + .booked() + .blocking_write("process_multiple_changes(update_cleared_ts)"); booked_writer.update_cleared_ts(ts); } - for (_, changeset, _, _) in changesets.iter() { if let Some(ts) = changeset.ts() { let dur = (agent.clock().new_timestamp().get_time() - ts.0).to_duration(); @@ -1050,9 +1069,8 @@ pub async fn process_multiple_changes( for (_actor_id, changeset, db_version, _src) in changesets { change_chunk_size += changeset.changes().len(); - agent - .subs_manager() - .match_changes(changeset.changes(), db_version); + match_changes(agent.subs_manager(), changeset.changes(), db_version); + match_changes(agent.updates_manager(), changeset.changes(), db_version); } histogram!("corro.agent.changes.processing.time.seconds").record(start.elapsed()); diff --git a/crates/corro-agent/src/api/public/mod.rs b/crates/corro-agent/src/api/public/mod.rs index af1b72fa..9112fb0b 100644 --- a/crates/corro-agent/src/api/public/mod.rs +++ b/crates/corro-agent/src/api/public/mod.rs @@ -33,6 +33,8 @@ use corro_types::broadcast::broadcast_changes; pub mod pubsub; +pub mod update; + pub async fn make_broadcastable_changes( agent: &Agent, f: F, diff --git a/crates/corro-agent/src/api/public/pubsub.rs b/crates/corro-agent/src/api/public/pubsub.rs index 181575e2..71e9d1b9 100644 --- a/crates/corro-agent/src/api/public/pubsub.rs +++ b/crates/corro-agent/src/api/public/pubsub.rs @@ -3,6 +3,7 @@ use std::{collections::HashMap, io::Write, sync::Arc, time::Duration}; use axum::{http::StatusCode, response::IntoResponse, Extension}; use bytes::{BufMut, Bytes, BytesMut}; use compact_str::{format_compact, ToCompactString}; +use corro_types::updates::Handle; use corro_types::{ agent::Agent, api::{ChangeId, QueryEvent, QueryEventMeta, Statement}, @@ -74,7 +75,9 @@ async fn sub_by_id( bcast_cache_write.remove(&id); if let Some(handle) = subs.remove(&id) { info!(sub_id = %id, "Removed subscription from sub_by_id"); - tokio::spawn(handle.cleanup()); + tokio::spawn(async move { + handle.cleanup().await; + }); } return hyper::Response::builder() @@ -883,12 +886,14 @@ mod tests { use corro_types::base::{CrsqlDbVersion, CrsqlSeq, Version}; use corro_types::broadcast::{ChangeSource, ChangeV1, Changeset}; use corro_types::pubsub::pack_columns; + use corro_types::updates::NotifyEvent; use corro_types::{ api::{ChangeId, RowId}, config::Config, pubsub::ChangeType, }; use http_body::Body; + use serde::de::DeserializeOwned; use spawn::wait_for_all_pending_handles; use std::ops::RangeInclusive; use std::time::Instant; @@ -898,6 +903,7 @@ mod tests { use super::*; use crate::agent::process_multiple_changes; + use crate::api::public::update::{api_v1_updates, SharedUpdateBroadcastCache}; use crate::{ agent::setup, api::public::{api_v1_db_schema, api_v1_transactions}, @@ -951,6 +957,7 @@ mod tests { assert!(body.0.results.len() == 2); let bcast_cache: SharedMatcherBroadcastCache = Default::default(); + let update_bcast_cache: SharedUpdateBroadcastCache = Default::default(); { let mut res = api_v1_subs( @@ -970,6 +977,23 @@ mod tests { assert_eq!(res.status(), StatusCode::OK); + // only want notifications + let mut notify_res = api_v1_updates( + Extension(agent.clone()), + Extension(update_bcast_cache.clone()), + Extension(tripwire.clone()), + axum::extract::Path("tests".to_string()), + ) + .await + .into_response(); + + if !notify_res.status().is_success() { + let b = notify_res.body_mut().data().await.unwrap().unwrap(); + println!("body: {}", String::from_utf8_lossy(&b)); + } + + assert_eq!(notify_res.status(), StatusCode::OK); + let (status_code, _) = api_v1_transactions( Extension(agent.clone()), axum::Json(vec![Statement::WithParams( @@ -988,18 +1012,25 @@ mod tests { done: false, }; + let mut notify_rows = RowsIter { + body: notify_res.into_body(), + codec: LinesCodec::new(), + buf: BytesMut::new(), + done: false, + }; + assert_eq!( - rows.recv().await.unwrap().unwrap(), + rows.recv::().await.unwrap().unwrap(), QueryEvent::Columns(vec!["id".into(), "text".into()]) ); assert_eq!( - rows.recv().await.unwrap().unwrap(), + rows.recv::().await.unwrap().unwrap(), QueryEvent::Row(RowId(1), vec!["service-id".into(), "service-name".into()]) ); assert_eq!( - rows.recv().await.unwrap().unwrap(), + rows.recv::().await.unwrap().unwrap(), QueryEvent::Row( RowId(2), vec!["service-id-2".into(), "service-name-2".into()] @@ -1007,12 +1038,12 @@ mod tests { ); assert!(matches!( - rows.recv().await.unwrap().unwrap(), + rows.recv::().await.unwrap().unwrap(), QueryEvent::EndOfQuery { .. } )); assert_eq!( - rows.recv().await.unwrap().unwrap(), + rows.recv::().await.unwrap().unwrap(), QueryEvent::Change( ChangeType::Insert, RowId(3), @@ -1033,7 +1064,7 @@ mod tests { assert_eq!(status_code, StatusCode::OK); assert_eq!( - rows.recv().await.unwrap().unwrap(), + rows.recv::().await.unwrap().unwrap(), QueryEvent::Change( ChangeType::Insert, RowId(4), @@ -1042,6 +1073,16 @@ mod tests { ) ); + assert_eq!( + notify_rows.recv::().await.unwrap().unwrap(), + NotifyEvent::Notify(ChangeType::Update, vec!["service-id-3".into()],) + ); + + assert_eq!( + notify_rows.recv::().await.unwrap().unwrap(), + NotifyEvent::Notify(ChangeType::Update, vec!["service-id-4".into()],) + ); + let mut res = api_v1_subs( Extension(agent.clone()), Extension(bcast_cache.clone()), @@ -1070,7 +1111,7 @@ mod tests { }; assert_eq!( - rows_from.recv().await.unwrap().unwrap(), + rows_from.recv::().await.unwrap().unwrap(), QueryEvent::Change( ChangeType::Insert, RowId(4), @@ -1079,6 +1120,30 @@ mod tests { ) ); + // new subscriber for updates + let mut notify_res2 = api_v1_updates( + Extension(agent.clone()), + Extension(update_bcast_cache.clone()), + Extension(tripwire.clone()), + axum::extract::Path("tests".to_string()), + ) + .await + .into_response(); + + if !notify_res2.status().is_success() { + let b = notify_res2.body_mut().data().await.unwrap().unwrap(); + println!("body: {}", String::from_utf8_lossy(&b)); + } + + assert_eq!(notify_res2.status(), StatusCode::OK); + + let mut notify_rows2 = RowsIter { + body: notify_res2.into_body(), + codec: LinesCodec::new(), + buf: BytesMut::new(), + done: false, + }; + let (status_code, _) = api_v1_transactions( Extension(agent.clone()), axum::Json(vec![Statement::WithParams( @@ -1097,9 +1162,24 @@ mod tests { ChangeId(3), ); - assert_eq!(rows.recv().await.unwrap().unwrap(), query_evt); + assert_eq!(rows.recv::().await.unwrap().unwrap(), query_evt); + + assert_eq!( + rows_from.recv::().await.unwrap().unwrap(), + query_evt + ); + + let notify_evt = NotifyEvent::Notify(ChangeType::Update, vec!["service-id-5".into()]); - assert_eq!(rows_from.recv().await.unwrap().unwrap(), query_evt); + assert_eq!( + notify_rows.recv::().await.unwrap().unwrap(), + notify_evt + ); + + assert_eq!( + notify_rows2.recv::().await.unwrap().unwrap(), + notify_evt + ); // subscriber who arrives later! @@ -1128,17 +1208,17 @@ mod tests { }; assert_eq!( - rows.recv().await.unwrap().unwrap(), + rows.recv::().await.unwrap().unwrap(), QueryEvent::Columns(vec!["id".into(), "text".into()]) ); assert_eq!( - rows.recv().await.unwrap().unwrap(), + rows.recv::().await.unwrap().unwrap(), QueryEvent::Row(RowId(1), vec!["service-id".into(), "service-name".into()]) ); assert_eq!( - rows.recv().await.unwrap().unwrap(), + rows.recv::().await.unwrap().unwrap(), QueryEvent::Row( RowId(2), vec!["service-id-2".into(), "service-name-2".into()] @@ -1146,7 +1226,7 @@ mod tests { ); assert_eq!( - rows.recv().await.unwrap().unwrap(), + rows.recv::().await.unwrap().unwrap(), QueryEvent::Row( RowId(3), vec!["service-id-3".into(), "service-name-3".into()], @@ -1154,7 +1234,7 @@ mod tests { ); assert_eq!( - rows.recv().await.unwrap().unwrap(), + rows.recv::().await.unwrap().unwrap(), QueryEvent::Row( RowId(4), vec!["service-id-4".into(), "service-name-4".into()], @@ -1162,12 +1242,44 @@ mod tests { ); assert_eq!( - rows.recv().await.unwrap().unwrap(), + rows.recv::().await.unwrap().unwrap(), QueryEvent::Row( RowId(5), vec!["service-id-5".into(), "service-name-5".into()] ) ); + + let (status_code, _) = api_v1_transactions( + Extension(agent.clone()), + axum::Json(vec![Statement::WithParams( + "insert into tests (id, text) values (?,?)".into(), + vec!["service-id-6".into(), "service-name-6".into()], + )]), + ) + .await; + + assert_eq!(status_code, StatusCode::OK); + + let (status_code, _) = api_v1_transactions( + Extension(agent.clone()), + axum::Json(vec![Statement::WithParams( + "delete from tests where id = ?".into(), + vec!["service-id-6".into()], + )]), + ) + .await; + + assert_eq!(status_code, StatusCode::OK); + + assert_eq!( + notify_rows.recv::().await.unwrap().unwrap(), + NotifyEvent::Notify(ChangeType::Update, vec!["service-id-6".into()],) + ); + + assert_eq!( + notify_rows.recv::().await.unwrap().unwrap(), + NotifyEvent::Notify(ChangeType::Delete, vec!["service-id-6".into()],) + ); } // previous subs have been dropped. @@ -1200,7 +1312,7 @@ mod tests { }; assert_eq!( - rows_from.recv().await.unwrap().unwrap(), + rows_from.recv::().await.unwrap().unwrap(), QueryEvent::Change( ChangeType::Insert, RowId(4), @@ -1210,7 +1322,7 @@ mod tests { ); assert_eq!( - rows_from.recv().await.unwrap().unwrap(), + rows_from.recv::().await.unwrap().unwrap(), QueryEvent::Change( ChangeType::Insert, RowId(5), @@ -1260,7 +1372,7 @@ mod tests { assert_eq!(status_code, StatusCode::OK); assert_eq!( - rows_from.recv().await.unwrap().unwrap(), + rows_from.recv::().await.unwrap().unwrap(), QueryEvent::Change( ChangeType::Insert, RowId(6), @@ -1299,7 +1411,7 @@ mod tests { }; assert_eq!( - rows_from.recv().await.unwrap().unwrap(), + rows_from.recv::().await.unwrap().unwrap(), QueryEvent::Change( ChangeType::Insert, RowId(6), @@ -1314,6 +1426,7 @@ mod tests { #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn match_buffered_changes() -> eyre::Result<()> { _ = tracing_subscriber::fmt::try_init(); + let (tripwire, tripwire_worker, tripwire_tx) = Tripwire::new_simple(); let ta1 = launch_test_agent(|conf| conf.build(), tripwire.clone()).await?; @@ -1376,6 +1489,7 @@ mod tests { .await?; let bcast_cache: SharedMatcherBroadcastCache = Default::default(); + let update_bcast_cache: SharedUpdateBroadcastCache = Default::default(); let mut res = api_v1_subs( Extension(ta1.agent.clone()), Extension(bcast_cache.clone()), @@ -1393,6 +1507,30 @@ mod tests { assert_eq!(res.status(), StatusCode::OK); + // only notifications + let mut notify_res = api_v1_updates( + Extension(ta1.agent.clone()), + Extension(update_bcast_cache.clone()), + Extension(tripwire.clone()), + axum::extract::Path("buftests".to_string()), + ) + .await + .into_response(); + + if !notify_res.status().is_success() { + let b = notify_res.body_mut().data().await.unwrap().unwrap(); + println!("body: {}", String::from_utf8_lossy(&b)); + } + + assert_eq!(notify_res.status(), StatusCode::OK); + + let mut notify_rows = RowsIter { + body: notify_res.into_body(), + codec: LinesCodec::new(), + buf: BytesMut::new(), + done: false, + }; + let mut rows = RowsIter { body: res.into_body(), codec: LinesCodec::new(), @@ -1401,17 +1539,17 @@ mod tests { }; assert_eq!( - rows.recv().await.unwrap().unwrap(), + rows.recv::().await.unwrap().unwrap(), QueryEvent::Columns(vec!["pk".into(), "col1".into(), "col2".into()]) ); assert_eq!( - rows.recv().await.unwrap().unwrap(), + rows.recv::().await.unwrap().unwrap(), QueryEvent::Row(RowId(1), vec![Integer(1), "one".into(), "one line".into()]) ); assert!(matches!( - rows.recv().await.unwrap().unwrap(), + rows.recv::().await.unwrap().unwrap(), QueryEvent::EndOfQuery { .. } )); @@ -1487,7 +1625,7 @@ mod tests { ) .await?; - let res = timeout(Duration::from_secs(5), rows.recv()).await?; + let res = timeout(Duration::from_secs(5), rows.recv::()).await?; assert_eq!( res.unwrap().unwrap(), @@ -1499,6 +1637,12 @@ mod tests { ) ); + let notify_res = timeout(Duration::from_secs(5), notify_rows.recv::()).await?; + assert_eq!( + notify_res.unwrap().unwrap(), + NotifyEvent::Notify(ChangeType::Update, vec![Integer(2)],) + ); + tripwire_tx.send(()).await.ok(); tripwire_worker.await; wait_for_all_pending_handles().await; @@ -1514,7 +1658,7 @@ mod tests { } impl RowsIter { - async fn recv(&mut self) -> Option> { + async fn recv(&mut self) -> Option> { if self.done { return None; } diff --git a/crates/corro-agent/src/api/public/update.rs b/crates/corro-agent/src/api/public/update.rs new file mode 100644 index 00000000..51c9deba --- /dev/null +++ b/crates/corro-agent/src/api/public/update.rs @@ -0,0 +1,246 @@ +use std::{collections::HashMap, io::Write, sync::Arc, time::Duration}; + +use axum::{http::StatusCode, response::IntoResponse, Extension}; +use bytes::{BufMut, Bytes, BytesMut}; +use compact_str::ToCompactString; +use corro_types::{ + agent::Agent, + updates::{Handle, NotifyEvent, UpdateCreated, UpdateHandle, UpdatesManager}, +}; +use futures::future::poll_fn; +use tokio::sync::{ + broadcast::{self, error::RecvError}, + mpsc, RwLock as TokioRwLock, +}; +use tracing::{debug, info, warn}; +use tripwire::Tripwire; +use uuid::Uuid; + +use crate::api::public::pubsub::MatcherUpsertError; + +pub type UpdateBroadcastCache = HashMap>; +pub type SharedUpdateBroadcastCache = Arc>; + +// this should be a fraction of the MAX_UNSUB_TIME +const RECEIVERS_CHECK_INTERVAL: Duration = Duration::from_secs(30); + +pub async fn api_v1_updates( + Extension(agent): Extension, + Extension(bcast_cache): Extension, + Extension(tripwire): Extension, + axum::extract::Path(table): axum::extract::Path, +) -> impl IntoResponse { + info!("Received update request for table: {table}"); + + let mut bcast_write = bcast_cache.write().await; + let updates = agent.updates_manager(); + + let upsert_res = updates.get_or_insert( + &table, + &agent.schema().read(), + agent.pool(), + tripwire.clone(), + ); + + let (handle, maybe_created) = match upsert_res { + Ok(res) => res, + Err(e) => return hyper::Response::::from(MatcherUpsertError::from(e)), + }; + + let (tx, body) = hyper::Body::channel(); + // let (forward_tx, forward_rx) = mpsc::channel(10240); + + let (update_id, sub_rx) = + match upsert_update(handle.clone(), maybe_created, updates, &mut bcast_write).await { + Ok(id) => id, + Err(e) => return hyper::Response::::from(e), + }; + + tokio::spawn(forward_update_bytes_to_body_sender( + handle, sub_rx, tx, tripwire, + )); + + hyper::Response::builder() + .status(StatusCode::OK) + .header("corro-query-id", update_id.to_string()) + .body(body) + .expect("could not generate ok http response for update request") +} + +pub async fn upsert_update( + handle: UpdateHandle, + maybe_created: Option, + updates: &UpdatesManager, + bcast_write: &mut UpdateBroadcastCache, +) -> Result<(Uuid, broadcast::Receiver), MatcherUpsertError> { + let sub_rx = if let Some(created) = maybe_created { + let (sub_tx, sub_rx) = broadcast::channel(10240); + bcast_write.insert(handle.id(), sub_tx.clone()); + tokio::spawn(process_update_channel( + updates.clone(), + handle.id(), + sub_tx, + created.evt_rx, + )); + + sub_rx + } else { + let id = handle.id(); + let sub_tx = bcast_write + .get(&id) + .cloned() + .ok_or(MatcherUpsertError::MissingBroadcaster)?; + debug!("found update handle"); + + sub_tx.subscribe() + }; + + Ok((handle.id(), sub_rx)) +} + +pub async fn process_update_channel( + updates: UpdatesManager, + id: Uuid, + tx: broadcast::Sender, + mut evt_rx: mpsc::Receiver, +) { + let mut buf = BytesMut::new(); + + // interval check for receivers + // useful for queries that don't change often so we can cleanup... + let mut subs_check = tokio::time::interval(RECEIVERS_CHECK_INTERVAL); + + loop { + tokio::select! { + biased; + Some(query_evt) = evt_rx.recv() => { + match make_query_event_bytes(&mut buf, &query_evt) { + Ok(b) => { + if tx.send(b).is_err() { + break; + } + }, + Err(e) => { + match make_query_event_bytes(&mut buf, &NotifyEvent::Error(e.to_compact_string())) { + Ok(b) => { + let _ = tx.send(b); + } + Err(e) => { + warn!(update_id = %id, "failed to send error in update channel: {e}"); + } + } + break; + } + }; + }, + _ = subs_check.tick() => { + if tx.receiver_count() == 0 { + break; + }; + }, + }; + } + + warn!(sub_id = %id, "updates channel done"); + + // remove and get handle from the agent's "matchers" + let handle = match updates.remove(&id) { + Some(h) => { + info!(update_id = %id, "Removed update handle from process_update_channel"); + h + } + None => { + warn!(update_id = %id, "update handle was already gone. odd!"); + return; + } + }; + + // clean up the subscription + handle.cleanup().await; +} + +fn make_query_event_bytes( + buf: &mut BytesMut, + query_evt: &NotifyEvent, +) -> serde_json::Result { + { + let mut writer = buf.writer(); + serde_json::to_writer(&mut writer, query_evt)?; + + // NOTE: I think that's infaillible... + writer + .write_all(b"\n") + .expect("could not write new line to BytesMut Writer"); + } + + Ok(buf.split().freeze()) +} + +async fn forward_update_bytes_to_body_sender( + update: UpdateHandle, + mut rx: broadcast::Receiver, + mut tx: hyper::body::Sender, + mut tripwire: Tripwire, +) { + let mut buf = BytesMut::new(); + + let send_deadline = tokio::time::sleep(Duration::from_millis(10)); + tokio::pin!(send_deadline); + + loop { + tokio::select! { + biased; + res = rx.recv() => { + match res { + Ok(event_buf) => { + buf.extend_from_slice(&event_buf); + if buf.len() >= 64 * 1024 { + if let Err(e) = tx.send_data(buf.split().freeze()).await { + warn!(update_id = %update.id(), "could not forward update query event to receiver: {e}"); + return; + } + }; + }, + Err(RecvError::Lagged(skipped)) => { + warn!(update_id = %update.id(), "update skipped {} events, aborting", skipped); + return; + }, + Err(RecvError::Closed) => { + info!(update_id = %update.id(), "events subcription ran out"); + return; + }, + } + }, + _ = &mut send_deadline => { + if !buf.is_empty() { + if let Err(e) = tx.send_data(buf.split().freeze()).await { + warn!(update_id = %update.id(), "could not forward subscription query event to receiver: {e}"); + return; + } + } else { + if let Err(e) = poll_fn(|cx| tx.poll_ready(cx)).await { + warn!(update_id = %update.id(), error = %e, "body sender was closed or errored, stopping event broadcast sends"); + return; + } + send_deadline.as_mut().reset(tokio::time::Instant::now() + Duration::from_millis(10)); + continue; + } + }, + _ = update.cancelled() => { + info!(update_id = %update.id(), "update cancelled, aborting forwarding bytes to subscriber"); + return; + }, + _ = &mut tripwire => { + break; + } + } + } + + while let Ok(event_buf) = rx.try_recv() { + buf.extend_from_slice(&event_buf); + if let Err(e) = tx.send_data(buf.split().freeze()).await { + warn!(update_id = %update.id(), "could not forward subscription query event to receiver: {e}"); + return; + } + } +} diff --git a/crates/corro-api-types/src/lib.rs b/crates/corro-api-types/src/lib.rs index 339d9e29..a2991cce 100644 --- a/crates/corro-api-types/src/lib.rs +++ b/crates/corro-api-types/src/lib.rs @@ -56,6 +56,7 @@ pub enum QueryEventMeta { EndOfQuery(Option), Change(ChangeId), Error, + Notify, } /// RowId newtype to differentiate from ChangeId diff --git a/crates/corro-types/src/agent.rs b/crates/corro-types/src/agent.rs index e199e53a..8850146a 100644 --- a/crates/corro-types/src/agent.rs +++ b/crates/corro-types/src/agent.rs @@ -34,6 +34,7 @@ use tokio_util::sync::{CancellationToken, DropGuard}; use tracing::{debug, error, trace, warn}; use tripwire::Tripwire; +use crate::updates::UpdatesManager; use crate::{ actor::{Actor, ActorId, ClusterId}, base::{CrsqlDbVersion, CrsqlSeq, Version}, @@ -76,6 +77,8 @@ pub struct AgentConfig { pub subs_manager: SubsManager, + pub updates_manager: UpdatesManager, + pub tripwire: Tripwire, } @@ -100,6 +103,7 @@ pub struct AgentInner { cluster_id: ArcSwap, limits: Limits, subs_manager: SubsManager, + updates_manager: UpdatesManager, } #[derive(Debug, Clone)] @@ -132,6 +136,7 @@ impl Agent { sync: Arc::new(Semaphore::new(3)), }, subs_manager: config.subs_manager, + updates_manager: config.updates_manager, })) } @@ -237,6 +242,10 @@ impl Agent { &self.0.subs_manager } + pub fn updates_manager(&self) -> &UpdatesManager { + &self.0.updates_manager + } + pub fn set_cluster_id(&self, cluster_id: ClusterId) { self.0.cluster_id.store(Arc::new(cluster_id)); } diff --git a/crates/corro-types/src/broadcast.rs b/crates/corro-types/src/broadcast.rs index d2ea9b61..e7f84ec3 100644 --- a/crates/corro-types/src/broadcast.rs +++ b/crates/corro-types/src/broadcast.rs @@ -1,4 +1,9 @@ -use std::{cmp, fmt, io, num::NonZeroU32, ops::{Deref, RangeInclusive}, time::Duration}; +use std::{ + cmp, fmt, io, + num::NonZeroU32, + ops::{Deref, RangeInclusive}, + time::Duration, +}; use bytes::{Bytes, BytesMut}; use corro_api_types::{row_to_change, Change}; @@ -27,6 +32,7 @@ use crate::{ channel::CorroSender, sqlite::SqlitePoolError, sync::SyncTraceContextV1, + updates::match_changes, }; #[derive(Debug, Clone, Readable, Writable)] @@ -165,9 +171,14 @@ impl Changeset { // determine the estimated resource cost of processing a change pub fn processing_cost(&self) -> usize { match self { - Changeset::Empty { versions, .. } => cmp::min((versions.end().0 - versions.start().0) as usize + 1, 20), - Changeset::EmptySet { versions, .. } => versions.iter().map(|versions| cmp::min((versions.end().0 - versions.start().0) as usize + 1, 20)).sum::(), - Changeset::Full { changes, ..} => changes.len(), + Changeset::Empty { versions, .. } => { + cmp::min((versions.end().0 - versions.start().0) as usize + 1, 20) + } + Changeset::EmptySet { versions, .. } => versions + .iter() + .map(|versions| cmp::min((versions.end().0 - versions.start().0) as usize + 1, 20)) + .sum::(), + Changeset::Full { changes, .. } => changes.len(), } } @@ -508,7 +519,8 @@ pub async fn broadcast_changes( trace!("broadcasting changes: {changes:?} for seq: {seqs:?}"); - agent.subs_manager().match_changes(&changes, db_version); + match_changes(agent.subs_manager(), &changes, db_version); + match_changes(agent.updates_manager(), &changes, db_version); let tx_bcast = agent.tx_bcast().clone(); tokio::spawn(async move { diff --git a/crates/corro-types/src/lib.rs b/crates/corro-types/src/lib.rs index 6a74a0f1..ee4cc9fe 100644 --- a/crates/corro-types/src/lib.rs +++ b/crates/corro-types/src/lib.rs @@ -12,4 +12,6 @@ pub mod schema; pub mod sqlite; pub mod sync; pub mod tls; +pub mod updates; + pub use corro_base_types as base; diff --git a/crates/corro-types/src/pubsub.rs b/crates/corro-types/src/pubsub.rs index b34e7d91..4d914805 100644 --- a/crates/corro-types/src/pubsub.rs +++ b/crates/corro-types/src/pubsub.rs @@ -5,6 +5,7 @@ use std::{ time::{Duration, Instant}, }; +use async_trait::async_trait; use bytes::{Buf, BufMut}; use camino::{Utf8Path, Utf8PathBuf}; use compact_str::{format_compact, ToCompactString}; @@ -45,8 +46,10 @@ use crate::{ base::CrsqlDbVersion, schema::{Schema, Table}, sqlite::CrConn, + updates::HandleMetrics, }; +use crate::updates::{Handle, Manager}; pub use corro_api_types::sqlite::ChangeType; #[derive(Debug, Default, Clone)] @@ -58,13 +61,32 @@ struct InnerSubsManager { queries: HashMap, } -// tools to bootstrap a new subscriber +// tools to bootstrap a new subscriber or notifier pub struct MatcherCreated { pub evt_rx: mpsc::Receiver, } const SUB_EVENT_CHANNEL_CAP: usize = 512; +impl Manager for SubsManager { + fn trait_type(&self) -> String { + "subs".to_string() + } + + fn get(&self, id: &Uuid) -> Option { + self.0.read().get(id) + } + + fn remove(&self, id: &Uuid) -> Option { + let mut inner = self.0.write(); + inner.remove(id) + } + + fn get_handles(&self) -> BTreeMap { + self.0.read().handles.clone() + } +} + impl SubsManager { pub fn get(&self, id: &Uuid) -> Option { self.0.read().get(id) @@ -166,153 +188,14 @@ impl SubsManager { let mut inner = self.0.write(); inner.remove(id) } - - pub fn match_changes_from_db_version( - &self, - conn: &Connection, - db_version: CrsqlDbVersion, - ) -> rusqlite::Result<()> { - let handles = { - let inner = self.0.read(); - if inner.handles.is_empty() { - return Ok(()); - } - inner.handles.clone() - }; - - let mut candidates = handles - .iter() - .map(|(id, handle)| (id, (MatchCandidates::new(), handle))) - .collect::>(); - - { - let mut prepped = conn.prepare_cached( - r#" - SELECT "table", pk, cid - FROM crsql_changes - WHERE db_version = ? - ORDER BY seq ASC - "#, - )?; - - let rows = prepped.query_map([db_version], |row| { - Ok(( - row.get::<_, TableName>(0)?, - row.get::<_, Vec>(1)?, - row.get::<_, ColumnName>(2)?, - )) - })?; - - for change_res in rows { - let (table, pk, column) = change_res?; - - for (_id, (candidates, handle)) in candidates.iter_mut() { - let change = MatchableChange { - table: &table, - pk: &pk, - column: &column, - }; - handle.filter_matchable_change(candidates, change); - } - } - } - - // metrics... - for (id, (candidates, handle)) in candidates { - let mut match_count = 0; - - for (table, pks) in candidates.iter() { - let count = pks.len(); - match_count += count; - counter!("corro.subs.changes.matched.count", "sql_hash" => handle.inner.hash.clone(), "table" => table.to_string()).increment(count as u64); - } - - trace!(sub_id = %id, %db_version, "found {match_count} candidates"); - - if let Err(e) = handle.inner.changes_tx.try_send((candidates, db_version)) { - error!(sub_id = %id, "could not send change candidates to subscription handler: {e}"); - match e { - mpsc::error::TrySendError::Full(item) => { - warn!("channel is full, falling back to async send"); - let changes_tx = handle.inner.changes_tx.clone(); - tokio::spawn(async move { - _ = changes_tx.send(item).await; - }); - } - mpsc::error::TrySendError::Closed(_) => { - if let Some(handle) = self.remove(id) { - tokio::spawn(handle.cleanup()); - } - } - } - } - } - - Ok(()) - } - - pub fn match_changes(&self, changes: &[Change], db_version: CrsqlDbVersion) { - trace!( - %db_version, - "trying to match changes to subscribers, len: {}", - changes.len() - ); - if changes.is_empty() { - return; - } - let handles = { - let inner = self.0.read(); - if inner.handles.is_empty() { - return; - } - inner.handles.clone() - }; - - for (id, handle) in handles.iter() { - trace!(sub_id = %id, %db_version, "attempting to match changes to a subscription"); - let mut candidates = MatchCandidates::new(); - let mut match_count = 0; - for change in changes.iter().map(MatchableChange::from) { - if handle.filter_matchable_change(&mut candidates, change) { - match_count += 1; - } - } - - // metrics... - for (table, pks) in candidates.iter() { - counter!("corro.subs.changes.matched.count", "sql_hash" => handle.inner.hash.clone(), "table" => table.to_string()).increment(pks.len() as u64); - } - - trace!(sub_id = %id, %db_version, "found {match_count} candidates"); - - if let Err(e) = handle.inner.changes_tx.try_send((candidates, db_version)) { - error!(sub_id = %id, "could not send change candidates to subscription handler: {e}"); - match e { - mpsc::error::TrySendError::Full(item) => { - warn!("channel is full, falling back to async send"); - - let changes_tx = handle.inner.changes_tx.clone(); - tokio::spawn(async move { - _ = changes_tx.send(item).await; - }); - counter!("corro.subs.changes.channel.async_fallbacks_count", "sql_hash" => handle.inner.hash.clone()).increment(1); - } - mpsc::error::TrySendError::Closed(_) => { - if let Some(handle) = self.remove(id) { - tokio::spawn(handle.cleanup()); - } - } - } - } - } - } } #[derive(Debug)] -struct MatchableChange<'a> { - table: &'a TableName, - pk: &'a [u8], - column: &'a ColumnName, +pub struct MatchableChange<'a> { + pub table: &'a TableName, + pub pk: &'a [u8], + pub column: &'a ColumnName, + pub cl: i64, } impl<'a> From<&'a Change> for MatchableChange<'a> { @@ -321,6 +204,7 @@ impl<'a> From<&'a Change> for MatchableChange<'a> { table: &value.table, pk: &value.pk, column: &value.cid, + cl: value.cl, } } } @@ -382,15 +266,81 @@ struct InnerMatcherHandle { // some state from the matcher so we can take a look later subs_path: String, cached_statements: HashMap, + metrics: HashMap, } -type MatchCandidates = IndexMap>>; +pub type MatchCandidates = IndexMap, i64>>; -impl MatcherHandle { - pub fn id(&self) -> Uuid { +#[async_trait] +impl Handle for MatcherHandle { + fn id(&self) -> Uuid { self.inner.id } + fn cancelled(&self) -> WaitForCancellationFuture { + self.inner.cancel.cancelled() + } + + fn changes_tx(&self) -> mpsc::Sender<(MatchCandidates, CrsqlDbVersion)> { + self.inner.changes_tx.clone() + } + + async fn cleanup(&self) { + self.inner.cancel.cancel(); + info!(sub_id = %self.inner.id, "Canceled subscription"); + } + + fn filter_matchable_change( + &self, + candidates: &mut MatchCandidates, + change: MatchableChange, + ) -> bool { + trace!("filtering change {change:?}"); + // don't double process the same pk + if candidates + .get(change.table) + .map(|pks| pks.contains_key(change.pk)) + .unwrap_or_default() + { + trace!("already contained key"); + return false; + } + + // don't consider changes that don't have both the table + col in the matcher query + if !self + .inner + .parsed + .table_columns + .get(change.table.as_str()) + .map(|cols| change.column.is_crsql_sentinel() || cols.contains(change.column.as_str())) + .unwrap_or_default() + { + trace!("could not match against parsed query table and columns"); + return false; + } + + if let Some(v) = candidates.get_mut(change.table) { + v.insert(change.pk.to_vec(), change.cl).is_none() + } else { + candidates.insert( + change.table.clone(), + [(change.pk.to_vec(), change.cl)].into(), + ); + true + } + } + + fn get_counter(&self, table: &str) -> &HandleMetrics { + self.inner.metrics.get(table).unwrap_or_else(|| { + panic!( + "metrics counter for table '{}' missing. subs hash {}!", + self.inner.hash, table + ) + }) + } +} + +impl MatcherHandle { pub fn sql(&self) -> &String { &self.inner.sql } @@ -415,11 +365,6 @@ impl MatcherHandle { &self.inner.cached_statements } - pub async fn cleanup(self) { - self.inner.cancel.cancel(); - info!(sub_id = %self.inner.id, "Canceled subscription"); - } - pub fn pool(&self) -> &RusqlitePool { &self.inner.pool } @@ -432,10 +377,6 @@ impl MatcherHandle { } } - pub fn cancelled(&self) -> WaitForCancellationFuture { - self.inner.cancel.cancelled() - } - pub fn max_change_id(&self, conn: &Connection) -> rusqlite::Result { self.wait_for_running_state(); let mut prepped = conn.prepare_cached("SELECT COALESCE(MAX(id), 0) FROM changes")?; @@ -558,43 +499,6 @@ impl MatcherHandle { Ok(max_change_id) } - - fn filter_matchable_change( - &self, - candidates: &mut MatchCandidates, - change: MatchableChange, - ) -> bool { - trace!("filtering change {change:?}"); - // don't double process the same pk - if candidates - .get(change.table) - .map(|pks| pks.contains(change.pk)) - .unwrap_or_default() - { - trace!("already contained key"); - return false; - } - - // don't consider changes that don't have both the table + col in the matcher query - if !self - .inner - .parsed - .table_columns - .get(change.table.as_str()) - .map(|cols| change.column.is_crsql_sentinel() || cols.contains(change.column.as_str())) - .unwrap_or_default() - { - trace!("could not match against parsed query table and columns"); - return false; - } - - if let Some(v) = candidates.get_mut(change.table) { - v.insert(change.pk.to_vec()) - } else { - candidates.insert(change.table.clone(), [change.pk.to_vec()].into()); - true - } - } } type StateLock = Arc<(Mutex, Condvar)>; @@ -840,6 +744,14 @@ impl Matcher { // big channel to not miss anything let (changes_tx, changes_rx) = mpsc::channel(20480); + // metrics counters + let mut counter_map = HashMap::new(); + for table in parsed.table_columns.keys() { + counter_map.insert(table.clone(), HandleMetrics{ + matched_count: counter!("corro.subs.changes.matched.count", "sql_hash" => sql_hash.clone(), "table" => table.to_string()), + }); + } + let handle = MatcherHandle { inner: Arc::new(InnerMatcherHandle { id, @@ -857,6 +769,7 @@ impl Matcher { changes_tx, cached_statements: statements.clone(), subs_path: sub_path.to_string(), + metrics: counter_map, }), state: state.clone(), }; @@ -1193,8 +1106,8 @@ impl Matcher { Some((candidates, db_version)) = self.changes_rx.recv() => { for (table, pks) in candidates { let buffed = buf.entry(table).or_default(); - for pk in pks { - if buffed.insert(pk) { + for (pk, cl) in pks { + if buffed.insert(pk, cl).is_none() { buf_count += 1; } } @@ -1537,7 +1450,7 @@ impl Matcher { for (table, pks) in candidates { let pks = pks .iter() - .map(|pk| unpack_columns(pk)) + .map(|(pk, _)| unpack_columns(pk)) .collect::>, _>>()?; let tmp_table_name = format!("temp_{table}"); @@ -1804,7 +1717,7 @@ impl Matcher { { let mut changes_prepped = state_conn.prepare_cached( r#" - SELECT DISTINCT "table", pk + SELECT DISTINCT "table", pk, cl FROM crsql_changes WHERE db_version > ? AND db_version <= ? -- TODO: allow going over? @@ -1818,7 +1731,7 @@ impl Matcher { candidates .entry(row.get(0)?) .or_default() - .insert(row.get(1)?); + .insert(row.get(1)?, row.get(2)?); } } diff --git a/crates/corro-types/src/updates.rs b/crates/corro-types/src/updates.rs new file mode 100644 index 00000000..59daca07 --- /dev/null +++ b/crates/corro-types/src/updates.rs @@ -0,0 +1,560 @@ +use crate::agent::SplitPool; +use crate::pubsub::{unpack_columns, MatchCandidates, MatchableChange, MatcherError}; +use crate::schema::Schema; +use async_trait::async_trait; +use compact_str::CompactString; +use corro_api_types::sqlite::ChangeType; +use corro_api_types::{Change, ColumnName, SqliteValue, SqliteValueRef, TableName}; +use corro_base_types::CrsqlDbVersion; +use metrics::{counter, histogram, Counter}; +use parking_lot::RwLock; +use rusqlite::Connection; +use serde::{Deserialize, Serialize}; +use spawn::spawn_counted; +use std::collections::BTreeMap; +use std::fmt::{Debug, Formatter}; +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::mpsc; +use tokio::task::block_in_place; +use tokio::time::Instant; +use tokio_util::sync::{CancellationToken, WaitForCancellationFuture}; +use tracing::{debug, error, info, trace, warn}; +use tripwire::Tripwire; +use uuid::Uuid; + +pub trait Manager { + fn trait_type(&self) -> String; + fn get(&self, id: &Uuid) -> Option; + fn remove(&self, id: &Uuid) -> Option; + fn get_handles(&self) -> BTreeMap; +} + +#[async_trait] +pub trait Handle { + fn id(&self) -> Uuid; + fn cancelled(&self) -> WaitForCancellationFuture; + fn filter_matchable_change( + &self, + candidates: &mut MatchCandidates, + change: MatchableChange, + ) -> bool; + fn changes_tx(&self) -> mpsc::Sender<(MatchCandidates, CrsqlDbVersion)>; + async fn cleanup(&self); + fn get_counter(&self, table: &str) -> &HandleMetrics; +} + +pub type NotifyEvent = TypedNotifyEvent>; + +pub enum NotifyType { + Upsert = 0, + Delete = 2, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "snake_case")] +pub enum TypedNotifyEvent { + Notify(ChangeType, T), + Error(CompactString), +} + +#[derive(Clone)] +pub struct HandleMetrics { + pub matched_count: Counter, +} + +impl std::fmt::Debug for HandleMetrics { + fn fmt(&self, f: &mut Formatter) -> Result<(), std::fmt::Error> { + f.debug_struct("Counter").finish_non_exhaustive() + } +} + +#[derive(Default, Debug, Clone)] +pub struct UpdatesManager(Arc>); + +impl Manager for UpdatesManager { + fn trait_type(&self) -> String { + "updates".to_string() + } + + fn get(&self, id: &Uuid) -> Option { + self.0.read().get(id) + } + + fn remove(&self, id: &Uuid) -> Option { + let mut inner = self.0.write(); + inner.remove(id) + } + + fn get_handles(&self) -> BTreeMap { + self.0.read().handles.clone() + } +} + +#[async_trait] +impl Handle for UpdateHandle { + fn id(&self) -> Uuid { + self.inner.id + } + + fn cancelled(&self) -> WaitForCancellationFuture { + self.inner.cancel.cancelled() + } + + fn filter_matchable_change( + &self, + candidates: &mut MatchCandidates, + change: MatchableChange, + ) -> bool { + if change.table.to_string() != self.inner.name { + return false; + } + + trace!("filtering change {change:?}"); + // don't double process the same pk + if candidates + .get(change.table) + .map(|pks| pks.contains_key(change.pk)) + .unwrap_or_default() + { + trace!("already contained key"); + return false; + } + + if let Some(v) = candidates.get_mut(change.table) { + v.insert(change.pk.into(), change.cl).is_none() + } else { + candidates.insert( + change.table.clone(), + [(change.pk.to_vec(), change.cl)].into(), + ); + true + } + } + + fn changes_tx(&self) -> mpsc::Sender<(MatchCandidates, CrsqlDbVersion)> { + self.inner.changes_tx.clone() + } + + async fn cleanup(&self) { + self.inner.cancel.cancel(); + info!(sub_id = %self.inner.id, "Canceled subscription"); + } + + fn get_counter(&self, table: &str) -> &HandleMetrics { + // this should not happen + if table != self.inner.name { + warn!(update_tbl = %self.inner.name, "udpates handle get_counter method called for wrong table : {table}! This shouldn't happen!") + } + &self.inner.counters + } +} + +const UPDATE_EVENT_CHANNEL_CAP: usize = 512; + +// tools to bootstrap a new notifier +pub struct UpdateCreated { + pub evt_rx: mpsc::Receiver, +} + +#[derive(Debug, Default)] +struct InnerUpdatesManager { + tables: BTreeMap, + handles: BTreeMap, +} + +impl InnerUpdatesManager { + fn remove(&mut self, id: &Uuid) -> Option { + let handle = self.handles.remove(id)?; + Some(handle) + } + + fn get(&self, id: &Uuid) -> Option { + self.handles.get(id).cloned() + } +} + +impl UpdatesManager { + pub fn get(&self, table: &str) -> Option { + let id = self.0.read().tables.get(table).cloned(); + if let Some(id) = id { + return self.0.read().handles.get(&id).cloned(); + } + None + } + + pub fn get_or_insert( + &self, + tbl_name: &str, + schema: &Schema, + _pool: &SplitPool, + tripwire: Tripwire, + ) -> Result<(UpdateHandle, Option), MatcherError> { + if let Some(handle) = self.get(tbl_name) { + return Ok((handle, None)); + } + + let mut inner = self.0.write(); + let (evt_tx, evt_rx) = mpsc::channel(UPDATE_EVENT_CHANNEL_CAP); + + let id = Uuid::new_v4(); + let handle_res = UpdateHandle::create(id, tbl_name, schema, evt_tx, tripwire); + + let handle = match handle_res { + Ok(handle) => handle, + Err(e) => { + return Err(e); + } + }; + + inner.handles.insert(id, handle.clone()); + inner.tables.insert(tbl_name.to_string(), id); + + Ok((handle, Some(UpdateCreated { evt_rx }))) + } + + pub fn remove(&self, id: &Uuid) -> Option { + let mut inner = self.0.write(); + + inner.remove(id) + } +} + +#[derive(Clone, Debug)] +pub struct UpdateHandle { + inner: Arc, +} + +#[derive(Clone, Debug)] +pub struct InnerUpdateHandle { + id: Uuid, + name: String, + cancel: CancellationToken, + changes_tx: mpsc::Sender<(MatchCandidates, CrsqlDbVersion)>, + counters: HandleMetrics, +} + +impl UpdateHandle { + pub fn id(&self) -> Uuid { + self.inner.id + } + + pub fn create( + id: Uuid, + tbl_name: &str, + schema: &Schema, + evt_tx: mpsc::Sender, + tripwire: Tripwire, + ) -> Result { + // check for existing handles + match schema.tables.get(tbl_name) { + Some(table) => { + if table.pk.is_empty() { + return Err(MatcherError::MissingPrimaryKeys); + } + } + None => return Err(MatcherError::TableNotFound(tbl_name.to_string())), + }; + + let cancel = CancellationToken::new(); + let (changes_tx, changes_rx) = mpsc::channel(20480); + let handle = UpdateHandle { + inner: Arc::new(InnerUpdateHandle { + id, + name: tbl_name.to_owned(), + cancel: cancel.clone(), + changes_tx, + counters: HandleMetrics { + matched_count: counter!("corro.updates.changes.matched.count", "table" => tbl_name.to_owned()), + }, + }), + }; + spawn_counted(batch_candidates(id, cancel, evt_tx, changes_rx, tripwire)); + Ok(handle) + } + + pub async fn cleanup(self) { + self.inner.cancel.cancel(); + info!(update_id = %self.inner.id, "Canceled update"); + } +} + +fn handle_candidates( + evt_tx: mpsc::Sender, + candidates: MatchCandidates, +) -> Result<(), MatcherError> { + if candidates.is_empty() { + return Ok(()); + } + + trace!( + "got some candidates for updates! {:?}", + candidates.keys().collect::>() + ); + + for (_, pks) in candidates { + let pks = pks + .iter() + .map(|(pk, cl)| unpack_columns(pk).map(|x| (x, *cl))) + .collect::, i64)>, _>>()?; + + for (pk, cl) in pks { + let mut change_type = ChangeType::Update; + if cl % 2 == 0 { + change_type = ChangeType::Delete + } + if let Err(e) = evt_tx.blocking_send(NotifyEvent::Notify( + change_type, + pk.iter().map(|x| x.to_owned()).collect::>(), + )) { + debug!("could not send back row to matcher sub sender: {e}"); + return Err(MatcherError::EventReceiverClosed); + } + } + } + + Ok(()) +} + +async fn batch_candidates( + id: Uuid, + cancel: CancellationToken, + evt_tx: mpsc::Sender, + mut changes_rx: mpsc::Receiver<(MatchCandidates, CrsqlDbVersion)>, + mut tripwire: Tripwire, +) { + const PROCESS_CHANGES_THRESHOLD: usize = 1000; + const PROCESS_BUFFER_DEADLINE: Duration = Duration::from_millis(600); + + info!(sub_id = %id, "Starting loop to run the subscription"); + + let mut buf = MatchCandidates::new(); + let mut buf_count = 0; + + // max duration of aggregating candidates + let process_changes_deadline = tokio::time::sleep(PROCESS_BUFFER_DEADLINE); + tokio::pin!(process_changes_deadline); + + let mut process = false; + loop { + tokio::select! { + biased; + _ = cancel.cancelled() => { + info!(sub_id = %id, "Acknowledged updates cancellation, breaking loop."); + break; + } + Some((candidates, _)) = changes_rx.recv() => { + for (table, pk_map) in candidates { + let buffed = buf.entry(table).or_default(); + for (pk, cl) in pk_map { + if buffed.insert(pk, cl).is_none() { + buf_count += 1; + } + } + } + + if buf_count >= PROCESS_CHANGES_THRESHOLD { + process = true + } + }, + _ = process_changes_deadline.as_mut() => { + process_changes_deadline + .as_mut() + .reset(Instant::now() + PROCESS_BUFFER_DEADLINE); + if buf_count != 0 { + process = true + } + }, + _ = &mut tripwire => { + trace!(sub_id = %id, "tripped batch_candidates, returning"); + return; + } + else => { + return; + } + } + + if process { + let start = Instant::now(); + if let Err(e) = + block_in_place(|| handle_candidates(evt_tx.clone(), std::mem::take(&mut buf))) + { + if !matches!(e, MatcherError::EventReceiverClosed) { + error!(sub_id = %id, "could not handle change: {e}"); + } + break; + } + let elapsed = start.elapsed(); + + histogram!("corro.updates.changes.processing.duration.seconds", "table" => id.to_string()).record(elapsed); + + buf_count = 0; + + // reset the deadline + process_changes_deadline + .as_mut() + .reset(Instant::now() + PROCESS_BUFFER_DEADLINE); + } + } + + debug!(id = %id, "update loop is done"); +} + +pub fn match_changes(manager: &impl Manager, changes: &[Change], db_version: CrsqlDbVersion) +where + H: Handle + Send + 'static, +{ + let trait_type = manager.trait_type(); + trace!( + %db_version, + "trying to match changes to {trait_type}, len: {}", + changes.len() + ); + if changes.is_empty() { + return; + } + + let handles = manager.get_handles(); + if handles.is_empty() { + return; + } + + for (id, handle) in handles.iter() { + trace!(sub_id = %id, %db_version, "attempting to match changes to a subscription"); + let mut candidates = MatchCandidates::new(); + let mut match_count = 0; + for change in changes.iter().map(MatchableChange::from) { + if handle.filter_matchable_change(&mut candidates, change) { + match_count += 1; + } + } + + // metrics... + for (table, pks) in candidates.iter() { + handle + .get_counter(table) + .matched_count + .increment(pks.len() as u64); + } + + trace!(sub_id = %id, %db_version, "found {match_count} candidates"); + + if let Err(e) = handle + .changes_tx() + .try_send((candidates, db_version)) + { + error!(sub_id = %id, "could not send change candidates to {trait_type} handler: {e}"); + match e { + mpsc::error::TrySendError::Full(item) => { + warn!("channel is full, falling back to async send"); + + let changes_tx = handle.changes_tx(); + tokio::spawn(async move { + _ = changes_tx.send(item).await; + }); + } + mpsc::error::TrySendError::Closed(_) => { + if let Some(handle) = manager.remove(id) { + tokio::spawn(async move { + handle.cleanup().await; + }); + } + } + } + } + } +} + +pub fn match_changes_from_db_version( + manager: &impl Manager, + conn: &Connection, + db_version: CrsqlDbVersion, +) -> rusqlite::Result<()> +where + H: Handle + Send + 'static, +{ + let handles = manager.get_handles(); + if handles.is_empty() { + return Ok(()); + } + + let trait_type = manager.trait_type(); + let mut candidates = handles + .iter() + .map(|(id, handle)| (id, (MatchCandidates::new(), handle))) + .collect::>(); + + { + let mut prepped = conn.prepare_cached( + r#" + SELECT "table", pk, cid, cl + FROM crsql_changes + WHERE db_version = ? + ORDER BY seq ASC + "#, + )?; + + let rows = prepped.query_map([db_version], |row| { + Ok(( + row.get::<_, TableName>(0)?, + row.get::<_, Vec>(1)?, + row.get::<_, ColumnName>(2)?, + row.get::<_, i64>(3)?, + )) + })?; + + for change_res in rows { + let (table, pk, column, cl) = change_res?; + + for (_id, (candidates, handle)) in candidates.iter_mut() { + let change = MatchableChange { + table: &table, + pk: &pk, + column: &column, + cl, + }; + handle.filter_matchable_change(candidates, change); + } + } + } + + // metrics... + for (id, (candidates, handle)) in candidates { + let mut match_count = 0; + for (table, pks) in candidates.iter() { + let count = pks.len(); + match_count += count; + handle + .get_counter(table) + .matched_count + .increment(pks.len() as u64); + } + + trace!(sub_id = %id, %db_version, "found {match_count} candidates"); + + if let Err(e) = handle + .changes_tx() + .try_send((candidates, db_version)) + { + error!(sub_id = %id, "could not send change candidates to {trait_type} handler: {e}"); + match e { + mpsc::error::TrySendError::Full(item) => { + warn!("channel is full, falling back to async send"); + let changes_tx = handle.changes_tx(); + tokio::spawn(async move { + _ = changes_tx.send(item).await; + }); + } + mpsc::error::TrySendError::Closed(_) => { + if let Some(handle) = manager.remove(id) { + tokio::spawn(async move { + handle.cleanup().await; + }); + } + } + } + } + } + + Ok(()) +} From d2abe513366b394c28ce03fe0a390fd274155c2f Mon Sep 17 00:00:00 2001 From: Pavel Borzenkov Date: Wed, 20 Nov 2024 13:45:35 +0100 Subject: [PATCH 2/4] agent: pass SharedUpdateBroadcastCache to /updates handler It's required by the handler but not getting passed to Axum's middleware stack. --- crates/corro-agent/src/agent/run_root.rs | 2 ++ crates/corro-agent/src/agent/setup.rs | 11 +++++++++-- crates/corro-agent/src/agent/util.rs | 3 +++ 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/crates/corro-agent/src/agent/run_root.rs b/crates/corro-agent/src/agent/run_root.rs index c6b9730e..422dc675 100644 --- a/crates/corro-agent/src/agent/run_root.rs +++ b/crates/corro-agent/src/agent/run_root.rs @@ -49,6 +49,7 @@ async fn run(agent: Agent, opts: AgentOptions, pconf: PerfConfig) -> eyre::Resul rx_foca, subs_manager, subs_bcast_cache, + updates_bcast_cache, rtt_rx, } = opts; @@ -96,6 +97,7 @@ async fn run(agent: Agent, opts: AgentOptions, pconf: PerfConfig) -> eyre::Resul &agent, &tripwire, subs_bcast_cache, + updates_bcast_cache, &subs_manager, api_listeners, ) diff --git a/crates/corro-agent/src/agent/setup.rs b/crates/corro-agent/src/agent/setup.rs index 884721b7..7faf8d6a 100644 --- a/crates/corro-agent/src/agent/setup.rs +++ b/crates/corro-agent/src/agent/setup.rs @@ -25,10 +25,14 @@ use tripwire::Tripwire; use crate::{ api::{ peer::gossip_server_endpoint, - public::pubsub::{process_sub_channel, MatcherBroadcastCache, SharedMatcherBroadcastCache}, + public::{ + pubsub::{process_sub_channel, MatcherBroadcastCache, SharedMatcherBroadcastCache}, + update::SharedUpdateBroadcastCache, + }, }, transport::Transport, }; +use corro_types::updates::UpdatesManager; use corro_types::{ actor::ActorId, agent::{migrate, Agent, AgentConfig, Booked, BookedVersions, LockRegistry, SplitPool}, @@ -41,7 +45,6 @@ use corro_types::{ schema::{init_schema, Schema}, sqlite::CrConn, }; -use corro_types::updates::UpdatesManager; /// Runtime state for the Corrosion agent pub struct AgentOptions { @@ -58,6 +61,7 @@ pub struct AgentOptions { pub rtt_rx: TokioReceiver<(SocketAddr, Duration)>, pub subs_manager: SubsManager, pub subs_bcast_cache: SharedMatcherBroadcastCache, + pub updates_bcast_cache: SharedUpdateBroadcastCache, pub tripwire: Tripwire, } @@ -118,6 +122,8 @@ pub async fn setup(conf: Config, tripwire: Tripwire) -> eyre::Result<(Agent, Age ) .await?; + let updates_bcast_cache = SharedUpdateBroadcastCache::default(); + let cluster_id = { let conn = pool.read().await?; conn.query_row( @@ -189,6 +195,7 @@ pub async fn setup(conf: Config, tripwire: Tripwire) -> eyre::Result<(Agent, Age rtt_rx, subs_manager: subs_manager.clone(), subs_bcast_cache, + updates_bcast_cache, tripwire: tripwire.clone(), }; diff --git a/crates/corro-agent/src/agent/util.rs b/crates/corro-agent/src/agent/util.rs index 67a90b8a..054011a8 100644 --- a/crates/corro-agent/src/agent/util.rs +++ b/crates/corro-agent/src/agent/util.rs @@ -10,6 +10,7 @@ use crate::{ api::public::{ api_v1_db_schema, api_v1_queries, api_v1_table_stats, api_v1_transactions, pubsub::{api_v1_sub_by_id, api_v1_subs}, + update::SharedUpdateBroadcastCache, }, transport::Transport, }; @@ -170,6 +171,7 @@ pub async fn setup_http_api_handler( agent: &Agent, tripwire: &Tripwire, subs_bcast_cache: BcastCache, + updates_bcast_cache: SharedUpdateBroadcastCache, subs_manager: &SubsManager, api_listeners: Vec, ) -> eyre::Result<()> { @@ -280,6 +282,7 @@ pub async fn setup_http_api_handler( .layer(Extension(Arc::new(AtomicI64::new(0)))) .layer(Extension(agent.clone())) .layer(Extension(subs_bcast_cache)) + .layer(Extension(updates_bcast_cache)) .layer(Extension(subs_manager.clone())) .layer(Extension(tripwire.clone())), ) From df36ccc24ad8397b95c195037eec768c2c666aa8 Mon Sep 17 00:00:00 2001 From: Pavel Borzenkov Date: Wed, 20 Nov 2024 14:10:35 +0100 Subject: [PATCH 3/4] api: move NotifyEvent to corro-api-types As the client doesn't import "big" corro-types, move public facing types to the dedicated crate. --- crates/corro-agent/src/api/public/pubsub.rs | 2 +- crates/corro-agent/src/api/public/update.rs | 3 ++- crates/corro-api-types/src/lib.rs | 9 +++++++ crates/corro-types/src/updates.rs | 28 +++------------------ 4 files changed, 15 insertions(+), 27 deletions(-) diff --git a/crates/corro-agent/src/api/public/pubsub.rs b/crates/corro-agent/src/api/public/pubsub.rs index 71e9d1b9..425f3606 100644 --- a/crates/corro-agent/src/api/public/pubsub.rs +++ b/crates/corro-agent/src/api/public/pubsub.rs @@ -882,11 +882,11 @@ async fn forward_bytes_to_body_sender( #[cfg(test)] mod tests { use corro_types::actor::ActorId; + use corro_types::api::NotifyEvent; use corro_types::api::{Change, ColumnName, TableName}; use corro_types::base::{CrsqlDbVersion, CrsqlSeq, Version}; use corro_types::broadcast::{ChangeSource, ChangeV1, Changeset}; use corro_types::pubsub::pack_columns; - use corro_types::updates::NotifyEvent; use corro_types::{ api::{ChangeId, RowId}, config::Config, diff --git a/crates/corro-agent/src/api/public/update.rs b/crates/corro-agent/src/api/public/update.rs index 51c9deba..8f78fc10 100644 --- a/crates/corro-agent/src/api/public/update.rs +++ b/crates/corro-agent/src/api/public/update.rs @@ -5,7 +5,8 @@ use bytes::{BufMut, Bytes, BytesMut}; use compact_str::ToCompactString; use corro_types::{ agent::Agent, - updates::{Handle, NotifyEvent, UpdateCreated, UpdateHandle, UpdatesManager}, + api::NotifyEvent, + updates::{Handle, UpdateCreated, UpdateHandle, UpdatesManager}, }; use futures::future::poll_fn; use tokio::sync::{ diff --git a/crates/corro-api-types/src/lib.rs b/crates/corro-api-types/src/lib.rs index a2991cce..f67d7caa 100644 --- a/crates/corro-api-types/src/lib.rs +++ b/crates/corro-api-types/src/lib.rs @@ -59,6 +59,15 @@ pub enum QueryEventMeta { Notify, } +pub type NotifyEvent = TypedNotifyEvent>; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "snake_case")] +pub enum TypedNotifyEvent { + Notify(ChangeType, T), + Error(CompactString), +} + /// RowId newtype to differentiate from ChangeId #[derive( Debug, diff --git a/crates/corro-types/src/updates.rs b/crates/corro-types/src/updates.rs index 59daca07..09810bc6 100644 --- a/crates/corro-types/src/updates.rs +++ b/crates/corro-types/src/updates.rs @@ -2,14 +2,12 @@ use crate::agent::SplitPool; use crate::pubsub::{unpack_columns, MatchCandidates, MatchableChange, MatcherError}; use crate::schema::Schema; use async_trait::async_trait; -use compact_str::CompactString; use corro_api_types::sqlite::ChangeType; -use corro_api_types::{Change, ColumnName, SqliteValue, SqliteValueRef, TableName}; +use corro_api_types::{Change, ColumnName, NotifyEvent, SqliteValueRef, TableName}; use corro_base_types::CrsqlDbVersion; use metrics::{counter, histogram, Counter}; use parking_lot::RwLock; use rusqlite::Connection; -use serde::{Deserialize, Serialize}; use spawn::spawn_counted; use std::collections::BTreeMap; use std::fmt::{Debug, Formatter}; @@ -44,20 +42,6 @@ pub trait Handle { fn get_counter(&self, table: &str) -> &HandleMetrics; } -pub type NotifyEvent = TypedNotifyEvent>; - -pub enum NotifyType { - Upsert = 0, - Delete = 2, -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -#[serde(rename_all = "snake_case")] -pub enum TypedNotifyEvent { - Notify(ChangeType, T), - Error(CompactString), -} - #[derive(Clone)] pub struct HandleMetrics { pub matched_count: Counter, @@ -439,10 +423,7 @@ where trace!(sub_id = %id, %db_version, "found {match_count} candidates"); - if let Err(e) = handle - .changes_tx() - .try_send((candidates, db_version)) - { + if let Err(e) = handle.changes_tx().try_send((candidates, db_version)) { error!(sub_id = %id, "could not send change candidates to {trait_type} handler: {e}"); match e { mpsc::error::TrySendError::Full(item) => { @@ -532,10 +513,7 @@ where trace!(sub_id = %id, %db_version, "found {match_count} candidates"); - if let Err(e) = handle - .changes_tx() - .try_send((candidates, db_version)) - { + if let Err(e) = handle.changes_tx().try_send((candidates, db_version)) { error!(sub_id = %id, "could not send change candidates to {trait_type} handler: {e}"); match e { mpsc::error::TrySendError::Full(item) => { From 9cbf35479023bf9ac239ac00d7c2e64d320575d8 Mon Sep 17 00:00:00 2001 From: Pavel Borzenkov Date: Wed, 20 Nov 2024 14:33:50 +0100 Subject: [PATCH 4/4] client: add Rust client for the new /updates endpoint The client is much simpler than the existing subscription client: - it doesn't handle I/O error and simply bubble it up (as the user still need to handle it) - the only possible errors are I/O and deserialization as the /updates stream doesn't have a "change id" concept --- crates/corro-client/src/lib.rs | 41 +++++++++++++++++++++- crates/corro-client/src/sub.rs | 62 +++++++++++++++++++++++++++++++++- 2 files changed, 101 insertions(+), 2 deletions(-) diff --git a/crates/corro-client/src/lib.rs b/crates/corro-client/src/lib.rs index b5019320..9c4e17cd 100644 --- a/crates/corro-client/src/lib.rs +++ b/crates/corro-client/src/lib.rs @@ -11,7 +11,7 @@ use std::{ sync::Arc, time::{self, Duration, Instant}, }; -use sub::{QueryStream, SubscriptionStream}; +use sub::{QueryStream, SubscriptionStream, UpdatesStream}; use tokio::{ sync::{RwLock, RwLockReadGuard}, time::timeout, @@ -218,6 +218,45 @@ impl CorrosionApiClient { self.subscription_typed(id, skip_rows, from).await } + pub async fn updates_typed( + &self, + table: &str, + ) -> Result, Error> { + let p_and_q: PathAndQuery = format!("/v1/updates/{}", table).try_into()?; + + let url = hyper::Uri::builder() + .scheme("http") + .authority(self.api_addr.to_string()) + .path_and_query(p_and_q) + .build()?; + + let req = hyper::Request::builder() + .method(hyper::Method::POST) + .uri(url) + .header(hyper::header::CONTENT_TYPE, "application/json") + .header(hyper::header::ACCEPT, "application/json") + .body(hyper::Body::empty())?; + + let res = self.api_client.request(req).await?; + + if !res.status().is_success() { + return Err(Error::UnexpectedStatusCode(res.status())); + } + + // TODO: make that header name a const in corro-types + let id = res + .headers() + .get(HeaderName::from_static("corro-query-id")) + .and_then(|v| v.to_str().ok().and_then(|v| v.parse().ok())) + .ok_or(Error::ExpectedQueryId)?; + + Ok(UpdatesStream::new(id, res.into_body())) + } + + pub async fn updates(&self, table: &str) -> Result>, Error> { + self.updates_typed(table).await + } + pub async fn execute(&self, statements: &[Statement]) -> Result { let req = hyper::Request::builder() .method(hyper::Method::POST) diff --git a/crates/corro-client/src/sub.rs b/crates/corro-client/src/sub.rs index 6d21a475..264a62dc 100644 --- a/crates/corro-client/src/sub.rs +++ b/crates/corro-client/src/sub.rs @@ -8,7 +8,7 @@ use std::{ }; use bytes::{Buf, Bytes, BytesMut}; -use corro_api_types::{ChangeId, QueryEvent, TypedQueryEvent}; +use corro_api_types::{ChangeId, QueryEvent, TypedNotifyEvent, TypedQueryEvent}; use futures::{ready, Future, Stream}; use hyper::{client::HttpConnector, Body}; use pin_project_lite::pin_project; @@ -307,6 +307,66 @@ where } } +pub struct UpdatesStream { + id: Uuid, + stream: FramedBody, + _deser: std::marker::PhantomData, +} + +#[derive(Debug, thiserror::Error)] +pub enum UpdatesError { + #[error(transparent)] + Io(#[from] io::Error), + #[error(transparent)] + Deserialize(#[from] serde_json::Error), + #[error("max line length exceeded")] + MaxLineLengthExceeded, +} + +impl UpdatesStream +where + T: DeserializeOwned + Unpin, +{ + pub fn new(id: Uuid, body: hyper::Body) -> Self { + Self { + id, + stream: FramedRead::new( + StreamReader::new(IoBodyStream { body }), + LinesBytesCodec::default(), + ), + _deser: Default::default(), + } + } + + pub fn id(&self) -> Uuid { + self.id + } +} + +impl Stream for UpdatesStream +where + T: DeserializeOwned + Unpin, +{ + type Item = Result, UpdatesError>; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let res = ready!(Pin::new(&mut self.stream).poll_next(cx)); + match res { + Some(Ok(b)) => match serde_json::from_slice(&b) { + Ok(evt) => Poll::Ready(Some(Ok(evt))), + Err(e) => Poll::Ready(Some(Err(e.into()))), + }, + Some(Err(e)) => match e { + LinesCodecError::MaxLineLengthExceeded => { + Poll::Ready(Some(Err(UpdatesError::MaxLineLengthExceeded))) + } + LinesCodecError::Io(io_err) => Poll::Ready(Some(Err(io_err.into()))), + }, + None => Poll::Ready(None), + } + } +} + pub struct QueryStream { stream: FramedBody, _deser: std::marker::PhantomData,