diff --git a/Cargo.lock b/Cargo.lock index d68f63fe72..8d0fb1dc29 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3900,6 +3900,7 @@ dependencies = [ "extractors", "futures", "insta", + "json", "json-patch", "librocksdb-sys", "locate-bin", diff --git a/crates/flowctl/src/raw/capture.rs b/crates/flowctl/src/raw/capture.rs index 7786694fa6..2d3864eab7 100644 --- a/crates/flowctl/src/raw/capture.rs +++ b/crates/flowctl/src/raw/capture.rs @@ -127,10 +127,6 @@ pub async fn do_capture( .map(|i| i.clone().into()) .unwrap_or(std::time::Duration::from_secs(1)); - let mut ticker = tokio::time::interval(interval); - ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); - _ = ticker.tick().await; // First tick is immediate. - let mut output = std::io::stdout(); // TODO(johnny): This is currently only partly implemented, but is awaiting @@ -147,15 +143,18 @@ pub async fn do_capture( // Upon a checkpoint, wait until the next tick interval has elapsed before acknowledging. if let Some(_checkpoint) = response.checkpoint { - _ = ticker.tick().await; - - request_tx - .send(Ok(capture::Request { - acknowledge: Some(capture::request::Acknowledge { checkpoints: 1 }), - ..Default::default() - })) - .await - .unwrap(); + let mut request_tx = request_tx.clone(); + tokio::spawn(async move { + () = tokio::time::sleep(interval).await; + + request_tx + .feed(Ok(capture::Request { + acknowledge: Some(capture::request::Acknowledge { checkpoints: 1 }), + ..Default::default() + })) + .await + .unwrap(); + }); } } diff --git a/crates/ops/src/lib.rs b/crates/ops/src/lib.rs index bd86573d61..9271cb6a73 100644 --- a/crates/ops/src/lib.rs +++ b/crates/ops/src/lib.rs @@ -5,9 +5,9 @@ use std::io::Write; pub mod decode; pub mod tracing; -pub use proto_flow::ops::log::Level as LogLevel; -pub use proto_flow::ops::Log; -pub use proto_flow::ops::TaskType; +// Re-export many types from proto_flow::ops, so that users of this crate +// don't also have to use that module. +pub use proto_flow::ops::{log::Level as LogLevel, stats, Log, Meta, ShardRef, Stats, TaskType}; #[derive(Serialize, Deserialize, Clone, Debug)] #[serde(rename_all = "camelCase")] @@ -22,7 +22,7 @@ pub struct Shard { r_clock_begin: HexU32, } -impl From for proto_flow::ops::ShardRef { +impl From for ShardRef { fn from( Shard { kind, diff --git a/crates/proto-flow/src/runtime.rs b/crates/proto-flow/src/runtime.rs index ba20c81d6e..0788e700b5 100644 --- a/crates/proto-flow/src/runtime.rs +++ b/crates/proto-flow/src/runtime.rs @@ -152,12 +152,47 @@ pub struct Container { pub struct CaptureRequestExt { #[prost(message, optional, tag = "1")] pub labels: ::core::option::Option, + #[prost(message, optional, tag = "2")] + pub open: ::core::option::Option, +} +/// Nested message and enum types in `CaptureRequestExt`. +pub mod capture_request_ext { + #[allow(clippy::derive_partial_eq_without_eq)] + #[derive(Clone, PartialEq, ::prost::Message)] + pub struct Open { + /// RocksDB descriptor which should be opened. + #[prost(message, optional, tag = "1")] + pub rocksdb_descriptor: ::core::option::Option, + } } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct CaptureResponseExt { #[prost(message, optional, tag = "1")] pub container: ::core::option::Option, + #[prost(message, optional, tag = "2")] + pub captured: ::core::option::Option, + #[prost(message, optional, tag = "4")] + pub checkpoint: ::core::option::Option, +} +/// Nested message and enum types in `CaptureResponseExt`. +pub mod capture_response_ext { + #[allow(clippy::derive_partial_eq_without_eq)] + #[derive(Clone, PartialEq, ::prost::Message)] + pub struct Captured { + /// Packed key extracted from the captured document. + #[prost(bytes = "bytes", tag = "1")] + pub key_packed: ::prost::bytes::Bytes, + /// Packed partition values extracted from the captured document. + #[prost(bytes = "bytes", tag = "2")] + pub partitions_packed: ::prost::bytes::Bytes, + } + #[allow(clippy::derive_partial_eq_without_eq)] + #[derive(Clone, PartialEq, ::prost::Message)] + pub struct Checkpoint { + #[prost(message, optional, tag = "1")] + pub stats: ::core::option::Option, + } } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] @@ -227,6 +262,18 @@ pub mod derive_response_ext { pub struct MaterializeRequestExt { #[prost(message, optional, tag = "1")] pub labels: ::core::option::Option, + #[prost(message, optional, tag = "2")] + pub open: ::core::option::Option, +} +/// Nested message and enum types in `MaterializeRequestExt`. +pub mod materialize_request_ext { + #[allow(clippy::derive_partial_eq_without_eq)] + #[derive(Clone, PartialEq, ::prost::Message)] + pub struct Open { + /// RocksDB descriptor which should be opened. + #[prost(message, optional, tag = "1")] + pub rocksdb_descriptor: ::core::option::Option, + } } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] diff --git a/crates/runtime/Cargo.toml b/crates/runtime/Cargo.toml index c141832b33..a669ecf430 100644 --- a/crates/runtime/Cargo.toml +++ b/crates/runtime/Cargo.toml @@ -15,6 +15,7 @@ coroutines = { path = "../coroutines" } derive-sqlite = { path = "../derive-sqlite" } doc = { path = "../doc" } extractors = { path = "../extractors" } +json = { path = "../json" } locate-bin = { path = "../locate-bin" } models = { path = "../models" } ops = { path = "../ops" } diff --git a/crates/runtime/src/capture/combine.rs b/crates/runtime/src/capture/combine.rs new file mode 100644 index 0000000000..193d4ca56d --- /dev/null +++ b/crates/runtime/src/capture/combine.rs @@ -0,0 +1,275 @@ +use crate::{task_state::RocksDB, TaskCombiner}; +use anyhow::Context; +use futures::{channel::mpsc, SinkExt, Stream, StreamExt, TryStreamExt}; +use proto_flow::capture::{self, request, response, Request, Response}; +use proto_flow::flow; +use proto_flow::ops; +use proto_flow::runtime::capture_response_ext; +use std::sync::Arc; +use std::time::SystemTime; + +pub fn adapt_requests( + peek_request: &Request, + request_rx: R, +) -> anyhow::Result<(impl Stream>, ResponseArgs)> +where + R: Stream>, +{ + // Open RocksDB based on the request::Open internal descriptor. + let db = Arc::new(RocksDB::open( + peek_request + .get_internal()? + .open + .and_then(|o| o.rocksdb_descriptor), + )?); + let response_db = db.clone(); + + // Channel for receiving checkpoints of each request::Acknowledge from the response stream. + let (ack_tx, mut ack_rx) = mpsc::channel(0); + // Channel for passing request::Open to the response stream. + let (mut open_tx, open_rx) = mpsc::channel(0); + + let request_rx = coroutines::try_coroutine(move |mut co| async move { + let mut request_rx = std::pin::pin!(request_rx); + + while let Some(mut request) = request_rx.try_next().await? { + if let Some(open) = &mut request.open { + // Fetch out the connector's current state. + if let Some(state) = db.load_connector_state()? { + open.state_json = state.to_string(); + } + + // Tell the response loop about the request::Open. + // It will inspect it upon a future response::Opened message. + open_tx + .feed(open.clone()) + .await + .context("failed to send request::Open to response stream")?; + + co.yield_(request).await; + } else if let Some(_ack) = &request.acknowledge { + // Receive the actual request::Acknowledge to forward from the response stream. + let ack: request::Acknowledge = ack_rx.next().await.context( + "failed to receive on request::Acknowledge from the response stream", + )?; + + co.yield_(Request { + acknowledge: Some(ack), + ..Default::default() + }) + .await; + } else { + co.yield_(request).await; // Forward everything else. + } + } + Ok(()) + }); + + Ok(( + request_rx, + ResponseArgs { + ack_tx, + db: response_db, + open_rx, + }, + )) +} + +pub struct ResponseArgs { + ack_tx: mpsc::Sender, + db: Arc, + open_rx: mpsc::Receiver, +} + +pub fn adapt_responses( + args: ResponseArgs, + response_rx: R, +) -> impl Stream> +where + R: Stream>, +{ + let ResponseArgs { + mut ack_tx, + db, + mut open_rx, + } = args; + + // Combiner of an opened capture. + let mut maybe_opened: Option = None; + + let mut txn_bytes: usize = 0; + let mut txn_checkpoints: u32 = 0; + let mut in_checkpoint: bool = false; + let mut started_at = SystemTime::UNIX_EPOCH; + let ser_policy = doc::SerPolicy::default(); + + let response_rx = coroutines::try_coroutine(move |mut co| async move { + let mut response_rx = std::pin::pin!(response_rx.fuse()); + + loop { + let response = tokio::select! { + // Option 1: tell the request loop of a ready checkpoint. + // If this branch is taken, we'll then go on to drain the checkpoint. + _ = ack_tx.feed(request::Acknowledge { checkpoints: txn_checkpoints }), if !in_checkpoint && txn_checkpoints > 0 => None, + // Option 2: read a next connector response, which may merge in a next checkpoint from the delegate. + Some(response) = response_rx.next(), if txn_bytes < COMBINER_BYTE_THRESHOLD => Some(response?), + // Option 3: `checkpoints` is zero and `response_rx` polls to None. All done. + else => { return Ok(()) }, + }; + + let Some(response) = response else { + let mut combiner = maybe_opened.unwrap(); + + // Drain Combiner into Captured responses. + let doc::Combiner::Accumulator(accumulator) = combiner.inner else { + unreachable!() + }; + + let mut drainer = accumulator + .into_drainer() + .context("preparing to drain combiner")?; + let mut buf = bytes::BytesMut::new(); + + let mut checkpoint = Response::default(); + + loop { + let doc::combine::DrainedDoc { meta, root } = drainer + .next() + .expect("drainer cannot EOF before state checkpoint")?; + + // Loop exit condition: the final item is the combined transaction checkpoint. + if meta.binding() == combiner.bindings.len() { + assert!(drainer.next().is_none()); + + checkpoint.checkpoint = Some(response::Checkpoint { + state: Some(flow::ConnectorState { + merge_patch: false, + updated_json: serde_json::to_string(&ser_policy.on_owned(&root)) + .expect("checkpoint serialization cannot fail"), + }), + }); + break; + } + + let (key_packed, partitions_packed, doc_json) = + combiner.bindings[meta.binding()].drained(&root, &mut buf); + + let captured = Response { + captured: Some(response::Captured { + binding: meta.binding() as u32, + doc_json, + }), + ..Default::default() + } + .with_internal_buf(&mut buf, |internal| { + internal.captured = Some(capture_response_ext::Captured { + key_packed, + partitions_packed, + }); + }); + co.yield_(captured).await; + } + // Combiner is now drained and is ready to accumulate again. + combiner.inner = doc::Combiner::Accumulator(drainer.into_new_accumulator()?); + + // Next we build up statistics to yield with our own response::Checkpoint. + let stats = ops::Stats { + capture: combiner.build_binding_stats(), + derive: None, + interval: None, + materialize: Default::default(), + meta: Some(ops::Meta { + uuid: crate::UUID_PLACEHOLDER.to_string(), + }), + open_seconds_total: started_at.elapsed().unwrap().as_secs_f64(), + shard: Some(combiner.shard_ref.clone()), + timestamp: Some(proto_flow::as_timestamp(started_at)), + txn_count: 1, + }; + + // Now send the Checkpoint response extended with accumulated stats. + co.yield_(checkpoint.with_internal(|internal| { + internal.checkpoint = + Some(capture_response_ext::Checkpoint { stats: Some(stats) }); + })) + .await; + + // If inferred schemas were updated, log them out for continuous schema inference. + combiner.log_updated_schemas(); + + txn_bytes = 0; + txn_checkpoints = 0; + maybe_opened = Some(combiner); + continue; + }; + + if let Some(opened) = &response.opened { + let open = open_rx + .next() + .await + .context("failed to receive request::Open from request loop")?; + + maybe_opened = Some(TaskCombiner::open_capture(&open, opened)?); + + co.yield_(response).await; // Forward. + } else if let Some(capture::response::Captured { binding, doc_json }) = + &response.captured + { + let combiner = maybe_opened + .as_mut() + .context("connector sent Captured before Opened")?; + + if !in_checkpoint { + if txn_checkpoints == 0 { + started_at = SystemTime::now(); + } + in_checkpoint = true; + } + + combiner.combine(*binding, doc_json, false)?; + txn_bytes += doc_json.len(); + + // Not forwarded. + } else if let Some(capture::response::Checkpoint { state }) = &response.checkpoint { + let combiner = maybe_opened + .as_mut() + .context("connector sent Checkpoint before Opened")?; + + let flow::ConnectorState { + updated_json, + merge_patch, + } = state + .as_ref() + .context("connector sent Checkpoint without `state` field")?; + + if !in_checkpoint && txn_checkpoints == 0 { + started_at = SystemTime::now(); + } + txn_checkpoints += 1; + in_checkpoint = false; + + // Combine over the checkpoint state. + if *merge_patch && combiner.has_state_schema { + anyhow::bail!( + "connector sent Checkpoint with `mergePatch` but defines a state JSON-Schema", + ); + } else if !merge_patch { + combiner.combine(combiner.bindings.len() as u32, "null", false)?; + } + + combiner.combine(combiner.bindings.len() as u32, updated_json, false)?; + } else { + // Simply forward everything else. + () = co.yield_(response).await; + } + } + }); + + response_rx +} + +// COMBINER_BYTE_THRESHOLD is a coarse target on the documents which can be +// optimistically combined within a capture transaction, while awaiting the +// commit of a previous transaction. Upon reaching this threshold, further +// documents and checkpoints will not be folded into the transaction. +const COMBINER_BYTE_THRESHOLD: usize = 1 << 25; // 32MB. diff --git a/crates/runtime/src/capture/mod.rs b/crates/runtime/src/capture/mod.rs index 887748c94a..014ed19c0f 100644 --- a/crates/runtime/src/capture/mod.rs +++ b/crates/runtime/src/capture/mod.rs @@ -25,6 +25,7 @@ use std::sync::Arc; // Drain combiner into forwarded Captured. // Forward Checkpoint enriched with stats. +mod combine; mod image; mod local; @@ -89,6 +90,8 @@ where // Request interceptor which adjusts the dynamic log level based on internal shard labels. let request_rx = adjust_log_level(request_rx, self.set_log_level); + // Request interceptor for combining over response documents and checkpoints. + let (request_rx, combine_args) = combine::adapt_requests(&peek_request, request_rx)?; let response_rx = match endpoint { models::CaptureEndpoint::Connector(_) => image::connector( @@ -108,7 +111,9 @@ where } }; - Ok(response_rx) + let response_rx = combine::adapt_responses(combine_args, response_rx); + + Ok(response_rx.boxed()) } } @@ -122,6 +127,7 @@ where request_rx.inspect_ok(move |request| { let Ok(CaptureRequestExt { labels: Some(ops::ShardLabeling { log_level, .. }), + .. }) = request.get_internal() else { return; diff --git a/crates/runtime/src/derive/combine.rs b/crates/runtime/src/derive/combine.rs index c3951733be..9ed4cfd11b 100644 --- a/crates/runtime/src/derive/combine.rs +++ b/crates/runtime/src/derive/combine.rs @@ -1,8 +1,8 @@ +use crate::TaskCombiner; use anyhow::Context; use futures::{channel::mpsc, SinkExt, Stream, StreamExt, TryStreamExt}; use proto_flow::derive::{request, response, Request, Response}; -use proto_flow::flow::{self, collection_spec, CollectionSpec}; -use proto_flow::ops; +use proto_flow::flow; use proto_flow::runtime::derive_response_ext; use std::time::SystemTime; @@ -15,7 +15,7 @@ where { // Maximum UUID Clock value observed in request::Read documents. let mut max_clock = 0; - // Statistics for read documents, passed to the response stream on flush. + // Statistics for read documents, indexed on transform and passed to the response stream on flush. let mut read_stats: Vec = Vec::new(); // Time at which the current transaction was started. let mut started_at: Option = None; @@ -91,16 +91,8 @@ where mut open_rx, } = args; - // Statistics for documents published by us when draining. - let mut drain_stats: ops::stats::DocsAndBytes = Default::default(); - // Inferred shape of published documents. - let mut inferred_shape: doc::Shape = doc::Shape::nothing(); - // Did `inferred_shape` change during the current transaction? - let mut inferred_shape_changed: bool = false; - // State of an opened derivation. - let mut maybe_opened: Option = None; - // Statistics for documents published by the wrapped delegate. - let mut publish_stats: ops::stats::DocsAndBytes = Default::default(); + // Combiner and indexed read transforms of an opened derivation. + let mut maybe_opened: Option<(TaskCombiner, Vec<(String, models::Collection)>)> = None; coroutines::try_coroutine(move |mut co| async move { let mut response_rx = std::pin::pin!(response_rx); @@ -112,29 +104,27 @@ where .await .context("failed to receive request::Open from request loop")?; - maybe_opened = Some(Opened::build(open, opened)?); + maybe_opened = Some(TaskCombiner::open_derivation(&open, opened)?); co.yield_(response).await; // Forward. } else if let Some(published) = &response.published { - let opened = maybe_opened + let (combiner, _transforms) = maybe_opened .as_mut() .context("connector sent Published before Opened")?; - opened.combine_right(&published)?; - publish_stats.docs_total += 1; - publish_stats.bytes_total += published.doc_json.len() as u64; + combiner.combine(0, &published.doc_json, false)?; // Not forwarded. } else if let Some(_flushed) = &response.flushed { - let mut opened = maybe_opened + let (mut combiner, transforms) = maybe_opened .take() .context("connector sent Flushed before Opened")?; let (max_clock, read_stats, started_at) = flush_rx .next() .await - .context("failed to receive on request::Flush from request loop")?; + .context("failed to receive from request stream on response::Flushed")?; // Drain Combiner into Published responses. - let doc::Combiner::Accumulator(accumulator) = opened.combiner else { + let doc::Combiner::Accumulator(accumulator) = combiner.inner else { unreachable!() }; @@ -143,29 +133,14 @@ where .context("preparing to drain combiner")?; let mut buf = bytes::BytesMut::new(); + // Derivations have only one binding. + let binding = &mut combiner.bindings[0]; + while let Some(drained) = drainer.next() { let doc::combine::DrainedDoc { meta: _, root } = drained?; - if inferred_shape.widen_owned(&root) { - doc::shape::limits::enforce_shape_complexity_limit( - &mut inferred_shape, - doc::shape::limits::DEFAULT_SCHEMA_COMPLEXITY_LIMIT, - ); - inferred_shape_changed = true; - } - - let key_packed = - doc::Extractor::extract_all_owned(&root, &opened.key_extractors, &mut buf); - let partitions_packed = doc::Extractor::extract_all_owned( - &root, - &opened.partition_extractors, - &mut buf, - ); - - let doc_json = serde_json::to_string(&opened.ser_policy.on_owned(&root)) - .expect("document serialization cannot fail"); - drain_stats.docs_total += 1; - drain_stats.bytes_total += doc_json.len() as u64; + let (key_packed, partitions_packed, doc_json) = + binding.drained(&root, &mut buf); let published = Response { published: Some(response::Published { doc_json }), @@ -181,13 +156,12 @@ where co.yield_(published).await; } // Combiner is now drained and is ready to accumulate again. - opened.combiner = doc::Combiner::Accumulator(drainer.into_new_accumulator()?); + combiner.inner = doc::Combiner::Accumulator(drainer.into_new_accumulator()?); // Next we build up statistics to yield with our own response::Flushed. - let duration = started_at.elapsed().unwrap_or_default(); + let duration = started_at.elapsed().unwrap(); - let transforms = opened - .transforms + let transform_stats = transforms .iter() .zip(read_stats.into_iter()) .filter_map(|((name, source), read_stats)| { @@ -198,7 +172,7 @@ where name.clone(), ops::stats::derive::Transform { input: Some(read_stats), - source: source.clone(), + source: source.to_string(), }, )) } @@ -208,9 +182,9 @@ where let stats = ops::Stats { capture: Default::default(), derive: Some(ops::stats::Derive { - transforms, - published: maybe_counts(&mut publish_stats), - out: maybe_counts(&mut drain_stats), + transforms: transform_stats, + published: maybe_counts(&mut binding.stats_combined), + out: maybe_counts(&mut binding.stats_drained), }), interval: None, materialize: Default::default(), @@ -218,7 +192,7 @@ where uuid: crate::UUID_PLACEHOLDER.to_string(), }), open_seconds_total: duration.as_secs_f64(), - shard: Some(opened.shard.clone()), + shard: Some(combiner.shard_ref.clone()), timestamp: Some(proto_flow::as_timestamp(started_at)), txn_count: 1, }; @@ -229,23 +203,10 @@ where })) .await; - // If the inferred doc::Shape was updated, log it out for continuous schema inference. - if inferred_shape_changed { - inferred_shape_changed = false; - - let serialized = serde_json::to_value(&doc::shape::schema::to_schema( - inferred_shape.clone(), - )) - .expect("shape serialization should never fail"); + // If inferred schemas were updated, log them out for continuous schema inference. + combiner.log_updated_schemas(); - tracing::info!( - schema = ?::ops::DebugJson(serialized), - collection_name = %opened.shard.name, - "inferred schema updated" - ); - } - - maybe_opened = Some(opened); + maybe_opened = Some((combiner, transforms)); } else { // All other request types are forwarded. co.yield_(response).await; @@ -255,142 +216,6 @@ where }) } -pub struct Opened { - // Combiner of published documents. - combiner: doc::Combiner, - // JSON pointer to the derived document UUID. - document_uuid_ptr: Option, - // Key components of derived documents. - key_extractors: Vec, - // Partitions to extract when draining the Combiner. - partition_extractors: Vec, - // Document serialization policy. - ser_policy: doc::SerPolicy, - // Shard of this derivation. - shard: ops::ShardRef, - // Ordered transform (transform-name, source-collection). - transforms: Vec<(String, String)>, -} - -impl Opened { - pub fn build(open: request::Open, _opened: &response::Opened) -> anyhow::Result { - let request::Open { - collection, - range, - state_json: _, - version: _, - } = open; - - let CollectionSpec { - ack_template_json: _, - derivation, - key, - name, - partition_fields, - partition_template: _, - projections, - read_schema_json: _, - uuid_ptr: document_uuid_ptr, - write_schema_json, - } = collection.as_ref().context("missing collection")?; - - let collection_spec::Derivation { - connector_type: _, - config_json: _, - transforms, - .. - } = derivation.as_ref().context("missing derivation")?; - - // TODO(johnny): Expose to connector protocol and extract from Open/Opened. - let ser_policy = doc::SerPolicy::default(); - - let range = range.as_ref().context("missing range")?; - - if key.is_empty() { - return Err(anyhow::anyhow!("derived collection key cannot be empty").into()); - } - let key_extractors = extractors::for_key(&key, &projections, &ser_policy)?; - - let document_uuid_ptr = if document_uuid_ptr.is_empty() { - None - } else { - Some(doc::Pointer::from(&document_uuid_ptr)) - }; - - let write_schema_json = doc::validation::build_bundle(&write_schema_json) - .context("collection write_schema_json is not a JSON schema")?; - let validator = - doc::Validator::new(write_schema_json).context("could not build a schema validator")?; - - let combiner = doc::Combiner::new( - doc::combine::Spec::with_one_binding( - false, // Derivations use partial reductions. - key_extractors.clone(), - None, - validator, - ), - tempfile::tempfile().context("opening temporary spill file")?, - )?; - - // Identify ordered, partitioned projections to extract on combiner drain. - let partition_extractors = - extractors::for_fields(partition_fields, projections, &ser_policy)?; - - let transforms = transforms - .iter() - .map(|transform| { - ( - transform.name.clone(), - transform.collection.as_ref().unwrap().name.clone(), - ) - }) - .collect(); - - let shard = ops::ShardRef { - kind: ops::TaskType::Derivation as i32, - name: name.clone(), - key_begin: format!("{:08x}", range.key_begin), - r_clock_begin: format!("{:08x}", range.r_clock_begin), - }; - - Ok(Self { - combiner, - document_uuid_ptr, - key_extractors, - partition_extractors, - ser_policy, - shard, - transforms, - }) - } - - pub fn combine_right(&mut self, published: &response::Published) -> anyhow::Result<()> { - let memtable = match &mut self.combiner { - doc::Combiner::Accumulator(accumulator) => accumulator.memtable()?, - _ => panic!("implementation error: combiner is draining, not accumulating"), - }; - let alloc = memtable.alloc(); - - let mut deser = serde_json::Deserializer::from_str(&published.doc_json); - let mut doc = doc::HeapNode::from_serde(&mut deser, alloc).with_context(|| { - format!( - "couldn't parse published document as JSON: {}", - &published.doc_json - ) - })?; - - if let Some(ptr) = &self.document_uuid_ptr { - if let Some(node) = ptr.create_heap_node(&mut doc, alloc) { - *node = - doc::HeapNode::String(doc::BumpStr::from_str(crate::UUID_PLACEHOLDER, alloc)); - } - } - memtable.add(0, doc, false)?; - - Ok(()) - } -} - fn maybe_counts(s: &mut ops::stats::DocsAndBytes) -> Option { if s.bytes_total != 0 { Some(std::mem::take(s)) diff --git a/crates/runtime/src/derive/rocksdb.rs b/crates/runtime/src/derive/rocksdb.rs index a76fba3e84..be43c321ee 100644 --- a/crates/runtime/src/derive/rocksdb.rs +++ b/crates/runtime/src/derive/rocksdb.rs @@ -1,11 +1,11 @@ +use crate::task_state::RocksDB; use anyhow::Context; use futures::SinkExt; use futures::{channel::mpsc, Stream, StreamExt, TryStreamExt}; use prost::Message; use proto_flow::derive::{request, Request, Response}; use proto_flow::flow; -use proto_flow::runtime::{derive_response_ext, RocksDbDescriptor}; -use proto_gazette::consumer::Checkpoint; +use proto_flow::runtime::derive_response_ext; use std::sync::Arc; pub fn adapt_requests( @@ -144,118 +144,3 @@ where Ok(()) }) } - -struct RocksDB { - db: rocksdb::DB, - _path: std::path::PathBuf, - _tmp: Option, -} - -impl std::ops::Deref for RocksDB { - type Target = rocksdb::DB; - - fn deref(&self) -> &Self::Target { - &self.db - } -} - -impl RocksDB { - pub fn open(desc: Option) -> anyhow::Result { - let (mut opts, path, _tmp) = match desc { - Some(RocksDbDescriptor { - rocksdb_path, - rocksdb_env_memptr, - }) => { - tracing::debug!( - ?rocksdb_path, - ?rocksdb_env_memptr, - "opening hooked RocksDB database" - ); - - // Re-hydrate the provided memory address into rocksdb::Env wrapping - // an owned *mut librocksdb_sys::rocksdb_env_t. - let env = unsafe { - rocksdb::Env::from_raw(rocksdb_env_memptr as *mut librocksdb_sys::rocksdb_env_t) - }; - - let mut opts = rocksdb::Options::default(); - opts.set_env(&env); - - (opts, std::path::PathBuf::from(rocksdb_path), None) - } - _ => { - let dir = tempfile::TempDir::new().unwrap(); - let opts = rocksdb::Options::default(); - - tracing::debug!( - rocksdb_path = ?dir.path(), - "opening temporary RocksDB database" - ); - - (opts, dir.path().to_owned(), Some(dir)) - } - }; - - opts.create_if_missing(true); - opts.create_missing_column_families(true); - - let column_families = match rocksdb::DB::list_cf(&opts, &path) { - Ok(cf) => cf, - // Listing column families will fail if the DB doesn't exist. - // Assume as such, as we'll otherwise fail when we attempt to open. - Err(_) => vec![rocksdb::DEFAULT_COLUMN_FAMILY_NAME.to_string()], - }; - let mut db = rocksdb::DB::open_cf(&opts, &path, column_families.iter()) - .context("failed to open RocksDB")?; - - for column_family in column_families { - // We used to use a `registers` column family for derivations, but we no longer do - // and they were never actually used in production. Rocks requires that all existing - // column families are opened, so we just open and drop any of these legacy "registers" - // column families. - if column_family == "registers" { - tracing::warn!(%column_family, "dropping legacy rocksdb column family"); - db.drop_cf(&column_family) - .context("dropping legacy column family")?; - } - } - - Ok(Self { - db, - _path: path, - _tmp, - }) - } - - pub fn load_checkpoint(&self) -> anyhow::Result { - match self.db.get_pinned(Self::CHECKPOINT_KEY)? { - Some(v) => { - Ok(Checkpoint::decode(v.as_ref()) - .context("failed to decode consumer checkpoint")?) - } - None => Ok(Checkpoint::default()), - } - } - - pub fn load_connector_state(&self) -> anyhow::Result> { - let state = self - .db - .get_pinned(Self::CONNECTOR_STATE_KEY) - .context("failed to load connector state")?; - - // If found, decode and attach to `open`. - if let Some(state) = state { - let state: serde_json::Value = - serde_json::from_slice(&state).context("failed to decode connector state")?; - - Ok(Some(state)) - } else { - Ok(None) - } - } - - // Key encoding under which a marshalled checkpoint is stored. - pub const CHECKPOINT_KEY: &[u8] = b"checkpoint"; - // Key encoding under which a connector state is stored. - pub const CONNECTOR_STATE_KEY: &[u8] = b"connector-state"; -} diff --git a/crates/runtime/src/lib.rs b/crates/runtime/src/lib.rs index 493b6d0ab6..b8d595184c 100644 --- a/crates/runtime/src/lib.rs +++ b/crates/runtime/src/lib.rs @@ -7,11 +7,14 @@ mod derive; mod image_connector; mod local_connector; mod materialize; +mod task_combiner; mod task_service; +mod task_state; mod tokio_context; mod unary; mod unseal; +use task_combiner::TaskCombiner; pub use task_service::TaskService; pub use tokio_context::TokioContext; diff --git a/crates/runtime/src/materialize/mod.rs b/crates/runtime/src/materialize/mod.rs index b4e06fe5cf..cdc8e5f5e8 100644 --- a/crates/runtime/src/materialize/mod.rs +++ b/crates/runtime/src/materialize/mod.rs @@ -130,6 +130,7 @@ where request_rx.inspect_ok(move |request| { let Ok(MaterializeRequestExt { labels: Some(ops::ShardLabeling { log_level, .. }), + .. }) = request.get_internal() else { return; diff --git a/crates/runtime/src/task_combiner.rs b/crates/runtime/src/task_combiner.rs new file mode 100644 index 0000000000..6e294b18ff --- /dev/null +++ b/crates/runtime/src/task_combiner.rs @@ -0,0 +1,467 @@ +use anyhow::Context; +use ops::stats::DocsAndBytes; +use proto_flow::flow::{self, CollectionSpec, FieldSelection}; +use std::collections::BTreeMap; + +pub struct TaskCombiner { + // Descriptions for each binding of the task. + pub bindings: Vec, + // Does this task use a custom state JSON-schema? + pub has_state_schema: bool, + // Wrapped Combiner. + pub inner: doc::Combiner, + // ShardRef of this task. + pub shard_ref: ops::ShardRef, +} + +pub struct TaskBinding { + // Name of the collection bound to this binding. + pub collection: models::Collection, + // JSON pointer at which a placeholder document UUID is added. + pub document_uuid_ptr: Option, + // Inferred shape of binding documents. + pub inferred_shape: Option, + // Did `inferred_shape` change during the current transaction? + pub inferred_shape_changed: bool, + // Key components which are extracted from binding documents. + pub key_extractors: Vec, + // Serialization policy for documents. + pub ser_policy: doc::SerPolicy, + // Statistics for combined documents of the binding fed into the combiner. + pub stats_combined: DocsAndBytes, + // Statistics for documents of the binding drained from the combiner. + pub stats_drained: DocsAndBytes, + // Statistics for reduced documents of the binding fed into the combiner. + pub stats_reduced: DocsAndBytes, + // Partition or field values which are extracted from binding documents. + pub value_extractors: Vec, +} + +impl TaskCombiner { + pub fn open_capture( + open: &proto_flow::capture::request::Open, + _opened: &proto_flow::capture::response::Opened, + ) -> anyhow::Result { + let proto_flow::capture::request::Open { + capture, + range, + state_json: _, + version: _, + } = open; + + let capture = capture.as_ref().context("missing capture")?; + let range = range.as_ref().context("missing range")?; + let ser_policy = doc::SerPolicy::default(); + + let bindings = capture + .bindings + .iter() + .enumerate() + .map(|(index, spec)| { + TaskBinding::for_capture(ser_policy.clone(), spec) + .with_context(|| format!("binding {index} is invalid")) + }) + .collect::, _>>()?; + + let shard_ref = ops::ShardRef { + kind: ops::TaskType::Capture as i32, + name: capture.name.clone(), + key_begin: format!("{:08x}", range.key_begin), + r_clock_begin: format!("{:08x}", range.r_clock_begin), + }; + + // TODO(johnny): Pass in from built CaptureSpec. + let state_schema = ""; + + Self::build(bindings, shard_ref, state_schema) + } + + pub fn open_derivation( + open: &proto_flow::derive::request::Open, + _opened: &proto_flow::derive::response::Opened, + ) -> anyhow::Result<(Self, Vec<(String, models::Collection)>)> { + let proto_flow::derive::request::Open { + collection, + range, + state_json: _, + version: _, + } = open; + + let collection = collection.as_ref().context("missing collection")?; + let range = range.as_ref().context("missing range")?; + let ser_policy = doc::SerPolicy::default(); + + let bindings = vec![TaskBinding::for_written_collection(ser_policy, collection)?]; + + let shard_ref = ops::ShardRef { + kind: ops::TaskType::Derivation as i32, + name: collection.name.clone(), + key_begin: format!("{:08x}", range.key_begin), + r_clock_begin: format!("{:08x}", range.r_clock_begin), + }; + + // TODO(johnny): Pass in from built collection_spec::Derivation. + let state_schema = ""; + + // In addition to derived and combined documents, derivations must also track + // and record statistics for read collection documents. + let transforms = collection + .derivation + .as_ref() + .context("missing derivation")? + .transforms + .iter() + .map(|transform| { + ( + transform.name.clone(), + models::Collection::new(&transform.collection.as_ref().unwrap().name), + ) + }) + .collect(); + + Ok((Self::build(bindings, shard_ref, state_schema)?, transforms)) + } + + pub fn open_materialization( + open: &proto_flow::materialize::request::Open, + _opened: &proto_flow::materialize::response::Opened, + ) -> anyhow::Result { + let proto_flow::materialize::request::Open { + materialization, + range, + state_json: _, + version: _, + } = open; + + let materialization = materialization + .as_ref() + .context("missing materialization")?; + let range = range.as_ref().context("missing range")?; + + // TODO(johnny): Hack to address string truncation for these common materialization connectors + // that don't handle large strings very well. This should be negotiated via connector protocol. + // See go/runtime/materialize.go:135 + let ser_policy = if [ + "ghcr.io/estuary/materialize-snowflake", + "ghcr.io/estuary/materialize-redshift", + "ghcr.io/estuary/materialize-sqlite", + ] + .iter() + .any(|image| materialization.config_json.contains(image)) + { + doc::SerPolicy { + str_truncate_after: 1 << 16, // Truncate at 64KB. + } + } else { + doc::SerPolicy::default() + }; + + let bindings = materialization + .bindings + .iter() + .enumerate() + .map(|(index, spec)| { + TaskBinding::for_materialization(ser_policy.clone(), spec) + .with_context(|| format!("binding {index} is invalid")) + }) + .collect::, _>>()?; + + let shard_ref = ops::ShardRef { + kind: ops::TaskType::Materialization as i32, + name: materialization.name.clone(), + key_begin: format!("{:08x}", range.key_begin), + r_clock_begin: format!("{:08x}", range.r_clock_begin), + }; + + // TODO(johnny): Pass in from built MaterializationSpec. + let state_schema = ""; + + Self::build(bindings, shard_ref, state_schema) + } + + pub fn combine( + &mut self, + binding_index: u32, + doc_json: &str, + front: bool, + ) -> anyhow::Result<()> { + let binding = self + .bindings + .get_mut(binding_index as usize) + .with_context(|| "invalid combine-right binding {binding}")?; + + let memtable = match &mut self.inner { + doc::Combiner::Accumulator(accumulator) => accumulator.memtable()?, + _ => panic!("implementation error: combiner is draining, not accumulating"), + }; + let alloc = memtable.alloc(); + + let mut de = serde_json::Deserializer::from_str(doc_json); + let mut doc = doc::HeapNode::from_serde(&mut de, alloc) + .with_context(|| format!("couldn't parse published document as JSON: {doc_json}"))?; + + if let Some(ptr) = &binding.document_uuid_ptr { + if let Some(node) = ptr.create_heap_node(&mut doc, alloc) { + *node = + doc::HeapNode::String(doc::BumpStr::from_str(crate::UUID_PLACEHOLDER, alloc)); + } + } + memtable.add(binding_index, doc, front)?; + + if front { + binding.stats_reduced.docs_total += 1; + binding.stats_reduced.bytes_total += doc_json.len() as u64; + } else { + binding.stats_combined.docs_total += 1; + binding.stats_combined.bytes_total += doc_json.len() as u64; + } + + Ok(()) + } + + pub fn build_binding_stats(&mut self) -> BTreeMap { + let mut stats = BTreeMap::::new(); + + let merge = |from: &mut DocsAndBytes, to: &mut Option| { + if from.docs_total != 0 { + let entry = to.get_or_insert_with(Default::default); + entry.docs_total += from.docs_total; + entry.bytes_total += from.bytes_total; + *from = DocsAndBytes::default(); // Clear. + } + }; + + for binding in self.bindings.iter_mut() { + if binding.stats_drained.docs_total == 0 { + continue; // Skip creating a stats map entry. + } + let entry = stats.entry(binding.collection.to_string()).or_default(); + + merge(&mut binding.stats_combined, &mut entry.right); + merge(&mut binding.stats_drained, &mut entry.out); + merge(&mut binding.stats_reduced, &mut entry.left); + } + + stats + } + + pub fn log_updated_schemas(&mut self) { + for binding in self.bindings.iter_mut() { + if binding.inferred_shape_changed { + binding.inferred_shape_changed = false; + + let serialized = serde_json::to_value(&doc::shape::schema::to_schema( + binding.inferred_shape.clone().unwrap(), + )) + .expect("shape serialization should never fail"); + + tracing::info!( + schema = ?::ops::DebugJson(serialized), + collection_name = %binding.collection, + "inferred schema updated" + ); + } + } + } + + fn build( + bindings: Vec<(TaskBinding, (bool, Vec, doc::Validator))>, + shard_ref: ops::ShardRef, + state_schema: &str, + ) -> anyhow::Result { + let (bindings, combiner_spec): (Vec, Vec<_>) = bindings.into_iter().unzip(); + + let (has_state_schema, state_validator) = if state_schema.is_empty() { + let state_schema = doc::reduce::merge_patch_schema().to_string(); + ( + false, + doc::Validator::new(doc::validation::build_bundle(&state_schema).unwrap()).unwrap(), + ) + } else { + ( + true, + doc::Validator::new(doc::validation::build_bundle(state_schema)?)?, + ) + }; + + // Initialize combiner with all bindings, plus one extra for state reductions. + let combiner_spec = doc::combine::Spec::with_bindings( + combiner_spec + .into_iter() + .map(|(is_full, key, validator)| (is_full, key, None, validator)) + .chain(std::iter::once((false, Vec::new(), None, state_validator))), + ); + let combiner = doc::Combiner::new( + combiner_spec, + tempfile::tempfile().context("opening temporary spill file")?, + )?; + + Ok(Self { + bindings, + has_state_schema, + inner: combiner, + shard_ref, + }) + } +} + +impl TaskBinding { + // Map a drained document into accumulated statistics and a + // (packed-key, packed-values, encoded-json) tuple. + pub fn drained( + &mut self, + root: &doc::OwnedNode, + buf: &mut bytes::BytesMut, + ) -> (bytes::Bytes, bytes::Bytes, String) { + if let Some(inferred_shape) = &mut self.inferred_shape { + if inferred_shape.widen_owned(root) { + doc::shape::limits::enforce_shape_complexity_limit( + inferred_shape, + doc::shape::limits::DEFAULT_SCHEMA_COMPLEXITY_LIMIT, + ); + self.inferred_shape_changed = true; + } + } + + let key_packed = doc::Extractor::extract_all_owned(root, &self.key_extractors, buf); + let val_packed = doc::Extractor::extract_all_owned(root, &self.value_extractors, buf); + let doc_json = serde_json::to_string(&self.ser_policy.on_owned(root)) + .expect("document serialization cannot fail"); + + self.stats_drained.docs_total += 1; + self.stats_drained.bytes_total += doc_json.len() as u64; + + (key_packed, val_packed, doc_json) + } + + fn for_capture( + ser_policy: doc::SerPolicy, + spec: &flow::capture_spec::Binding, + ) -> anyhow::Result<(Self, (bool, Vec, doc::Validator))> { + Self::for_written_collection( + ser_policy, + spec.collection + .as_ref() + .context("missing required collection")?, + ) + } + + fn for_materialization( + ser_policy: doc::SerPolicy, + spec: &flow::materialization_spec::Binding, + ) -> anyhow::Result<(Self, (bool, Vec, doc::Validator))> { + let flow::materialization_spec::Binding { + collection, + field_selection, + .. + } = spec; + + let CollectionSpec { + ack_template_json: _, + derivation: _, + key, + name, + partition_fields: _, + partition_template: _, + projections, + read_schema_json, + uuid_ptr: _, + write_schema_json, + } = collection.as_ref().context("missing required collection")?; + + // We always combine over the collection key. + if key.is_empty() { + anyhow::bail!("collection key cannot be empty"); + } + let combiner_extractors = extractors::for_key(&key, &projections, &ser_policy)?; + + // Materializations are allowed to choose a subset of key and value fields. + // Usually this matches the collection key, but that isn't required. + let FieldSelection { + keys: selected_keys, + values: selected_values, + .. + } = field_selection + .as_ref() + .context("missing required field selection")?; + + let key_extractors = extractors::for_fields(&selected_keys, &projections, &ser_policy)?; + let value_extractors = extractors::for_fields(&selected_values, projections, &ser_policy)?; + + let read_schema_json = if !read_schema_json.is_empty() { + read_schema_json + } else { + write_schema_json + }; + + let built_schema = doc::validation::build_bundle(&read_schema_json) + .context("collection read schema is not a JSON schema")?; + let validator = + doc::Validator::new(built_schema).context("could not build a schema validator")?; + + Ok(( + Self { + collection: models::Collection::new(name), + document_uuid_ptr: None, // Not added. + inferred_shape: None, // Not inferred. + inferred_shape_changed: false, + key_extractors, + ser_policy, + stats_combined: Default::default(), + stats_drained: Default::default(), + stats_reduced: Default::default(), + value_extractors, + }, + (!spec.delta_updates, combiner_extractors, validator), + )) + } + + fn for_written_collection( + ser_policy: doc::SerPolicy, + spec: &flow::CollectionSpec, + ) -> anyhow::Result<(Self, (bool, Vec, doc::Validator))> { + let CollectionSpec { + ack_template_json: _, + derivation: _, + key, + name, + partition_fields, + partition_template: _, + projections, + read_schema_json: _, + uuid_ptr, + write_schema_json, + } = spec; + + if uuid_ptr.is_empty() { + anyhow::bail!("uuid_ptr cannot be empty"); + } else if key.is_empty() { + anyhow::bail!("collection key cannot be empty"); + } + + let document_uuid_ptr = Some(doc::Pointer::from(uuid_ptr)); + let key_extractors = extractors::for_key(&key, &projections, &ser_policy)?; + let value_extractors = extractors::for_fields(partition_fields, projections, &ser_policy)?; + + let built_schema = doc::validation::build_bundle(&write_schema_json) + .context("collection write_schema_json is not a JSON schema")?; + let validator = + doc::Validator::new(built_schema).context("could not build a schema validator")?; + + Ok(( + Self { + collection: models::Collection::new(name), + document_uuid_ptr, + inferred_shape: Some(doc::Shape::nothing()), + inferred_shape_changed: false, + key_extractors: key_extractors.clone(), + ser_policy, + stats_combined: Default::default(), + stats_drained: Default::default(), + stats_reduced: Default::default(), + value_extractors, + }, + (false, key_extractors, validator), + )) + } +} diff --git a/crates/runtime/src/task_state.rs b/crates/runtime/src/task_state.rs new file mode 100644 index 0000000000..27e642c5ab --- /dev/null +++ b/crates/runtime/src/task_state.rs @@ -0,0 +1,280 @@ +use anyhow::Context; +use prost::Message; +use proto_flow::runtime::RocksDbDescriptor; +use proto_gazette::consumer::Checkpoint; + +pub struct RocksDB { + db: rocksdb::DB, + _path: std::path::PathBuf, + _tmp: Option, +} + +impl std::ops::Deref for RocksDB { + type Target = rocksdb::DB; + + fn deref(&self) -> &Self::Target { + &self.db + } +} + +impl RocksDB { + pub fn open(desc: Option) -> anyhow::Result { + let (mut opts, path, _tmp) = match desc { + Some(RocksDbDescriptor { + rocksdb_path, + rocksdb_env_memptr, + }) => { + tracing::debug!( + ?rocksdb_path, + ?rocksdb_env_memptr, + "opening hooked RocksDB database" + ); + + // Re-hydrate the provided memory address into rocksdb::Env wrapping + // an owned *mut librocksdb_sys::rocksdb_env_t. + let env = unsafe { + rocksdb::Env::from_raw(rocksdb_env_memptr as *mut librocksdb_sys::rocksdb_env_t) + }; + + let mut opts = rocksdb::Options::default(); + opts.set_env(&env); + + (opts, std::path::PathBuf::from(rocksdb_path), None) + } + None => { + let dir = tempfile::TempDir::new().unwrap(); + let opts = rocksdb::Options::default(); + + tracing::debug!( + rocksdb_path = ?dir.path(), + "opening temporary RocksDB database" + ); + + (opts, dir.path().to_owned(), Some(dir)) + } + }; + + opts.create_if_missing(true); + opts.create_missing_column_families(true); + + let column_families = match rocksdb::DB::list_cf(&opts, &path) { + Ok(cf) => cf, + // Listing column families will fail if the DB doesn't exist. + // Assume as such, as we'll otherwise fail when we attempt to open. + Err(_) => vec![rocksdb::DEFAULT_COLUMN_FAMILY_NAME.to_string()], + }; + + let mut cf_descriptors = Vec::with_capacity(column_families.len()); + for name in column_families { + let mut cf_opts = rocksdb::Options::default(); + + if name == rocksdb::DEFAULT_COLUMN_FAMILY_NAME { + let state_schema = doc::reduce::merge_patch_schema().to_string(); + + cf_opts.set_merge_operator( + "task-state", + build_merge_fn(true, &state_schema)?, + build_merge_fn(false, &state_schema)?, + ); + } + + cf_descriptors.push(rocksdb::ColumnFamilyDescriptor::new(name, cf_opts)); + } + + let db = rocksdb::DB::open_cf_descriptors(&opts, &path, cf_descriptors) + .context("failed to open RocksDB")?; + + // TODO(johnny): Handle migration from a JSON state file here. + + Ok(Self { + db, + _path: path, + _tmp, + }) + } + + pub fn load_checkpoint(&self) -> anyhow::Result { + match self.db.get_pinned(Self::CHECKPOINT_KEY)? { + Some(v) => { + Ok(Checkpoint::decode(v.as_ref()) + .context("failed to decode consumer checkpoint")?) + } + None => Ok(Checkpoint::default()), + } + } + + pub fn load_connector_state(&self) -> anyhow::Result> { + let state = self + .db + .get_pinned(Self::CONNECTOR_STATE_KEY) + .context("failed to load connector state")?; + + // If found, decode and attach to `open`. + if let Some(state) = state { + let state: serde_json::Value = + serde_json::from_slice(&state).context("failed to decode connector state")?; + + Ok(Some(state)) + } else { + Ok(None) + } + } + + // Key encoding under which a marshalled checkpoint is stored. + pub const CHECKPOINT_KEY: &[u8] = b"checkpoint"; + // Key encoding under which a connector state is stored. + pub const CONNECTOR_STATE_KEY: &[u8] = b"connector-state"; +} + +fn build_merge_fn( + full: bool, + state_schema: &str, +) -> anyhow::Result { + let schema = doc::validation::build_bundle(&state_schema)?; + let validator = doc::Validator::new(schema)?; + + let foo = move |_key: &[u8], + initial: Option<&[u8]>, + mut operands: &mut rocksdb::merge_operator::MergeOperands| + -> anyhow::Result>> { + let alloc = doc::HeapNode::new_allocator(); + + let parse_doc = |bytes: &[u8]| { + let mut de = serde_json::Deserializer::from_slice(bytes); + doc::HeapNode::from_serde(&mut de, &alloc).with_context(|| { + format!( + "couldn't parse document as JSON: {}", + String::from_utf8_lossy(bytes) + ) + }) + }; + + // Initialize the reduction stack. + let mut stack = vec![parse_doc( + operands + .next() + .context("expected at least one merge operand")?, + )?]; + + // We directly use a RawValidator to avoid a mutable borrow over `validator`'s inner instance. + // This makes the operator parallel-safe. + let mut raw_validator = doc::RawValidator::new(validator.schema_index()); + let schema_uri = &validator.schemas()[0].curi; + + // Perform a pass of associative reductions. + for rhs in operands { + let rhs = parse_doc(rhs)?; + raw_validator.prepare(schema_uri)?; + + let root = json::Location::Root; + let span = doc::walker::walk_document(&rhs, &mut raw_validator, &root, 0); + + let rhs_valid = doc::Validation { + document: &rhs, + schema: schema_uri, + span, + validator: &mut raw_validator, + } + .ok()?; + + // Attempt to associatively-reduce `doc` into the top of the stack. + let top = stack.last_mut().unwrap(); + + match doc::reduce::reduce( + doc::LazyNode::Heap::(top), + doc::LazyNode::Heap(&rhs), + rhs_valid, + &alloc, + false, + ) { + Err(doc::reduce::Error::NotAssociative) => { + // Push `rhs` to the top of the stack. + stack.push(rhs); + } + Ok((doc, _delete)) => { + // Replace the stack tip with reduced `doc`. + *top = doc; + } + Err(err) => panic!("{err:#}"), + } + } + + if !full { + assert!(initial.is_none()); + + if stack.len() == 1 { + return Ok(Some( + serde_json::to_vec(&doc::SerPolicy::default().on(&stack[0])).unwrap(), + )); + } else { + return Ok(None); + } + } + + let mut reduced = parse_doc(initial.context("initial value is present")?)?; + + for rhs in stack { + raw_validator.prepare(schema_uri)?; + + let root = json::Location::Root; + let span = doc::walker::walk_document(&rhs, &mut raw_validator, &root, 0); + + let rhs_valid = doc::Validation { + document: &rhs, + schema: schema_uri, + span, + validator: &mut raw_validator, + } + .ok()?; + + (reduced, _) = doc::reduce::reduce( + doc::LazyNode::Heap::(&reduced), + doc::LazyNode::Heap(&rhs), + rhs_valid, + &alloc, + true, + )?; + } + + Ok(Some( + serde_json::to_vec(&doc::SerPolicy::default().on(&reduced)).unwrap(), + )) + }; + + let foo = move |key: &[u8], + initial: Option<&[u8]>, + operands: &mut rocksdb::merge_operator::MergeOperands| + -> Option> { + match foo(key, initial, operands) { + Ok(ok) => ok, + Err(err) => { + eprintln!("MERGE FAILED: {:#}", err); + tracing::error!(%err); + None + } + } + }; + + Ok(foo) +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn foo_the_bar() { + let db = RocksDB::open(None).unwrap(); + + let mut batch = rocksdb::WriteBatch::default(); + + batch.merge(RocksDB::CONNECTOR_STATE_KEY, r#"{"a":"b","n":null}"#); + batch.merge(RocksDB::CONNECTOR_STATE_KEY, r#"{"a":"c","nn":null}"#); + batch.merge(RocksDB::CONNECTOR_STATE_KEY, r#"{"d":"e","ans":42}"#); + + db.write(batch).unwrap(); + + let foo = db.get(RocksDB::CONNECTOR_STATE_KEY).unwrap(); + assert_eq!(foo, Some(vec![])) + } +} diff --git a/crates/runtime/src/unseal/mod.rs b/crates/runtime/src/unseal/mod.rs index 24dacb0fb4..4e1483813d 100644 --- a/crates/runtime/src/unseal/mod.rs +++ b/crates/runtime/src/unseal/mod.rs @@ -24,8 +24,8 @@ pub async fn decrypt_sops(config: &models::RawValue) -> anyhow::Result anyhow::Result anyhow::Result 0 { + i -= len(m.PartitionsPacked) + copy(dAtA[i:], m.PartitionsPacked) + i = encodeVarintRuntime(dAtA, i, uint64(len(m.PartitionsPacked))) + i-- + dAtA[i] = 0x12 + } + if len(m.KeyPacked) > 0 { + i -= len(m.KeyPacked) + copy(dAtA[i:], m.KeyPacked) + i = encodeVarintRuntime(dAtA, i, uint64(len(m.KeyPacked))) + i-- + dAtA[i] = 0xa + } + return len(dAtA) - i, nil +} + +func (m *CaptureResponseExt_Checkpoint) Marshal() (dAtA []byte, err error) { + size := m.ProtoSize() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *CaptureResponseExt_Checkpoint) MarshalTo(dAtA []byte) (int, error) { + size := m.ProtoSize() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *CaptureResponseExt_Checkpoint) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if m.XXX_unrecognized != nil { + i -= len(m.XXX_unrecognized) + copy(dAtA[i:], m.XXX_unrecognized) + } + if m.Stats != nil { + { + size, err := m.Stats.MarshalToSizedBuffer(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = encodeVarintRuntime(dAtA, i, uint64(size)) + } + i-- + dAtA[i] = 0xa + } + return len(dAtA) - i, nil +} + func (m *DeriveRequestExt) Marshal() (dAtA []byte, err error) { size := m.ProtoSize() dAtA = make([]byte, size) @@ -2049,6 +2244,50 @@ func (m *CaptureResponseExt) ProtoSize() (n int) { l = m.Container.ProtoSize() n += 1 + l + sovRuntime(uint64(l)) } + if m.Captured != nil { + l = m.Captured.ProtoSize() + n += 1 + l + sovRuntime(uint64(l)) + } + if m.Checkpoint != nil { + l = m.Checkpoint.ProtoSize() + n += 1 + l + sovRuntime(uint64(l)) + } + if m.XXX_unrecognized != nil { + n += len(m.XXX_unrecognized) + } + return n +} + +func (m *CaptureResponseExt_Captured) ProtoSize() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + l = len(m.KeyPacked) + if l > 0 { + n += 1 + l + sovRuntime(uint64(l)) + } + l = len(m.PartitionsPacked) + if l > 0 { + n += 1 + l + sovRuntime(uint64(l)) + } + if m.XXX_unrecognized != nil { + n += len(m.XXX_unrecognized) + } + return n +} + +func (m *CaptureResponseExt_Checkpoint) ProtoSize() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if m.Stats != nil { + l = m.Stats.ProtoSize() + n += 1 + l + sovRuntime(uint64(l)) + } if m.XXX_unrecognized != nil { n += len(m.XXX_unrecognized) } @@ -3630,6 +3869,284 @@ func (m *CaptureResponseExt) Unmarshal(dAtA []byte) error { return err } iNdEx = postIndex + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Captured", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowRuntime + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthRuntime + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return ErrInvalidLengthRuntime + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + if m.Captured == nil { + m.Captured = &CaptureResponseExt_Captured{} + } + if err := m.Captured.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + case 4: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Checkpoint", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowRuntime + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthRuntime + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return ErrInvalidLengthRuntime + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + if m.Checkpoint == nil { + m.Checkpoint = &CaptureResponseExt_Checkpoint{} + } + if err := m.Checkpoint.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipRuntime(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return ErrInvalidLengthRuntime + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + m.XXX_unrecognized = append(m.XXX_unrecognized, dAtA[iNdEx:iNdEx+skippy]...) + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *CaptureResponseExt_Captured) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowRuntime + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: Captured: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: Captured: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field KeyPacked", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowRuntime + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return ErrInvalidLengthRuntime + } + postIndex := iNdEx + byteLen + if postIndex < 0 { + return ErrInvalidLengthRuntime + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.KeyPacked = append(m.KeyPacked[:0], dAtA[iNdEx:postIndex]...) + if m.KeyPacked == nil { + m.KeyPacked = []byte{} + } + iNdEx = postIndex + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field PartitionsPacked", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowRuntime + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return ErrInvalidLengthRuntime + } + postIndex := iNdEx + byteLen + if postIndex < 0 { + return ErrInvalidLengthRuntime + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.PartitionsPacked = append(m.PartitionsPacked[:0], dAtA[iNdEx:postIndex]...) + if m.PartitionsPacked == nil { + m.PartitionsPacked = []byte{} + } + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipRuntime(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return ErrInvalidLengthRuntime + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + m.XXX_unrecognized = append(m.XXX_unrecognized, dAtA[iNdEx:iNdEx+skippy]...) + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *CaptureResponseExt_Checkpoint) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowRuntime + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: Checkpoint: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: Checkpoint: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Stats", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowRuntime + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthRuntime + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return ErrInvalidLengthRuntime + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + if m.Stats == nil { + m.Stats = &ops.Stats{} + } + if err := m.Stats.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex default: iNdEx = preIndex skippy, err := skipRuntime(dAtA[iNdEx:]) diff --git a/go/protocols/runtime/runtime.proto b/go/protocols/runtime/runtime.proto index a280b78dab..b7abf51f4c 100644 --- a/go/protocols/runtime/runtime.proto +++ b/go/protocols/runtime/runtime.proto @@ -140,10 +140,29 @@ message Container { message CaptureRequestExt { ops.ShardLabeling labels = 1; + + message Open { + // RocksDB descriptor which should be opened. + RocksDBDescriptor rocksdb_descriptor = 1; + } + Open open = 2; } message CaptureResponseExt { Container container = 1; + + message Captured { + // Packed key extracted from the captured document. + bytes key_packed = 1; + // Packed partition values extracted from the captured document. + bytes partitions_packed = 2; + } + Captured captured = 2; + + message Checkpoint { + ops.Stats stats = 1; + } + Checkpoint checkpoint = 4; } message DeriveRequestExt { @@ -184,6 +203,12 @@ message DeriveResponseExt { message MaterializeRequestExt { ops.ShardLabeling labels = 1; + + message Open { + // RocksDB descriptor which should be opened. + RocksDBDescriptor rocksdb_descriptor = 1; + } + Open open = 2; } message MaterializeResponseExt {