diff --git a/Cargo.lock b/Cargo.lock index 0a4bd343..237a4c2b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -731,6 +731,7 @@ checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa" name = "corro-admin" version = "0.1.0" dependencies = [ + "bytes", "camino", "corro-agent", "corro-types", diff --git a/crates/corro-admin/Cargo.toml b/crates/corro-admin/Cargo.toml index 3d83ff8a..11807a45 100644 --- a/crates/corro-admin/Cargo.toml +++ b/crates/corro-admin/Cargo.toml @@ -21,3 +21,4 @@ tracing = { workspace = true } tripwire = { path = "../tripwire" } rangemap = { workspace = true } uuid = { workspace = true } +bytes = { workspace = true } \ No newline at end of file diff --git a/crates/corro-admin/src/lib.rs b/crates/corro-admin/src/lib.rs index 38a5d53a..094e17de 100644 --- a/crates/corro-admin/src/lib.rs +++ b/crates/corro-admin/src/lib.rs @@ -1,13 +1,24 @@ -use std::{fmt::Display, time::Duration}; +use std::{fmt::Display, net::SocketAddr, time::Duration}; +use bytes::BytesMut; use camino::Utf8PathBuf; +use corro_agent::{ + api::peer::{ + encode_write_bipayload_msg, + follow::{read_follow_msg, FollowMessage, FollowMessageV1}, + }, + transport::Transport, +}; use corro_types::{ actor::{ActorId, ClusterId}, agent::{Agent, Bookie, LockKind, LockMeta, LockState}, + api::SqliteValueRef, base::{CrsqlDbVersion, CrsqlSeq, CrsqlSiteVersion}, - broadcast::{FocaCmd, FocaInput}, + broadcast::{BiPayload, Changeset, FocaCmd, FocaInput}, + pubsub::unpack_columns, sqlite::SqlitePoolError, sync::generate_sync, + updates::Handle, }; use futures::{SinkExt, TryStreamExt}; use rusqlite::{named_params, OptionalExtension}; @@ -21,7 +32,7 @@ use tokio::{ task::block_in_place, }; use tokio_serde::{formats::Json, Framed}; -use tokio_util::codec::LengthDelimitedCodec; +use tokio_util::codec::{FramedRead, LengthDelimitedCodec}; use tracing::{debug, error, info, warn}; use tripwire::Tripwire; use uuid::Uuid; @@ -41,6 +52,7 @@ pub struct AdminConfig { pub fn start_server( agent: Agent, bookie: Bookie, + transport: Transport, config: AdminConfig, mut tripwire: Tripwire, ) -> Result<(), AdminError> { @@ -74,8 +86,9 @@ pub fn start_server( let agent = agent.clone(); let bookie = bookie.clone(); let config = config.clone(); + let transport = transport.clone(); async move { - if let Err(e) = handle_conn(agent, &bookie, config, stream).await { + if let Err(e) = handle_conn(agent, &bookie, &transport, config, stream).await { error!("could not handle admin connection: {e}"); } } @@ -95,6 +108,16 @@ pub enum Command { Cluster(ClusterCommand), Actor(ActorCommand), Subs(SubsCommand), + Debug(DebugCommand), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum DebugCommand { + Follow { + peer_addr: SocketAddr, + from: Option, + local_only: bool, + }, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -193,6 +216,7 @@ impl From for LockMetaElapsed { async fn handle_conn( agent: Agent, bookie: &Bookie, + transport: &Transport, _config: AdminConfig, stream: UnixStream, ) -> Result<(), AdminError> { @@ -475,6 +499,132 @@ async fn handle_conn( } }; } + Command::Debug(DebugCommand::Follow { + peer_addr, + from, + local_only, + }) => match transport.open_bi(peer_addr).await { + Ok((mut tx, recv)) => { + let mut codec = LengthDelimitedCodec::builder() + .max_frame_length(100 * 1_024 * 1_024) + .new_codec(); + let mut encoding_buf = BytesMut::new(); + let mut buf = BytesMut::new(); + + if let Err(e) = encode_write_bipayload_msg( + &mut codec, + &mut encoding_buf, + &mut buf, + BiPayload::V1 { + data: corro_types::broadcast::BiPayloadV1::Follow { + from: from.map(CrsqlDbVersion), + local_only, + }, + cluster_id: agent.cluster_id(), + }, + &mut tx, + ) + .await + { + send_error( + &mut stream, + format!("could not send follow payload to {peer_addr}: {e}"), + ) + .await; + continue; + } + + let mut framed = FramedRead::new( + recv, + LengthDelimitedCodec::builder() + .max_frame_length(100 * 1_024 * 1_024) + .new_codec(), + ); + + 'msg: loop { + match read_follow_msg(&mut framed).await { + Ok(None) => { + send_success(&mut stream).await; + break; + } + Err(e) => { + send_error( + &mut stream, + format!("error receiving follow message: {e}"), + ) + .await; + break; + } + Ok(Some(msg)) => { + match msg { + FollowMessage::V1(FollowMessageV1::Change(change)) => { + let actor_id = change.actor_id; + match change.changeset { + Changeset::Full { + version, + changes, + ts, + .. + } => { + if let Err(e) = stream + .send(Response::Json(serde_json::json!({ + "actor_id": actor_id, + "type": "full", + "version": version, + "ts": ts.to_string(), + }))) + .await + { + warn!("could not send to steam, breaking ({e})"); + break; + } + + for change in changes { + if let Err(e) = stream.send( + Response::Json( + serde_json::json!({ + "table": change.table, + "pk": unpack_columns(&change.pk).unwrap().iter().map(SqliteValueRef::to_owned).collect::>(), + "cid": change.cid, + "val": change.val, + "col_version": change.col_version, + "db_version": change.db_version, + "seq": change.seq, + "site_id": ActorId::from_bytes(change.site_id), + "cl": change.cl, + }), + ), + ) + .await { + warn!("could not send to steam, breaking ({e})"); + break 'msg; + } + } + } + changeset => { + send_log( + &mut stream, + LogLevel::Warn, + format!("unknown change type received: {changeset:?}"), + ) + .await; + } + } + } + } + } + } + } + } + Err(e) => { + send_error( + &mut stream, + format!("could not open bi-directional stream with {peer_addr}: {e}"), + ) + .await; + continue; + } + }, }, Ok(None) => { debug!("done with admin conn"); diff --git a/crates/corro-agent/src/agent/bi.rs b/crates/corro-agent/src/agent/bi.rs index f4172f12..8b95ee7e 100644 --- a/crates/corro-agent/src/agent/bi.rs +++ b/crates/corro-agent/src/agent/bi.rs @@ -1,4 +1,4 @@ -use crate::api::peer::serve_sync; +use crate::api::peer::{follow::serve_follow, serve_sync}; use corro_types::{ agent::{Agent, Bookie}, broadcast::{BiPayload, BiPayloadV1}, @@ -56,7 +56,12 @@ pub fn spawn_bipayload_handler( let agent = agent.clone(); let bookie = bookie.clone(); async move { - let mut framed = FramedRead::new(rx, LengthDelimitedCodec::builder().max_frame_length(100 * 1_024 * 1_024).new_codec()); + let mut framed = FramedRead::new( + rx, + LengthDelimitedCodec::builder() + .max_frame_length(100 * 1_024 * 1_024) + .new_codec(), + ); loop { match timeout(Duration::from_secs(5), StreamExt::next(&mut framed)).await { @@ -72,30 +77,38 @@ pub fn spawn_bipayload_handler( match BiPayload::read_from_buffer(&b) { Ok(payload) => { match payload { - BiPayload::V1 { - data: - BiPayloadV1::SyncStart { - actor_id, - trace_ctx, - }, - cluster_id, - } => { - trace!( - "framed read buffer len: {}", - framed.read_buffer().len() - ); + BiPayload::V1 { data, cluster_id } => match data { + BiPayloadV1::SyncStart { + actor_id, + trace_ctx, + } => { + trace!( + "framed read buffer len: {}", + framed.read_buffer().len() + ); - // println!("got sync state: {state:?}"); - if let Err(e) = serve_sync( - &agent, &bookie, actor_id, trace_ctx, - cluster_id, framed, tx, - ) - .await - { - warn!("could not complete receiving sync: {e}"); + // println!("got sync state: {state:?}"); + if let Err(e) = serve_sync( + &agent, &bookie, actor_id, trace_ctx, + cluster_id, framed, tx, + ) + .await + { + warn!("could not complete receiving sync: {e}"); + } + break; } - break; - } + BiPayloadV1::Follow { from, local_only } => { + if let Err(e) = serve_follow( + &agent, from, local_only, tx, + ) + .await + { + warn!("could not complete follow: {e}"); + } + break; + } + }, } } diff --git a/crates/corro-agent/src/agent/run_root.rs b/crates/corro-agent/src/agent/run_root.rs index 6805f286..7842fb11 100644 --- a/crates/corro-agent/src/agent/run_root.rs +++ b/crates/corro-agent/src/agent/run_root.rs @@ -8,6 +8,7 @@ use crate::{ metrics, setup, util, AgentOptions, }, broadcast::runtime_loop, + transport::Transport, }; use corro_types::{ actor::ActorId, @@ -26,12 +27,16 @@ use tripwire::Tripwire; /// /// First initialise `AgentOptions` state via `setup()`, then spawn a /// new task that runs the main agent state machine -pub async fn start_with_config(conf: Config, tripwire: Tripwire) -> eyre::Result<(Agent, Bookie)> { +pub async fn start_with_config( + conf: Config, + tripwire: Tripwire, +) -> eyre::Result<(Agent, Bookie, Transport)> { let (agent, opts) = setup(conf.clone(), tripwire.clone()).await?; + let transport = opts.transport.clone(); let bookie = run(agent.clone(), opts, conf.perf).await?; - Ok((agent, bookie)) + Ok((agent, bookie, transport)) } async fn run(agent: Agent, opts: AgentOptions, pconf: PerfConfig) -> eyre::Result { @@ -49,6 +54,7 @@ async fn run(agent: Agent, opts: AgentOptions, pconf: PerfConfig) -> eyre::Resul rx_foca, subs_manager, subs_bcast_cache, + updates_bcast_cache, rtt_rx, } = opts; @@ -96,6 +102,7 @@ async fn run(agent: Agent, opts: AgentOptions, pconf: PerfConfig) -> eyre::Resul &agent, &tripwire, subs_bcast_cache, + updates_bcast_cache, &subs_manager, api_listeners, ) diff --git a/crates/corro-agent/src/agent/setup.rs b/crates/corro-agent/src/agent/setup.rs index 77540ddc..243ba1c4 100644 --- a/crates/corro-agent/src/agent/setup.rs +++ b/crates/corro-agent/src/agent/setup.rs @@ -2,6 +2,7 @@ // External crates use arc_swap::ArcSwap; +use backoff::Backoff; use camino::Utf8PathBuf; use indexmap::IndexMap; use metrics::counter; @@ -26,8 +27,12 @@ use tripwire::Tripwire; // Internals use crate::{ api::{ + self, peer::gossip_server_endpoint, - public::pubsub::{process_sub_channel, MatcherBroadcastCache, SharedMatcherBroadcastCache}, + public::{ + pubsub::{process_sub_channel, MatcherBroadcastCache, SharedMatcherBroadcastCache}, + update::SharedUpdateBroadcastCache, + }, }, transport::Transport, }; @@ -43,6 +48,7 @@ use corro_types::{ schema::{init_schema, Schema}, sqlite::CrConn, }; +use corro_types::{config::FollowFrom, updates::UpdatesManager}; /// Runtime state for the Corrosion agent pub struct AgentOptions { @@ -59,6 +65,7 @@ pub struct AgentOptions { pub rtt_rx: TokioReceiver<(SocketAddr, Duration)>, pub subs_manager: SubsManager, pub subs_bcast_cache: SharedMatcherBroadcastCache, + pub updates_bcast_cache: SharedUpdateBroadcastCache, pub tripwire: Tripwire, } @@ -108,6 +115,7 @@ pub async fn setup(conf: Config, tripwire: Tripwire) -> eyre::Result<(Agent, Age let subs_manager = SubsManager::default(); + let updates_manager = UpdatesManager::default(); // Setup subscription handlers let subs_bcast_cache = setup_spawn_subscriptions( &subs_manager, @@ -118,6 +126,8 @@ pub async fn setup(conf: Config, tripwire: Tripwire) -> eyre::Result<(Agent, Age ) .await?; + let updates_bcast_cache = SharedUpdateBroadcastCache::default(); + let cluster_id = { let conn = pool.read().await?; conn.query_row( @@ -223,7 +233,7 @@ pub async fn setup(conf: Config, tripwire: Tripwire) -> eyre::Result<(Agent, Age let opts = AgentOptions { gossip_server_endpoint, - transport, + transport: transport.clone(), api_listeners, lock_registry, rx_bcast, @@ -235,9 +245,12 @@ pub async fn setup(conf: Config, tripwire: Tripwire) -> eyre::Result<(Agent, Age rtt_rx, subs_manager: subs_manager.clone(), subs_bcast_cache, + updates_bcast_cache, tripwire: tripwire.clone(), }; + let follow = conf.follow.clone(); + let agent = Agent::new(AgentConfig { actor_id, pool, @@ -258,9 +271,75 @@ pub async fn setup(conf: Config, tripwire: Tripwire) -> eyre::Result<(Agent, Age schema: RwLock::new(schema), cluster_id, subs_manager, + updates_manager, tripwire, }); + if let Some(follow) = follow { + let agent = agent.clone(); + tokio::spawn(async move { + let boff = Backoff::new(0) + .timeout_range(Duration::from_millis(100), Duration::from_secs(2)) + .iter(); + + let addr = follow.addr; + let (mut last_from, specific_from) = if let FollowFrom::DbVersion(from) = follow.from { + (Some(from), true) + } else { + (None, false) + }; + + for dur in boff { + let from = { + if let Some(from) = last_from.take() { + from + } else { + let conn = agent.pool().read().await.unwrap(); + conn.query_row("SELECT crsql_db_version()", [], |row| row.get(0)) + .unwrap() + } + }; + + info!("following from db_version = {from}"); + + match transport.open_bi(addr).await { + Ok((tx, rx)) => { + match api::peer::follow::follow( + &agent, + tx, + rx, + Some(from), + false, + follow.broadcast.as_ref(), + ) + .await + { + Ok(dbv) => { + info!("following terminated, last db version: {dbv:?}"); + last_from = dbv; + } + Err(e) => { + error!("could not follow to the end: {e}"); + if specific_from { + last_from = Some(from); + } + } + } + } + Err(e) => { + error!("could not open bidirectional stream to {addr}: {e}"); + } + } + + warn!("follow broken, retrying in {dur:?}"); + + tokio::time::sleep(dur).await + } + + info!("follow loop done"); + }); + } + Ok((agent, opts)) } diff --git a/crates/corro-agent/src/agent/util.rs b/crates/corro-agent/src/agent/util.rs index 89375065..3cf29f0c 100644 --- a/crates/corro-agent/src/agent/util.rs +++ b/crates/corro-agent/src/agent/util.rs @@ -10,6 +10,7 @@ use crate::{ api::public::{ api_v1_db_schema, api_v1_queries, api_v1_table_stats, api_v1_transactions, pubsub::{api_v1_sub_by_id, api_v1_subs}, + update::SharedUpdateBroadcastCache, }, transport::Transport, }; @@ -22,6 +23,7 @@ use corro_types::{ channel::CorroReceiver, config::AuthzConfig, pubsub::SubsManager, + updates::{match_changes, match_changes_from_db_version}, }; use std::{ cmp, @@ -33,6 +35,7 @@ use std::{ time::{Duration, Instant}, }; +use crate::api::public::update::api_v1_updates; use axum::{ error_handling::HandleErrorLayer, extract::DefaultBodyLimit, @@ -165,6 +168,7 @@ pub async fn setup_http_api_handler( agent: &Agent, tripwire: &Tripwire, subs_bcast_cache: BcastCache, + updates_bcast_cache: SharedUpdateBroadcastCache, subs_manager: &SubsManager, api_listeners: Vec, ) -> eyre::Result<()> { @@ -213,6 +217,20 @@ pub async fn setup_http_api_handler( .layer(ConcurrencyLimitLayer::new(128)), ), ) + .route( + "/v1/updates/:table", + post(api_v1_updates).route_layer( + tower::ServiceBuilder::new() + .layer(HandleErrorLayer::new(|_error: BoxError| async { + Ok::<_, Infallible>(( + StatusCode::SERVICE_UNAVAILABLE, + "max concurrency limit reached".to_string(), + )) + })) + .layer(LoadShedLayer::new()) + .layer(ConcurrencyLimitLayer::new(128)), + ), + ) .route( "/v1/subscriptions/:id", get(api_v1_sub_by_id).route_layer( @@ -261,6 +279,7 @@ pub async fn setup_http_api_handler( .layer(Extension(Arc::new(AtomicI64::new(0)))) .layer(Extension(agent.clone())) .layer(Extension(subs_bcast_cache)) + .layer(Extension(updates_bcast_cache)) .layer(Extension(subs_manager.clone())) .layer(Extension(tripwire.clone())), ) @@ -641,11 +660,16 @@ pub async fn process_fully_buffered_changes( if let Some(db_version) = db_version { let conn = agent.pool().read().await?; block_in_place(|| { - if let Err(e) = agent - .subs_manager() - .match_changes_from_db_version(&conn, db_version) + if let Err(e) = match_changes_from_db_version(agent.subs_manager(), &conn, db_version) { + error!(%db_version, "could not match changes for subs from db version: {e}"); + } + }); + + block_in_place(|| { + if let Err(e) = + match_changes_from_db_version(agent.updates_manager(), &conn, db_version) { - error!(%db_version, "could not match changes from db version: {e}"); + error!(%db_version, "could not match changes for updates from db version: {e}"); } }); } @@ -918,9 +942,8 @@ pub async fn process_multiple_changes( for (_actor_id, changeset, db_version, _src) in changesets { change_chunk_size += changeset.changes().len(); - agent - .subs_manager() - .match_changes(changeset.changes(), db_version); + match_changes(agent.subs_manager(), changeset.changes(), db_version); + match_changes(agent.updates_manager(), changeset.changes(), db_version); } histogram!("corro.agent.changes.processing.time.seconds").record(start.elapsed()); diff --git a/crates/corro-agent/src/api/peer/follow.rs b/crates/corro-agent/src/api/peer/follow.rs new file mode 100644 index 00000000..dc42c315 --- /dev/null +++ b/crates/corro-agent/src/api/peer/follow.rs @@ -0,0 +1,332 @@ +use std::{io, time::Duration}; + +use bytes::{BufMut, BytesMut}; +use corro_types::{ + actor::ActorId, + agent::Agent, + api::row_to_change, + base::{CrsqlDbVersion, CrsqlSeq, CrsqlSiteVersion}, + broadcast::{BiPayload, ChangeSource, ChangeV1, Changeset, Timestamp}, + change::ChunkedChanges, + config::FollowBroadcast, + sqlite::SqlitePoolError, +}; +use futures::{Stream, StreamExt}; +use metrics::counter; +use quinn::{RecvStream, SendStream}; +use rand::{rngs::OsRng, Rng}; +use rusqlite::{params_from_iter, Row, ToSql}; +use speedy::{Readable, Writable}; +use tokio::{sync::mpsc, task::block_in_place}; +use tokio_util::codec::{Encoder, FramedRead, LengthDelimitedCodec}; +use tracing::{debug, error, trace}; + +use super::{encode_write_bipayload_msg, BiPayloadSendError}; + +#[derive(Debug, Clone, PartialEq, Readable, Writable)] +pub enum FollowMessage { + V1(FollowMessageV1), +} + +impl FollowMessage { + pub fn from_slice>(slice: S) -> Result { + Self::read_from_buffer(slice.as_ref()) + } + + pub fn from_buf(buf: &mut BytesMut) -> Result { + Ok(Self::from_slice(buf)?) + } +} + +#[derive(Debug, Clone, PartialEq, Readable, Writable)] +pub enum FollowMessageV1 { + Change(ChangeV1), +} + +#[derive(Debug, thiserror::Error)] +pub enum FollowError { + #[error(transparent)] + SqlitePool(#[from] SqlitePoolError), + #[error(transparent)] + Encode(#[from] FollowMessageEncodeError), + #[error(transparent)] + Decode(#[from] FollowMessageDecodeError), + #[error(transparent)] + Io(#[from] std::io::Error), + #[error(transparent)] + Write(#[from] quinn::WriteError), + #[error(transparent)] + Rusqlite(#[from] rusqlite::Error), + #[error("follow send channel is closed")] + ChannelClosed, + #[error(transparent)] + BiPayloadSend(#[from] BiPayloadSendError), +} + +#[derive(Debug, thiserror::Error)] +pub enum FollowMessageEncodeError { + #[error(transparent)] + Encode(#[from] speedy::Error), + #[error(transparent)] + Io(#[from] io::Error), +} + +#[derive(Debug, thiserror::Error)] +pub enum FollowMessageDecodeError { + #[error(transparent)] + Decode(#[from] speedy::Error), + #[error(transparent)] + Io(#[from] io::Error), +} + +async fn encode_write_follow_msg( + codec: &mut LengthDelimitedCodec, + encode_buf: &mut BytesMut, + send_buf: &mut BytesMut, + msg: FollowMessage, + write: &mut SendStream, +) -> Result<(), FollowError> { + encode_follow_msg(codec, encode_buf, send_buf, msg)?; + + write_buf(send_buf, write).await +} + +fn encode_follow_msg( + codec: &mut LengthDelimitedCodec, + encode_buf: &mut BytesMut, + send_buf: &mut BytesMut, + msg: FollowMessage, +) -> Result<(), FollowError> { + msg.write_to_stream(encode_buf.writer()) + .map_err(FollowMessageEncodeError::from)?; + + let data = encode_buf.split().freeze(); + trace!("encoded sync message, len: {}", data.len()); + codec.encode(data, send_buf)?; + Ok(()) +} + +async fn write_buf(send_buf: &mut BytesMut, write: &mut SendStream) -> Result<(), FollowError> { + let len = send_buf.len(); + write.write_chunk(send_buf.split().freeze()).await?; + counter!("corro.follow.chunk.sent.bytes").increment(len as u64); + + Ok(()) +} + +pub async fn serve_follow( + agent: &Agent, + from: Option, + local_only: bool, + mut write: SendStream, +) -> Result<(), FollowError> { + let mut last_db_version = { + if let Some(db_version) = from { + db_version + } else { + let conn = agent.pool().read().await?; + conn.query_row("SELECT crsql_db_version()", [], |row| row.get(0))? + } + }; + + // channel provides backpressure + let (tx, mut rx) = mpsc::channel(128); + + tokio::spawn(async move { + let mut codec = LengthDelimitedCodec::builder() + .max_frame_length(100 * 1_024 * 1_024) + .new_codec(); + let mut send_buf = BytesMut::new(); + let mut encode_buf = BytesMut::new(); + + while let Some(msg) = rx.recv().await { + encode_write_follow_msg(&mut codec, &mut encode_buf, &mut send_buf, msg, &mut write) + .await?; + } + + Ok::<_, FollowError>(()) + }); + + let actor_id = agent.actor_id(); + + loop { + let conn = agent.pool().read().await?; + + block_in_place(|| { + let (extra_where_clause, query_params): (_, Vec<&dyn ToSql>) = if local_only { + ("AND actor_id = ?", vec![&last_db_version, &actor_id]) + } else { + ("", vec![&last_db_version]) + }; + + let mut bk_prepped = conn.prepare_cached(&format!("SELECT actor_id, start_version, db_version, last_seq, ts FROM __corro_bookkeeping WHERE db_version IS NOT NULL AND db_version > ? {extra_where_clause} ORDER BY db_version ASC"))?; + + let map = |row: &Row| { + Ok(( + row.get(0)?, + row.get(1)?, + row.get(2)?, + row.get(3)?, + row.get(4)?, + )) + }; + + // implicit read transaction + let bk_rows = bk_prepped.query_map(params_from_iter(query_params), map)?; + + for bk_res in bk_rows { + let (actor_id, version, db_version, last_seq, ts): ( + ActorId, + CrsqlSiteVersion, + CrsqlDbVersion, + CrsqlSeq, + Timestamp, + ) = bk_res?; + + debug!("sending changes for: {actor_id} v{version} (db_version: {db_version})"); + + let mut prepped = conn.prepare_cached( + "SELECT \"table\", pk, cid, val, col_version, db_version, seq, site_id, cl FROM crsql_changes WHERE db_version = ? ORDER BY db_version ASC, seq ASC", + )?; + // implicit read transaction + let rows = prepped.query_map([db_version], row_to_change)?; + + let chunked = ChunkedChanges::new(rows, CrsqlSeq(0), last_seq, 8192); + + for changes_seqs in chunked { + let (changes, seqs) = changes_seqs?; + tx.blocking_send(FollowMessage::V1(FollowMessageV1::Change(ChangeV1 { + actor_id, + changeset: Changeset::Full { + version, + changes, + seqs, + last_seq, + ts, + }, + }))) + .map_err(|_| FollowError::ChannelClosed)?; + } + + last_db_version = db_version; // record last db version processed for next go around + } + + Ok::<_, FollowError>(()) + })?; + + // prevents hot-looping + tokio::time::sleep(Duration::from_secs(1)).await; + } +} + +pub async fn read_follow_msg> + Unpin>( + read: &mut R, +) -> Result, FollowError> { + match read.next().await { + Some(buf_res) => match buf_res { + Ok(mut buf) => { + counter!("corro.follow.chunk.recv.bytes").increment(buf.len() as u64); + match FollowMessage::from_buf(&mut buf) { + Ok(msg) => Ok(Some(msg)), + Err(e) => Err(FollowError::from(e)), + } + } + Err(e) => Err(FollowError::from(e)), + }, + None => Ok(None), + } +} + +pub async fn recv_follow( + agent: &Agent, + mut read: FramedRead, + local_only: bool, + broadcast: Option<&FollowBroadcast>, +) -> Result, FollowError> { + let mut last_db_version = None; + let tx_changes = agent.tx_changes(); + loop { + match read_follow_msg(&mut read).await { + Ok(None) => break, + Err(e) => { + error!("could not receive follow message: {e}"); + break; + } + Ok(Some(msg)) => match msg { + FollowMessage::V1(FollowMessageV1::Change(changeset)) => { + let db_version = changeset.changes().first().map(|change| change.db_version); + debug!( + "received changeset for version(s) {:?} and db_version {db_version:?}", + changeset.versions() + ); + let change_src = if local_only + || broadcast + .map(|bcast| should_broadcast(&changeset.actor_id, bcast)) + .unwrap_or(false) + { + ChangeSource::Broadcast + } else { + ChangeSource::Follow + }; + tx_changes + .send((changeset, change_src)) + .await + .map_err(|_| FollowError::ChannelClosed)?; + if let Some(db_version) = db_version { + last_db_version = Some(db_version); + } + } + }, + } + } + + Ok(last_db_version) +} + +fn should_broadcast(actor_id: &ActorId, broadcast: &FollowBroadcast) -> bool { + match broadcast { + FollowBroadcast::ActorIds(set) => set.contains(actor_id), + FollowBroadcast::Percent(percent) => OsRng.gen_range(0..100) < *percent, + } +} + +pub async fn follow( + agent: &Agent, + mut tx: SendStream, + recv: RecvStream, + from: Option, + local_only: bool, + broadcast: Option<&FollowBroadcast>, +) -> Result, FollowError> { + let mut codec = LengthDelimitedCodec::builder() + .max_frame_length(100 * 1_024 * 1_024) + .new_codec(); + let mut encoding_buf = BytesMut::new(); + let mut buf = BytesMut::new(); + + encode_write_bipayload_msg( + &mut codec, + &mut encoding_buf, + &mut buf, + BiPayload::V1 { + data: corro_types::broadcast::BiPayloadV1::Follow { from, local_only }, + cluster_id: agent.cluster_id(), + }, + &mut tx, + ) + .await?; + + let framed = FramedRead::new( + recv, + LengthDelimitedCodec::builder() + .max_frame_length(100 * 1_024 * 1_024) + .new_codec(), + ); + + recv_follow(agent, framed, local_only, broadcast).await +} + +#[cfg(test)] +mod tests { + // use super::*; +} diff --git a/crates/corro-agent/src/api/peer.rs b/crates/corro-agent/src/api/peer/mod.rs similarity index 99% rename from crates/corro-agent/src/api/peer.rs rename to crates/corro-agent/src/api/peer/mod.rs index 8196fdfb..472ee788 100644 --- a/crates/corro-agent/src/api/peer.rs +++ b/crates/corro-agent/src/api/peer/mod.rs @@ -43,6 +43,8 @@ use crate::transport::{Transport, TransportError}; use corro_types::{actor::ActorId, agent::Bookie}; +pub mod follow; + #[derive(Debug, thiserror::Error)] pub enum SyncError { #[error(transparent)] @@ -913,7 +915,7 @@ fn encode_sync_msg( Ok(()) } -async fn encode_write_bipayload_msg( +pub async fn encode_write_bipayload_msg( codec: &mut LengthDelimitedCodec, encode_buf: &mut BytesMut, send_buf: &mut BytesMut, @@ -933,7 +935,8 @@ fn encode_bipayload_msg( send_buf: &mut BytesMut, msg: BiPayload, ) -> Result<(), BiPayloadEncodeError> { - msg.write_to_stream(encode_buf.writer())?; + msg.write_to_stream(encode_buf.writer()) + .map_err(BiPayloadEncodeError::from)?; codec.encode(encode_buf.split().freeze(), send_buf)?; Ok(()) diff --git a/crates/corro-agent/src/api/public/mod.rs b/crates/corro-agent/src/api/public/mod.rs index 9b40a604..32ee9f24 100644 --- a/crates/corro-agent/src/api/public/mod.rs +++ b/crates/corro-agent/src/api/public/mod.rs @@ -33,6 +33,8 @@ use corro_types::broadcast::broadcast_changes; pub mod pubsub; +pub mod update; + pub async fn make_broadcastable_changes( agent: &Agent, f: F, diff --git a/crates/corro-agent/src/api/public/pubsub.rs b/crates/corro-agent/src/api/public/pubsub.rs index 5c30038b..d051cd7b 100644 --- a/crates/corro-agent/src/api/public/pubsub.rs +++ b/crates/corro-agent/src/api/public/pubsub.rs @@ -3,6 +3,7 @@ use std::{collections::HashMap, io::Write, sync::Arc, time::Duration}; use axum::{http::StatusCode, response::IntoResponse, Extension}; use bytes::{BufMut, Bytes, BytesMut}; use compact_str::{format_compact, ToCompactString}; +use corro_types::updates::Handle; use corro_types::{ agent::Agent, api::{ChangeId, QueryEvent, QueryEventMeta, Statement}, @@ -74,7 +75,9 @@ async fn sub_by_id( bcast_cache_write.remove(&id); if let Some(handle) = subs.remove(&id) { info!(sub_id = %id, "Removed subscription from sub_by_id"); - tokio::spawn(handle.cleanup()); + tokio::spawn(async move { + handle.cleanup().await; + }); } return hyper::Response::builder() @@ -879,6 +882,7 @@ async fn forward_bytes_to_body_sender( #[cfg(test)] mod tests { use corro_types::actor::ActorId; + use corro_types::api::NotifyEvent; use corro_types::api::{Change, ColumnName, TableName}; use corro_types::base::{CrsqlDbVersion, CrsqlSeq, CrsqlSiteVersion}; use corro_types::broadcast::{ChangeSource, ChangeV1, Changeset}; @@ -889,6 +893,7 @@ mod tests { pubsub::ChangeType, }; use http_body::Body; + use serde::de::DeserializeOwned; use spawn::wait_for_all_pending_handles; use std::ops::RangeInclusive; use std::time::Instant; @@ -898,6 +903,7 @@ mod tests { use super::*; use crate::agent::process_multiple_changes; + use crate::api::public::update::{api_v1_updates, SharedUpdateBroadcastCache}; use crate::{ agent::setup, api::public::{api_v1_db_schema, api_v1_transactions}, @@ -951,6 +957,7 @@ mod tests { assert!(body.0.results.len() == 2); let bcast_cache: SharedMatcherBroadcastCache = Default::default(); + let update_bcast_cache: SharedUpdateBroadcastCache = Default::default(); { let mut res = api_v1_subs( @@ -970,6 +977,23 @@ mod tests { assert_eq!(res.status(), StatusCode::OK); + // only want notifications + let mut notify_res = api_v1_updates( + Extension(agent.clone()), + Extension(update_bcast_cache.clone()), + Extension(tripwire.clone()), + axum::extract::Path("tests".to_string()), + ) + .await + .into_response(); + + if !notify_res.status().is_success() { + let b = notify_res.body_mut().data().await.unwrap().unwrap(); + println!("body: {}", String::from_utf8_lossy(&b)); + } + + assert_eq!(notify_res.status(), StatusCode::OK); + let (status_code, _) = api_v1_transactions( Extension(agent.clone()), axum::Json(vec![Statement::WithParams( @@ -988,18 +1012,25 @@ mod tests { done: false, }; + let mut notify_rows = RowsIter { + body: notify_res.into_body(), + codec: LinesCodec::new(), + buf: BytesMut::new(), + done: false, + }; + assert_eq!( - rows.recv().await.unwrap().unwrap(), + rows.recv::().await.unwrap().unwrap(), QueryEvent::Columns(vec!["id".into(), "text".into()]) ); assert_eq!( - rows.recv().await.unwrap().unwrap(), + rows.recv::().await.unwrap().unwrap(), QueryEvent::Row(RowId(1), vec!["service-id".into(), "service-name".into()]) ); assert_eq!( - rows.recv().await.unwrap().unwrap(), + rows.recv::().await.unwrap().unwrap(), QueryEvent::Row( RowId(2), vec!["service-id-2".into(), "service-name-2".into()] @@ -1007,12 +1038,12 @@ mod tests { ); assert!(matches!( - rows.recv().await.unwrap().unwrap(), + rows.recv::().await.unwrap().unwrap(), QueryEvent::EndOfQuery { .. } )); assert_eq!( - rows.recv().await.unwrap().unwrap(), + rows.recv::().await.unwrap().unwrap(), QueryEvent::Change( ChangeType::Insert, RowId(3), @@ -1033,7 +1064,7 @@ mod tests { assert_eq!(status_code, StatusCode::OK); assert_eq!( - rows.recv().await.unwrap().unwrap(), + rows.recv::().await.unwrap().unwrap(), QueryEvent::Change( ChangeType::Insert, RowId(4), @@ -1042,6 +1073,16 @@ mod tests { ) ); + assert_eq!( + notify_rows.recv::().await.unwrap().unwrap(), + NotifyEvent::Notify(ChangeType::Update, vec!["service-id-3".into()],) + ); + + assert_eq!( + notify_rows.recv::().await.unwrap().unwrap(), + NotifyEvent::Notify(ChangeType::Update, vec!["service-id-4".into()],) + ); + let mut res = api_v1_subs( Extension(agent.clone()), Extension(bcast_cache.clone()), @@ -1070,7 +1111,7 @@ mod tests { }; assert_eq!( - rows_from.recv().await.unwrap().unwrap(), + rows_from.recv::().await.unwrap().unwrap(), QueryEvent::Change( ChangeType::Insert, RowId(4), @@ -1079,6 +1120,30 @@ mod tests { ) ); + // new subscriber for updates + let mut notify_res2 = api_v1_updates( + Extension(agent.clone()), + Extension(update_bcast_cache.clone()), + Extension(tripwire.clone()), + axum::extract::Path("tests".to_string()), + ) + .await + .into_response(); + + if !notify_res2.status().is_success() { + let b = notify_res2.body_mut().data().await.unwrap().unwrap(); + println!("body: {}", String::from_utf8_lossy(&b)); + } + + assert_eq!(notify_res2.status(), StatusCode::OK); + + let mut notify_rows2 = RowsIter { + body: notify_res2.into_body(), + codec: LinesCodec::new(), + buf: BytesMut::new(), + done: false, + }; + let (status_code, _) = api_v1_transactions( Extension(agent.clone()), axum::Json(vec![Statement::WithParams( @@ -1097,9 +1162,24 @@ mod tests { ChangeId(3), ); - assert_eq!(rows.recv().await.unwrap().unwrap(), query_evt); + assert_eq!(rows.recv::().await.unwrap().unwrap(), query_evt); + + assert_eq!( + rows_from.recv::().await.unwrap().unwrap(), + query_evt + ); + + let notify_evt = NotifyEvent::Notify(ChangeType::Update, vec!["service-id-5".into()]); - assert_eq!(rows_from.recv().await.unwrap().unwrap(), query_evt); + assert_eq!( + notify_rows.recv::().await.unwrap().unwrap(), + notify_evt + ); + + assert_eq!( + notify_rows2.recv::().await.unwrap().unwrap(), + notify_evt + ); // subscriber who arrives later! @@ -1128,17 +1208,17 @@ mod tests { }; assert_eq!( - rows.recv().await.unwrap().unwrap(), + rows.recv::().await.unwrap().unwrap(), QueryEvent::Columns(vec!["id".into(), "text".into()]) ); assert_eq!( - rows.recv().await.unwrap().unwrap(), + rows.recv::().await.unwrap().unwrap(), QueryEvent::Row(RowId(1), vec!["service-id".into(), "service-name".into()]) ); assert_eq!( - rows.recv().await.unwrap().unwrap(), + rows.recv::().await.unwrap().unwrap(), QueryEvent::Row( RowId(2), vec!["service-id-2".into(), "service-name-2".into()] @@ -1146,7 +1226,7 @@ mod tests { ); assert_eq!( - rows.recv().await.unwrap().unwrap(), + rows.recv::().await.unwrap().unwrap(), QueryEvent::Row( RowId(3), vec!["service-id-3".into(), "service-name-3".into()], @@ -1154,7 +1234,7 @@ mod tests { ); assert_eq!( - rows.recv().await.unwrap().unwrap(), + rows.recv::().await.unwrap().unwrap(), QueryEvent::Row( RowId(4), vec!["service-id-4".into(), "service-name-4".into()], @@ -1162,12 +1242,44 @@ mod tests { ); assert_eq!( - rows.recv().await.unwrap().unwrap(), + rows.recv::().await.unwrap().unwrap(), QueryEvent::Row( RowId(5), vec!["service-id-5".into(), "service-name-5".into()] ) ); + + let (status_code, _) = api_v1_transactions( + Extension(agent.clone()), + axum::Json(vec![Statement::WithParams( + "insert into tests (id, text) values (?,?)".into(), + vec!["service-id-6".into(), "service-name-6".into()], + )]), + ) + .await; + + assert_eq!(status_code, StatusCode::OK); + + let (status_code, _) = api_v1_transactions( + Extension(agent.clone()), + axum::Json(vec![Statement::WithParams( + "delete from tests where id = ?".into(), + vec!["service-id-6".into()], + )]), + ) + .await; + + assert_eq!(status_code, StatusCode::OK); + + assert_eq!( + notify_rows.recv::().await.unwrap().unwrap(), + NotifyEvent::Notify(ChangeType::Update, vec!["service-id-6".into()],) + ); + + assert_eq!( + notify_rows.recv::().await.unwrap().unwrap(), + NotifyEvent::Notify(ChangeType::Delete, vec!["service-id-6".into()],) + ); } // previous subs have been dropped. @@ -1200,7 +1312,7 @@ mod tests { }; assert_eq!( - rows_from.recv().await.unwrap().unwrap(), + rows_from.recv::().await.unwrap().unwrap(), QueryEvent::Change( ChangeType::Insert, RowId(4), @@ -1210,7 +1322,7 @@ mod tests { ); assert_eq!( - rows_from.recv().await.unwrap().unwrap(), + rows_from.recv::().await.unwrap().unwrap(), QueryEvent::Change( ChangeType::Insert, RowId(5), @@ -1260,7 +1372,7 @@ mod tests { assert_eq!(status_code, StatusCode::OK); assert_eq!( - rows_from.recv().await.unwrap().unwrap(), + rows_from.recv::().await.unwrap().unwrap(), QueryEvent::Change( ChangeType::Insert, RowId(6), @@ -1299,7 +1411,7 @@ mod tests { }; assert_eq!( - rows_from.recv().await.unwrap().unwrap(), + rows_from.recv::().await.unwrap().unwrap(), QueryEvent::Change( ChangeType::Insert, RowId(6), @@ -1314,6 +1426,7 @@ mod tests { #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn match_buffered_changes() -> eyre::Result<()> { _ = tracing_subscriber::fmt::try_init(); + let (tripwire, tripwire_worker, tripwire_tx) = Tripwire::new_simple(); let ta1 = launch_test_agent(|conf| conf.build(), tripwire.clone()).await?; @@ -1376,6 +1489,7 @@ mod tests { .await?; let bcast_cache: SharedMatcherBroadcastCache = Default::default(); + let update_bcast_cache: SharedUpdateBroadcastCache = Default::default(); let mut res = api_v1_subs( Extension(ta1.agent.clone()), Extension(bcast_cache.clone()), @@ -1393,6 +1507,30 @@ mod tests { assert_eq!(res.status(), StatusCode::OK); + // only notifications + let mut notify_res = api_v1_updates( + Extension(ta1.agent.clone()), + Extension(update_bcast_cache.clone()), + Extension(tripwire.clone()), + axum::extract::Path("buftests".to_string()), + ) + .await + .into_response(); + + if !notify_res.status().is_success() { + let b = notify_res.body_mut().data().await.unwrap().unwrap(); + println!("body: {}", String::from_utf8_lossy(&b)); + } + + assert_eq!(notify_res.status(), StatusCode::OK); + + let mut notify_rows = RowsIter { + body: notify_res.into_body(), + codec: LinesCodec::new(), + buf: BytesMut::new(), + done: false, + }; + let mut rows = RowsIter { body: res.into_body(), codec: LinesCodec::new(), @@ -1401,17 +1539,17 @@ mod tests { }; assert_eq!( - rows.recv().await.unwrap().unwrap(), + rows.recv::().await.unwrap().unwrap(), QueryEvent::Columns(vec!["pk".into(), "col1".into(), "col2".into()]) ); assert_eq!( - rows.recv().await.unwrap().unwrap(), + rows.recv::().await.unwrap().unwrap(), QueryEvent::Row(RowId(1), vec![Integer(1), "one".into(), "one line".into()]) ); assert!(matches!( - rows.recv().await.unwrap().unwrap(), + rows.recv::().await.unwrap().unwrap(), QueryEvent::EndOfQuery { .. } )); @@ -1487,7 +1625,7 @@ mod tests { ) .await?; - let res = timeout(Duration::from_secs(5), rows.recv()).await?; + let res = timeout(Duration::from_secs(5), rows.recv::()).await?; assert_eq!( res.unwrap().unwrap(), @@ -1499,6 +1637,12 @@ mod tests { ) ); + let notify_res = timeout(Duration::from_secs(5), notify_rows.recv::()).await?; + assert_eq!( + notify_res.unwrap().unwrap(), + NotifyEvent::Notify(ChangeType::Update, vec![Integer(2)],) + ); + tripwire_tx.send(()).await.ok(); tripwire_worker.await; wait_for_all_pending_handles().await; @@ -1514,7 +1658,7 @@ mod tests { } impl RowsIter { - async fn recv(&mut self) -> Option> { + async fn recv(&mut self) -> Option> { if self.done { return None; } diff --git a/crates/corro-agent/src/api/public/update.rs b/crates/corro-agent/src/api/public/update.rs new file mode 100644 index 00000000..8f78fc10 --- /dev/null +++ b/crates/corro-agent/src/api/public/update.rs @@ -0,0 +1,247 @@ +use std::{collections::HashMap, io::Write, sync::Arc, time::Duration}; + +use axum::{http::StatusCode, response::IntoResponse, Extension}; +use bytes::{BufMut, Bytes, BytesMut}; +use compact_str::ToCompactString; +use corro_types::{ + agent::Agent, + api::NotifyEvent, + updates::{Handle, UpdateCreated, UpdateHandle, UpdatesManager}, +}; +use futures::future::poll_fn; +use tokio::sync::{ + broadcast::{self, error::RecvError}, + mpsc, RwLock as TokioRwLock, +}; +use tracing::{debug, info, warn}; +use tripwire::Tripwire; +use uuid::Uuid; + +use crate::api::public::pubsub::MatcherUpsertError; + +pub type UpdateBroadcastCache = HashMap>; +pub type SharedUpdateBroadcastCache = Arc>; + +// this should be a fraction of the MAX_UNSUB_TIME +const RECEIVERS_CHECK_INTERVAL: Duration = Duration::from_secs(30); + +pub async fn api_v1_updates( + Extension(agent): Extension, + Extension(bcast_cache): Extension, + Extension(tripwire): Extension, + axum::extract::Path(table): axum::extract::Path, +) -> impl IntoResponse { + info!("Received update request for table: {table}"); + + let mut bcast_write = bcast_cache.write().await; + let updates = agent.updates_manager(); + + let upsert_res = updates.get_or_insert( + &table, + &agent.schema().read(), + agent.pool(), + tripwire.clone(), + ); + + let (handle, maybe_created) = match upsert_res { + Ok(res) => res, + Err(e) => return hyper::Response::::from(MatcherUpsertError::from(e)), + }; + + let (tx, body) = hyper::Body::channel(); + // let (forward_tx, forward_rx) = mpsc::channel(10240); + + let (update_id, sub_rx) = + match upsert_update(handle.clone(), maybe_created, updates, &mut bcast_write).await { + Ok(id) => id, + Err(e) => return hyper::Response::::from(e), + }; + + tokio::spawn(forward_update_bytes_to_body_sender( + handle, sub_rx, tx, tripwire, + )); + + hyper::Response::builder() + .status(StatusCode::OK) + .header("corro-query-id", update_id.to_string()) + .body(body) + .expect("could not generate ok http response for update request") +} + +pub async fn upsert_update( + handle: UpdateHandle, + maybe_created: Option, + updates: &UpdatesManager, + bcast_write: &mut UpdateBroadcastCache, +) -> Result<(Uuid, broadcast::Receiver), MatcherUpsertError> { + let sub_rx = if let Some(created) = maybe_created { + let (sub_tx, sub_rx) = broadcast::channel(10240); + bcast_write.insert(handle.id(), sub_tx.clone()); + tokio::spawn(process_update_channel( + updates.clone(), + handle.id(), + sub_tx, + created.evt_rx, + )); + + sub_rx + } else { + let id = handle.id(); + let sub_tx = bcast_write + .get(&id) + .cloned() + .ok_or(MatcherUpsertError::MissingBroadcaster)?; + debug!("found update handle"); + + sub_tx.subscribe() + }; + + Ok((handle.id(), sub_rx)) +} + +pub async fn process_update_channel( + updates: UpdatesManager, + id: Uuid, + tx: broadcast::Sender, + mut evt_rx: mpsc::Receiver, +) { + let mut buf = BytesMut::new(); + + // interval check for receivers + // useful for queries that don't change often so we can cleanup... + let mut subs_check = tokio::time::interval(RECEIVERS_CHECK_INTERVAL); + + loop { + tokio::select! { + biased; + Some(query_evt) = evt_rx.recv() => { + match make_query_event_bytes(&mut buf, &query_evt) { + Ok(b) => { + if tx.send(b).is_err() { + break; + } + }, + Err(e) => { + match make_query_event_bytes(&mut buf, &NotifyEvent::Error(e.to_compact_string())) { + Ok(b) => { + let _ = tx.send(b); + } + Err(e) => { + warn!(update_id = %id, "failed to send error in update channel: {e}"); + } + } + break; + } + }; + }, + _ = subs_check.tick() => { + if tx.receiver_count() == 0 { + break; + }; + }, + }; + } + + warn!(sub_id = %id, "updates channel done"); + + // remove and get handle from the agent's "matchers" + let handle = match updates.remove(&id) { + Some(h) => { + info!(update_id = %id, "Removed update handle from process_update_channel"); + h + } + None => { + warn!(update_id = %id, "update handle was already gone. odd!"); + return; + } + }; + + // clean up the subscription + handle.cleanup().await; +} + +fn make_query_event_bytes( + buf: &mut BytesMut, + query_evt: &NotifyEvent, +) -> serde_json::Result { + { + let mut writer = buf.writer(); + serde_json::to_writer(&mut writer, query_evt)?; + + // NOTE: I think that's infaillible... + writer + .write_all(b"\n") + .expect("could not write new line to BytesMut Writer"); + } + + Ok(buf.split().freeze()) +} + +async fn forward_update_bytes_to_body_sender( + update: UpdateHandle, + mut rx: broadcast::Receiver, + mut tx: hyper::body::Sender, + mut tripwire: Tripwire, +) { + let mut buf = BytesMut::new(); + + let send_deadline = tokio::time::sleep(Duration::from_millis(10)); + tokio::pin!(send_deadline); + + loop { + tokio::select! { + biased; + res = rx.recv() => { + match res { + Ok(event_buf) => { + buf.extend_from_slice(&event_buf); + if buf.len() >= 64 * 1024 { + if let Err(e) = tx.send_data(buf.split().freeze()).await { + warn!(update_id = %update.id(), "could not forward update query event to receiver: {e}"); + return; + } + }; + }, + Err(RecvError::Lagged(skipped)) => { + warn!(update_id = %update.id(), "update skipped {} events, aborting", skipped); + return; + }, + Err(RecvError::Closed) => { + info!(update_id = %update.id(), "events subcription ran out"); + return; + }, + } + }, + _ = &mut send_deadline => { + if !buf.is_empty() { + if let Err(e) = tx.send_data(buf.split().freeze()).await { + warn!(update_id = %update.id(), "could not forward subscription query event to receiver: {e}"); + return; + } + } else { + if let Err(e) = poll_fn(|cx| tx.poll_ready(cx)).await { + warn!(update_id = %update.id(), error = %e, "body sender was closed or errored, stopping event broadcast sends"); + return; + } + send_deadline.as_mut().reset(tokio::time::Instant::now() + Duration::from_millis(10)); + continue; + } + }, + _ = update.cancelled() => { + info!(update_id = %update.id(), "update cancelled, aborting forwarding bytes to subscriber"); + return; + }, + _ = &mut tripwire => { + break; + } + } + } + + while let Ok(event_buf) = rx.try_recv() { + buf.extend_from_slice(&event_buf); + if let Err(e) = tx.send_data(buf.split().freeze()).await { + warn!(update_id = %update.id(), "could not forward subscription query event to receiver: {e}"); + return; + } + } +} diff --git a/crates/corro-api-types/src/lib.rs b/crates/corro-api-types/src/lib.rs index bf54bd07..ce98338d 100644 --- a/crates/corro-api-types/src/lib.rs +++ b/crates/corro-api-types/src/lib.rs @@ -56,6 +56,16 @@ pub enum QueryEventMeta { EndOfQuery(Option), Change(ChangeId), Error, + Notify, +} + +pub type NotifyEvent = TypedNotifyEvent>; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "snake_case")] +pub enum TypedNotifyEvent { + Notify(ChangeType, T), + Error(CompactString), } /// RowId newtype to differentiate from ChangeId diff --git a/crates/corro-client/src/lib.rs b/crates/corro-client/src/lib.rs index b5019320..9c4e17cd 100644 --- a/crates/corro-client/src/lib.rs +++ b/crates/corro-client/src/lib.rs @@ -11,7 +11,7 @@ use std::{ sync::Arc, time::{self, Duration, Instant}, }; -use sub::{QueryStream, SubscriptionStream}; +use sub::{QueryStream, SubscriptionStream, UpdatesStream}; use tokio::{ sync::{RwLock, RwLockReadGuard}, time::timeout, @@ -218,6 +218,45 @@ impl CorrosionApiClient { self.subscription_typed(id, skip_rows, from).await } + pub async fn updates_typed( + &self, + table: &str, + ) -> Result, Error> { + let p_and_q: PathAndQuery = format!("/v1/updates/{}", table).try_into()?; + + let url = hyper::Uri::builder() + .scheme("http") + .authority(self.api_addr.to_string()) + .path_and_query(p_and_q) + .build()?; + + let req = hyper::Request::builder() + .method(hyper::Method::POST) + .uri(url) + .header(hyper::header::CONTENT_TYPE, "application/json") + .header(hyper::header::ACCEPT, "application/json") + .body(hyper::Body::empty())?; + + let res = self.api_client.request(req).await?; + + if !res.status().is_success() { + return Err(Error::UnexpectedStatusCode(res.status())); + } + + // TODO: make that header name a const in corro-types + let id = res + .headers() + .get(HeaderName::from_static("corro-query-id")) + .and_then(|v| v.to_str().ok().and_then(|v| v.parse().ok())) + .ok_or(Error::ExpectedQueryId)?; + + Ok(UpdatesStream::new(id, res.into_body())) + } + + pub async fn updates(&self, table: &str) -> Result>, Error> { + self.updates_typed(table).await + } + pub async fn execute(&self, statements: &[Statement]) -> Result { let req = hyper::Request::builder() .method(hyper::Method::POST) diff --git a/crates/corro-client/src/sub.rs b/crates/corro-client/src/sub.rs index 6d21a475..264a62dc 100644 --- a/crates/corro-client/src/sub.rs +++ b/crates/corro-client/src/sub.rs @@ -8,7 +8,7 @@ use std::{ }; use bytes::{Buf, Bytes, BytesMut}; -use corro_api_types::{ChangeId, QueryEvent, TypedQueryEvent}; +use corro_api_types::{ChangeId, QueryEvent, TypedNotifyEvent, TypedQueryEvent}; use futures::{ready, Future, Stream}; use hyper::{client::HttpConnector, Body}; use pin_project_lite::pin_project; @@ -307,6 +307,66 @@ where } } +pub struct UpdatesStream { + id: Uuid, + stream: FramedBody, + _deser: std::marker::PhantomData, +} + +#[derive(Debug, thiserror::Error)] +pub enum UpdatesError { + #[error(transparent)] + Io(#[from] io::Error), + #[error(transparent)] + Deserialize(#[from] serde_json::Error), + #[error("max line length exceeded")] + MaxLineLengthExceeded, +} + +impl UpdatesStream +where + T: DeserializeOwned + Unpin, +{ + pub fn new(id: Uuid, body: hyper::Body) -> Self { + Self { + id, + stream: FramedRead::new( + StreamReader::new(IoBodyStream { body }), + LinesBytesCodec::default(), + ), + _deser: Default::default(), + } + } + + pub fn id(&self) -> Uuid { + self.id + } +} + +impl Stream for UpdatesStream +where + T: DeserializeOwned + Unpin, +{ + type Item = Result, UpdatesError>; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let res = ready!(Pin::new(&mut self.stream).poll_next(cx)); + match res { + Some(Ok(b)) => match serde_json::from_slice(&b) { + Ok(evt) => Poll::Ready(Some(Ok(evt))), + Err(e) => Poll::Ready(Some(Err(e.into()))), + }, + Some(Err(e)) => match e { + LinesCodecError::MaxLineLengthExceeded => { + Poll::Ready(Some(Err(UpdatesError::MaxLineLengthExceeded))) + } + LinesCodecError::Io(io_err) => Poll::Ready(Some(Err(io_err.into()))), + }, + None => Poll::Ready(None), + } + } +} + pub struct QueryStream { stream: FramedBody, _deser: std::marker::PhantomData, diff --git a/crates/corro-tests/src/lib.rs b/crates/corro-tests/src/lib.rs index e77d8617..2dc8e800 100644 --- a/crates/corro-tests/src/lib.rs +++ b/crates/corro-tests/src/lib.rs @@ -79,7 +79,7 @@ pub async fn launch_test_agent Result, limits: Limits, subs_manager: SubsManager, + updates_manager: UpdatesManager, } #[derive(Debug, Clone)] @@ -136,6 +140,7 @@ impl Agent { sync: Arc::new(Semaphore::new(3)), }, subs_manager: config.subs_manager, + updates_manager: config.updates_manager, })) } @@ -241,6 +246,10 @@ impl Agent { &self.0.subs_manager } + pub fn updates_manager(&self) -> &UpdatesManager { + &self.0.updates_manager + } + pub fn set_cluster_id(&self, cluster_id: ClusterId) { self.0.cluster_id.store(Arc::new(cluster_id)); } diff --git a/crates/corro-types/src/broadcast.rs b/crates/corro-types/src/broadcast.rs index fcabf970..86749885 100644 --- a/crates/corro-types/src/broadcast.rs +++ b/crates/corro-types/src/broadcast.rs @@ -33,6 +33,7 @@ use crate::{ channel::CorroSender, sqlite::SqlitePoolError, sync::SyncTraceContextV1, + updates::match_changes, }; #[derive(Debug, Clone, Readable, Writable)] @@ -75,6 +76,10 @@ pub enum BiPayloadV1 { #[speedy(default_on_eof)] trace_ctx: SyncTraceContextV1, }, + Follow { + from: Option, + local_only: bool, + }, } #[derive(Debug)] @@ -147,6 +152,7 @@ pub struct ChangesetPerTablePk(IndexMap>); pub enum ChangeSource { Broadcast, Sync, + Follow, } #[derive(Debug, Clone, PartialEq, Readable, Writable)] @@ -578,7 +584,8 @@ pub async fn broadcast_changes( trace!("broadcasting changes: {changes:?} for seq: {seqs:?}"); - agent.subs_manager().match_changes(&changes, db_version); + match_changes(agent.subs_manager(), &changes, db_version); + match_changes(agent.updates_manager(), &changes, db_version); let tx_bcast = agent.tx_bcast().clone(); tokio::spawn(async move { diff --git a/crates/corro-types/src/config.rs b/crates/corro-types/src/config.rs index 00529ec3..63fbc5c1 100644 --- a/crates/corro-types/src/config.rs +++ b/crates/corro-types/src/config.rs @@ -1,52 +1,27 @@ -use std::net::{Ipv6Addr, SocketAddr, SocketAddrV6}; +use std::{ + collections::HashSet, + net::{Ipv6Addr, SocketAddr, SocketAddrV6}, +}; use camino::Utf8PathBuf; +use corro_base_types::CrsqlDbVersion; use serde::{Deserialize, Serialize}; use serde_with::{formats::PreferOne, serde_as, OneOrMany}; +use crate::actor::ActorId; + pub const DEFAULT_GOSSIP_PORT: u16 = 4001; const DEFAULT_GOSSIP_IDLE_TIMEOUT: u32 = 30; -const fn default_apply_queue() -> usize { - 100 -} - -const fn default_wal_threshold() -> usize { - 10 -} - -const fn default_processing_queue() -> usize { - 20000 -} - -/// Used for the apply channel -const fn default_huge_channel() -> usize { - 2048 -} - -// -const fn default_big_channel() -> usize { - 1024 -} - -const fn default_mid_channel() -> usize { - 512 -} - -const fn default_small_channel() -> usize { - 256 -} - -const fn default_apply_timeout() -> usize { - 50 -} - #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Config { pub db: DbConfig, pub api: ApiConfig, pub gossip: GossipConfig, + #[serde(default)] + pub follow: Option, + #[serde(default)] pub perf: PerfConfig, @@ -62,6 +37,31 @@ pub struct Config { pub consul: Option, } +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "kebab-case")] +pub struct FollowConfig { + pub addr: SocketAddr, + #[serde(default)] + pub from: FollowFrom, + #[serde(default)] + pub broadcast: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "kebab-case")] +pub enum FollowBroadcast { + ActorIds(HashSet), + Percent(u8), +} + +#[derive(Debug, Default, Clone, Serialize, Deserialize)] +#[serde(untagged, rename_all = "kebab-case")] +pub enum FollowFrom { + #[default] + Latest, + DbVersion(CrsqlDbVersion), +} + #[derive(Debug, Default, Clone, Serialize, Deserialize)] #[serde(rename_all = "kebab-case")] pub struct TelemetryConfig { @@ -291,6 +291,40 @@ impl Config { } } +const fn default_apply_queue() -> usize { + 100 +} + +const fn default_wal_threshold() -> usize { + 10 +} + +const fn default_processing_queue() -> usize { + 20000 +} + +/// Used for the apply channel +const fn default_huge_channel() -> usize { + 2048 +} + +// +const fn default_big_channel() -> usize { + 1024 +} + +const fn default_mid_channel() -> usize { + 512 +} + +const fn default_small_channel() -> usize { + 256 +} + +const fn default_apply_timeout() -> usize { + 50 +} + #[derive(Debug, Default)] pub struct ConfigBuilder { pub db_path: Option, @@ -306,6 +340,7 @@ pub struct ConfigBuilder { consul: Option, tls: Option, perf: Option, + follow: Option, } impl ConfigBuilder { @@ -364,6 +399,20 @@ impl ConfigBuilder { self } + pub fn follow( + mut self, + addr: SocketAddr, + from: FollowFrom, + broadcast: Option, + ) -> Self { + self.follow = Some(FollowConfig { + addr, + from, + broadcast, + }); + self + } + pub fn build(self) -> Result { let db_path = self.db_path.ok_or(ConfigBuilderError::DbPathRequired)?; @@ -402,6 +451,7 @@ impl ConfigBuilder { max_mtu: None, // TODO: add a builder function for it disable_gso: false, }, + follow: self.follow, perf: self.perf.unwrap_or_default(), admin: AdminConfig { uds_path: self.admin_path.unwrap_or_else(default_admin_path), diff --git a/crates/corro-types/src/lib.rs b/crates/corro-types/src/lib.rs index 6a74a0f1..ee4cc9fe 100644 --- a/crates/corro-types/src/lib.rs +++ b/crates/corro-types/src/lib.rs @@ -12,4 +12,6 @@ pub mod schema; pub mod sqlite; pub mod sync; pub mod tls; +pub mod updates; + pub use corro_base_types as base; diff --git a/crates/corro-types/src/pubsub.rs b/crates/corro-types/src/pubsub.rs index f4e2eeb4..775f8641 100644 --- a/crates/corro-types/src/pubsub.rs +++ b/crates/corro-types/src/pubsub.rs @@ -5,6 +5,7 @@ use std::{ time::{Duration, Instant}, }; +use async_trait::async_trait; use bytes::{Buf, BufMut}; use camino::{Utf8Path, Utf8PathBuf}; use compact_str::{format_compact, ToCompactString}; @@ -45,8 +46,10 @@ use crate::{ base::CrsqlDbVersion, schema::{Schema, Table}, sqlite::CrConn, + updates::HandleMetrics, }; +use crate::updates::{Handle, Manager}; pub use corro_api_types::sqlite::ChangeType; #[derive(Debug, Default, Clone)] @@ -58,13 +61,32 @@ struct InnerSubsManager { queries: HashMap, } -// tools to bootstrap a new subscriber +// tools to bootstrap a new subscriber or notifier pub struct MatcherCreated { pub evt_rx: mpsc::Receiver, } const SUB_EVENT_CHANNEL_CAP: usize = 512; +impl Manager for SubsManager { + fn trait_type(&self) -> String { + "subs".to_string() + } + + fn get(&self, id: &Uuid) -> Option { + self.0.read().get(id) + } + + fn remove(&self, id: &Uuid) -> Option { + let mut inner = self.0.write(); + inner.remove(id) + } + + fn get_handles(&self) -> BTreeMap { + self.0.read().handles.clone() + } +} + impl SubsManager { pub fn get(&self, id: &Uuid) -> Option { self.0.read().get(id) @@ -166,153 +188,14 @@ impl SubsManager { let mut inner = self.0.write(); inner.remove(id) } - - pub fn match_changes_from_db_version( - &self, - conn: &Connection, - db_version: CrsqlDbVersion, - ) -> rusqlite::Result<()> { - let handles = { - let inner = self.0.read(); - if inner.handles.is_empty() { - return Ok(()); - } - inner.handles.clone() - }; - - let mut candidates = handles - .iter() - .map(|(id, handle)| (id, (MatchCandidates::new(), handle))) - .collect::>(); - - { - let mut prepped = conn.prepare_cached( - r#" - SELECT "table", pk, cid - FROM crsql_changes - WHERE db_version = ? - ORDER BY seq ASC - "#, - )?; - - let rows = prepped.query_map([db_version], |row| { - Ok(( - row.get::<_, TableName>(0)?, - row.get::<_, Vec>(1)?, - row.get::<_, ColumnName>(2)?, - )) - })?; - - for change_res in rows { - let (table, pk, column) = change_res?; - - for (_id, (candidates, handle)) in candidates.iter_mut() { - let change = MatchableChange { - table: &table, - pk: &pk, - column: &column, - }; - handle.filter_matchable_change(candidates, change); - } - } - } - - // metrics... - for (id, (candidates, handle)) in candidates { - let mut match_count = 0; - - for (table, pks) in candidates.iter() { - let count = pks.len(); - match_count += count; - counter!("corro.subs.changes.matched.count", "sql_hash" => handle.inner.hash.clone(), "table" => table.to_string()).increment(count as u64); - } - - trace!(sub_id = %id, %db_version, "found {match_count} candidates"); - - if let Err(e) = handle.inner.changes_tx.try_send((candidates, db_version)) { - error!(sub_id = %id, "could not send change candidates to subscription handler: {e}"); - match e { - mpsc::error::TrySendError::Full(item) => { - warn!("channel is full, falling back to async send"); - let changes_tx = handle.inner.changes_tx.clone(); - tokio::spawn(async move { - _ = changes_tx.send(item).await; - }); - } - mpsc::error::TrySendError::Closed(_) => { - if let Some(handle) = self.remove(id) { - tokio::spawn(handle.cleanup()); - } - } - } - } - } - - Ok(()) - } - - pub fn match_changes(&self, changes: &[Change], db_version: CrsqlDbVersion) { - trace!( - %db_version, - "trying to match changes to subscribers, len: {}", - changes.len() - ); - if changes.is_empty() { - return; - } - let handles = { - let inner = self.0.read(); - if inner.handles.is_empty() { - return; - } - inner.handles.clone() - }; - - for (id, handle) in handles.iter() { - trace!(sub_id = %id, %db_version, "attempting to match changes to a subscription"); - let mut candidates = MatchCandidates::new(); - let mut match_count = 0; - for change in changes.iter().map(MatchableChange::from) { - if handle.filter_matchable_change(&mut candidates, change) { - match_count += 1; - } - } - - // metrics... - for (table, pks) in candidates.iter() { - counter!("corro.subs.changes.matched.count", "sql_hash" => handle.inner.hash.clone(), "table" => table.to_string()).increment(pks.len() as u64); - } - - trace!(sub_id = %id, %db_version, "found {match_count} candidates"); - - if let Err(e) = handle.inner.changes_tx.try_send((candidates, db_version)) { - error!(sub_id = %id, "could not send change candidates to subscription handler: {e}"); - match e { - mpsc::error::TrySendError::Full(item) => { - warn!("channel is full, falling back to async send"); - - let changes_tx = handle.inner.changes_tx.clone(); - tokio::spawn(async move { - _ = changes_tx.send(item).await; - }); - counter!("corro.subs.changes.channel.async_fallbacks_count", "sql_hash" => handle.inner.hash.clone()).increment(1); - } - mpsc::error::TrySendError::Closed(_) => { - if let Some(handle) = self.remove(id) { - tokio::spawn(handle.cleanup()); - } - } - } - } - } - } } #[derive(Debug)] -struct MatchableChange<'a> { - table: &'a TableName, - pk: &'a [u8], - column: &'a ColumnName, +pub struct MatchableChange<'a> { + pub table: &'a TableName, + pub pk: &'a [u8], + pub column: &'a ColumnName, + pub cl: i64, } impl<'a> From<&'a Change> for MatchableChange<'a> { @@ -321,6 +204,7 @@ impl<'a> From<&'a Change> for MatchableChange<'a> { table: &value.table, pk: &value.pk, column: &value.cid, + cl: value.cl, } } } @@ -382,15 +266,81 @@ struct InnerMatcherHandle { // some state from the matcher so we can take a look later subs_path: String, cached_statements: HashMap, + metrics: HashMap, } -type MatchCandidates = IndexMap>>; +pub type MatchCandidates = IndexMap, i64>>; -impl MatcherHandle { - pub fn id(&self) -> Uuid { +#[async_trait] +impl Handle for MatcherHandle { + fn id(&self) -> Uuid { self.inner.id } + fn cancelled(&self) -> WaitForCancellationFuture { + self.inner.cancel.cancelled() + } + + fn changes_tx(&self) -> mpsc::Sender<(MatchCandidates, CrsqlDbVersion)> { + self.inner.changes_tx.clone() + } + + async fn cleanup(&self) { + self.inner.cancel.cancel(); + info!(sub_id = %self.inner.id, "Canceled subscription"); + } + + fn filter_matchable_change( + &self, + candidates: &mut MatchCandidates, + change: MatchableChange, + ) -> bool { + trace!("filtering change {change:?}"); + // don't double process the same pk + if candidates + .get(change.table) + .map(|pks| pks.contains_key(change.pk)) + .unwrap_or_default() + { + trace!("already contained key"); + return false; + } + + // don't consider changes that don't have both the table + col in the matcher query + if !self + .inner + .parsed + .table_columns + .get(change.table.as_str()) + .map(|cols| change.column.is_crsql_sentinel() || cols.contains(change.column.as_str())) + .unwrap_or_default() + { + trace!("could not match against parsed query table and columns"); + return false; + } + + if let Some(v) = candidates.get_mut(change.table) { + v.insert(change.pk.to_vec(), change.cl).is_none() + } else { + candidates.insert( + change.table.clone(), + [(change.pk.to_vec(), change.cl)].into(), + ); + true + } + } + + fn get_counter(&self, table: &str) -> &HandleMetrics { + self.inner.metrics.get(table).unwrap_or_else(|| { + panic!( + "metrics counter for table '{}' missing. subs hash {}!", + self.inner.hash, table + ) + }) + } +} + +impl MatcherHandle { pub fn sql(&self) -> &String { &self.inner.sql } @@ -415,11 +365,6 @@ impl MatcherHandle { &self.inner.cached_statements } - pub async fn cleanup(self) { - self.inner.cancel.cancel(); - info!(sub_id = %self.inner.id, "Canceled subscription"); - } - pub fn pool(&self) -> &RusqlitePool { &self.inner.pool } @@ -432,10 +377,6 @@ impl MatcherHandle { } } - pub fn cancelled(&self) -> WaitForCancellationFuture { - self.inner.cancel.cancelled() - } - pub fn max_change_id(&self, conn: &Connection) -> rusqlite::Result { self.wait_for_running_state(); let mut prepped = conn.prepare_cached("SELECT COALESCE(MAX(id), 0) FROM changes")?; @@ -558,43 +499,6 @@ impl MatcherHandle { Ok(max_change_id) } - - fn filter_matchable_change( - &self, - candidates: &mut MatchCandidates, - change: MatchableChange, - ) -> bool { - trace!("filtering change {change:?}"); - // don't double process the same pk - if candidates - .get(change.table) - .map(|pks| pks.contains(change.pk)) - .unwrap_or_default() - { - trace!("already contained key"); - return false; - } - - // don't consider changes that don't have both the table + col in the matcher query - if !self - .inner - .parsed - .table_columns - .get(change.table.as_str()) - .map(|cols| change.column.is_crsql_sentinel() || cols.contains(change.column.as_str())) - .unwrap_or_default() - { - trace!("could not match against parsed query table and columns"); - return false; - } - - if let Some(v) = candidates.get_mut(change.table) { - v.insert(change.pk.to_vec()) - } else { - candidates.insert(change.table.clone(), [change.pk.to_vec()].into()); - true - } - } } type StateLock = Arc<(Mutex, Condvar)>; @@ -842,6 +746,14 @@ impl Matcher { // big channel to not miss anything let (changes_tx, changes_rx) = mpsc::channel(20480); + // metrics counters + let mut counter_map = HashMap::new(); + for table in parsed.table_columns.keys() { + counter_map.insert(table.clone(), HandleMetrics{ + matched_count: counter!("corro.subs.changes.matched.count", "sql_hash" => sql_hash.clone(), "table" => table.to_string()), + }); + } + let handle = MatcherHandle { inner: Arc::new(InnerMatcherHandle { id, @@ -859,6 +771,7 @@ impl Matcher { changes_tx, cached_statements: statements.clone(), subs_path: sub_path.to_string(), + metrics: counter_map, }), state: state.clone(), }; @@ -1195,8 +1108,8 @@ impl Matcher { Some((candidates, db_version)) = self.changes_rx.recv() => { for (table, pks) in candidates { let buffed = buf.entry(table).or_default(); - for pk in pks { - if buffed.insert(pk) { + for (pk, cl) in pks { + if buffed.insert(pk, cl).is_none() { buf_count += 1; } } @@ -1539,7 +1452,7 @@ impl Matcher { for (table, pks) in candidates { let pks = pks .iter() - .map(|pk| unpack_columns(pk)) + .map(|(pk, _)| unpack_columns(pk)) .collect::>, _>>()?; let tmp_table_name = format!("temp_{table}"); @@ -1806,7 +1719,7 @@ impl Matcher { { let mut changes_prepped = state_conn.prepare_cached( r#" - SELECT DISTINCT "table", pk + SELECT DISTINCT "table", pk, cl FROM crsql_changes WHERE db_version > ? AND db_version <= ? -- TODO: allow going over? @@ -1820,7 +1733,7 @@ impl Matcher { candidates .entry(row.get(0)?) .or_default() - .insert(row.get(1)?); + .insert(row.get(1)?, row.get(2)?); } } diff --git a/crates/corro-types/src/updates.rs b/crates/corro-types/src/updates.rs new file mode 100644 index 00000000..09810bc6 --- /dev/null +++ b/crates/corro-types/src/updates.rs @@ -0,0 +1,538 @@ +use crate::agent::SplitPool; +use crate::pubsub::{unpack_columns, MatchCandidates, MatchableChange, MatcherError}; +use crate::schema::Schema; +use async_trait::async_trait; +use corro_api_types::sqlite::ChangeType; +use corro_api_types::{Change, ColumnName, NotifyEvent, SqliteValueRef, TableName}; +use corro_base_types::CrsqlDbVersion; +use metrics::{counter, histogram, Counter}; +use parking_lot::RwLock; +use rusqlite::Connection; +use spawn::spawn_counted; +use std::collections::BTreeMap; +use std::fmt::{Debug, Formatter}; +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::mpsc; +use tokio::task::block_in_place; +use tokio::time::Instant; +use tokio_util::sync::{CancellationToken, WaitForCancellationFuture}; +use tracing::{debug, error, info, trace, warn}; +use tripwire::Tripwire; +use uuid::Uuid; + +pub trait Manager { + fn trait_type(&self) -> String; + fn get(&self, id: &Uuid) -> Option; + fn remove(&self, id: &Uuid) -> Option; + fn get_handles(&self) -> BTreeMap; +} + +#[async_trait] +pub trait Handle { + fn id(&self) -> Uuid; + fn cancelled(&self) -> WaitForCancellationFuture; + fn filter_matchable_change( + &self, + candidates: &mut MatchCandidates, + change: MatchableChange, + ) -> bool; + fn changes_tx(&self) -> mpsc::Sender<(MatchCandidates, CrsqlDbVersion)>; + async fn cleanup(&self); + fn get_counter(&self, table: &str) -> &HandleMetrics; +} + +#[derive(Clone)] +pub struct HandleMetrics { + pub matched_count: Counter, +} + +impl std::fmt::Debug for HandleMetrics { + fn fmt(&self, f: &mut Formatter) -> Result<(), std::fmt::Error> { + f.debug_struct("Counter").finish_non_exhaustive() + } +} + +#[derive(Default, Debug, Clone)] +pub struct UpdatesManager(Arc>); + +impl Manager for UpdatesManager { + fn trait_type(&self) -> String { + "updates".to_string() + } + + fn get(&self, id: &Uuid) -> Option { + self.0.read().get(id) + } + + fn remove(&self, id: &Uuid) -> Option { + let mut inner = self.0.write(); + inner.remove(id) + } + + fn get_handles(&self) -> BTreeMap { + self.0.read().handles.clone() + } +} + +#[async_trait] +impl Handle for UpdateHandle { + fn id(&self) -> Uuid { + self.inner.id + } + + fn cancelled(&self) -> WaitForCancellationFuture { + self.inner.cancel.cancelled() + } + + fn filter_matchable_change( + &self, + candidates: &mut MatchCandidates, + change: MatchableChange, + ) -> bool { + if change.table.to_string() != self.inner.name { + return false; + } + + trace!("filtering change {change:?}"); + // don't double process the same pk + if candidates + .get(change.table) + .map(|pks| pks.contains_key(change.pk)) + .unwrap_or_default() + { + trace!("already contained key"); + return false; + } + + if let Some(v) = candidates.get_mut(change.table) { + v.insert(change.pk.into(), change.cl).is_none() + } else { + candidates.insert( + change.table.clone(), + [(change.pk.to_vec(), change.cl)].into(), + ); + true + } + } + + fn changes_tx(&self) -> mpsc::Sender<(MatchCandidates, CrsqlDbVersion)> { + self.inner.changes_tx.clone() + } + + async fn cleanup(&self) { + self.inner.cancel.cancel(); + info!(sub_id = %self.inner.id, "Canceled subscription"); + } + + fn get_counter(&self, table: &str) -> &HandleMetrics { + // this should not happen + if table != self.inner.name { + warn!(update_tbl = %self.inner.name, "udpates handle get_counter method called for wrong table : {table}! This shouldn't happen!") + } + &self.inner.counters + } +} + +const UPDATE_EVENT_CHANNEL_CAP: usize = 512; + +// tools to bootstrap a new notifier +pub struct UpdateCreated { + pub evt_rx: mpsc::Receiver, +} + +#[derive(Debug, Default)] +struct InnerUpdatesManager { + tables: BTreeMap, + handles: BTreeMap, +} + +impl InnerUpdatesManager { + fn remove(&mut self, id: &Uuid) -> Option { + let handle = self.handles.remove(id)?; + Some(handle) + } + + fn get(&self, id: &Uuid) -> Option { + self.handles.get(id).cloned() + } +} + +impl UpdatesManager { + pub fn get(&self, table: &str) -> Option { + let id = self.0.read().tables.get(table).cloned(); + if let Some(id) = id { + return self.0.read().handles.get(&id).cloned(); + } + None + } + + pub fn get_or_insert( + &self, + tbl_name: &str, + schema: &Schema, + _pool: &SplitPool, + tripwire: Tripwire, + ) -> Result<(UpdateHandle, Option), MatcherError> { + if let Some(handle) = self.get(tbl_name) { + return Ok((handle, None)); + } + + let mut inner = self.0.write(); + let (evt_tx, evt_rx) = mpsc::channel(UPDATE_EVENT_CHANNEL_CAP); + + let id = Uuid::new_v4(); + let handle_res = UpdateHandle::create(id, tbl_name, schema, evt_tx, tripwire); + + let handle = match handle_res { + Ok(handle) => handle, + Err(e) => { + return Err(e); + } + }; + + inner.handles.insert(id, handle.clone()); + inner.tables.insert(tbl_name.to_string(), id); + + Ok((handle, Some(UpdateCreated { evt_rx }))) + } + + pub fn remove(&self, id: &Uuid) -> Option { + let mut inner = self.0.write(); + + inner.remove(id) + } +} + +#[derive(Clone, Debug)] +pub struct UpdateHandle { + inner: Arc, +} + +#[derive(Clone, Debug)] +pub struct InnerUpdateHandle { + id: Uuid, + name: String, + cancel: CancellationToken, + changes_tx: mpsc::Sender<(MatchCandidates, CrsqlDbVersion)>, + counters: HandleMetrics, +} + +impl UpdateHandle { + pub fn id(&self) -> Uuid { + self.inner.id + } + + pub fn create( + id: Uuid, + tbl_name: &str, + schema: &Schema, + evt_tx: mpsc::Sender, + tripwire: Tripwire, + ) -> Result { + // check for existing handles + match schema.tables.get(tbl_name) { + Some(table) => { + if table.pk.is_empty() { + return Err(MatcherError::MissingPrimaryKeys); + } + } + None => return Err(MatcherError::TableNotFound(tbl_name.to_string())), + }; + + let cancel = CancellationToken::new(); + let (changes_tx, changes_rx) = mpsc::channel(20480); + let handle = UpdateHandle { + inner: Arc::new(InnerUpdateHandle { + id, + name: tbl_name.to_owned(), + cancel: cancel.clone(), + changes_tx, + counters: HandleMetrics { + matched_count: counter!("corro.updates.changes.matched.count", "table" => tbl_name.to_owned()), + }, + }), + }; + spawn_counted(batch_candidates(id, cancel, evt_tx, changes_rx, tripwire)); + Ok(handle) + } + + pub async fn cleanup(self) { + self.inner.cancel.cancel(); + info!(update_id = %self.inner.id, "Canceled update"); + } +} + +fn handle_candidates( + evt_tx: mpsc::Sender, + candidates: MatchCandidates, +) -> Result<(), MatcherError> { + if candidates.is_empty() { + return Ok(()); + } + + trace!( + "got some candidates for updates! {:?}", + candidates.keys().collect::>() + ); + + for (_, pks) in candidates { + let pks = pks + .iter() + .map(|(pk, cl)| unpack_columns(pk).map(|x| (x, *cl))) + .collect::, i64)>, _>>()?; + + for (pk, cl) in pks { + let mut change_type = ChangeType::Update; + if cl % 2 == 0 { + change_type = ChangeType::Delete + } + if let Err(e) = evt_tx.blocking_send(NotifyEvent::Notify( + change_type, + pk.iter().map(|x| x.to_owned()).collect::>(), + )) { + debug!("could not send back row to matcher sub sender: {e}"); + return Err(MatcherError::EventReceiverClosed); + } + } + } + + Ok(()) +} + +async fn batch_candidates( + id: Uuid, + cancel: CancellationToken, + evt_tx: mpsc::Sender, + mut changes_rx: mpsc::Receiver<(MatchCandidates, CrsqlDbVersion)>, + mut tripwire: Tripwire, +) { + const PROCESS_CHANGES_THRESHOLD: usize = 1000; + const PROCESS_BUFFER_DEADLINE: Duration = Duration::from_millis(600); + + info!(sub_id = %id, "Starting loop to run the subscription"); + + let mut buf = MatchCandidates::new(); + let mut buf_count = 0; + + // max duration of aggregating candidates + let process_changes_deadline = tokio::time::sleep(PROCESS_BUFFER_DEADLINE); + tokio::pin!(process_changes_deadline); + + let mut process = false; + loop { + tokio::select! { + biased; + _ = cancel.cancelled() => { + info!(sub_id = %id, "Acknowledged updates cancellation, breaking loop."); + break; + } + Some((candidates, _)) = changes_rx.recv() => { + for (table, pk_map) in candidates { + let buffed = buf.entry(table).or_default(); + for (pk, cl) in pk_map { + if buffed.insert(pk, cl).is_none() { + buf_count += 1; + } + } + } + + if buf_count >= PROCESS_CHANGES_THRESHOLD { + process = true + } + }, + _ = process_changes_deadline.as_mut() => { + process_changes_deadline + .as_mut() + .reset(Instant::now() + PROCESS_BUFFER_DEADLINE); + if buf_count != 0 { + process = true + } + }, + _ = &mut tripwire => { + trace!(sub_id = %id, "tripped batch_candidates, returning"); + return; + } + else => { + return; + } + } + + if process { + let start = Instant::now(); + if let Err(e) = + block_in_place(|| handle_candidates(evt_tx.clone(), std::mem::take(&mut buf))) + { + if !matches!(e, MatcherError::EventReceiverClosed) { + error!(sub_id = %id, "could not handle change: {e}"); + } + break; + } + let elapsed = start.elapsed(); + + histogram!("corro.updates.changes.processing.duration.seconds", "table" => id.to_string()).record(elapsed); + + buf_count = 0; + + // reset the deadline + process_changes_deadline + .as_mut() + .reset(Instant::now() + PROCESS_BUFFER_DEADLINE); + } + } + + debug!(id = %id, "update loop is done"); +} + +pub fn match_changes(manager: &impl Manager, changes: &[Change], db_version: CrsqlDbVersion) +where + H: Handle + Send + 'static, +{ + let trait_type = manager.trait_type(); + trace!( + %db_version, + "trying to match changes to {trait_type}, len: {}", + changes.len() + ); + if changes.is_empty() { + return; + } + + let handles = manager.get_handles(); + if handles.is_empty() { + return; + } + + for (id, handle) in handles.iter() { + trace!(sub_id = %id, %db_version, "attempting to match changes to a subscription"); + let mut candidates = MatchCandidates::new(); + let mut match_count = 0; + for change in changes.iter().map(MatchableChange::from) { + if handle.filter_matchable_change(&mut candidates, change) { + match_count += 1; + } + } + + // metrics... + for (table, pks) in candidates.iter() { + handle + .get_counter(table) + .matched_count + .increment(pks.len() as u64); + } + + trace!(sub_id = %id, %db_version, "found {match_count} candidates"); + + if let Err(e) = handle.changes_tx().try_send((candidates, db_version)) { + error!(sub_id = %id, "could not send change candidates to {trait_type} handler: {e}"); + match e { + mpsc::error::TrySendError::Full(item) => { + warn!("channel is full, falling back to async send"); + + let changes_tx = handle.changes_tx(); + tokio::spawn(async move { + _ = changes_tx.send(item).await; + }); + } + mpsc::error::TrySendError::Closed(_) => { + if let Some(handle) = manager.remove(id) { + tokio::spawn(async move { + handle.cleanup().await; + }); + } + } + } + } + } +} + +pub fn match_changes_from_db_version( + manager: &impl Manager, + conn: &Connection, + db_version: CrsqlDbVersion, +) -> rusqlite::Result<()> +where + H: Handle + Send + 'static, +{ + let handles = manager.get_handles(); + if handles.is_empty() { + return Ok(()); + } + + let trait_type = manager.trait_type(); + let mut candidates = handles + .iter() + .map(|(id, handle)| (id, (MatchCandidates::new(), handle))) + .collect::>(); + + { + let mut prepped = conn.prepare_cached( + r#" + SELECT "table", pk, cid, cl + FROM crsql_changes + WHERE db_version = ? + ORDER BY seq ASC + "#, + )?; + + let rows = prepped.query_map([db_version], |row| { + Ok(( + row.get::<_, TableName>(0)?, + row.get::<_, Vec>(1)?, + row.get::<_, ColumnName>(2)?, + row.get::<_, i64>(3)?, + )) + })?; + + for change_res in rows { + let (table, pk, column, cl) = change_res?; + + for (_id, (candidates, handle)) in candidates.iter_mut() { + let change = MatchableChange { + table: &table, + pk: &pk, + column: &column, + cl, + }; + handle.filter_matchable_change(candidates, change); + } + } + } + + // metrics... + for (id, (candidates, handle)) in candidates { + let mut match_count = 0; + for (table, pks) in candidates.iter() { + let count = pks.len(); + match_count += count; + handle + .get_counter(table) + .matched_count + .increment(pks.len() as u64); + } + + trace!(sub_id = %id, %db_version, "found {match_count} candidates"); + + if let Err(e) = handle.changes_tx().try_send((candidates, db_version)) { + error!(sub_id = %id, "could not send change candidates to {trait_type} handler: {e}"); + match e { + mpsc::error::TrySendError::Full(item) => { + warn!("channel is full, falling back to async send"); + let changes_tx = handle.changes_tx(); + tokio::spawn(async move { + _ = changes_tx.send(item).await; + }); + } + mpsc::error::TrySendError::Closed(_) => { + if let Some(handle) = manager.remove(id) { + tokio::spawn(async move { + handle.cleanup().await; + }); + } + } + } + } + } + + Ok(()) +} diff --git a/crates/corrosion/src/command/agent.rs b/crates/corrosion/src/command/agent.rs index e72fd769..80ca9e9d 100644 --- a/crates/corrosion/src/command/agent.rs +++ b/crates/corrosion/src/command/agent.rs @@ -53,13 +53,15 @@ pub async fn run(config: Config, config_path: &Utf8PathBuf) -> eyre::Result<()> let (tripwire, tripwire_worker) = tripwire::Tripwire::new_signals(); - let (agent, bookie) = corro_agent::agent::start_with_config(config.clone(), tripwire.clone()) - .await - .expect("could not start agent"); + let (agent, bookie, transport) = + corro_agent::agent::start_with_config(config.clone(), tripwire.clone()) + .await + .expect("could not start agent"); corro_admin::start_server( agent.clone(), bookie.clone(), + transport, AdminConfig { listen_path: config.admin.uds_path.clone(), config_path: config_path.clone(), diff --git a/crates/corrosion/src/main.rs b/crates/corrosion/src/main.rs index c65cb678..d21d0510 100644 --- a/crates/corrosion/src/main.rs +++ b/crates/corrosion/src/main.rs @@ -536,6 +536,21 @@ async fn process_cli(cli: Cli) -> eyre::Result<()> { conn.send_command(corro_admin::Command::Subs(corro_admin::SubsCommand::List)) .await?; } + Command::Debug(DebugCommand::Follow { + peer_addr, + from, + local_only, + }) => { + let mut conn = AdminConn::connect(cli.admin_path()).await?; + conn.send_command(corro_admin::Command::Debug( + corro_admin::DebugCommand::Follow { + peer_addr: *peer_addr, + from: *from, + local_only: *local_only, + }, + )) + .await?; + } } Ok(()) @@ -704,6 +719,10 @@ enum Command { /// Subscription related commands #[command(subcommand)] Subs(SubsCommand), + + /// Debug commands + #[command(subcommand)] + Debug(DebugCommand), } #[derive(Subcommand)] @@ -800,3 +819,15 @@ enum SubsCommand { id: Option, }, } + +#[derive(Subcommand)] +enum DebugCommand { + /// Follow a node's changes + Follow { + peer_addr: SocketAddr, + #[arg(long, default_value = None)] + from: Option, + #[arg(long, default_value_t = false)] + local_only: bool, + }, +}