diff --git a/rust-arroyo/Cargo.toml b/rust-arroyo/Cargo.toml new file mode 100644 index 00000000..e6f1d1da --- /dev/null +++ b/rust-arroyo/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "rust_arroyo" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +chrono = "0.4.26" +coarsetime = "0.1.33" +once_cell = "1.18.0" +rand = "0.8.5" +rdkafka = { version = "0.36.1", features = ["cmake-build", "tracing"] } +sentry = { version = "0.32.0" } +serde = { version = "1.0.137", features = ["derive"] } +serde_json = "1.0.81" +thiserror = "1.0" +tokio = { version = "1.19.2", features = ["full"] } +tracing = "0.1.40" +uuid = { version = "1.5.0", features = ["v4"] } +parking_lot = "0.12.1" + +[dev-dependencies] +tracing-subscriber = "0.3.18" diff --git a/rust-arroyo/examples/base_consumer.rs b/rust-arroyo/examples/base_consumer.rs new file mode 100644 index 00000000..3c46d524 --- /dev/null +++ b/rust-arroyo/examples/base_consumer.rs @@ -0,0 +1,40 @@ +extern crate rust_arroyo; + +use rust_arroyo::backends::kafka::config::KafkaConfig; +use rust_arroyo::backends::kafka::InitialOffset; +use rust_arroyo::backends::kafka::KafkaConsumer; +use rust_arroyo::backends::AssignmentCallbacks; +use rust_arroyo::backends::CommitOffsets; +use rust_arroyo::backends::Consumer; +use rust_arroyo::types::{Partition, Topic}; +use std::collections::HashMap; + +struct EmptyCallbacks {} +impl AssignmentCallbacks for EmptyCallbacks { + fn on_assign(&self, _: HashMap) {} + fn on_revoke(&self, _: C, _: Vec) {} +} + +fn main() { + tracing_subscriber::fmt::init(); + + let config = KafkaConfig::new_consumer_config( + vec!["127.0.0.1:9092".to_string()], + "my_group".to_string(), + InitialOffset::Latest, + false, + 30_000, + None, + ); + + let topic = Topic::new("test_static"); + let mut consumer = KafkaConsumer::new(config, &[topic], EmptyCallbacks {}).unwrap(); + println!("Subscribed"); + for _ in 0..20 { + println!("Polling"); + let res = consumer.poll(None); + if let Some(x) = res.unwrap() { + println!("MSG {:?}", x) + } + } +} diff --git a/rust-arroyo/examples/base_processor.rs b/rust-arroyo/examples/base_processor.rs new file mode 100644 index 00000000..2b652fe7 --- /dev/null +++ b/rust-arroyo/examples/base_processor.rs @@ -0,0 +1,37 @@ +extern crate rust_arroyo; + +use chrono::Duration; +use rust_arroyo::backends::kafka::config::KafkaConfig; +use rust_arroyo::backends::kafka::types::KafkaPayload; +use rust_arroyo::backends::kafka::InitialOffset; +use rust_arroyo::processing::strategies::commit_offsets::CommitOffsets; +use rust_arroyo::processing::strategies::{ProcessingStrategy, ProcessingStrategyFactory}; +use rust_arroyo::processing::StreamProcessor; +use rust_arroyo::types::Topic; + +struct TestFactory {} +impl ProcessingStrategyFactory for TestFactory { + fn create(&self) -> Box> { + Box::new(CommitOffsets::new(Duration::seconds(1))) + } +} + +fn main() { + tracing_subscriber::fmt::init(); + + let config = KafkaConfig::new_consumer_config( + vec!["127.0.0.1:9092".to_string()], + "my_group".to_string(), + InitialOffset::Latest, + false, + 30_000, + None, + ); + + let mut processor = + StreamProcessor::with_kafka(config, TestFactory {}, Topic::new("test_static"), None); + + for _ in 0..20 { + processor.run_once().unwrap(); + } +} diff --git a/rust-arroyo/examples/transform_and_produce.rs b/rust-arroyo/examples/transform_and_produce.rs new file mode 100644 index 00000000..85720884 --- /dev/null +++ b/rust-arroyo/examples/transform_and_produce.rs @@ -0,0 +1,92 @@ +// An example of using the RunTask and Produce strategies together. +// inspired by https://github.com/getsentry/arroyo/blob/main/examples/transform_and_produce/script.py +// This creates a consumer that reads from a topic test_in, reverses the string message, +// and then produces it to topic test_out. +extern crate rust_arroyo; + +use rdkafka::message::ToBytes; +use rust_arroyo::backends::kafka::config::KafkaConfig; +use rust_arroyo::backends::kafka::producer::KafkaProducer; +use rust_arroyo::backends::kafka::types::KafkaPayload; +use rust_arroyo::backends::kafka::InitialOffset; +use rust_arroyo::processing::strategies::produce::Produce; +use rust_arroyo::processing::strategies::run_task::RunTask; +use rust_arroyo::processing::strategies::run_task_in_threads::ConcurrencyConfig; +use rust_arroyo::processing::strategies::{ + CommitRequest, InvalidMessage, ProcessingStrategy, ProcessingStrategyFactory, StrategyError, + SubmitError, +}; +use rust_arroyo::processing::StreamProcessor; +use rust_arroyo::types::{Message, Topic, TopicOrPartition}; + +use std::time::Duration; + +fn reverse_string(value: KafkaPayload) -> Result { + let payload = value.payload().unwrap(); + let str_payload = std::str::from_utf8(payload).unwrap(); + let result_str = str_payload.chars().rev().collect::(); + + println!("transforming value: {:?} -> {:?}", str_payload, &result_str); + + let result = KafkaPayload::new( + value.key().cloned(), + value.headers().cloned(), + Some(result_str.to_bytes().to_vec()), + ); + Ok(result) +} +struct Noop {} +impl ProcessingStrategy for Noop { + fn poll(&mut self) -> Result, StrategyError> { + Ok(None) + } + fn submit(&mut self, _message: Message) -> Result<(), SubmitError> { + Ok(()) + } + fn close(&mut self) {} + fn terminate(&mut self) {} + fn join(&mut self, _timeout: Option) -> Result, StrategyError> { + Ok(None) + } +} + +#[tokio::main] +async fn main() { + tracing_subscriber::fmt::init(); + + struct ReverseStringAndProduceStrategyFactory { + concurrency: ConcurrencyConfig, + config: KafkaConfig, + topic: Topic, + } + impl ProcessingStrategyFactory for ReverseStringAndProduceStrategyFactory { + fn create(&self) -> Box> { + let producer = KafkaProducer::new(self.config.clone()); + let topic = TopicOrPartition::Topic(self.topic); + let reverse_string_and_produce_strategy = RunTask::new( + reverse_string, + Produce::new(Noop {}, producer, &self.concurrency, topic), + ); + Box::new(reverse_string_and_produce_strategy) + } + } + + let config = KafkaConfig::new_consumer_config( + vec!["0.0.0.0:9092".to_string()], + "my_group".to_string(), + InitialOffset::Latest, + false, + 30_000, + None, + ); + + let factory = ReverseStringAndProduceStrategyFactory { + concurrency: ConcurrencyConfig::new(5), + config: config.clone(), + topic: Topic::new("test_out"), + }; + + let processor = StreamProcessor::with_kafka(config, factory, Topic::new("test_in"), None); + println!("running processor. transforming from test_in to test_out"); + processor.run().unwrap(); +} diff --git a/rust-arroyo/src/backends/kafka/config.rs b/rust-arroyo/src/backends/kafka/config.rs new file mode 100644 index 00000000..ff46a429 --- /dev/null +++ b/rust-arroyo/src/backends/kafka/config.rs @@ -0,0 +1,145 @@ +use rdkafka::config::ClientConfig as RdKafkaConfig; +use std::collections::HashMap; + +use super::InitialOffset; + +#[derive(Debug, Clone)] +pub struct OffsetResetConfig { + pub auto_offset_reset: InitialOffset, + pub strict_offset_reset: bool, +} + +#[derive(Debug, Clone)] +pub struct KafkaConfig { + config_map: HashMap, + // Only applies to consumers + offset_reset_config: Option, +} + +impl KafkaConfig { + pub fn new_config( + bootstrap_servers: Vec, + override_params: Option>, + ) -> Self { + let mut config_map = HashMap::new(); + config_map.insert("bootstrap.servers".to_string(), bootstrap_servers.join(",")); + let config = Self { + config_map, + offset_reset_config: None, + }; + + apply_override_params(config, override_params) + } + + pub fn new_consumer_config( + bootstrap_servers: Vec, + group_id: String, + auto_offset_reset: InitialOffset, + strict_offset_reset: bool, + max_poll_interval_ms: usize, + override_params: Option>, + ) -> Self { + let mut config = KafkaConfig::new_config(bootstrap_servers, None); + config.offset_reset_config = Some(OffsetResetConfig { + auto_offset_reset, + strict_offset_reset, + }); + config.config_map.insert("group.id".to_string(), group_id); + config + .config_map + .insert("enable.auto.commit".to_string(), "false".to_string()); + + config.config_map.insert( + "max.poll.interval.ms".to_string(), + max_poll_interval_ms.to_string(), + ); + + // HACK: If the max poll interval is less than 45 seconds, set the session timeout + // to the same. (its default is 45 seconds and it must be <= to max.poll.interval.ms) + if max_poll_interval_ms < 45_000 { + config.config_map.insert( + "session.timeout.ms".to_string(), + max_poll_interval_ms.to_string(), + ); + } + + apply_override_params(config, override_params) + } + + pub fn new_producer_config( + bootstrap_servers: Vec, + override_params: Option>, + ) -> Self { + let config = KafkaConfig::new_config(bootstrap_servers, None); + + apply_override_params(config, override_params) + } + + pub fn offset_reset_config(&self) -> Option<&OffsetResetConfig> { + self.offset_reset_config.as_ref() + } +} + +impl From for RdKafkaConfig { + fn from(cfg: KafkaConfig) -> Self { + let mut config_obj = RdKafkaConfig::new(); + for (key, val) in cfg.config_map.iter() { + config_obj.set(key, val); + } + + // NOTE: Offsets are explicitly managed as part of the assignment + // callback, so preemptively resetting offsets is not enabled when + // strict_offset_reset is enabled. + if let Some(config) = cfg.offset_reset_config { + let auto_offset_reset = if config.strict_offset_reset { + InitialOffset::Error + } else { + config.auto_offset_reset + }; + config_obj.set("auto.offset.reset", auto_offset_reset.to_string()); + } + config_obj + } +} + +fn apply_override_params( + mut config: KafkaConfig, + override_params: Option>, +) -> KafkaConfig { + if let Some(params) = override_params { + for (param, value) in params { + config.config_map.insert(param, value); + } + } + config +} + +#[cfg(test)] +mod tests { + use crate::backends::kafka::InitialOffset; + + use super::KafkaConfig; + use rdkafka::config::ClientConfig as RdKafkaConfig; + use std::collections::HashMap; + + #[test] + fn test_build_consumer_configuration() { + let config = KafkaConfig::new_consumer_config( + vec!["127.0.0.1:9092".to_string()], + "my-group".to_string(), + InitialOffset::Error, + false, + 30_000, + Some(HashMap::from([( + "queued.max.messages.kbytes".to_string(), + "1000000".to_string(), + )])), + ); + + let rdkafka_config: RdKafkaConfig = config.into(); + assert_eq!( + rdkafka_config.get("queued.max.messages.kbytes"), + Some("1000000") + ); + } +} diff --git a/rust-arroyo/src/backends/kafka/errors.rs b/rust-arroyo/src/backends/kafka/errors.rs new file mode 100644 index 00000000..6d12b656 --- /dev/null +++ b/rust-arroyo/src/backends/kafka/errors.rs @@ -0,0 +1,16 @@ +use rdkafka::error::{KafkaError, RDKafkaErrorCode}; + +use crate::backends::ConsumerError; + +impl From for ConsumerError { + fn from(err: KafkaError) -> Self { + match err { + KafkaError::OffsetFetch(RDKafkaErrorCode::OffsetOutOfRange) => { + ConsumerError::OffsetOutOfRange { + source: Box::new(err), + } + } + other => ConsumerError::BrokerError(Box::new(other)), + } + } +} diff --git a/rust-arroyo/src/backends/kafka/mod.rs b/rust-arroyo/src/backends/kafka/mod.rs new file mode 100644 index 00000000..805ad45c --- /dev/null +++ b/rust-arroyo/src/backends/kafka/mod.rs @@ -0,0 +1,741 @@ +use super::kafka::config::KafkaConfig; +use super::AssignmentCallbacks; +use super::CommitOffsets; +use super::Consumer as ArroyoConsumer; +use super::ConsumerError; +use crate::backends::kafka::types::KafkaPayload; +use crate::types::{BrokerMessage, Partition, Topic}; +use chrono::{DateTime, NaiveDateTime, Utc}; +use parking_lot::Mutex; +use rdkafka::client::ClientContext; +use rdkafka::config::{ClientConfig, RDKafkaLogLevel}; +use rdkafka::consumer::base_consumer::BaseConsumer; +use rdkafka::consumer::{CommitMode, Consumer, ConsumerContext}; +use rdkafka::error::KafkaError; +use rdkafka::message::{BorrowedMessage, Message}; +use rdkafka::topic_partition_list::{Offset, TopicPartitionList}; +use rdkafka::types::{RDKafkaErrorCode, RDKafkaRespErr}; +use sentry::Hub; +use std::collections::HashMap; +use std::collections::HashSet; +use std::fmt; +use std::str::FromStr; +use std::sync::Arc; +use std::time::Duration; + +pub mod config; +mod errors; +pub mod producer; +pub mod types; + +#[derive(Eq, Hash, PartialEq)] +enum KafkaConsumerState { + Consuming, + #[allow(dead_code)] + Error, + #[allow(dead_code)] + Assigning, + #[allow(dead_code)] + Revoking, +} + +#[derive(Debug, Clone, Copy, Default)] +pub enum InitialOffset { + Earliest, + Latest, + #[default] + Error, +} + +impl fmt::Display for InitialOffset { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + InitialOffset::Earliest => write!(f, "earliest"), + InitialOffset::Latest => write!(f, "latest"), + InitialOffset::Error => write!(f, "error"), + } + } +} + +impl FromStr for InitialOffset { + type Err = ConsumerError; + fn from_str(auto_offset_reset: &str) -> Result { + match auto_offset_reset { + "earliest" => Ok(InitialOffset::Earliest), + "latest" => Ok(InitialOffset::Latest), + "error" => Ok(InitialOffset::Error), + _ => Err(ConsumerError::InvalidConfig), + } + } +} + +impl KafkaConsumerState { + fn assert_consuming_state(&self) -> Result<(), ConsumerError> { + match self { + KafkaConsumerState::Error => Err(ConsumerError::ConsumerErrored), + _ => Ok(()), + } + } +} + +fn create_kafka_message(topics: &[Topic], msg: BorrowedMessage) -> BrokerMessage { + let topic = msg.topic(); + // NOTE: We avoid calling `Topic::new` here, as that uses a lock to intern the `topic` name. + // As we only ever expect one of our pre-defined topics, we can also guard against Broker errors. + let Some(&topic) = topics.iter().find(|t| t.as_str() == topic) else { + panic!("Received message for topic `{topic}` that we never subscribed to"); + }; + let partition = Partition { + topic, + index: msg.partition() as u16, + }; + let time_millis = msg.timestamp().to_millis().unwrap_or(0); + + BrokerMessage::new( + KafkaPayload::new( + msg.key().map(|k| k.to_vec()), + msg.headers().map(|h| h.into()), + msg.payload().map(|p| p.to_vec()), + ), + partition, + msg.offset() as u64, + DateTime::from_naive_utc_and_offset( + NaiveDateTime::from_timestamp_millis(time_millis).unwrap_or(NaiveDateTime::MIN), + Utc, + ), + ) +} + +fn commit_impl( + consumer: &BaseConsumer>, + offsets: HashMap, +) -> Result<(), ConsumerError> { + let mut partitions = TopicPartitionList::with_capacity(offsets.len()); + for (partition, offset) in &offsets { + partitions.add_partition_offset( + partition.topic.as_str(), + partition.index as i32, + Offset::from_raw(*offset as i64), + )?; + } + + consumer.commit(&partitions, CommitMode::Sync).unwrap(); + Ok(()) +} + +struct OffsetCommitter<'a, C: AssignmentCallbacks> { + consumer: &'a BaseConsumer>, +} + +impl<'a, C: AssignmentCallbacks> CommitOffsets for OffsetCommitter<'a, C> { + fn commit(self, offsets: HashMap) -> Result<(), ConsumerError> { + commit_impl(self.consumer, offsets) + } +} + +pub struct CustomContext { + hub: Arc, + callbacks: C, + offset_state: Arc>, + initial_offset_reset: InitialOffset, +} + +impl ClientContext for CustomContext { + fn log(&self, level: RDKafkaLogLevel, fac: &str, log_message: &str) { + Hub::run(self.hub.clone(), || match level { + RDKafkaLogLevel::Emerg + | RDKafkaLogLevel::Alert + | RDKafkaLogLevel::Critical + | RDKafkaLogLevel::Error => { + tracing::error!("librdkafka: {fac} {log_message}"); + } + RDKafkaLogLevel::Warning => { + tracing::warn!("librdkafka: {fac} {log_message}"); + } + RDKafkaLogLevel::Notice | RDKafkaLogLevel::Info => { + tracing::info!("librdkafka: {fac} {log_message}"); + } + RDKafkaLogLevel::Debug => { + tracing::debug!("librdkafka: {fac} {log_message}"); + } + }) + } + + fn error(&self, error: KafkaError, reason: &str) { + Hub::run(self.hub.clone(), || { + let error: &dyn std::error::Error = &error; + tracing::error!(error, "librdkafka: {error}: {reason}"); + }) + } +} + +impl ConsumerContext for CustomContext { + // handle entire rebalancing flow ourselves, so that we can call rdkafka.assign with a + // customized list of offsets. if we use pre_rebalance and post_rebalance callbacks from + // rust-rdkafka, consumer.assign will be called *for us*, leaving us with this flow on + // partition assignment: + // + // 1. rdkafka.assign done by rdkafka + // 2. post_rebalance called + // 3. post_rebalance modifies the assignment (if e.g. strict_offset_reset=true and + // auto_offset_reset=latest) + // 4. rdkafka.assign is called *again* + // + // in comparison, confluent-kafka-python will execute on_assign, and only call rdkafka.assign + // if the callback did not already explicitly call assign. + // + // if we call rdkafka.assign multiple times, we have seen random AutoOffsetReset errors popping + // up in poll(), since we (briefly) assigned invalid offsets to the consumer + fn rebalance( + &self, + base_consumer: &BaseConsumer, + err: RDKafkaRespErr, + tpl: &mut TopicPartitionList, + ) { + match err { + RDKafkaRespErr::RD_KAFKA_RESP_ERR__REVOKE_PARTITIONS => { + let mut partitions: Vec = Vec::new(); + let mut offset_state = self.offset_state.lock(); + for partition in tpl.elements().iter() { + let topic = Topic::new(partition.topic()); + let index = partition.partition() as u16; + let arroyo_partition = Partition::new(topic, index); + + if offset_state.offsets.remove(&arroyo_partition).is_none() { + tracing::warn!( + "failed to delete offset for unknown partition: {}", + arroyo_partition + ); + } + offset_state.paused.remove(&arroyo_partition); + partitions.push(arroyo_partition); + } + + let committer = OffsetCommitter { + consumer: base_consumer, + }; + + // before we give up the assignment, strategies need to flush and commit + self.callbacks.on_revoke(committer, partitions); + + base_consumer + .unassign() + .expect("failed to revoke partitions"); + } + RDKafkaRespErr::RD_KAFKA_RESP_ERR__ASSIGN_PARTITIONS => { + let committed_offsets = base_consumer + .committed_offsets((*tpl).clone(), None) + .unwrap(); + + let mut offset_map: HashMap = + HashMap::with_capacity(committed_offsets.count()); + let mut tpl = TopicPartitionList::with_capacity(committed_offsets.count()); + + for partition in committed_offsets.elements() { + let raw_offset = partition.offset().to_raw().unwrap(); + + let topic = Topic::new(partition.topic()); + + let new_offset = if raw_offset >= 0 { + raw_offset + } else { + // Resolve according to the auto offset reset policy + let (low_watermark, high_watermark) = base_consumer + .fetch_watermarks(partition.topic(), partition.partition(), None) + .unwrap(); + + match self.initial_offset_reset { + InitialOffset::Earliest => low_watermark, + InitialOffset::Latest => high_watermark, + InitialOffset::Error => { + panic!("received unexpected offset"); + } + } + }; + + offset_map.insert( + Partition::new(topic, partition.partition() as u16), + new_offset as u64, + ); + + tpl.add_partition_offset( + partition.topic(), + partition.partition(), + Offset::from_raw(new_offset), + ) + .unwrap(); + } + + // assign() asap, we can create strategies later + base_consumer + .assign(&tpl) + .expect("failed to assign partitions"); + self.offset_state.lock().offsets.extend(&offset_map); + + // Ensure that all partitions are resumed on assignment to avoid + // carrying over state from a previous assignment. + base_consumer + .resume(&tpl) + .expect("failed to resume partitions"); + + self.callbacks.on_assign(offset_map); + } + _ => { + let error_code: RDKafkaErrorCode = err.into(); + // We don't panic here since we will likely re-encounter the error on poll + tracing::error!("Error rebalancing: {}", error_code); + } + } + } +} + +#[derive(Default)] +struct OffsetState { + // offsets: the currently-*read* offset of the consumer, updated on poll() + // staged_offsets do not exist: the Commit strategy takes care of offset staging + offsets: HashMap, + // list of partitions that are currently paused + paused: HashSet, +} + +pub struct KafkaConsumer { + consumer: BaseConsumer>, + topics: Vec, + state: KafkaConsumerState, + offset_state: Arc>, +} + +impl KafkaConsumer { + pub fn new(config: KafkaConfig, topics: &[Topic], callbacks: C) -> Result { + let offset_state = Arc::new(Mutex::new(OffsetState::default())); + let initial_offset_reset = config + .offset_reset_config() + .ok_or(ConsumerError::InvalidConfig)? + .auto_offset_reset; + + let context = CustomContext { + hub: Hub::current(), + callbacks, + offset_state: offset_state.clone(), + initial_offset_reset, + }; + + let mut config_obj: ClientConfig = config.into(); + + // TODO: Can this actually fail? + let consumer: BaseConsumer> = config_obj + .set_log_level(RDKafkaLogLevel::Warning) + .create_with_context(context)?; + + let topic_str: Vec<&str> = topics.iter().map(|t| t.as_str()).collect(); + consumer.subscribe(&topic_str)?; + let topics = topics.to_owned(); + + Ok(Self { + consumer, + topics, + state: KafkaConsumerState::Consuming, + offset_state, + }) + } + + pub fn shutdown(self) {} +} + +impl Drop for KafkaConsumer { + fn drop(&mut self) { + self.consumer.unsubscribe(); + } +} + +impl ArroyoConsumer for KafkaConsumer { + fn poll( + &mut self, + timeout: Option, + ) -> Result>, ConsumerError> { + self.state.assert_consuming_state()?; + + let duration = timeout.unwrap_or(Duration::ZERO); + let res = self.consumer.poll(duration); + + match res { + None => Ok(None), + Some(res) => { + let msg = create_kafka_message(&self.topics, res?); + self.offset_state + .lock() + .offsets + .insert(msg.partition, msg.offset + 1); + + Ok(Some(msg)) + } + } + } + + fn pause(&mut self, partitions: HashSet) -> Result<(), ConsumerError> { + self.state.assert_consuming_state()?; + + let mut topic_partition_list = TopicPartitionList::with_capacity(partitions.len()); + + { + let offset_state = self.offset_state.lock(); + let offsets = &offset_state.offsets; + for partition in &partitions { + let offset = offsets + .get(partition) + .ok_or(ConsumerError::UnassignedPartition)?; + topic_partition_list.add_partition_offset( + partition.topic.as_str(), + partition.index as i32, + Offset::from_raw(*offset as i64), + )?; + } + } + + self.consumer.pause(&topic_partition_list)?; + + { + let mut offset_state = self.offset_state.lock(); + offset_state.paused.extend(partitions.clone()); + } + + Ok(()) + } + + fn resume(&mut self, partitions: HashSet) -> Result<(), ConsumerError> { + self.state.assert_consuming_state()?; + + let mut topic_partition_list = TopicPartitionList::new(); + let mut to_unpause = Vec::new(); + { + let offset_state = self.offset_state.lock(); + let offsets = &offset_state.offsets; + for partition in partitions { + if !offsets.contains_key(&partition) { + return Err(ConsumerError::UnassignedPartition); + } + topic_partition_list + .add_partition(partition.topic.as_str(), partition.index as i32); + to_unpause.push(partition); + } + } + + self.consumer.resume(&topic_partition_list)?; + + { + let mut offset_state = self.offset_state.lock(); + for partition in to_unpause { + offset_state.paused.remove(&partition); + } + } + + Ok(()) + } + + fn paused(&self) -> Result, ConsumerError> { + self.state.assert_consuming_state()?; + Ok(self.offset_state.lock().paused.clone()) + } + + fn tell(&self) -> Result, ConsumerError> { + self.state.assert_consuming_state()?; + Ok(self.offset_state.lock().offsets.clone()) + } + + fn seek(&self, offsets: HashMap) -> Result<(), ConsumerError> { + self.state.assert_consuming_state()?; + + { + let offset_state = self.offset_state.lock(); + for key in offsets.keys() { + if !offset_state.offsets.contains_key(key) { + return Err(ConsumerError::UnassignedPartition); + } + } + } + + for (partition, offset) in &offsets { + self.consumer.seek( + partition.topic.as_str(), + partition.index as i32, + Offset::from_raw(*offset as i64), + None, + )?; + } + + { + let mut offset_state = self.offset_state.lock(); + offset_state.offsets.extend(offsets); + } + + Ok(()) + } + + fn commit_offsets(&mut self, offsets: HashMap) -> Result<(), ConsumerError> { + self.state.assert_consuming_state()?; + commit_impl(&self.consumer, offsets) + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashSet; + + use super::{AssignmentCallbacks, InitialOffset, KafkaConsumer}; + use crate::backends::kafka::config::KafkaConfig; + use crate::backends::kafka::producer::KafkaProducer; + use crate::backends::kafka::KafkaPayload; + use crate::backends::{Consumer, Producer}; + use crate::testutils::{get_default_broker, TestTopic}; + use crate::types::{BrokerMessage, Partition, Topic}; + use std::collections::HashMap; + use std::thread::sleep; + use std::time::Duration; + + struct EmptyCallbacks {} + impl AssignmentCallbacks for EmptyCallbacks { + fn on_assign(&self, partitions: HashMap) { + println!("assignment event: {:?}", partitions); + } + fn on_revoke(&self, _: C, partitions: Vec) { + println!("revocation event: {:?}", partitions); + } + } + + fn wait_for_assignments(consumer: &mut KafkaConsumer) { + for _ in 0..10 { + consumer.poll(Some(Duration::from_millis(5_000))).unwrap(); + if !consumer.tell().unwrap().is_empty() { + println!("Received assignment"); + break; + } + sleep(Duration::from_millis(200)); + } + } + + fn blocking_poll( + consumer: &mut KafkaConsumer, + ) -> Option> { + let mut consumer_message = None; + + for _ in 0..10 { + consumer_message = consumer.poll(Some(Duration::from_millis(5_000))).unwrap(); + + if consumer_message.is_some() { + break; + } + } + + consumer_message + } + + #[test] + fn test_subscribe() { + let configuration = KafkaConfig::new_consumer_config( + vec![std::env::var("DEFAULT_BROKERS").unwrap_or("127.0.0.1:9092".to_string())], + "my-group".to_string(), + InitialOffset::Latest, + false, + 30_000, + None, + ); + let topic = Topic::new("test"); + KafkaConsumer::new(configuration, &[topic], EmptyCallbacks {}).unwrap(); + } + + #[test] + fn test_tell() { + let topic = TestTopic::create("test-tell"); + let configuration = KafkaConfig::new_consumer_config( + vec![std::env::var("DEFAULT_BROKERS").unwrap_or("127.0.0.1:9092".to_string())], + "my-group-1".to_string(), + InitialOffset::Latest, + false, + 30_000, + None, + ); + let mut consumer = + KafkaConsumer::new(configuration, &[topic.topic], EmptyCallbacks {}).unwrap(); + assert_eq!(consumer.tell().unwrap(), HashMap::new()); + + // Getting the assignment may take a while + for _ in 0..10 { + consumer.poll(Some(Duration::from_millis(5_000))).unwrap(); + if consumer.tell().unwrap().len() == 1 { + println!("Received assignment"); + break; + } + sleep(Duration::from_millis(200)); + } + + let offsets = consumer.tell().unwrap(); + // One partition was assigned + assert_eq!(offsets.len(), 1); + consumer.shutdown(); + } + + /// check that consumer does not crash with strict_offset_reset if the offset does not exist + /// yet. + #[test] + fn test_offset_reset_strict() { + let topic = TestTopic::create("test-offset-reset-strict"); + let configuration = KafkaConfig::new_consumer_config( + vec![std::env::var("DEFAULT_BROKERS").unwrap_or("127.0.0.1:9092".to_string())], + "my-group-1".to_string(), + InitialOffset::Earliest, + true, + 30_000, + None, + ); + + let producer_configuration = KafkaConfig::new_producer_config( + vec![std::env::var("DEFAULT_BROKERS").unwrap_or("127.0.0.1:9092".to_string())], + None, + ); + + let producer = KafkaProducer::new(producer_configuration); + let payload = KafkaPayload::new(None, None, Some("asdf".as_bytes().to_vec())); + + producer + .produce(&crate::types::TopicOrPartition::Topic(topic.topic), payload) + .expect("Message produced"); + + let mut consumer = + KafkaConsumer::new(configuration, &[topic.topic], EmptyCallbacks {}).unwrap(); + assert_eq!(consumer.tell().unwrap(), HashMap::new()); + + let mut consumer_message = None; + + for _ in 0..10 { + consumer_message = consumer.poll(Some(Duration::from_millis(5_000))).unwrap(); + + if consumer_message.is_some() { + break; + } + } + + let consumer_message = consumer_message.unwrap(); + + assert_eq!(consumer_message.offset, 0); + let consumer_payload = consumer_message.payload.payload().unwrap(); + assert_eq!(consumer_payload, b"asdf"); + + assert!(consumer + .poll(Some(Duration::from_millis(10))) + .unwrap() + .is_none()); + + consumer + .commit_offsets(HashMap::from([( + consumer_message.partition, + consumer_message.offset + 1, + )])) + .unwrap(); + + consumer.shutdown(); + } + + #[test] + fn test_commit() { + let topic = TestTopic::create("test-commit"); + let configuration = KafkaConfig::new_consumer_config( + vec![std::env::var("DEFAULT_BROKERS").unwrap_or("127.0.0.1:9092".to_string())], + "my-group-2".to_string(), + InitialOffset::Latest, + false, + 30_000, + None, + ); + + let mut consumer = + KafkaConsumer::new(configuration, &[topic.topic], EmptyCallbacks {}).unwrap(); + + let positions = HashMap::from([( + Partition { + topic: topic.topic, + index: 0, + }, + 100, + )]); + + // Wait until the consumer got an assignment + for _ in 0..10 { + consumer.poll(Some(Duration::from_millis(5_000))).unwrap(); + if consumer.tell().unwrap().len() == 1 { + println!("Received assignment"); + break; + } + sleep(Duration::from_millis(200)); + } + + consumer.commit_offsets(positions.clone()).unwrap(); + consumer.shutdown(); + } + + #[test] + fn test_pause() { + let topic = TestTopic::create("test-pause"); + let configuration = KafkaConfig::new_consumer_config( + vec![get_default_broker()], + // for this particular test, a separate consumer group is apparently needed, as + // otherwise random rebalancing events will occur when other tests with the same + // consumer group (but not the same topic) run at the same time + "my-group-1-test-pause".to_string(), + InitialOffset::Earliest, + true, + 60_000, + None, + ); + + let mut consumer = + KafkaConsumer::new(configuration, &[topic.topic], EmptyCallbacks {}).unwrap(); + + wait_for_assignments(&mut consumer); + + let payload = KafkaPayload::new(None, None, Some("asdf".as_bytes().to_vec())); + topic.produce(payload); + + let old_offsets = consumer.tell().unwrap(); + assert_eq!( + old_offsets, + HashMap::from([(Partition::new(topic.topic, 0), 0)]) + ); + + let consumer_message = blocking_poll(&mut consumer).unwrap(); + + assert_eq!(consumer_message.offset, 0); + let consumer_payload = consumer_message.payload.payload().unwrap(); + assert_eq!(consumer_payload, b"asdf"); + + // try to reproduce the scenario described at + // https://github.com/getsentry/arroyo/blob/af4be59fa74acbe00e1cf8dd7921de11acb99509/arroyo/backends/kafka/consumer.py#L503-L504 + // -- "Seeking to a specific partition offset and immediately pausing that partition causes + // the seek to be ignored for some reason." + consumer.seek(old_offsets.clone()).unwrap(); + let current_partitions: HashSet<_> = consumer.tell().unwrap().into_keys().collect(); + assert_eq!(current_partitions.len(), 1); + consumer.pause(current_partitions.clone()).unwrap(); + assert_eq!( + consumer.tell().unwrap(), + HashMap::from([(Partition::new(topic.topic, 0), 0)]) + ); + + let empty_poll = consumer.poll(Some(Duration::from_secs(5))).unwrap(); + assert!(empty_poll.is_none(), "{:?}", empty_poll); + consumer.resume(current_partitions).unwrap(); + + assert_eq!(consumer.tell().unwrap(), old_offsets); + assert!(consumer.paused().unwrap().is_empty()); + + assert_eq!( + blocking_poll(&mut consumer) + .unwrap() + .payload + .payload() + .unwrap(), + b"asdf" + ); + + consumer.shutdown(); + } +} diff --git a/rust-arroyo/src/backends/kafka/producer.rs b/rust-arroyo/src/backends/kafka/producer.rs new file mode 100644 index 00000000..030230a4 --- /dev/null +++ b/rust-arroyo/src/backends/kafka/producer.rs @@ -0,0 +1,61 @@ +use crate::backends::kafka::config::KafkaConfig; +use crate::backends::kafka::types::KafkaPayload; +use crate::backends::Producer as ArroyoProducer; +use crate::backends::ProducerError; +use crate::types::TopicOrPartition; +use rdkafka::config::ClientConfig; +use rdkafka::producer::{DefaultProducerContext, ThreadedProducer}; + +pub struct KafkaProducer { + producer: ThreadedProducer, +} + +impl KafkaProducer { + pub fn new(config: KafkaConfig) -> Self { + let config_obj: ClientConfig = config.into(); + let threaded_producer: ThreadedProducer<_> = config_obj.create().unwrap(); + + Self { + producer: threaded_producer, + } + } +} + +impl ArroyoProducer for KafkaProducer { + fn produce( + &self, + destination: &TopicOrPartition, + payload: KafkaPayload, + ) -> Result<(), ProducerError> { + let base_record = payload.to_base_record(destination); + + self.producer + .send(base_record) + .map_err(|_| ProducerError::ProducerErrorred)?; + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::KafkaProducer; + use crate::backends::kafka::config::KafkaConfig; + use crate::backends::kafka::types::KafkaPayload; + use crate::backends::Producer; + use crate::types::{Topic, TopicOrPartition}; + #[test] + fn test_producer() { + let topic = Topic::new("test"); + let destination = TopicOrPartition::Topic(topic); + let configuration = + KafkaConfig::new_producer_config(vec!["127.0.0.1:9092".to_string()], None); + + let producer = KafkaProducer::new(configuration); + + let payload = KafkaPayload::new(None, None, Some("asdf".as_bytes().to_vec())); + producer + .produce(&destination, payload) + .expect("Message produced") + } +} diff --git a/rust-arroyo/src/backends/kafka/types.rs b/rust-arroyo/src/backends/kafka/types.rs new file mode 100644 index 00000000..3b11e1de --- /dev/null +++ b/rust-arroyo/src/backends/kafka/types.rs @@ -0,0 +1,146 @@ +use crate::types::TopicOrPartition; +use rdkafka::message::{BorrowedHeaders, Header, OwnedHeaders}; +use rdkafka::producer::BaseRecord; + +use std::sync::Arc; +#[derive(Clone, Debug)] +pub struct Headers { + headers: OwnedHeaders, +} + +impl Headers { + pub fn new() -> Self { + Self { + headers: OwnedHeaders::new(), + } + } + + pub fn insert(self, key: &str, value: Option>) -> Headers { + let headers = self.headers.insert(Header { + key, + value: value.as_ref(), + }); + Self { headers } + } +} + +impl Default for Headers { + fn default() -> Self { + Self::new() + } +} + +impl From<&BorrowedHeaders> for Headers { + fn from(value: &BorrowedHeaders) -> Self { + Self { + headers: value.detach(), + } + } +} + +impl From for OwnedHeaders { + fn from(value: Headers) -> Self { + value.headers + } +} + +#[derive(Clone, Debug)] +struct KafkaPayloadInner { + pub key: Option>, + pub headers: Option, + pub payload: Option>, +} + +#[derive(Clone, Debug)] +pub struct KafkaPayload { + inner: Arc, +} + +impl<'a> KafkaPayload { + pub fn new(key: Option>, headers: Option, payload: Option>) -> Self { + Self { + inner: Arc::new(KafkaPayloadInner { + key, + headers, + payload, + }), + } + } + + pub fn key(&self) -> Option<&Vec> { + self.inner.key.as_ref() + } + + pub fn headers(&self) -> Option<&Headers> { + self.inner.headers.as_ref() + } + + pub fn payload(&self) -> Option<&Vec> { + self.inner.payload.as_ref() + } + + pub fn to_base_record( + &'a self, + destination: &'a TopicOrPartition, + ) -> BaseRecord<'_, Vec, Vec> { + let topic = match destination { + TopicOrPartition::Topic(topic) => topic.as_str(), + TopicOrPartition::Partition(partition) => partition.topic.as_str(), + }; + + let partition = match destination { + TopicOrPartition::Topic(_) => None, + TopicOrPartition::Partition(partition) => Some(partition.index), + }; + + let mut base_record = BaseRecord::to(topic); + + if let Some(msg_key) = self.key() { + base_record = base_record.key(msg_key); + } + + if let Some(msg_payload) = self.payload() { + base_record = base_record.payload(msg_payload); + } + + if let Some(headers) = self.headers() { + base_record = base_record.headers((*headers).clone().into()); + } + + if let Some(index) = partition { + base_record = base_record.partition(index as i32) + } + + base_record + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::types::Topic; + + #[test] + fn test_kafka_payload() { + let destination = TopicOrPartition::Topic(Topic::new("test")); + let p: KafkaPayload = KafkaPayload::new(None, None, None); + let base_record = p.to_base_record(&destination); + assert_eq!(base_record.topic, "test"); + assert_eq!(base_record.key, None); + assert_eq!(base_record.payload, None); + assert_eq!(base_record.partition, None); + + let mut headers = Headers::new(); + headers = headers.insert("version", Some(b"1".to_vec())); + let p2 = KafkaPayload::new( + Some(b"key".to_vec()), + Some(headers), + Some(b"message".to_vec()), + ); + + let base_record = p2.to_base_record(&destination); + assert_eq!(base_record.topic, "test"); + assert_eq!(base_record.key, Some(&b"key".to_vec())); + assert_eq!(base_record.payload, Some(&b"message".to_vec())); + } +} diff --git a/rust-arroyo/src/backends/local/broker.rs b/rust-arroyo/src/backends/local/broker.rs new file mode 100644 index 00000000..754dddc3 --- /dev/null +++ b/rust-arroyo/src/backends/local/broker.rs @@ -0,0 +1,248 @@ +use crate::backends::storages::{ConsumeError, MessageStorage, TopicDoesNotExist, TopicExists}; +use crate::types::{BrokerMessage, Partition, Topic}; +use crate::utils::clock::Clock; +use chrono::DateTime; +use std::collections::{HashMap, HashSet}; +use thiserror::Error; +use uuid::Uuid; + +pub struct LocalBroker { + storage: Box + Send + Sync>, + clock: Box, + offsets: HashMap>, + subscriptions: HashMap>>, +} + +#[derive(Error, Debug, Clone)] +#[error(transparent)] +pub enum BrokerError { + #[error("Partition does not exist")] + PartitionDoesNotExist, + + #[error("Rebalance not supported")] + RebalanceNotSupported, + + #[error("Topic does not exist")] + TopicDoesNotExist, +} + +impl From for BrokerError { + fn from(_: TopicDoesNotExist) -> Self { + BrokerError::TopicDoesNotExist + } +} + +impl LocalBroker { + pub fn new( + storage: Box + Send + Sync>, + clock: Box, + ) -> Self { + Self { + storage, + clock, + offsets: HashMap::new(), + subscriptions: HashMap::new(), + } + } + + pub fn create_topic(&mut self, topic: Topic, partitions: u16) -> Result<(), TopicExists> { + self.storage.create_topic(topic, partitions) + } + + pub fn get_topic_partition_count(&self, topic: &Topic) -> Result { + self.storage.partition_count(topic) + } + + pub fn produce( + &mut self, + partition: &Partition, + payload: TPayload, + ) -> Result { + let time = self.clock.time(); + self.storage + .produce(partition, payload, DateTime::from(time)) + } + + pub fn subscribe( + &mut self, + consumer_id: Uuid, + consumer_group: String, + topics: Vec, + ) -> Result, BrokerError> { + // Handle rebalancing request which is not supported + let group_subscriptions = self.subscriptions.get(&consumer_group); + if let Some(group_s) = group_subscriptions { + let consumer_subscription = group_s.get(&consumer_id); + if let Some(consume_subs) = consumer_subscription { + let subscribed_topics = consume_subs; + let mut non_matches = subscribed_topics + .iter() + .zip(&topics) + .filter(|&(a, b)| a != b); + if non_matches.next().is_some() { + return Err(BrokerError::RebalanceNotSupported); + } + } else { + return Err(BrokerError::RebalanceNotSupported); + } + } + + let mut assignments = HashMap::new(); + let mut assigned_topics = HashSet::new(); + + for topic in topics.iter() { + if !assigned_topics.contains(topic) { + assigned_topics.insert(topic); + let partition_count = self.storage.partition_count(topic)?; + if !self.offsets.contains_key(&consumer_group) { + self.offsets.insert(consumer_group.clone(), HashMap::new()); + } + for n in 0..partition_count { + let p = Partition::new(*topic, n); + let offset = self.offsets[&consumer_group] + .get(&p) + .copied() + .unwrap_or_default(); + assignments.insert(p, offset); + } + } + } + + let group_subscriptions = self.subscriptions.get_mut(&consumer_group); + match group_subscriptions { + None => { + let mut new_group_subscriptions = HashMap::new(); + new_group_subscriptions.insert(consumer_id, topics); + self.subscriptions + .insert(consumer_group.clone(), new_group_subscriptions); + } + Some(group_subscriptions) => { + group_subscriptions.insert(consumer_id, topics); + } + } + Ok(assignments) + } + + pub fn unsubscribe(&mut self, id: Uuid, group: String) -> Result, BrokerError> { + let mut ret_partitions = Vec::new(); + + let Some(group_subscriptions) = self.subscriptions.get_mut(&group) else { + return Ok(vec![]); + }; + + let Some(subscribed_topics) = group_subscriptions.get(&id) else { + return Ok(vec![]); + }; + + for topic in subscribed_topics.iter() { + let partitions = self.storage.partition_count(topic)?; + for n in 0..partitions { + ret_partitions.push(Partition::new(*topic, n)); + } + } + group_subscriptions.remove(&id); + Ok(ret_partitions) + } + + pub fn consume( + &self, + partition: &Partition, + offset: u64, + ) -> Result>, ConsumeError> { + self.storage.consume(partition, offset) + } + + pub fn commit(&mut self, consumer_group: &str, offsets: HashMap) { + self.offsets.insert(consumer_group.to_string(), offsets); + } + + #[cfg(test)] + pub(crate) fn storage_mut(&mut self) -> &mut dyn MessageStorage { + &mut *self.storage + } +} + +#[cfg(test)] +mod tests { + use super::LocalBroker; + use crate::backends::storages::memory::MemoryMessageStorage; + use crate::types::{Partition, Topic}; + use crate::utils::clock::SystemClock; + use std::collections::HashMap; + use uuid::Uuid; + + #[test] + fn test_topic_creation() { + let storage: MemoryMessageStorage = Default::default(); + let clock = SystemClock {}; + let mut broker = LocalBroker::new(Box::new(storage), Box::new(clock)); + + let topic = Topic::new("test"); + let res = broker.create_topic(topic, 16); + assert!(res.is_ok()); + + let res2 = broker.create_topic(topic, 16); + assert!(res2.is_err()); + + let partitions = broker.get_topic_partition_count(&topic); + assert_eq!(partitions.unwrap(), 16); + } + + #[test] + fn test_produce_consume() { + let storage: MemoryMessageStorage = Default::default(); + let clock = SystemClock {}; + let mut broker = LocalBroker::new(Box::new(storage), Box::new(clock)); + + let partition = Partition::new(Topic::new("test"), 0); + let _ = broker.create_topic(Topic::new("test"), 1); + let r_prod = broker.produce(&partition, "message".to_string()); + assert!(r_prod.is_ok()); + assert_eq!(r_prod.unwrap(), 0); + + let message = broker.consume(&partition, 0).unwrap().unwrap(); + assert_eq!(message.offset, 0); + assert_eq!(message.partition, partition.clone()); + assert_eq!(message.payload, "message".to_string()); + } + + fn build_broker() -> LocalBroker { + let storage: MemoryMessageStorage = Default::default(); + let clock = SystemClock {}; + let mut broker = LocalBroker::new(Box::new(storage), Box::new(clock)); + + let topic1 = Topic::new("test1"); + let topic2 = Topic::new("test2"); + + let _ = broker.create_topic(topic1, 2); + let _ = broker.create_topic(topic2, 1); + broker + } + + #[test] + fn test_assignment() { + let mut broker = build_broker(); + + let topic1 = Topic::new("test1"); + let topic2 = Topic::new("test2"); + + let r_assignments = + broker.subscribe(Uuid::nil(), "group".to_string(), vec![topic1, topic2]); + assert!(r_assignments.is_ok()); + let expected = HashMap::from([ + (Partition::new(topic1, 0), 0), + (Partition::new(topic1, 1), 0), + (Partition::new(topic2, 0), 0), + ]); + assert_eq!(r_assignments.unwrap(), expected); + + let unassignmnts = broker.unsubscribe(Uuid::nil(), "group".to_string()); + assert!(unassignmnts.is_ok()); + let expected = vec![ + Partition::new(topic1, 0), + Partition::new(topic1, 1), + Partition::new(topic2, 0), + ]; + assert_eq!(unassignmnts.unwrap(), expected); + } +} diff --git a/rust-arroyo/src/backends/local/mod.rs b/rust-arroyo/src/backends/local/mod.rs new file mode 100644 index 00000000..b8065fc8 --- /dev/null +++ b/rust-arroyo/src/backends/local/mod.rs @@ -0,0 +1,489 @@ +pub mod broker; + +use super::{AssignmentCallbacks, CommitOffsets, Consumer, ConsumerError, Producer, ProducerError}; +use crate::types::{BrokerMessage, Partition, Topic, TopicOrPartition}; +use broker::LocalBroker; +use parking_lot::Mutex; +use rand::prelude::*; +use std::collections::HashSet; +use std::collections::{HashMap, VecDeque}; +use std::sync::Arc; +use std::time::Duration; +use uuid::Uuid; + +#[derive(Debug, Clone)] +pub struct RebalanceNotSupported; + +enum Callback { + Assign(HashMap), +} + +struct SubscriptionState { + topics: Vec, + callbacks: Option, + offsets: HashMap, + last_eof_at: HashMap, +} + +struct OffsetCommitter<'a, TPayload> { + group: &'a str, + broker: &'a mut LocalBroker, +} + +impl<'a, TPayload> CommitOffsets for OffsetCommitter<'a, TPayload> { + fn commit(self, offsets: HashMap) -> Result<(), ConsumerError> { + self.broker.commit(self.group, offsets); + Ok(()) + } +} + +pub struct LocalConsumer { + id: Uuid, + group: String, + broker: Arc>>, + pending_callback: VecDeque, + paused: HashSet, + // The offset that a the last ``EndOfPartition`` exception that was + // raised at. To maintain consistency with the Confluent consumer, this + // is only sent once per (partition, offset) pair. + subscription_state: SubscriptionState, + enable_end_of_partition: bool, + commit_offset_calls: u32, +} + +impl LocalConsumer { + pub fn new( + id: Uuid, + broker: Arc>>, + group: String, + enable_end_of_partition: bool, + topics: &[Topic], + callbacks: C, + ) -> Self { + let mut ret = Self { + id, + group, + broker, + pending_callback: VecDeque::new(), + paused: HashSet::new(), + subscription_state: SubscriptionState { + topics: topics.to_vec(), + callbacks: Some(callbacks), + offsets: HashMap::new(), + last_eof_at: HashMap::new(), + }, + enable_end_of_partition, + commit_offset_calls: 0, + }; + + let offsets = ret + .broker + .lock() + .subscribe(ret.id, ret.group.clone(), topics.to_vec()) + .unwrap(); + + ret.pending_callback.push_back(Callback::Assign(offsets)); + + ret + } + + fn is_subscribed<'p>(&self, mut partitions: impl Iterator) -> bool { + let subscribed = &self.subscription_state.offsets; + partitions.all(|partition| subscribed.contains_key(partition)) + } + + pub fn shutdown(self) {} +} + +impl Consumer + for LocalConsumer +{ + fn poll( + &mut self, + _timeout: Option, + ) -> Result>, ConsumerError> { + while !self.pending_callback.is_empty() { + let callback = self.pending_callback.pop_front().unwrap(); + match callback { + Callback::Assign(offsets) => { + if let Some(callbacks) = self.subscription_state.callbacks.as_mut() { + callbacks.on_assign(offsets.clone()); + } + self.subscription_state.offsets = offsets; + } + } + } + + let keys = self.subscription_state.offsets.keys(); + let mut new_offset: Option<(Partition, u64)> = None; + let mut ret_message: Option> = None; + for partition in keys { + if self.paused.contains(partition) { + continue; + } + + let offset = self.subscription_state.offsets[partition]; + let message = self.broker.lock().consume(partition, offset).unwrap(); + if let Some(msg) = message { + new_offset = Some((*partition, msg.offset + 1)); + ret_message = Some(msg); + break; + } + + if self.enable_end_of_partition + && (!self.subscription_state.last_eof_at.contains_key(partition) + || offset > self.subscription_state.last_eof_at[partition]) + { + self.subscription_state + .last_eof_at + .insert(*partition, offset); + return Err(ConsumerError::EndOfPartition); + } + } + + Ok(new_offset.and_then(|(partition, offset)| { + self.subscription_state.offsets.insert(partition, offset); + ret_message + })) + } + + fn pause(&mut self, partitions: HashSet) -> Result<(), ConsumerError> { + if !self.is_subscribed(partitions.iter()) { + return Err(ConsumerError::EndOfPartition); + } + + self.paused.extend(partitions); + Ok(()) + } + + fn resume(&mut self, partitions: HashSet) -> Result<(), ConsumerError> { + if !self.is_subscribed(partitions.iter()) { + return Err(ConsumerError::UnassignedPartition); + } + + for p in partitions { + self.paused.remove(&p); + } + Ok(()) + } + + fn paused(&self) -> Result, ConsumerError> { + Ok(self.paused.clone()) + } + + fn tell(&self) -> Result, ConsumerError> { + Ok(self.subscription_state.offsets.clone()) + } + + fn seek(&self, _: HashMap) -> Result<(), ConsumerError> { + unimplemented!("Seek is not implemented"); + } + + fn commit_offsets(&mut self, offsets: HashMap) -> Result<(), ConsumerError> { + if !self.is_subscribed(offsets.keys()) { + return Err(ConsumerError::UnassignedPartition); + } + + self.broker.lock().commit(&self.group, offsets); + self.commit_offset_calls += 1; + + Ok(()) + } +} + +impl Drop for LocalConsumer { + fn drop(&mut self) { + if !self.subscription_state.topics.is_empty() { + let broker: &mut LocalBroker<_> = &mut self.broker.lock(); + let partitions = broker.unsubscribe(self.id, self.group.clone()).unwrap(); + + if let Some(c) = self.subscription_state.callbacks.as_mut() { + let offset_stage = OffsetCommitter { + group: &self.group, + broker, + }; + c.on_revoke(offset_stage, partitions); + } + } + } +} + +pub struct LocalProducer { + broker: Arc>>, +} + +impl LocalProducer { + pub fn new(broker: Arc>>) -> Self { + Self { broker } + } +} + +impl Clone for LocalProducer { + fn clone(&self) -> Self { + Self { + broker: self.broker.clone(), + } + } +} + +impl Producer for LocalProducer { + fn produce( + &self, + destination: &TopicOrPartition, + payload: TPayload, + ) -> Result<(), ProducerError> { + let mut broker = self.broker.lock(); + let partition = match destination { + TopicOrPartition::Topic(t) => { + let max_partitions = broker + .get_topic_partition_count(t) + .map_err(|_| ProducerError::ProducerErrorred)?; + let partition = thread_rng().gen_range(0..max_partitions); + Partition::new(*t, partition) + } + TopicOrPartition::Partition(p) => *p, + }; + + broker + .produce(&partition, payload) + .map_err(|_| ProducerError::ProducerErrorred)?; + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::{AssignmentCallbacks, LocalConsumer}; + use crate::backends::local::broker::LocalBroker; + use crate::backends::storages::memory::MemoryMessageStorage; + use crate::backends::{CommitOffsets, Consumer}; + use crate::types::{Partition, Topic}; + use crate::utils::clock::SystemClock; + use parking_lot::Mutex; + use std::collections::{HashMap, HashSet}; + use std::sync::Arc; + use std::time::Duration; + use uuid::Uuid; + + struct EmptyCallbacks {} + impl AssignmentCallbacks for EmptyCallbacks { + fn on_assign(&self, _: HashMap) {} + fn on_revoke(&self, _: C, _: Vec) {} + } + + fn build_broker() -> LocalBroker { + let storage: MemoryMessageStorage = Default::default(); + let clock = SystemClock {}; + let mut broker = LocalBroker::new(Box::new(storage), Box::new(clock)); + + let topic1 = Topic::new("test1"); + let topic2 = Topic::new("test2"); + + let _ = broker.create_topic(topic1, 2); + let _ = broker.create_topic(topic2, 1); + broker + } + + #[test] + fn test_consumer_subscription() { + let broker = build_broker(); + + let topic1 = Topic::new("test1"); + let topic2 = Topic::new("test2"); + + let mut consumer = LocalConsumer::new( + Uuid::nil(), + Arc::new(Mutex::new(broker)), + "test_group".to_string(), + true, + &[topic1, topic2], + EmptyCallbacks {}, + ); + + assert_eq!(consumer.pending_callback.len(), 1); + + let _ = consumer.poll(Some(Duration::from_millis(100))); + let expected = HashMap::from([ + (Partition::new(topic1, 0), 0), + (Partition::new(topic1, 1), 0), + (Partition::new(topic2, 0), 0), + ]); + assert_eq!(consumer.subscription_state.offsets, expected); + assert_eq!(consumer.pending_callback.len(), 0); + } + + #[test] + fn test_subscription_callback() { + let broker = build_broker(); + + let topic1 = Topic::new("test1"); + let topic2 = Topic::new("test2"); + + struct TheseCallbacks {} + impl AssignmentCallbacks for TheseCallbacks { + fn on_assign(&self, partitions: HashMap) { + let topic1 = Topic::new("test1"); + let topic2 = Topic::new("test2"); + assert_eq!( + partitions, + HashMap::from([ + ( + Partition { + topic: topic1, + index: 0 + }, + 0 + ), + ( + Partition { + topic: topic1, + index: 1 + }, + 0 + ), + ( + Partition { + topic: topic2, + index: 0 + }, + 0 + ), + ]) + ) + } + + fn on_revoke(&self, _: C, partitions: Vec) { + let topic1 = Topic::new("test1"); + let topic2 = Topic::new("test2"); + assert_eq!( + partitions, + vec![ + Partition { + topic: topic1, + index: 0 + }, + Partition { + topic: topic1, + index: 1 + }, + Partition { + topic: topic2, + index: 0 + }, + ] + ); + } + } + + let mut consumer = LocalConsumer::new( + Uuid::nil(), + Arc::new(Mutex::new(broker)), + "test_group".to_string(), + true, + &[topic1, topic2], + TheseCallbacks {}, + ); + + let _ = consumer.poll(Some(Duration::from_millis(100))); + } + + #[test] + fn test_consume() { + let mut broker = build_broker(); + + let topic2 = Topic::new("test2"); + let partition = Partition::new(topic2, 0); + let _ = broker.produce(&partition, "message1".to_string()); + let _ = broker.produce(&partition, "message2".to_string()); + + struct TheseCallbacks {} + impl AssignmentCallbacks for TheseCallbacks { + fn on_assign(&self, partitions: HashMap) { + let topic2 = Topic::new("test2"); + assert_eq!( + partitions, + HashMap::from([( + Partition { + topic: topic2, + index: 0 + }, + 0 + ),]) + ); + } + fn on_revoke(&self, _: C, _: Vec) {} + } + + let mut consumer = LocalConsumer::new( + Uuid::nil(), + Arc::new(Mutex::new(broker)), + "test_group".to_string(), + true, + &[topic2], + TheseCallbacks {}, + ); + + let msg1 = consumer.poll(Some(Duration::from_millis(100))).unwrap(); + assert!(msg1.is_some()); + let msg_content = msg1.unwrap(); + assert_eq!(msg_content.offset, 0); + assert_eq!(msg_content.payload, "message1".to_string()); + + let msg2 = consumer.poll(Some(Duration::from_millis(100))).unwrap(); + assert!(msg2.is_some()); + let msg_content = msg2.unwrap(); + assert_eq!(msg_content.offset, 1); + assert_eq!(msg_content.payload, "message2".to_string()); + + let ret = consumer.poll(Some(Duration::from_millis(100))); + assert!(ret.is_err()); + } + + #[test] + fn test_paused() { + let broker = build_broker(); + let topic2 = Topic::new("test2"); + let partition = Partition::new(topic2, 0); + let mut consumer = LocalConsumer::new( + Uuid::nil(), + Arc::new(Mutex::new(broker)), + "test_group".to_string(), + false, + &[topic2], + EmptyCallbacks {}, + ); + + assert_eq!(consumer.poll(None).unwrap(), None); + let _ = consumer.pause(HashSet::from([partition])); + assert_eq!(consumer.paused().unwrap(), HashSet::from([partition])); + + let _ = consumer.resume(HashSet::from([partition])); + assert_eq!(consumer.poll(None).unwrap(), None); + } + + #[test] + fn test_commit() { + let broker = build_broker(); + let topic2 = Topic::new("test2"); + let mut consumer = LocalConsumer::new( + Uuid::nil(), + Arc::new(Mutex::new(broker)), + "test_group".to_string(), + false, + &[topic2], + EmptyCallbacks {}, + ); + let _ = consumer.poll(None); + let positions = HashMap::from([(Partition::new(topic2, 0), 100)]); + + let offsets = consumer.commit_offsets(positions.clone()); + assert!(offsets.is_ok()); + + // Stage invalid positions + let invalid_positions = HashMap::from([(Partition::new(topic2, 1), 100)]); + + let commit_result = consumer.commit_offsets(invalid_positions); + assert!(commit_result.is_err()); + } +} diff --git a/rust-arroyo/src/backends/mod.rs b/rust-arroyo/src/backends/mod.rs new file mode 100755 index 00000000..c9e02625 --- /dev/null +++ b/rust-arroyo/src/backends/mod.rs @@ -0,0 +1,163 @@ +use super::types::{BrokerMessage, Partition, TopicOrPartition}; +use std::collections::{HashMap, HashSet}; +use std::time::Duration; +use thiserror::Error; + +pub mod kafka; +pub mod local; +pub mod storages; + +#[non_exhaustive] +#[derive(Error, Debug)] +pub enum ConsumerError { + #[error("Invalid config")] + InvalidConfig, + + #[error("End of partition reached")] + EndOfPartition, + + #[error("Not subscribed to a topic")] + NotSubscribed, + + #[error("The consumer errored")] + ConsumerErrored, + + #[error("The consumer is closed")] + ConsumerClosed, + + #[error("Partition not assigned to consumer")] + UnassignedPartition, + + #[error("Offset out of range")] + OffsetOutOfRange { source: Box }, + + #[error(transparent)] + BrokerError(#[from] Box), +} + +#[non_exhaustive] +#[derive(Error, Debug)] +pub enum ProducerError { + #[error("The producer errored")] + ProducerErrorred, +} + +/// This abstracts the committing of partition offsets. +pub trait CommitOffsets { + /// Commit the partition offsets stored in this object, plus the ones passed in `offsets`. + /// + /// Returns a map of all offsets that were committed. This combines [`Consumer::stage_offsets`] and + /// [`Consumer::commit_offsets`]. + fn commit(self, offsets: HashMap) -> Result<(), ConsumerError>; +} + +/// This is basically an observer pattern to receive the callbacks from +/// the consumer when partitions are assigned/revoked. +pub trait AssignmentCallbacks: Send + Sync { + fn on_assign(&self, partitions: HashMap); + fn on_revoke(&self, commit_offsets: C, partitions: Vec); +} + +/// This abstract class provides an interface for consuming messages from a +/// multiplexed collection of partitioned topic streams. +/// +/// Partitions support sequential access, as well as random access by +/// offsets. There are three types of offsets that a consumer interacts with: +/// working offsets, staged offsets, and committed offsets. Offsets always +/// represent the starting offset of the *next* message to be read. (For +/// example, committing an offset of X means the next message fetched via +/// poll will have a least an offset of X, and the last message read had an +/// offset less than X.) +/// +/// The working offsets are used track the current read position within a +/// partition. This can be also be considered as a cursor, or as high +/// watermark. Working offsets are local to the consumer process. They are +/// not shared with other consumer instances in the same consumer group and +/// do not persist beyond the lifecycle of the consumer instance, unless they +/// are committed. +/// +/// Committed offsets are managed by an external arbiter/service, and are +/// used as the starting point for a consumer when it is assigned a partition +/// during the subscription process. To ensure that a consumer roughly "picks +/// up where it left off" after restarting, or that another consumer in the +/// same group doesn't read messages that have been processed by another +/// consumer within the same group during a rebalance operation, positions must +/// be regularly committed by calling ``commit_offsets`` after they have been +/// staged with ``stage_offsets``. Offsets are not staged or committed +/// automatically! +/// +/// During rebalance operations, working offsets are rolled back to the +/// latest committed offset for a partition, and staged offsets are cleared +/// after the revocation callback provided to ``subscribe`` is called. (This +/// occurs even if the consumer retains ownership of the partition across +/// assignments.) For this reason, it is generally good practice to ensure +/// offsets are committed as part of the revocation callback. +pub trait Consumer: Send { + /// Fetch a message from the consumer. If no message is available before + /// the timeout, ``None`` is returned. + /// + /// This method may raise an ``OffsetOutOfRange`` exception if the + /// consumer attempts to read from an invalid location in one of it's + /// assigned partitions. (Additional details can be found in the + /// docstring for ``Consumer.seek``.) + fn poll( + &mut self, + timeout: Option, + ) -> Result>, ConsumerError>; + + /// Pause consuming from the provided partitions. + /// + /// A partition that is paused will be automatically resumed during + /// reassignment. This ensures that the behavior is consistent during + /// rebalances, regardless of whether or not this consumer retains + /// ownership of the partition. (If this partition was assigned to a + /// different consumer in the consumer group during a rebalance, that + /// consumer would not have knowledge of whether or not the partition was + /// previously paused and would start consuming from the partition.) If + /// partitions should remain paused across rebalances, this should be + /// implemented in the assignment callback. + /// + /// If any of the provided partitions are not in the assignment set, an + /// exception will be raised and no partitions will be paused. + fn pause(&mut self, partitions: HashSet) -> Result<(), ConsumerError>; + + /// Resume consuming from the provided partitions. + /// + /// If any of the provided partitions are not in the assignment set, an + /// exception will be raised and no partitions will be resumed. + fn resume(&mut self, partitions: HashSet) -> Result<(), ConsumerError>; + + /// Return the currently paused partitions. + fn paused(&self) -> Result, ConsumerError>; + + /// Return the working offsets for all currently assigned positions. + fn tell(&self) -> Result, ConsumerError>; + + /// Update the working offsets for the provided partitions. + /// + /// When using this method, it is possible to set a partition to an + /// invalid offset without an immediate error. (Examples of invalid + /// offsets include an offset that is too low and has already been + /// dropped by the broker due to data retention policies, or an offset + /// that is too high which is not yet associated with a message.) Since + /// this method only updates the local working offset (and does not + /// communicate with the broker), setting an invalid offset will cause a + /// subsequent ``poll`` call to raise ``OffsetOutOfRange`` exception, + /// even though the call to ``seek`` succeeded. + /// + /// If any provided partitions are not in the assignment set, an + /// exception will be raised and no offsets will be modified. + fn seek(&self, offsets: HashMap) -> Result<(), ConsumerError>; + + /// Commit offsets. + fn commit_offsets(&mut self, positions: HashMap) -> Result<(), ConsumerError>; +} + +pub trait Producer: Send + Sync { + /// Produce to a topic or partition. + fn produce( + &self, + destination: &TopicOrPartition, + payload: TPayload, + ) -> Result<(), ProducerError>; +} diff --git a/rust-arroyo/src/backends/storages/memory.rs b/rust-arroyo/src/backends/storages/memory.rs new file mode 100755 index 00000000..f30ae9a4 --- /dev/null +++ b/rust-arroyo/src/backends/storages/memory.rs @@ -0,0 +1,219 @@ +use super::{ConsumeError, MessageStorage, TopicDoesNotExist, TopicExists}; +use crate::types::{BrokerMessage, Partition, Topic}; +use chrono::{DateTime, Utc}; +use std::cmp::Ordering; +use std::collections::HashMap; +use std::convert::TryFrom; + +/// Stores a list of messages for each partition of a topic. +/// +/// `self.messages[i][j]` is the `j`-th message of the `i`-th partition. +struct TopicMessages { + messages: Vec>>, +} + +impl TopicMessages { + /// Creates empty messsage queues for the given number of partitions. + fn new(partitions: u16) -> Self { + Self { + messages: (0..partitions).map(|_| Vec::new()).collect(), + } + } + + /// Returns the messages of the given partition. + /// + /// # Errors + /// Returns `ConsumeError::PartitionDoesNotExist` if the partition number is out of bounds. + fn get_messages(&self, partition: u16) -> Result<&Vec>, ConsumeError> { + self.messages + .get(partition as usize) + .ok_or(ConsumeError::PartitionDoesNotExist) + } + + /// Appends the given message to its partition's message queue. + /// + /// # Errors + /// Returns `ConsumeError::PartitionDoesNotExist` if the message's partition number is out of bounds. + fn add_message(&mut self, message: BrokerMessage) -> Result<(), ConsumeError> { + let stream = self + .messages + .get_mut(message.partition.index as usize) + .ok_or(ConsumeError::PartitionDoesNotExist)?; + stream.push(message); + Ok(()) + } + + /// Returns the number of partitions. + fn partition_count(&self) -> u16 { + u16::try_from(self.messages.len()).unwrap() + } +} + +/// An implementation of [`MessageStorage`] that holds messages in memory. +pub struct MemoryMessageStorage { + topics: HashMap>, +} + +impl Default for MemoryMessageStorage { + fn default() -> Self { + MemoryMessageStorage { + topics: HashMap::new(), + } + } +} + +impl MessageStorage for MemoryMessageStorage { + fn create_topic(&mut self, topic: Topic, partitions: u16) -> Result<(), TopicExists> { + if self.topics.contains_key(&topic) { + return Err(TopicExists); + } + self.topics.insert(topic, TopicMessages::new(partitions)); + Ok(()) + } + + fn list_topics(&self) -> Vec<&Topic> { + self.topics.keys().collect() + } + + fn delete_topic(&mut self, topic: &Topic) -> Result<(), TopicDoesNotExist> { + self.topics.remove(topic).ok_or(TopicDoesNotExist)?; + Ok(()) + } + + fn partition_count(&self, topic: &Topic) -> Result { + match self.topics.get(topic) { + Some(x) => Ok(x.partition_count()), + None => Err(TopicDoesNotExist), + } + } + + fn consume( + &self, + partition: &Partition, + offset: u64, + ) -> Result>, ConsumeError> { + let offset = usize::try_from(offset).unwrap(); + let messages = self.topics[&partition.topic].get_messages(partition.index)?; + match messages.len().cmp(&offset) { + Ordering::Greater => Ok(Some(messages[offset].clone())), + Ordering::Less => Err(ConsumeError::OffsetOutOfRange), + Ordering::Equal => Ok(None), + } + } + + fn produce( + &mut self, + partition: &Partition, + payload: TPayload, + timestamp: DateTime, + ) -> Result { + let messages = self + .topics + .get_mut(&partition.topic) + .ok_or(ConsumeError::TopicDoesNotExist)?; + let offset = messages.get_messages(partition.index)?.len(); + let offset = u64::try_from(offset).unwrap(); + let _ = messages.add_message(BrokerMessage::new(payload, *partition, offset, timestamp)); + Ok(offset) + } +} + +#[cfg(test)] +mod tests { + use super::MemoryMessageStorage; + use super::TopicMessages; + use crate::backends::storages::MessageStorage; + use crate::types::{BrokerMessage, Partition, Topic}; + use chrono::Utc; + + #[test] + fn test_partition_count() { + let topic: TopicMessages = TopicMessages::new(64); + assert_eq!(topic.partition_count(), 64); + } + + #[test] + fn test_empty_partitions() { + let topic: TopicMessages = TopicMessages::new(2); + assert_eq!(topic.get_messages(0).unwrap().len(), 0); + assert_eq!(topic.get_messages(1).unwrap().len(), 0); + } + + #[test] + fn test_invalid_partition() { + let topic: TopicMessages = TopicMessages::new(2); + assert!(topic.get_messages(10).is_err()); + } + + #[test] + fn test_add_messages() { + let mut topic: TopicMessages = TopicMessages::new(2); + let now = Utc::now(); + let p = Partition::new(Topic::new("test"), 0); + let res = topic.add_message(BrokerMessage::new("payload".to_string(), p, 10, now)); + + assert!(res.is_ok()); + assert_eq!(topic.get_messages(0).unwrap().len(), 1); + + let queue = topic.get_messages(0).unwrap(); + assert_eq!(queue[0].offset, 10); + + assert_eq!(topic.get_messages(1).unwrap().len(), 0); + } + + #[test] + fn create_manage_topic() { + let mut m: MemoryMessageStorage = Default::default(); + let res = m.create_topic(Topic::new("test"), 16); + assert!(res.is_ok()); + let b = m.list_topics(); + assert_eq!(b.len(), 1); + assert_eq!(b[0].as_str(), "test"); + + let t = Topic::new("test"); + let res2 = m.delete_topic(&t); + assert!(res2.is_ok()); + let b2 = m.list_topics(); + assert_eq!(b2.len(), 0); + } + + #[test] + fn test_mem_partition_count() { + let mut m: MemoryMessageStorage = Default::default(); + let _ = m.create_topic(Topic::new("test"), 16); + + assert_eq!(m.partition_count(&Topic::new("test")).unwrap(), 16); + } + + #[test] + fn test_consume_empty() { + let mut m: MemoryMessageStorage = Default::default(); + let _ = m.create_topic(Topic::new("test"), 16); + let p = Partition::new(Topic::new("test"), 0); + let message = m.consume(&p, 0); + assert!(message.is_ok()); + assert!(message.unwrap().is_none()); + + let err = m.consume(&p, 1); + assert!(err.is_err()); + } + + #[test] + fn test_produce() { + let mut m: MemoryMessageStorage = Default::default(); + let _ = m.create_topic(Topic::new("test"), 2); + let p = Partition::new(Topic::new("test"), 0); + let time = Utc::now(); + let offset = m.produce(&p, "test".to_string(), time).unwrap(); + assert_eq!(offset, 0); + + let msg_c = m.consume(&p, 0).unwrap(); + assert!(msg_c.is_some()); + let existing_msg = msg_c.unwrap(); + assert_eq!(existing_msg.offset, 0); + assert_eq!(existing_msg.payload, "test".to_string()); + + let msg_none = m.consume(&p, 1).unwrap(); + assert!(msg_none.is_none()); + } +} diff --git a/rust-arroyo/src/backends/storages/mod.rs b/rust-arroyo/src/backends/storages/mod.rs new file mode 100755 index 00000000..23ac74a8 --- /dev/null +++ b/rust-arroyo/src/backends/storages/mod.rs @@ -0,0 +1,81 @@ +pub mod memory; +use super::super::types::{BrokerMessage, Partition, Topic}; +use chrono::{DateTime, Utc}; + +#[derive(Debug, Clone)] +pub struct TopicExists; + +#[derive(Debug, Clone)] +pub struct TopicDoesNotExist; + +#[derive(Debug, Clone)] +pub struct PartitionDoesNotExist; + +#[derive(Debug, Clone)] +pub struct OffsetOutOfRange; + +#[derive(Debug)] +pub enum ConsumeError { + TopicDoesNotExist, + PartitionDoesNotExist, + OffsetOutOfRange, +} + +/// TODO: docs +pub trait MessageStorage: Send { + /// Create a topic with the given number of partitions. + /// + /// # Errors + /// If the topic already exists, [`TopicExists`] will be returned. + fn create_topic(&mut self, topic: Topic, partitions: u16) -> Result<(), TopicExists>; + + /// List all topics. + fn list_topics(&self) -> Vec<&Topic>; + + /// Delete a topic. + /// + /// # Errors + /// If the topic does not exist, [`TopicDoesNotExist`] will be returned. + fn delete_topic(&mut self, topic: &Topic) -> Result<(), TopicDoesNotExist>; + + /// Get the number of partitions within a topic. + /// + /// # Errors + /// If the topic does not exist, [`TopicDoesNotExist`] will be returned. + fn partition_count(&self, topic: &Topic) -> Result; + + /// Consume a message from the provided partition, reading from the given + /// offset. If no message exists at the given offset when reading from + /// the tail of the partition, this method returns `Ok(None)`. + /// + /// # Errors + /// * If the offset is out of range (there are no messages, and we're not + /// reading from the tail of the partition where the next message would + /// be if it existed), [`OffsetOutOfRange`] will be returned. + /// + /// * If the topic does not exist, [`TopicDoesNotExist`] will + /// be returned. + /// + /// * If the topic exists but the partition does not, + /// [`PartitionDoesNotExist`] will be returned. + fn consume( + &self, + partition: &Partition, + offset: u64, + ) -> Result>, ConsumeError>; + + /// Produce a single message to the provided partition. + /// + /// # Errors + /// * If the topic does not exist, [`TopicDoesNotExist`] will + /// be returned. + /// + /// * If the topic exists but the partition does not, + /// [`PartitionDoesNotExist`] will be returned. + fn produce( + &mut self, + partition: &Partition, + payload: TPayload, + timestamp: DateTime, + ) -> Result; +} diff --git a/rust-arroyo/src/lib.rs b/rust-arroyo/src/lib.rs new file mode 100644 index 00000000..a8f6e89c --- /dev/null +++ b/rust-arroyo/src/lib.rs @@ -0,0 +1,6 @@ +pub mod backends; +pub mod metrics; +pub mod processing; +pub mod testutils; +pub mod types; +pub mod utils; diff --git a/rust-arroyo/src/metrics/globals.rs b/rust-arroyo/src/metrics/globals.rs new file mode 100644 index 00000000..19764855 --- /dev/null +++ b/rust-arroyo/src/metrics/globals.rs @@ -0,0 +1,44 @@ +use std::sync::OnceLock; + +use super::Metric; + +/// The global [`Recorder`] which will receive [`Metric`] to be recorded. +pub trait Recorder { + /// Instructs the recorder to record the given [`Metric`]. + fn record_metric(&self, metric: Metric<'_>); +} + +impl Recorder for Box { + fn record_metric(&self, metric: Metric<'_>) { + (**self).record_metric(metric) + } +} + +static GLOBAL_RECORDER: OnceLock> = OnceLock::new(); + +/// Initialize the global [`Recorder`]. +/// +/// This will register the given `recorder` as the single global [`Recorder`] instance. +/// +/// This function can only be called once, and subsequent calls will return an +/// [`Err`] in case a global [`Recorder`] has already been initialized. +pub fn init(recorder: R) -> Result<(), R> { + let mut result = Err(recorder); + { + let result = &mut result; + let _ = GLOBAL_RECORDER.get_or_init(|| { + let recorder = std::mem::replace(result, Ok(())).unwrap_err(); + Box::new(recorder) + }); + } + result +} + +/// Records a [`Metric`] with the globally configured [`Recorder`]. +/// +/// This function will be a noop in case no global [`Recorder`] is configured. +pub fn record_metric(metric: Metric<'_>) { + if let Some(recorder) = GLOBAL_RECORDER.get() { + recorder.record_metric(metric) + } +} diff --git a/rust-arroyo/src/metrics/macros.rs b/rust-arroyo/src/metrics/macros.rs new file mode 100644 index 00000000..05f6125c --- /dev/null +++ b/rust-arroyo/src/metrics/macros.rs @@ -0,0 +1,64 @@ +/// Create a [`Metric`]. +/// +/// Instead of creating metrics directly, it is recommended to immediately record +/// metrics using the [`counter!`], [`gauge!`] or [`distribution!`] macros. +/// +/// This is the recommended way to create a [`Metric`], as the +/// implementation details of it might change. +/// +/// [`Metric`]: crate::metrics::Metric +#[macro_export] +macro_rules! metric { + ($ty:ident: $key:expr, $value:expr + $(, $($tag_key:expr => $tag_val:expr),*)? + $(; $($tag_only_val:expr),*)? + ) => {{ + $crate::metrics::Metric { + key: &$key, + ty: $crate::metrics::MetricType::$ty, + + tags: &[ + $($((Some(&$tag_key), &$tag_val),)*)? + $($((None, &$tag_only_val),)*)? + ], + value: $value.into(), + + __private: (), + } + }}; +} + +/// Records a counter [`Metric`](crate::metrics::Metric) with the global [`Recorder`](crate::metrics::Recorder). +#[macro_export] +macro_rules! counter { + ($expr:expr) => { + $crate::__record_metric!(Counter: $expr, 1); + }; + ($($tt:tt)+) => { + $crate::__record_metric!(Counter: $($tt)+); + }; +} + +/// Records a gauge [`Metric`](crate::metrics::Metric) with the global [`Recorder`](crate::metrics::Recorder). +#[macro_export] +macro_rules! gauge { + ($($tt:tt)+) => { + $crate::__record_metric!(Gauge: $($tt)+); + }; +} + +/// Records a timer [`Metric`](crate::metrics::Metric) with the global [`Recorder`](crate::metrics::Recorder). +#[macro_export] +macro_rules! timer { + ($($tt:tt)+) => { + $crate::__record_metric!(Timer: $($tt)+); + }; +} + +#[macro_export] +#[doc(hidden)] +macro_rules! __record_metric { + ($($tt:tt)+) => {{ + $crate::metrics::record_metric($crate::metric!($($tt)+)); + }}; +} diff --git a/rust-arroyo/src/metrics/mod.rs b/rust-arroyo/src/metrics/mod.rs new file mode 100644 index 00000000..975426c3 --- /dev/null +++ b/rust-arroyo/src/metrics/mod.rs @@ -0,0 +1,8 @@ +mod globals; +mod macros; +mod statsd; +mod types; + +pub use globals::*; +pub use statsd::*; +pub use types::*; diff --git a/rust-arroyo/src/metrics/statsd.rs b/rust-arroyo/src/metrics/statsd.rs new file mode 100644 index 00000000..be37e91d --- /dev/null +++ b/rust-arroyo/src/metrics/statsd.rs @@ -0,0 +1,125 @@ +use std::cell::RefCell; +use std::fmt::{Debug, Display, Write}; + +use super::{Metric, Recorder}; + +thread_local! { + static STRING_BUFFER: RefCell = const { RefCell::new(String::new()) }; +} + +/// A generic sink used by the [`StatsdRecorder`]. +pub trait MetricSink { + /// Emits a StatsD-formatted `metric`. + fn emit(&self, metric: &str); +} + +/// A recorder emitting StatsD-formatted [`Metric`]s to a configured [`MetricSink`]. +pub struct StatsdRecorder { + prefix: String, + sink: S, + tags: String, +} + +impl Debug for StatsdRecorder { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("StatsdRecorder") + .field("prefix", &self.prefix) + .field("formatted tags", &self.tags) + .finish_non_exhaustive() + } +} + +impl StatsdRecorder { + /// Creates a new Recorder with the given `prefix` and `sink`. + /// + /// The recorder will emit [`Metric`]s to formatted in `statsd` format to the + /// configured [`MetricSink`]. + pub fn new(prefix: &str, sink: S) -> Self { + let prefix = if prefix.is_empty() { + String::new() + } else { + format!("{}.", prefix.trim_end_matches('.')) + }; + Self { + prefix, + sink, + tags: String::new(), + } + } + + fn write_tag(mut self, key: Option<&dyn Display>, value: &dyn Display) -> Self { + let t = &mut self.tags; + if t.is_empty() { + t.push_str("|#"); + } else { + t.push(','); + } + + if let Some(key) = key { + let _ = write!(t, "{key}"); + t.push(':'); + } + let _ = write!(t, "{value}"); + + self + } + + /// Add a global tag (as key/value) to this Recorder. + pub fn with_tag(self, key: impl Display, value: impl Display) -> Self { + self.write_tag(Some(&key), &value) + } + + /// Add a global tag (as a single value) to this Recorder. + pub fn with_tag_value(self, value: impl Display) -> Self { + self.write_tag(None, &value) + } + + fn write_metric(&self, metric: Metric<'_>, s: &mut String) { + s.push_str(&self.prefix); + metric.write_base_metric(s); + + s.push_str(&self.tags); + if !metric.tags.is_empty() { + if self.tags.is_empty() { + s.push_str("|#"); + } else { + s.push(','); + } + + metric.write_tags(s); + } + } +} + +impl Recorder for StatsdRecorder { + fn record_metric(&self, metric: Metric<'_>) { + STRING_BUFFER.with_borrow_mut(|s| { + s.clear(); + s.reserve(256); + + self.write_metric(metric, s); + + self.sink.emit(s); + }); + } +} + +impl Metric<'_> { + pub(crate) fn write_base_metric(&self, s: &mut String) { + let _ = write!(s, "{}:{}|", self.key, self.value); + s.push_str(self.ty.as_str()); + } + + pub(crate) fn write_tags(&self, s: &mut String) { + for (i, &(key, value)) in self.tags.iter().enumerate() { + if i > 0 { + s.push(','); + } + if let Some(key) = key { + let _ = write!(s, "{key}"); + s.push(':'); + } + let _ = write!(s, "{value}"); + } + } +} diff --git a/rust-arroyo/src/metrics/types.rs b/rust-arroyo/src/metrics/types.rs new file mode 100644 index 00000000..3c1396bb --- /dev/null +++ b/rust-arroyo/src/metrics/types.rs @@ -0,0 +1,109 @@ +use core::fmt::{self, Display}; +use std::time::Duration; + +/// The Type of a Metric. +/// +/// Counters, Gauges and Distributions are supported, +/// with more types to be added later. +#[non_exhaustive] +#[derive(Debug)] +pub enum MetricType { + /// A counter metric, using the StatsD `c` type. + Counter, + /// A gauge metric, using the StatsD `g` type. + Gauge, + /// A timer metric, using the StatsD `ms` type. + Timer, + // Distribution, + // Meter, + // Histogram, + // Set, +} + +impl MetricType { + /// Returns the StatsD metrics type. + pub fn as_str(&self) -> &str { + match self { + MetricType::Counter => "c", + MetricType::Gauge => "g", + MetricType::Timer => "ms", + } + } +} + +impl Display for MetricType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.as_str()) + } +} + +/// A Metric Value. +/// +/// This supports various numeric values for now, but might gain support for +/// `Duration` and other types later on. +#[non_exhaustive] +#[derive(Debug)] +pub enum MetricValue { + /// A signed value. + I64(i64), + /// An unsigned value. + U64(u64), + /// A floating-point value. + F64(f64), + /// A [`Duration`] value. + Duration(Duration), +} + +impl Display for MetricValue { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + MetricValue::I64(v) => v.fmt(f), + MetricValue::U64(v) => v.fmt(f), + MetricValue::F64(v) => v.fmt(f), + MetricValue::Duration(d) => d.as_millis().fmt(f), + } + } +} + +macro_rules! into_metric_value { + ($($from:ident),+ => $variant:ident) => { + $( + impl From<$from> for MetricValue { + #[inline(always)] + fn from(f: $from) -> Self { + Self::$variant(f.into()) + } + } + )+ + }; +} + +into_metric_value!(i8, i16, i32, i64 => I64); +into_metric_value!(u8, u16, u32, u64 => U64); +into_metric_value!(f32, f64 => F64); +into_metric_value!(Duration => Duration); + +/// An alias for a list of Metric tags. +pub type MetricTags<'a> = &'a [(Option<&'a dyn Display>, &'a dyn Display)]; + +/// A fully types Metric. +/// +/// Most importantly, the metric has a [`ty`](MetricType), a `key` and a [`value`](MetricValue). +/// It can also have a list of [`tags`](MetricTags). +/// +/// This struct might change in the future, and one should construct it via +/// the [`metric!`](crate::metric) macro instead. +pub struct Metric<'a> { + /// The key, or name, of the metric. + pub key: &'a dyn Display, + /// The type of metric. + pub ty: MetricType, + + /// A list of tags for this metric. + pub tags: MetricTags<'a>, + /// The metrics value. + pub value: MetricValue, + + #[doc(hidden)] + pub __private: (), +} diff --git a/rust-arroyo/src/processing/dlq.rs b/rust-arroyo/src/processing/dlq.rs new file mode 100644 index 00000000..05d4dba1 --- /dev/null +++ b/rust-arroyo/src/processing/dlq.rs @@ -0,0 +1,653 @@ +use std::cmp::Ordering; +use std::collections::{BTreeMap, HashMap, VecDeque}; +use std::fmt; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; + +use tokio::runtime::Handle; +use tokio::task::JoinHandle; + +use crate::backends::kafka::producer::KafkaProducer; +use crate::backends::kafka::types::KafkaPayload; +use crate::backends::Producer; +use crate::types::{BrokerMessage, Partition, Topic, TopicOrPartition}; + +// This is a per-partition max +const MAX_PENDING_FUTURES: usize = 1000; + +pub trait DlqProducer: Send + Sync { + // Send a message to the DLQ. + fn produce( + &self, + message: BrokerMessage, + ) -> Pin> + Send + Sync>>; + + fn build_initial_state( + &self, + limit: DlqLimit, + assignment: &HashMap, + ) -> DlqLimitState; +} + +// Drops all invalid messages. Produce returns an immediately resolved future. +struct NoopDlqProducer {} + +impl DlqProducer for NoopDlqProducer { + fn produce( + &self, + message: BrokerMessage, + ) -> Pin> + Send + Sync>> { + Box::pin(async move { message }) + } + + fn build_initial_state( + &self, + limit: DlqLimit, + _assignment: &HashMap, + ) -> DlqLimitState { + DlqLimitState::new(limit, HashMap::new()) + } +} + +// KafkaDlqProducer forwards invalid messages to a Kafka topic + +// Two additional fields are added to the headers of the Kafka message +// "original_partition": The partition of the original message +// "original_offset": The offset of the original message +pub struct KafkaDlqProducer { + producer: Arc, + topic: TopicOrPartition, +} + +impl KafkaDlqProducer { + pub fn new(producer: KafkaProducer, topic: Topic) -> Self { + Self { + producer: Arc::new(producer), + topic: TopicOrPartition::Topic(topic), + } + } +} + +impl DlqProducer for KafkaDlqProducer { + fn produce( + &self, + message: BrokerMessage, + ) -> Pin> + Send + Sync>> { + let producer = self.producer.clone(); + let topic = self.topic; + + let mut headers = message.payload.headers().cloned().unwrap_or_default(); + + headers = headers.insert( + "original_partition", + Some(message.offset.to_string().into_bytes()), + ); + headers = headers.insert( + "original_offset", + Some(message.offset.to_string().into_bytes()), + ); + + let payload = KafkaPayload::new( + message.payload.key().cloned(), + Some(headers), + message.payload.payload().cloned(), + ); + + Box::pin(async move { + producer + .produce(&topic, payload) + .expect("Message was produced"); + + message + }) + } + + fn build_initial_state( + &self, + limit: DlqLimit, + assignment: &HashMap, + ) -> DlqLimitState { + // XXX: We assume the last offsets were invalid when starting the consumer + DlqLimitState::new( + limit, + assignment + .iter() + .filter_map(|(p, o)| { + Some((*p, o.checked_sub(1).map(InvalidMessageStats::invalid_at)?)) + }) + .collect(), + ) + } +} + +/// Defines any limits that should be placed on the number of messages that are +/// forwarded to the DLQ. This exists to prevent 100% of messages from going into +/// the DLQ if something is misconfigured or bad code is deployed. In this scenario, +/// it may be preferable to stop processing messages altogether and deploy a fix +/// rather than rerouting every message to the DLQ. +/// +/// The ratio and max_consecutive_count are counted on a per-partition basis. +/// +/// The default is no limit. +#[derive(Debug, Clone, Copy, Default)] +pub struct DlqLimit { + pub max_invalid_ratio: Option, + pub max_consecutive_count: Option, +} + +/// A record of valid and invalid messages that have been received on a partition. +#[derive(Debug, Clone, Copy, Default)] +pub struct InvalidMessageStats { + /// The number of valid messages that have been received. + pub valid: u64, + /// The number of invalid messages that have been received. + pub invalid: u64, + /// The length of the current run of received invalid messages. + pub consecutive_invalid: u64, + /// The offset of the last received invalid message. + pub last_invalid_offset: u64, +} + +impl InvalidMessageStats { + /// Creates an empty record with the last invalid message received at `offset`. + /// + /// The `invalid` and `consecutive_invalid` fields are intentionally left at 0. + pub fn invalid_at(offset: u64) -> Self { + Self { + last_invalid_offset: offset, + ..Default::default() + } + } +} + +/// Struct that keeps a record of how many valid and invalid messages have been received +/// per partition and decides whether to produce a message to the DLQ according to a configured limit. +#[derive(Debug, Clone, Default)] +pub struct DlqLimitState { + limit: DlqLimit, + records: HashMap, +} + +impl DlqLimitState { + /// Creates a `DlqLimitState` with a given limit and initial set of records. + pub fn new(limit: DlqLimit, records: HashMap) -> Self { + Self { limit, records } + } + + /// Records an invalid message. + /// + /// This updates the internal statistics about the message's partition and + /// returns `true` if the message should be produced to the DLQ according to the + /// configured limit. + fn record_invalid_message(&mut self, message: &BrokerMessage) -> bool { + let record = self + .records + .entry(message.partition) + .and_modify(|record| { + let last_invalid = record.last_invalid_offset; + match message.offset { + o if o <= last_invalid => { + tracing::error!("Invalid message raised out of order") + } + o if o == last_invalid + 1 => record.consecutive_invalid += 1, + o => { + let valid_count = o - last_invalid + 1; + record.valid += valid_count; + record.consecutive_invalid = 1; + } + } + + record.invalid += 1; + record.last_invalid_offset = message.offset; + }) + .or_insert(InvalidMessageStats { + valid: 0, + invalid: 1, + consecutive_invalid: 1, + last_invalid_offset: message.offset, + }); + + if let Some(max_invalid_ratio) = self.limit.max_invalid_ratio { + if record.valid == 0 { + // When no valid messages have been processed, we should not + // accept the message into the dlq. It could be an indicator + // of severe problems on the pipeline. It is best to let the + // consumer backlog in those cases. + return false; + } + + if (record.invalid as f64) / (record.valid as f64) > max_invalid_ratio { + return false; + } + } + + if let Some(max_consecutive_count) = self.limit.max_consecutive_count { + if record.consecutive_invalid > max_consecutive_count { + return false; + } + } + true + } +} + +/// DLQ policy defines the DLQ configuration, and is passed to the stream processor +/// upon creation of the consumer. It consists of the DLQ producer implementation and +/// any limits that should be applied. +pub struct DlqPolicy { + handle: Handle, + producer: Box>, + limit: DlqLimit, + max_buffered_messages_per_partition: Option, +} + +impl DlqPolicy { + pub fn new( + handle: Handle, + producer: Box>, + limit: DlqLimit, + max_buffered_messages_per_partition: Option, + ) -> Self { + DlqPolicy { + handle, + producer, + limit, + max_buffered_messages_per_partition, + } + } +} + +impl fmt::Debug for DlqPolicy { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("DlqPolicy") + .field("limit", &self.limit) + .field( + "max_buffered_messages_per_partition", + &self.max_buffered_messages_per_partition, + ) + .finish_non_exhaustive() + } +} + +// Wraps the DLQ policy and keeps track of messages pending produce/commit. +type Futures = VecDeque<(u64, JoinHandle>)>; + +struct Inner { + dlq_policy: DlqPolicy, + dlq_limit_state: DlqLimitState, + futures: BTreeMap>, +} + +pub(crate) struct DlqPolicyWrapper { + inner: Option>, +} + +impl DlqPolicyWrapper { + pub fn new(dlq_policy: Option>) -> Self { + let inner = dlq_policy.map(|dlq_policy| Inner { + dlq_policy, + dlq_limit_state: DlqLimitState::default(), + futures: BTreeMap::new(), + }); + Self { inner } + } + + pub fn max_buffered_messages_per_partition(&self) -> Option { + self.inner + .as_ref() + .and_then(|i| i.dlq_policy.max_buffered_messages_per_partition) + } + + /// Clears the DLQ limits. + pub fn reset_dlq_limits(&mut self, assignment: &HashMap) { + let Some(inner) = self.inner.as_mut() else { + return; + }; + + inner.dlq_limit_state = inner + .dlq_policy + .producer + .build_initial_state(inner.dlq_policy.limit, assignment); + } + + // Removes all completed futures, then appends a future with message to be produced + // to the queue. Blocks if there are too many pending futures until some are done. + pub fn produce(&mut self, message: BrokerMessage) { + let Some(inner) = self.inner.as_mut() else { + tracing::info!("dlq policy missing, dropping message"); + return; + }; + for (_p, values) in inner.futures.iter_mut() { + while !values.is_empty() { + let len = values.len(); + let (_, future) = &mut values[0]; + if future.is_finished() || len >= MAX_PENDING_FUTURES { + let res = inner.dlq_policy.handle.block_on(future); + if let Err(err) = res { + tracing::error!("Error producing to DLQ: {}", err); + } + values.pop_front(); + } else { + break; + } + } + } + + if inner.dlq_limit_state.record_invalid_message(&message) { + tracing::info!("producing message to dlq"); + let (partition, offset) = (message.partition, message.offset); + + let task = inner.dlq_policy.producer.produce(message); + let handle = inner.dlq_policy.handle.spawn(task); + + inner + .futures + .entry(partition) + .or_default() + .push_back((offset, handle)); + } else { + panic!("DLQ limit was reached"); + } + } + + // Blocks until all messages up to the committable have been produced so + // they are safe to commit. + pub fn flush(&mut self, committable: &HashMap) { + let Some(inner) = self.inner.as_mut() else { + return; + }; + + for (&p, &committable_offset) in committable { + if let Some(values) = inner.futures.get_mut(&p) { + while let Some((offset, future)) = values.front_mut() { + // The committable offset is message's offset + 1 + if committable_offset > *offset { + if let Err(error) = inner.dlq_policy.handle.block_on(future) { + let error: &dyn std::error::Error = &error; + tracing::error!(error, "Error producing to DLQ"); + } + + values.pop_front(); + } else { + break; + } + } + } + } + } +} + +/// Stores messages that are pending commit. +/// +/// This is used to retreive raw messages in case they need to be placed in the DLQ. +#[derive(Debug, Clone, Default)] +pub struct BufferedMessages { + max_per_partition: Option, + buffered_messages: BTreeMap>>, +} + +impl BufferedMessages { + pub fn new(max_per_partition: Option) -> Self { + BufferedMessages { + max_per_partition, + buffered_messages: BTreeMap::new(), + } + } + + /// Add a message to the buffer. + /// + /// If the configured `max_per_partition` is `0`, this is a no-op. + pub fn append(&mut self, message: BrokerMessage) { + if self.max_per_partition == Some(0) { + return; + } + + let buffered = self.buffered_messages.entry(message.partition).or_default(); + if let Some(max) = self.max_per_partition { + if buffered.len() >= max { + tracing::warn!( + "DLQ buffer exceeded, dropping message on partition {}", + message.partition.index + ); + buffered.pop_front(); + } + } + + buffered.push_back(message); + } + + /// Return the message at the given offset or None if it is not found in the buffer. + /// Messages up to the offset for the given partition are removed. + pub fn pop(&mut self, partition: &Partition, offset: u64) -> Option> { + let messages = self.buffered_messages.get_mut(partition)?; + while let Some(message) = messages.front() { + match message.offset.cmp(&offset) { + Ordering::Equal => { + return messages.pop_front(); + } + Ordering::Greater => { + return None; + } + Ordering::Less => { + messages.pop_front(); + } + }; + } + + None + } + + // Clear the buffer. Should be called on rebalance. + pub fn reset(&mut self) { + self.buffered_messages.clear(); + } +} + +#[cfg(test)] +mod tests { + + use super::*; + + use std::sync::Mutex; + + use chrono::Utc; + + use crate::processing::strategies::run_task_in_threads::ConcurrencyConfig; + use crate::types::Topic; + + #[test] + fn test_buffered_messages() { + let mut buffer = BufferedMessages::new(None); + let partition = Partition { + topic: Topic::new("test"), + index: 1, + }; + + for i in 0..10 { + buffer.append(BrokerMessage { + partition, + offset: i, + payload: i, + timestamp: Utc::now(), + }); + } + + assert_eq!(buffer.pop(&partition, 0).unwrap().offset, 0); + assert_eq!(buffer.pop(&partition, 8).unwrap().offset, 8); + assert!(buffer.pop(&partition, 1).is_none()); // Removed when we popped offset 8 + assert_eq!(buffer.pop(&partition, 9).unwrap().offset, 9); + assert!(buffer.pop(&partition, 10).is_none()); // Doesn't exist + } + + #[test] + fn test_buffered_messages_limit() { + let mut buffer = BufferedMessages::new(Some(2)); + let partition = Partition { + topic: Topic::new("test"), + index: 1, + }; + + for i in 0..10 { + buffer.append(BrokerMessage { + partition, + offset: i, + payload: i, + timestamp: Utc::now(), + }); + } + + // It's gone + assert!(buffer.pop(&partition, 1).is_none()); + + assert_eq!(buffer.pop(&partition, 9).unwrap().payload, 9); + } + + #[test] + fn test_no_buffered_messages() { + let mut buffer = BufferedMessages::new(Some(0)); + let partition = Partition { + topic: Topic::new("test"), + index: 1, + }; + + for i in 0..10 { + buffer.append(BrokerMessage { + partition, + offset: i, + payload: i, + timestamp: Utc::now(), + }); + } + + assert!(buffer.pop(&partition, 9).is_none()); + } + + #[derive(Clone)] + struct TestDlqProducer { + pub call_count: Arc>, + } + + impl TestDlqProducer { + fn new() -> Self { + TestDlqProducer { + call_count: Arc::new(Mutex::new(0)), + } + } + } + + impl DlqProducer for TestDlqProducer { + fn produce( + &self, + message: BrokerMessage, + ) -> Pin> + Send + Sync>> { + *self.call_count.lock().unwrap() += 1; + Box::pin(async move { message }) + } + + fn build_initial_state( + &self, + limit: DlqLimit, + assignment: &HashMap, + ) -> DlqLimitState { + DlqLimitState::new( + limit, + assignment + .iter() + .map(|(p, _)| (*p, InvalidMessageStats::default())) + .collect(), + ) + } + } + + #[test] + fn test_dlq_policy_wrapper() { + let partition = Partition { + topic: Topic::new("test"), + index: 1, + }; + + let producer = TestDlqProducer::new(); + + let handle = ConcurrencyConfig::new(10).handle(); + let mut wrapper = DlqPolicyWrapper::new(Some(DlqPolicy::new( + handle, + Box::new(producer.clone()), + DlqLimit::default(), + None, + ))); + + wrapper.reset_dlq_limits(&HashMap::from([(partition, 0)])); + + for i in 0..10 { + wrapper.produce(BrokerMessage { + partition, + offset: i, + payload: i, + timestamp: Utc::now(), + }); + } + + wrapper.flush(&HashMap::from([(partition, 11)])); + + assert_eq!(*producer.call_count.lock().unwrap(), 10); + } + + #[test] + #[should_panic] + fn test_dlq_policy_wrapper_limit_exceeded() { + let partition = Partition { + topic: Topic::new("test"), + index: 1, + }; + + let producer = TestDlqProducer::new(); + + let handle = ConcurrencyConfig::new(10).handle(); + let mut wrapper = DlqPolicyWrapper::new(Some(DlqPolicy::new( + handle, + Box::new(producer), + DlqLimit { + max_consecutive_count: Some(5), + ..Default::default() + }, + None, + ))); + + wrapper.reset_dlq_limits(&HashMap::from([(partition, 0)])); + + for i in 0..10 { + wrapper.produce(BrokerMessage { + partition, + offset: i, + payload: i, + timestamp: Utc::now(), + }); + } + + wrapper.flush(&HashMap::from([(partition, 11)])); + } + + #[test] + fn test_dlq_limit_state() { + let partition = Partition::new(Topic::new("test_topic"), 0); + let limit = DlqLimit { + max_consecutive_count: Some(5), + ..Default::default() + }; + + let mut state = DlqLimitState::new( + limit, + HashMap::from([(partition, InvalidMessageStats::invalid_at(3))]), + ); + + // 1 valid message followed by 4 invalid + for i in 4..9 { + let msg = BrokerMessage::new(i, partition, i, chrono::Utc::now()); + assert!(state.record_invalid_message(&msg)); + } + + // Next message should not be accepted + let msg = BrokerMessage::new(9, partition, 9, chrono::Utc::now()); + assert!(!state.record_invalid_message(&msg)); + } +} diff --git a/rust-arroyo/src/processing/metrics_buffer.rs b/rust-arroyo/src/processing/metrics_buffer.rs new file mode 100644 index 00000000..984a22e0 --- /dev/null +++ b/rust-arroyo/src/processing/metrics_buffer.rs @@ -0,0 +1,58 @@ +use crate::timer; +use crate::utils::timing::Deadline; +use core::fmt::Debug; +use std::collections::BTreeMap; +use std::mem; +use std::time::Duration; + +#[derive(Debug)] +pub struct MetricsBuffer { + timers: BTreeMap, + flush_deadline: Deadline, +} + +const FLUSH_INTERVAL: Duration = Duration::from_secs(1); + +impl MetricsBuffer { + // A pretty shitty metrics buffer that only handles timing metrics + // and flushes them every second. Needs to be flush()-ed on shutdown + // Doesn't support tags + // Basically the same as https://github.com/getsentry/arroyo/blob/83f5f54e59892ad0b62946ef35d2daec3b561b10/arroyo/processing/processor.py#L80-L112 + // We may want to replace this with the statsdproxy aggregation step. + pub fn new() -> Self { + Self { + timers: BTreeMap::new(), + flush_deadline: Deadline::new(FLUSH_INTERVAL), + } + } + + pub fn incr_timing(&mut self, metric: &str, duration: Duration) { + if let Some(value) = self.timers.get_mut(metric) { + *value += duration; + } else { + self.timers.insert(metric.to_string(), duration); + } + self.throttled_record(); + } + + pub fn flush(&mut self) { + let timers = mem::take(&mut self.timers); + for (metric, duration) in timers { + timer!(&metric, duration); + } + + self.flush_deadline.restart(); + } + + fn throttled_record(&mut self) { + if self.flush_deadline.has_elapsed() { + self.flush(); + } + } +} + +impl Default for MetricsBuffer { + fn default() -> Self { + Self::new() + } +} diff --git a/rust-arroyo/src/processing/mod.rs b/rust-arroyo/src/processing/mod.rs new file mode 100644 index 00000000..d0b56da2 --- /dev/null +++ b/rust-arroyo/src/processing/mod.rs @@ -0,0 +1,636 @@ +use std::collections::HashMap; +use std::panic::{self, AssertUnwindSafe}; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use parking_lot::{Mutex, MutexGuard}; +use thiserror::Error; + +use crate::backends::kafka::config::KafkaConfig; +use crate::backends::kafka::types::KafkaPayload; +use crate::backends::kafka::KafkaConsumer; +use crate::backends::{AssignmentCallbacks, CommitOffsets, Consumer, ConsumerError}; +use crate::processing::dlq::{BufferedMessages, DlqPolicy, DlqPolicyWrapper}; +use crate::processing::strategies::{MessageRejected, StrategyError, SubmitError}; +use crate::types::{InnerMessage, Message, Partition, Topic}; +use crate::utils::timing::Deadline; +use crate::{counter, timer}; + +pub mod dlq; +mod metrics_buffer; +pub mod strategies; + +use strategies::{ProcessingStrategy, ProcessingStrategyFactory}; + +#[derive(Debug, Clone)] +pub struct InvalidState; + +#[derive(Debug, Clone)] +pub struct PollError; + +#[derive(Debug, Clone)] +pub struct PauseError; + +#[derive(Debug, Error)] +pub enum RunError { + #[error("invalid state")] + InvalidState, + #[error("poll error")] + Poll(#[source] ConsumerError), + #[error("pause error")] + Pause(#[source] ConsumerError), + #[error("strategy panicked")] + StrategyPanic, + #[error("the strategy errored")] + Strategy(#[source] Box), +} + +const BACKPRESSURE_THRESHOLD: Duration = Duration::from_secs(1); + +#[derive(Clone)] +pub struct ConsumerState(Arc<(AtomicBool, Mutex>)>); + +struct ConsumerStateInner { + processing_factory: Box>, + strategy: Option>>, + backpressure_deadline: Option, + metrics_buffer: metrics_buffer::MetricsBuffer, + dlq_policy: DlqPolicyWrapper, +} + +impl ConsumerState { + pub fn new( + processing_factory: Box>, + dlq_policy: Option>, + ) -> Self { + let inner = ConsumerStateInner { + processing_factory, + strategy: None, + backpressure_deadline: None, + metrics_buffer: metrics_buffer::MetricsBuffer::new(), + dlq_policy: DlqPolicyWrapper::new(dlq_policy), + }; + Self(Arc::new((AtomicBool::new(false), Mutex::new(inner)))) + } + + fn is_paused(&self) -> bool { + self.0 .0.load(Ordering::Relaxed) + } + + fn set_paused(&self, paused: bool) { + self.0 .0.store(paused, Ordering::Relaxed) + } + + fn locked_state(&self) -> MutexGuard> { + self.0 .1.lock() + } +} + +impl ConsumerStateInner { + fn clear_backpressure(&mut self) { + if let Some(deadline) = self.backpressure_deadline.take() { + self.metrics_buffer + .incr_timing("arroyo.consumer.backpressure.time", deadline.elapsed()); + } + } +} + +pub struct Callbacks(pub ConsumerState); + +#[derive(Debug, Clone)] +pub struct ProcessorHandle { + shutdown_requested: Arc, +} + +impl ProcessorHandle { + pub fn signal_shutdown(&mut self) { + self.shutdown_requested.store(true, Ordering::Relaxed); + } +} + +impl AssignmentCallbacks for Callbacks { + // TODO: Having the initialization of the strategy here + // means that ProcessingStrategy and ProcessingStrategyFactory + // have to be Send and Sync, which is really limiting and unnecessary. + // Revisit this so that it is not the callback that perform the + // initialization. But we just provide a signal back to the + // processor to do that. + fn on_assign(&self, partitions: HashMap) { + counter!( + "arroyo.consumer.partitions_assigned.count", + partitions.len() as i64 + ); + + let start = Instant::now(); + + let mut state = self.0.locked_state(); + state.processing_factory.update_partitions(&partitions); + state.strategy = Some(state.processing_factory.create()); + state.dlq_policy.reset_dlq_limits(&partitions); + + timer!("arroyo.consumer.create_strategy.time", start.elapsed()); + } + + fn on_revoke(&self, commit_offsets: C, partitions: Vec) { + tracing::info!("Start revoke partitions"); + counter!( + "arroyo.consumer.partitions_revoked.count", + partitions.len() as i64, + ); + + let start = Instant::now(); + + let mut state = self.0.locked_state(); + if let Some(s) = state.strategy.as_mut() { + let result = panic::catch_unwind(AssertUnwindSafe(|| { + s.close(); + s.join(None) + })); + + match result { + Ok(join_result) => { + if let Ok(Some(commit_request)) = join_result { + state.dlq_policy.flush(&commit_request.positions); + tracing::info!("Committing offsets"); + let res = commit_offsets.commit(commit_request.positions); + + if let Err(err) = res { + let error: &dyn std::error::Error = &err; + tracing::error!(error, "Failed to commit offsets"); + } + } + } + + Err(err) => { + tracing::error!(?err, "Strategy panicked during close/join"); + } + } + } + state.strategy = None; + self.0.set_paused(false); + state.clear_backpressure(); + + timer!("arroyo.consumer.join.time", start.elapsed()); + + tracing::info!("End revoke partitions"); + + // TODO: Figure out how to flush the metrics buffer from the recovation callback. + } +} + +/// A stream processor manages the relationship between a ``Consumer`` +/// instance and a ``ProcessingStrategy``, ensuring that processing +/// strategies are instantiated on partition assignment and closed on +/// partition revocation. +pub struct StreamProcessor { + consumer: Box>>, + consumer_state: ConsumerState, + message: Option>, + processor_handle: ProcessorHandle, + buffered_messages: BufferedMessages, + metrics_buffer: metrics_buffer::MetricsBuffer, +} + +impl StreamProcessor { + pub fn with_kafka + 'static>( + config: KafkaConfig, + factory: F, + topic: Topic, + dlq_policy: Option>, + ) -> Self { + let consumer_state = ConsumerState::new(Box::new(factory), dlq_policy); + let callbacks = Callbacks(consumer_state.clone()); + + // TODO: Can this fail? + let consumer = Box::new(KafkaConsumer::new(config, &[topic], callbacks).unwrap()); + + Self::new(consumer, consumer_state) + } +} + +impl StreamProcessor { + pub fn new( + consumer: Box>>, + consumer_state: ConsumerState, + ) -> Self { + let max_buffered_messages_per_partition = consumer_state + .locked_state() + .dlq_policy + .max_buffered_messages_per_partition(); + + Self { + consumer, + consumer_state, + message: None, + processor_handle: ProcessorHandle { + shutdown_requested: Arc::new(AtomicBool::new(false)), + }, + buffered_messages: BufferedMessages::new(max_buffered_messages_per_partition), + metrics_buffer: metrics_buffer::MetricsBuffer::new(), + } + } + + pub fn run_once(&mut self) -> Result<(), RunError> { + // In case the strategy panics, we attempt to catch it and return an error. + // This enables the consumer to crash rather than hang indedinitely. + panic::catch_unwind(AssertUnwindSafe(|| self._run_once())) + .unwrap_or(Err(RunError::StrategyPanic)) + } + + fn _run_once(&mut self) -> Result<(), RunError> { + counter!("arroyo.consumer.run.count"); + + let consumer_is_paused = self.consumer_state.is_paused(); + if consumer_is_paused { + // If the consumer was paused, it should not be returning any messages + // on `poll`. + let res = self.consumer.poll(Some(Duration::ZERO)).unwrap(); + if res.is_some() { + return Err(RunError::InvalidState); + } + } else if self.message.is_none() { + // Otherwise, we need to try fetch a new message from the consumer, + // even if there is no active assignment and/or processing strategy. + let poll_start = Instant::now(); + //TODO: Support errors properly + match self.consumer.poll(Some(Duration::from_secs(1))) { + Ok(msg) => { + self.metrics_buffer + .incr_timing("arroyo.consumer.poll.time", poll_start.elapsed()); + + if let Some(broker_msg) = msg { + self.message = Some(Message { + inner_message: InnerMessage::BrokerMessage(broker_msg.clone()), + }); + + self.buffered_messages.append(broker_msg); + } + } + Err(err) => { + let error: &dyn std::error::Error = &err; + tracing::error!(error, "poll error"); + return Err(RunError::Poll(err)); + } + } + } + + // since we do not drive the kafka consumer at this point, it is safe to acquire the state + // lock, as we can be sure that for the rest of this function, no assignment callback will + // run. + let mut consumer_state = self.consumer_state.locked_state(); + let consumer_state: &mut ConsumerStateInner<_> = &mut consumer_state; + + let Some(strategy) = consumer_state.strategy.as_mut() else { + match self.message.as_ref() { + None => return Ok(()), + Some(_) => return Err(RunError::InvalidState), + } + }; + let processing_start = Instant::now(); + + match strategy.poll() { + Ok(None) => {} + Ok(Some(request)) => { + for (partition, offset) in &request.positions { + self.buffered_messages.pop(partition, offset - 1); + } + + consumer_state.dlq_policy.flush(&request.positions); + self.consumer.commit_offsets(request.positions).unwrap(); + } + Err(StrategyError::InvalidMessage(e)) => { + match self.buffered_messages.pop(&e.partition, e.offset) { + Some(msg) => { + tracing::error!(?e, "Invalid message"); + consumer_state.dlq_policy.produce(msg); + } + None => { + tracing::error!("Could not find invalid message in buffer"); + } + } + } + + Err(StrategyError::Other(error)) => { + return Err(RunError::Strategy(error)); + } + }; + + let Some(msg_s) = self.message.take() else { + self.metrics_buffer.incr_timing( + "arroyo.consumer.processing.time", + processing_start.elapsed(), + ); + return Ok(()); + }; + + let ret = strategy.submit(msg_s); + self.metrics_buffer.incr_timing( + "arroyo.consumer.processing.time", + processing_start.elapsed(), + ); + + match ret { + Ok(()) => { + // Resume if we are currently in a paused state + if consumer_is_paused { + let partitions = self.consumer.tell().unwrap().into_keys().collect(); + + match self.consumer.resume(partitions) { + Ok(()) => { + self.consumer_state.set_paused(false); + } + Err(err) => { + let error: &dyn std::error::Error = &err; + tracing::error!(error, "pause error"); + return Err(RunError::Pause(err)); + } + } + } + + // Clear backpressure timestamp if it is set + consumer_state.clear_backpressure(); + } + Err(SubmitError::MessageRejected(MessageRejected { message })) => { + // Put back the carried over message + self.message = Some(message); + + let Some(deadline) = consumer_state.backpressure_deadline else { + consumer_state.backpressure_deadline = + Some(Deadline::new(BACKPRESSURE_THRESHOLD)); + return Ok(()); + }; + + // If we are in the backpressure state for more than 1 second, + // we pause the consumer and hold the message until it is + // accepted, at which point we can resume consuming. + if !consumer_is_paused && deadline.has_elapsed() { + tracing::warn!( + "Consumer is in backpressure state for more than 1 second, pausing", + ); + + let partitions = self.consumer.tell().unwrap().into_keys().collect(); + + match self.consumer.pause(partitions) { + Ok(()) => { + self.consumer_state.set_paused(true); + } + Err(err) => { + let error: &dyn std::error::Error = &err; + tracing::error!(error, "pause error"); + return Err(RunError::Pause(err)); + } + } + } + } + Err(SubmitError::InvalidMessage(message)) => { + let invalid_message = self + .buffered_messages + .pop(&message.partition, message.offset); + + if let Some(msg) = invalid_message { + tracing::error!(?message, "Invalid message"); + consumer_state.dlq_policy.produce(msg); + } else { + tracing::error!(?message, "Could not retrieve invalid message from buffer"); + } + } + } + Ok(()) + } + + /// The main run loop, see class docstring for more information. + pub fn run(mut self) -> Result<(), RunError> { + while !self + .processor_handle + .shutdown_requested + .load(Ordering::Relaxed) + { + if let Err(e) = self.run_once() { + let mut trait_callbacks = self.consumer_state.locked_state(); + + if let Some(strategy) = trait_callbacks.strategy.as_mut() { + strategy.terminate(); + } + + drop(trait_callbacks); // unlock mutex so we can close consumer + return Err(e); + } + } + Ok(()) + } + + pub fn get_handle(&self) -> ProcessorHandle { + self.processor_handle.clone() + } + + pub fn tell(&self) -> HashMap { + self.consumer.tell().unwrap() + } + + pub fn shutdown(self) {} +} + +#[cfg(test)] +mod tests { + use super::strategies::{ + CommitRequest, ProcessingStrategy, ProcessingStrategyFactory, StrategyError, SubmitError, + }; + use super::*; + use crate::backends::local::broker::LocalBroker; + use crate::backends::local::LocalConsumer; + use crate::backends::storages::memory::MemoryMessageStorage; + use crate::types::{Message, Partition, Topic}; + use crate::utils::clock::SystemClock; + use std::collections::HashMap; + use std::time::Duration; + use uuid::Uuid; + + struct TestStrategy { + message: Option>, + } + impl ProcessingStrategy for TestStrategy { + #[allow(clippy::manual_map)] + fn poll(&mut self) -> Result, StrategyError> { + Ok(self.message.as_ref().map(|message| CommitRequest { + positions: HashMap::from_iter(message.committable()), + })) + } + + fn submit(&mut self, message: Message) -> Result<(), SubmitError> { + self.message = Some(message); + Ok(()) + } + + fn close(&mut self) {} + + fn terminate(&mut self) {} + + fn join(&mut self, _: Option) -> Result, StrategyError> { + Ok(None) + } + } + + struct TestFactory {} + impl ProcessingStrategyFactory for TestFactory { + fn create(&self) -> Box> { + Box::new(TestStrategy { message: None }) + } + } + + fn build_broker() -> LocalBroker { + let storage: MemoryMessageStorage = Default::default(); + let clock = SystemClock {}; + let mut broker = LocalBroker::new(Box::new(storage), Box::new(clock)); + + let topic1 = Topic::new("test1"); + + let _ = broker.create_topic(topic1, 1); + broker + } + + #[test] + fn test_processor() { + let broker = build_broker(); + + let consumer_state = ConsumerState::new(Box::new(TestFactory {}), None); + + let consumer = Box::new(LocalConsumer::new( + Uuid::nil(), + Arc::new(Mutex::new(broker)), + "test_group".to_string(), + false, + &[Topic::new("test1")], + Callbacks(consumer_state.clone()), + )); + + let mut processor = StreamProcessor::new(consumer, consumer_state); + let res = processor.run_once(); + assert!(res.is_ok()) + } + + #[test] + fn test_consume() { + let mut broker = build_broker(); + let topic1 = Topic::new("test1"); + let partition = Partition::new(topic1, 0); + let _ = broker.produce(&partition, "message1".to_string()); + let _ = broker.produce(&partition, "message2".to_string()); + + let consumer_state = ConsumerState::new(Box::new(TestFactory {}), None); + + let consumer = Box::new(LocalConsumer::new( + Uuid::nil(), + Arc::new(Mutex::new(broker)), + "test_group".to_string(), + false, + &[Topic::new("test1")], + Callbacks(consumer_state.clone()), + )); + + let mut processor = StreamProcessor::new(consumer, consumer_state); + let res = processor.run_once(); + assert!(res.is_ok()); + let res = processor.run_once(); + assert!(res.is_ok()); + + let expected = HashMap::from([(partition, 2)]); + + assert_eq!(processor.tell(), expected) + } + + #[test] + fn test_strategy_panic() { + // Tests that a panic in any of the poll, submit, join, or close methods will crash the consumer + // and not deadlock + struct TestStrategy { + panic_on: &'static str, // poll, submit, join, close + } + impl ProcessingStrategy for TestStrategy { + fn poll(&mut self) -> Result, StrategyError> { + if self.panic_on == "poll" { + panic!("panic in poll"); + } + Ok(None) + } + + fn submit(&mut self, _message: Message) -> Result<(), SubmitError> { + if self.panic_on == "submit" { + panic!("panic in submit"); + } + + Ok(()) + } + + fn close(&mut self) { + if self.panic_on == "close" { + panic!("panic in close"); + } + } + + fn terminate(&mut self) {} + + fn join( + &mut self, + _: Option, + ) -> Result, StrategyError> { + if self.panic_on == "join" { + panic!("panic in join"); + } + + Ok(None) + } + } + + struct TestFactory { + panic_on: &'static str, + } + impl ProcessingStrategyFactory for TestFactory { + fn create(&self) -> Box> { + Box::new(TestStrategy { + panic_on: self.panic_on, + }) + } + } + + fn build_processor( + broker: LocalBroker, + panic_on: &'static str, + ) -> StreamProcessor { + let consumer_state = ConsumerState::new(Box::new(TestFactory { panic_on }), None); + + let consumer = Box::new(LocalConsumer::new( + Uuid::nil(), + Arc::new(Mutex::new(broker)), + "test_group".to_string(), + false, + &[Topic::new("test1")], + Callbacks(consumer_state.clone()), + )); + + StreamProcessor::new(consumer, consumer_state) + } + + let topic1 = Topic::new("test1"); + let partition = Partition::new(topic1, 0); + + let test_cases = ["poll", "submit", "join", "close"]; + + for test_case in test_cases { + let mut broker = build_broker(); + let _ = broker.produce(&partition, "message1".to_string()); + let _ = broker.produce(&partition, "message2".to_string()); + let mut processor = build_processor(broker, test_case); + + let res = processor.run_once(); + + if test_case == "join" || test_case == "close" { + assert!(res.is_ok()); + } else { + assert!(res.is_err()); + } + + processor.shutdown(); + } + } +} diff --git a/rust-arroyo/src/processing/strategies/commit_offsets.rs b/rust-arroyo/src/processing/strategies/commit_offsets.rs new file mode 100644 index 00000000..94c8b392 --- /dev/null +++ b/rust-arroyo/src/processing/strategies/commit_offsets.rs @@ -0,0 +1,154 @@ +use std::collections::HashMap; + +use chrono::{DateTime, Duration, Utc}; + +use crate::processing::strategies::{CommitRequest, ProcessingStrategy, SubmitError}; +use crate::timer; +use crate::types::{Message, Partition}; + +use super::StrategyError; + +pub struct CommitOffsets { + partitions: HashMap, + last_commit_time: DateTime, + last_record_time: DateTime, + commit_frequency: Duration, +} + +impl CommitOffsets { + pub fn new(commit_frequency: Duration) -> Self { + CommitOffsets { + partitions: Default::default(), + last_commit_time: Utc::now(), + last_record_time: Utc::now(), + commit_frequency, + } + } + + fn commit(&mut self, force: bool) -> Option { + if Utc::now() - self.last_commit_time <= self.commit_frequency && !force { + return None; + } + + if self.partitions.is_empty() { + return None; + } + + let ret = Some(CommitRequest { + positions: self.partitions.clone(), + }); + self.partitions.clear(); + self.last_commit_time = Utc::now(); + ret + } +} + +impl ProcessingStrategy for CommitOffsets { + fn poll(&mut self) -> Result, StrategyError> { + Ok(self.commit(false)) + } + + fn submit(&mut self, message: Message) -> Result<(), SubmitError> { + let now = Utc::now(); + if now - self.last_record_time > Duration::seconds(1) { + if let Some(timestamp) = message.timestamp() { + // FIXME: this used to be in seconds + timer!( + "arroyo.consumer.latency", + (now - timestamp).to_std().unwrap_or_default() + ); + self.last_record_time = now; + } + } + + for (partition, offset) in message.committable() { + self.partitions.insert(partition, offset); + } + Ok(()) + } + + fn close(&mut self) {} + + fn terminate(&mut self) {} + + fn join( + &mut self, + _: Option, + ) -> Result, StrategyError> { + Ok(self.commit(true)) + } +} + +#[cfg(test)] +mod tests { + use crate::backends::kafka::types::KafkaPayload; + use crate::processing::strategies::commit_offsets::CommitOffsets; + use crate::processing::strategies::{CommitRequest, ProcessingStrategy}; + + use crate::types::{BrokerMessage, InnerMessage, Message, Partition, Topic}; + use chrono::DateTime; + use std::thread::sleep; + use std::time::{Duration, SystemTime}; + + #[test] + fn test_commit_offsets() { + tracing_subscriber::fmt().with_test_writer().init(); + let partition1 = Partition::new(Topic::new("noop-commit"), 0); + let partition2 = Partition::new(Topic::new("noop-commit"), 1); + let timestamp = DateTime::from(SystemTime::now()); + + let m1 = Message { + inner_message: InnerMessage::BrokerMessage(BrokerMessage { + partition: partition1, + offset: 1000, + payload: KafkaPayload::new(None, None, None), + timestamp, + }), + }; + + let m2 = Message { + inner_message: InnerMessage::BrokerMessage(BrokerMessage { + partition: partition2, + offset: 2000, + payload: KafkaPayload::new(None, None, None), + timestamp, + }), + }; + + let mut noop = CommitOffsets::new(chrono::Duration::seconds(1)); + + let mut commit_req1 = CommitRequest { + positions: Default::default(), + }; + commit_req1.positions.insert(partition1, 1001); + noop.submit(m1).expect("Failed to submit"); + assert_eq!( + >::poll(&mut noop).unwrap(), + None + ); + + sleep(Duration::from_secs(2)); + assert_eq!( + >::poll(&mut noop).unwrap(), + Some(commit_req1) + ); + + let mut commit_req2 = CommitRequest { + positions: Default::default(), + }; + commit_req2.positions.insert(partition2, 2001); + noop.submit(m2).expect("Failed to submit"); + assert_eq!( + >::poll(&mut noop).unwrap(), + None + ); + assert_eq!( + >::join( + &mut noop, + Some(Duration::from_secs(5)) + ) + .unwrap(), + Some(commit_req2) + ); + } +} diff --git a/rust-arroyo/src/processing/strategies/healthcheck.rs b/rust-arroyo/src/processing/strategies/healthcheck.rs new file mode 100644 index 00000000..a6847e41 --- /dev/null +++ b/rust-arroyo/src/processing/strategies/healthcheck.rs @@ -0,0 +1,73 @@ +use std::path::PathBuf; +use std::time::{Duration, SystemTime}; + +use crate::counter; +use crate::processing::strategies::{ + CommitRequest, ProcessingStrategy, StrategyError, SubmitError, +}; +use crate::types::Message; + +const TOUCH_INTERVAL: Duration = Duration::from_secs(1); + +pub struct HealthCheck { + next_step: Next, + path: PathBuf, + interval: Duration, + deadline: SystemTime, +} + +impl HealthCheck { + pub fn new(next_step: Next, path: impl Into) -> Self { + let interval = TOUCH_INTERVAL; + let deadline = SystemTime::now() + interval; + + Self { + next_step, + path: path.into(), + interval, + deadline, + } + } + + fn maybe_touch_file(&mut self) { + let now = SystemTime::now(); + if now < self.deadline { + return; + } + + if let Err(err) = std::fs::File::create(&self.path) { + let error: &dyn std::error::Error = &err; + tracing::error!(error); + } + + counter!("arroyo.processing.strategies.healthcheck.touch"); + self.deadline = now + self.interval; + } +} + +impl ProcessingStrategy for HealthCheck +where + Next: ProcessingStrategy + 'static, +{ + fn poll(&mut self) -> Result, StrategyError> { + self.maybe_touch_file(); + + self.next_step.poll() + } + + fn submit(&mut self, message: Message) -> Result<(), SubmitError> { + self.next_step.submit(message) + } + + fn close(&mut self) { + self.next_step.close() + } + + fn terminate(&mut self) { + self.next_step.terminate() + } + + fn join(&mut self, timeout: Option) -> Result, StrategyError> { + self.next_step.join(timeout) + } +} diff --git a/rust-arroyo/src/processing/strategies/mod.rs b/rust-arroyo/src/processing/strategies/mod.rs new file mode 100644 index 00000000..61e2995d --- /dev/null +++ b/rust-arroyo/src/processing/strategies/mod.rs @@ -0,0 +1,205 @@ +use crate::types::{Message, Partition}; +use std::collections::HashMap; +use std::time::Duration; + +pub mod commit_offsets; +pub mod healthcheck; +pub mod produce; +pub mod reduce; +pub mod run_task; +pub mod run_task_in_threads; + +#[derive(Debug, Clone)] +pub enum SubmitError { + MessageRejected(MessageRejected), + InvalidMessage(InvalidMessage), +} + +#[derive(Debug, Clone)] +pub struct MessageRejected { + pub message: Message, +} + +#[derive(Debug, Clone)] +pub struct InvalidMessage { + pub partition: Partition, + pub offset: u64, +} + +/// Signals that we need to commit offsets +#[derive(Debug, Clone, PartialEq)] +pub struct CommitRequest { + pub positions: HashMap, +} + +impl CommitRequest { + pub fn merge(mut self, other: CommitRequest) -> Self { + // Merge commit requests, keeping the highest offset for each partition + for (partition, offset) in other.positions { + if let Some(pos_offset) = self.positions.get_mut(&partition) { + *pos_offset = (*pos_offset).max(offset); + } else { + self.positions.insert(partition, offset); + } + } + + self + } +} + +pub fn merge_commit_request( + value: Option, + other: Option, +) -> Option { + match (value, other) { + (None, None) => None, + (Some(x), None) => Some(x), + (None, Some(y)) => Some(y), + (Some(x), Some(y)) => Some(x.merge(y)), + } +} + +#[derive(Debug)] +pub enum StrategyError { + InvalidMessage(InvalidMessage), + Other(Box), +} + +impl From for StrategyError { + fn from(value: InvalidMessage) -> Self { + Self::InvalidMessage(value) + } +} + +/// A processing strategy defines how a stream processor processes messages +/// during the course of a single assignment. The processor is instantiated +/// when the assignment is received, and closed when the assignment is +/// revoked. +/// +/// This interface is intentionally not prescriptive, and affords a +/// significant degree of flexibility for the various implementations. +pub trait ProcessingStrategy: Send + Sync { + /// Poll the processor to check on the status of asynchronous tasks or + /// perform other scheduled work. + /// + /// This method is called on each consumer loop iteration, so this method + /// should not be used to perform work that may block for a significant + /// amount of time and block the progress of the consumer or exceed the + /// consumer poll interval timeout. + /// + /// This method may raise exceptions that were thrown by asynchronous + /// tasks since the previous call to ``poll``. + fn poll(&mut self) -> Result, StrategyError>; + + /// Submit a message for processing. + /// + /// Messages may be processed synchronously or asynchronously, depending + /// on the implementation of the processing strategy. Callers of this + /// method should not assume that this method returning successfully + /// implies that the message was successfully processed. + /// + /// If the processing strategy is unable to accept a message (due to it + /// being at or over capacity, for example), this method will raise a + /// ``MessageRejected`` exception. + fn submit(&mut self, message: Message) -> Result<(), SubmitError>; + + /// Close this instance. No more messages should be accepted by the + /// instance after this method has been called. + /// + /// This method should not block. Once this strategy instance has + /// finished processing (or discarded) all messages that were submitted + /// prior to this method being called, the strategy should commit its + /// partition offsets and release any resources that will no longer be + /// used (threads, processes, sockets, files, etc.) + fn close(&mut self); + + /// Close the processing strategy immediately, abandoning any work in + /// progress. No more messages should be accepted by the instance after + /// this method has been called. + fn terminate(&mut self); + + /// Block until the processing strategy has completed all previously + /// submitted work, or the provided timeout has been reached. This method + /// should be called after ``close`` to provide a graceful shutdown. + /// + /// This method is called synchronously by the stream processor during + /// assignment revocation, and blocks the assignment from being released + /// until this function exits, allowing any work in progress to be + /// completed and committed before the continuing the rebalancing + /// process. + fn join(&mut self, timeout: Option) -> Result, StrategyError>; +} + +impl + ?Sized> ProcessingStrategy for Box { + fn poll(&mut self) -> Result, StrategyError> { + (**self).poll() + } + + fn submit(&mut self, message: Message) -> Result<(), SubmitError> { + (**self).submit(message) + } + + fn close(&mut self) { + (**self).close() + } + + fn terminate(&mut self) { + (**self).terminate() + } + + fn join(&mut self, timeout: Option) -> Result, StrategyError> { + (**self).join(timeout) + } +} + +pub trait ProcessingStrategyFactory: Send + Sync { + /// Instantiate and return a ``ProcessingStrategy`` instance. + /// + /// This callback is executed on almost every rebalance, so do not do any heavy operations in + /// here. + /// + /// In a future version of Arroyo we might want to call this callback less often, for example + /// only if it is necessary to join and close strategies for ordering guarantees. + fn create(&self) -> Box>; + + /// Callback to find out about the currently assigned partitions. + /// + /// Do not do any heavy work in this callback, even less than in `create`. This is guaranteed + /// to be called every rebalance. + fn update_partitions(&self, _partitions: &HashMap) {} +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::types::Topic; + + #[test] + fn merge() { + let partition = Partition::new(Topic::new("topic"), 0); + let partition_2 = Partition::new(Topic::new("topic"), 1); + + let a = Some(CommitRequest { + positions: HashMap::from([(partition, 1)]), + }); + + let b = Some(CommitRequest { + positions: HashMap::from([(partition, 2)]), + }); + + let c = Some(CommitRequest { + positions: HashMap::from([(partition_2, 2)]), + }); + + assert_eq!(merge_commit_request(a.clone(), b.clone()), b.clone()); + + assert_eq!( + merge_commit_request(a.clone(), c.clone()), + Some(CommitRequest { + positions: HashMap::from([(partition, 1), (partition_2, 2)]), + }) + ); + + assert_eq!(merge_commit_request(c.clone(), None), c.clone()); + } +} diff --git a/rust-arroyo/src/processing/strategies/produce.rs b/rust-arroyo/src/processing/strategies/produce.rs new file mode 100644 index 00000000..fb19017a --- /dev/null +++ b/rust-arroyo/src/processing/strategies/produce.rs @@ -0,0 +1,280 @@ +use crate::backends::kafka::types::KafkaPayload; +use crate::backends::{Producer, ProducerError}; +use crate::processing::strategies::run_task_in_threads::{ + ConcurrencyConfig, RunTaskFunc, RunTaskInThreads, TaskRunner, +}; +use crate::processing::strategies::{ + CommitRequest, ProcessingStrategy, StrategyError, SubmitError, +}; +use crate::types::{Message, TopicOrPartition}; +use std::sync::Arc; +use std::time::Duration; + +use super::run_task_in_threads::RunTaskError; + +struct ProduceMessage { + producer: Arc>, + topic: TopicOrPartition, +} + +impl ProduceMessage { + pub fn new(producer: impl Producer + 'static, topic: TopicOrPartition) -> Self { + ProduceMessage { + producer: Arc::new(producer), + topic, + } + } +} + +impl TaskRunner for ProduceMessage { + fn get_task(&self, message: Message) -> RunTaskFunc { + let producer = self.producer.clone(); + let topic = self.topic; + + Box::pin(async move { + producer + .produce(&topic, message.payload().clone()) + .map_err(RunTaskError::Other)?; + Ok(message) + }) + } +} + +pub struct Produce { + inner: RunTaskInThreads, +} + +impl Produce { + pub fn new( + next_step: N, + producer: impl Producer + 'static, + concurrency: &ConcurrencyConfig, + topic: TopicOrPartition, + ) -> Self + where + N: ProcessingStrategy + 'static, + { + let inner = RunTaskInThreads::new( + next_step, + Box::new(ProduceMessage::new(producer, topic)), + concurrency, + Some("produce"), + ); + + Produce { inner } + } +} + +impl ProcessingStrategy for Produce { + fn poll(&mut self) -> Result, StrategyError> { + self.inner.poll() + } + + fn submit(&mut self, message: Message) -> Result<(), SubmitError> { + self.inner.submit(message) + } + + fn close(&mut self) { + self.inner.close(); + } + + fn terminate(&mut self) { + self.inner.terminate(); + } + + fn join(&mut self, timeout: Option) -> Result, StrategyError> { + self.inner.join(timeout) + } +} + +#[cfg(test)] +mod tests { + use parking_lot::Mutex; + use std::time::SystemTime; + + use super::*; + use crate::backends::kafka::config::KafkaConfig; + use crate::backends::kafka::producer::KafkaProducer; + use crate::backends::kafka::InitialOffset; + use crate::backends::local::broker::LocalBroker; + use crate::backends::local::LocalProducer; + use crate::backends::storages::memory::MemoryMessageStorage; + use crate::processing::strategies::StrategyError; + use crate::types::{BrokerMessage, InnerMessage, Partition, Topic}; + use crate::utils::clock::TestingClock; + use chrono::Utc; + + #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] + struct Counts { + submit: u8, + polled: bool, + } + + struct Mock(Arc>); + + impl Mock { + fn new() -> Self { + Self(Arc::new(Mutex::new(Default::default()))) + } + + fn counts(&self) -> Arc> { + self.0.clone() + } + } + + impl ProcessingStrategy for Mock { + fn poll(&mut self) -> Result, StrategyError> { + self.0.lock().polled = true; + Ok(None) + } + fn submit( + &mut self, + _message: Message, + ) -> Result<(), SubmitError> { + self.0.lock().submit += 1; + Ok(()) + } + fn close(&mut self) {} + fn terminate(&mut self) {} + fn join( + &mut self, + _timeout: Option, + ) -> Result, StrategyError> { + Ok(None) + } + } + + #[test] + fn test_produce() { + let config = KafkaConfig::new_consumer_config( + vec![std::env::var("DEFAULT_BROKERS").unwrap_or("127.0.0.1:9092".to_string())], + "my_group".to_string(), + InitialOffset::Latest, + false, + 30_000, + None, + ); + + let partition = Partition::new(Topic::new("test"), 0); + + struct Noop {} + impl ProcessingStrategy for Noop { + fn poll(&mut self) -> Result, StrategyError> { + Ok(None) + } + fn submit( + &mut self, + _message: Message, + ) -> Result<(), SubmitError> { + Ok(()) + } + fn close(&mut self) {} + fn terminate(&mut self) {} + fn join( + &mut self, + _timeout: Option, + ) -> Result, StrategyError> { + Ok(None) + } + } + + let producer: KafkaProducer = KafkaProducer::new(config); + let concurrency = ConcurrencyConfig::new(10); + let mut strategy = Produce::new( + Noop {}, + producer, + &concurrency, + TopicOrPartition::Topic(partition.topic), + ); + + let payload_str = "hello world".to_string().as_bytes().to_vec(); + let message = Message { + inner_message: InnerMessage::BrokerMessage(BrokerMessage { + payload: KafkaPayload::new(None, None, Some(payload_str.clone())), + partition, + offset: 0, + timestamp: Utc::now(), + }), + }; + + strategy.submit(message).unwrap(); + strategy.close(); + let _ = strategy.join(None); + } + + #[test] + fn test_produce_local() { + let orig_topic = Topic::new("orig-topic"); + let result_topic = Topic::new("result-topic"); + let clock = TestingClock::new(SystemTime::now()); + let storage = MemoryMessageStorage::default(); + let mut broker = LocalBroker::new(Box::new(storage), Box::new(clock)); + broker.create_topic(result_topic, 1).unwrap(); + + let broker = Arc::new(Mutex::new(broker)); + let producer = LocalProducer::new(broker.clone()); + + let next_step = Mock::new(); + let counts = next_step.counts(); + let concurrency_config = ConcurrencyConfig::new(1); + let mut strategy = Produce::new( + next_step, + producer, + &concurrency_config, + result_topic.into(), + ); + + let value = br#"{"something": "something"}"#.to_vec(); + let data = KafkaPayload::new(None, None, Some(value.clone())); + let now = chrono::Utc::now(); + + let message = Message::new_broker_message(data, Partition::new(orig_topic, 0), 1, now); + strategy.submit(message.clone()).unwrap(); + strategy.join(None).unwrap(); + + let produced_message = broker + .lock() + .storage_mut() + .consume(&Partition::new(result_topic, 0), 0) + .unwrap() + .unwrap(); + + assert_eq!(produced_message.payload.payload().unwrap(), &value); + + assert!(broker + .lock() + .storage_mut() + .consume(&Partition::new(result_topic, 0), 1) + .unwrap() + .is_none()); + + strategy.poll().unwrap(); + assert_eq!( + *counts.lock(), + Counts { + submit: 1, + polled: true + } + ); + + strategy.submit(message.clone()).unwrap(); + strategy.join(None).unwrap(); + assert_eq!( + *counts.lock(), + Counts { + submit: 2, + polled: true, + } + ); + + let mut result = Ok(()); + for _ in 0..3 { + result = strategy.submit(message.clone()); + if result.is_err() { + break; + } + } + + assert!(result.is_err()); + } +} diff --git a/rust-arroyo/src/processing/strategies/reduce.rs b/rust-arroyo/src/processing/strategies/reduce.rs new file mode 100644 index 00000000..1ce00125 --- /dev/null +++ b/rust-arroyo/src/processing/strategies/reduce.rs @@ -0,0 +1,409 @@ +use crate::processing::strategies::{ + merge_commit_request, CommitRequest, MessageRejected, ProcessingStrategy, StrategyError, + SubmitError, +}; +use crate::timer; +use crate::types::{Message, Partition}; +use crate::utils::timing::Deadline; +use std::collections::BTreeMap; +use std::mem; +use std::sync::Arc; +use std::time::Duration; + +use super::InvalidMessage; + +struct BatchState { + value: Option, + accumulator: Arc TResult + Send + Sync>, + offsets: BTreeMap, + batch_start_time: Deadline, + message_count: usize, + compute_batch_size: fn(&T) -> usize, +} + +impl BatchState { + fn new( + initial_value: TResult, + accumulator: Arc TResult + Send + Sync>, + max_batch_time: Duration, + compute_batch_size: fn(&T) -> usize, + ) -> BatchState { + BatchState { + value: Some(initial_value), + accumulator, + offsets: Default::default(), + batch_start_time: Deadline::new(max_batch_time), + message_count: 0, + compute_batch_size, + } + } + + fn add(&mut self, message: Message) { + for (partition, offset) in message.committable() { + self.offsets.insert(partition, offset); + } + + let tmp = self.value.take().unwrap(); + let payload = message.into_payload(); + self.message_count += (self.compute_batch_size)(&payload); + self.value = Some((self.accumulator)(tmp, payload)); + } +} + +pub struct Reduce { + next_step: Box>, + accumulator: Arc TResult + Send + Sync>, + initial_value: TResult, + max_batch_size: usize, + max_batch_time: Duration, + batch_state: BatchState, + message_carried_over: Option>, + commit_request_carried_over: Option, + compute_batch_size: fn(&T) -> usize, +} + +impl ProcessingStrategy for Reduce { + fn poll(&mut self) -> Result, StrategyError> { + let commit_request = self.next_step.poll()?; + self.commit_request_carried_over = + merge_commit_request(self.commit_request_carried_over.take(), commit_request); + + self.flush(false)?; + + Ok(self.commit_request_carried_over.take()) + } + + fn submit(&mut self, message: Message) -> Result<(), SubmitError> { + if self.message_carried_over.is_some() { + return Err(SubmitError::MessageRejected(MessageRejected { message })); + } + + self.batch_state.add(message); + + Ok(()) + } + + fn close(&mut self) { + self.next_step.close(); + } + + fn terminate(&mut self) { + self.next_step.terminate(); + } + + fn join(&mut self, timeout: Option) -> Result, StrategyError> { + let deadline = timeout.map(Deadline::new); + if self.message_carried_over.is_some() { + while self.message_carried_over.is_some() { + let next_commit = self.next_step.poll()?; + self.commit_request_carried_over = + merge_commit_request(self.commit_request_carried_over.take(), next_commit); + self.flush(true)?; + + if deadline.map_or(false, |d| d.has_elapsed()) { + tracing::warn!("Timeout reached while waiting for tasks to finish"); + break; + } + } + } else { + self.flush(true)?; + } + + let next_commit = self.next_step.join(deadline.map(|d| d.remaining()))?; + + Ok(merge_commit_request( + self.commit_request_carried_over.take(), + next_commit, + )) + } +} + +impl Reduce { + pub fn new( + next_step: N, + accumulator: Arc TResult + Send + Sync>, + initial_value: TResult, + max_batch_size: usize, + max_batch_time: Duration, + compute_batch_size: fn(&T) -> usize, + ) -> Self + where + N: ProcessingStrategy + 'static, + { + let batch_state = BatchState::new( + initial_value.clone(), + accumulator.clone(), + max_batch_time, + compute_batch_size, + ); + Reduce { + next_step: Box::new(next_step), + accumulator, + initial_value, + max_batch_size, + max_batch_time, + batch_state, + message_carried_over: None, + commit_request_carried_over: None, + compute_batch_size, + } + } + + fn flush(&mut self, force: bool) -> Result<(), InvalidMessage> { + // Try re-submitting the carried over message if there is one + if let Some(message) = self.message_carried_over.take() { + match self.next_step.submit(message) { + Err(SubmitError::MessageRejected(MessageRejected { + message: transformed_message, + })) => { + self.message_carried_over = Some(transformed_message); + } + Err(SubmitError::InvalidMessage(invalid_message)) => { + return Err(invalid_message); + } + Ok(_) => {} + } + } + + if self.batch_state.message_count == 0 { + return Ok(()); + } + + let batch_time = self.batch_state.batch_start_time.elapsed(); + let batch_complete = self.batch_state.message_count >= self.max_batch_size + || batch_time >= self.max_batch_time; + + if !batch_complete && !force { + return Ok(()); + } + + // FIXME: this used to be in seconds + timer!("arroyo.strategies.reduce.batch_time", batch_time); + + let batch_state = mem::replace( + &mut self.batch_state, + BatchState::new( + self.initial_value.clone(), + self.accumulator.clone(), + self.max_batch_time, + self.compute_batch_size, + ), + ); + + let next_message = + Message::new_any_message(batch_state.value.unwrap(), batch_state.offsets); + + match self.next_step.submit(next_message) { + Err(SubmitError::MessageRejected(MessageRejected { + message: transformed_message, + })) => { + self.message_carried_over = Some(transformed_message); + Ok(()) + } + Err(SubmitError::InvalidMessage(invalid_message)) => Err(invalid_message), + Ok(_) => Ok(()), + } + } +} + +#[cfg(test)] +mod tests { + use crate::processing::strategies::reduce::Reduce; + use crate::processing::strategies::{ + CommitRequest, ProcessingStrategy, StrategyError, SubmitError, + }; + use crate::types::{BrokerMessage, InnerMessage, Message, Partition, Topic}; + use std::sync::{Arc, Mutex}; + use std::time::Duration; + + struct NextStep { + pub submitted: Arc>>, + } + + impl ProcessingStrategy for NextStep { + fn poll(&mut self) -> Result, StrategyError> { + Ok(None) + } + + fn submit(&mut self, message: Message) -> Result<(), SubmitError> { + self.submitted.lock().unwrap().push(message.into_payload()); + Ok(()) + } + + fn close(&mut self) {} + + fn terminate(&mut self) {} + + fn join(&mut self, _: Option) -> Result, StrategyError> { + Ok(None) + } + } + + #[test] + fn test_reduce() { + let submitted_messages = Arc::new(Mutex::new(Vec::new())); + let submitted_messages_clone = submitted_messages.clone(); + + let partition1 = Partition::new(Topic::new("test"), 0); + + let max_batch_size = 2; + let max_batch_time = Duration::from_secs(1); + + let initial_value = Vec::new(); + let accumulator = Arc::new(|mut acc: Vec, value: u64| { + acc.push(value); + acc + }); + let compute_batch_size = |_: &_| -> usize { 1 }; + + let next_step = NextStep { + submitted: submitted_messages, + }; + + let mut strategy = Reduce::new( + next_step, + accumulator, + initial_value, + max_batch_size, + max_batch_time, + compute_batch_size, + ); + + for i in 0..3 { + let msg = Message { + inner_message: InnerMessage::BrokerMessage(BrokerMessage::new( + i, + partition1, + i, + chrono::Utc::now(), + )), + }; + strategy.submit(msg).unwrap(); + let _ = strategy.poll(); + } + + // 3 messages with a max batch size of 2 means 1 batch was cleared + // and 1 message is left before next size limit. + assert_eq!(strategy.batch_state.message_count, 1); + + strategy.close(); + let _ = strategy.join(None); + + // 2 batches were created + assert_eq!( + *submitted_messages_clone.lock().unwrap(), + vec![vec![0, 1], vec![2]] + ); + } + + #[test] + fn test_reduce_with_custom_batch_size() { + let submitted_messages = Arc::new(Mutex::new(Vec::new())); + let submitted_messages_clone = submitted_messages.clone(); + + let partition1 = Partition::new(Topic::new("test"), 0); + + let max_batch_size = 10; + let max_batch_time = Duration::from_secs(1); + + let initial_value = Vec::new(); + let accumulator = Arc::new(|mut acc: Vec, value: u64| { + acc.push(value); + acc + }); + let compute_batch_size = |_: &_| -> usize { 5 }; + + let next_step = NextStep { + submitted: submitted_messages, + }; + + let mut strategy = Reduce::new( + next_step, + accumulator, + initial_value, + max_batch_size, + max_batch_time, + compute_batch_size, + ); + + for i in 0..3 { + let msg = Message { + inner_message: InnerMessage::BrokerMessage(BrokerMessage::new( + i, + partition1, + i, + chrono::Utc::now(), + )), + }; + strategy.submit(msg).unwrap(); + let _ = strategy.poll(); + } + + // 3 messages returning 5 items each and a max batch size of 10 + // means 1 batch was cleared and 5 items are in the current batch. + assert_eq!(strategy.batch_state.message_count, 5); + + strategy.close(); + let _ = strategy.join(None); + + // 2 batches were created + assert_eq!( + *submitted_messages_clone.lock().unwrap(), + vec![vec![0, 1], vec![2]] + ); + } + + #[test] + fn test_reduce_with_zero_batch_size() { + let submitted_messages = Arc::new(Mutex::new(Vec::new())); + let submitted_messages_clone = submitted_messages.clone(); + + let partition1 = Partition::new(Topic::new("test"), 0); + + let max_batch_size = 1; + let max_batch_time = Duration::from_secs(100); + + let initial_value = Vec::new(); + let accumulator = Arc::new(|mut acc: Vec, value: u64| { + acc.push(value); + acc + }); + let compute_batch_size = |_: &_| -> usize { 0 }; + + let next_step = NextStep { + submitted: submitted_messages, + }; + + let mut strategy = Reduce::new( + next_step, + accumulator, + initial_value, + max_batch_size, + max_batch_time, + compute_batch_size, + ); + + for i in 0..3 { + let msg = Message { + inner_message: InnerMessage::BrokerMessage(BrokerMessage::new( + i, + partition1, + i, + chrono::Utc::now(), + )), + }; + strategy.submit(msg).unwrap(); + let _ = strategy.poll(); + } + + // since all submitted values had length 0, do not forward any messages to the next step + // until timeout (which will not happen as part of this test) + assert_eq!(strategy.batch_state.message_count, 0); + + strategy.close(); + let _ = strategy.join(None); + + // no batches were created + assert!(submitted_messages_clone.lock().unwrap().is_empty()); + } +} diff --git a/rust-arroyo/src/processing/strategies/run_task.rs b/rust-arroyo/src/processing/strategies/run_task.rs new file mode 100644 index 00000000..20ac12b2 --- /dev/null +++ b/rust-arroyo/src/processing/strategies/run_task.rs @@ -0,0 +1,145 @@ +use crate::processing::strategies::{ + merge_commit_request, CommitRequest, InvalidMessage, MessageRejected, ProcessingStrategy, + StrategyError, SubmitError, +}; +use crate::types::Message; +use std::time::Duration; + +pub struct RunTask { + pub function: + Box Result + Send + Sync + 'static>, + pub next_step: Box>, + pub message_carried_over: Option>, + pub commit_request_carried_over: Option, +} + +impl RunTask { + pub fn new(function: F, next_step: N) -> Self + where + N: ProcessingStrategy + 'static, + F: Fn(TPayload) -> Result + Send + Sync + 'static, + { + Self { + function: Box::new(function), + next_step: Box::new(next_step), + message_carried_over: None, + commit_request_carried_over: None, + } + } +} + +impl ProcessingStrategy + for RunTask +{ + fn poll(&mut self) -> Result, StrategyError> { + match self.next_step.poll() { + Ok(commit_request) => { + self.commit_request_carried_over = + merge_commit_request(self.commit_request_carried_over.take(), commit_request) + } + Err(invalid_message) => return Err(invalid_message), + } + + if let Some(message) = self.message_carried_over.take() { + match self.next_step.submit(message) { + Err(SubmitError::MessageRejected(MessageRejected { + message: transformed_message, + })) => { + self.message_carried_over = Some(transformed_message); + } + Err(SubmitError::InvalidMessage(invalid_message)) => { + return Err(invalid_message.into()); + } + Ok(_) => {} + } + } + + Ok(self.commit_request_carried_over.take()) + } + + fn submit(&mut self, message: Message) -> Result<(), SubmitError> { + if self.message_carried_over.is_some() { + return Err(SubmitError::MessageRejected(MessageRejected { message })); + } + + let next_message = message + .try_map(&self.function) + .map_err(SubmitError::InvalidMessage)?; + + match self.next_step.submit(next_message) { + Err(SubmitError::MessageRejected(MessageRejected { + message: transformed_message, + })) => { + self.message_carried_over = Some(transformed_message); + } + Err(SubmitError::InvalidMessage(invalid_message)) => { + return Err(SubmitError::InvalidMessage(invalid_message)); + } + Ok(_) => {} + } + Ok(()) + } + + fn close(&mut self) { + self.next_step.close() + } + + fn terminate(&mut self) { + self.next_step.terminate() + } + + fn join(&mut self, timeout: Option) -> Result, StrategyError> { + let next_commit = self.next_step.join(timeout)?; + Ok(merge_commit_request( + self.commit_request_carried_over.take(), + next_commit, + )) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::types::{BrokerMessage, InnerMessage, Message, Partition, Topic}; + use chrono::Utc; + + #[test] + fn test_run_task() { + fn identity(value: String) -> Result { + Ok(value) + } + + struct Noop {} + impl ProcessingStrategy for Noop { + fn poll(&mut self) -> Result, StrategyError> { + Ok(None) + } + fn submit(&mut self, _message: Message) -> Result<(), SubmitError> { + Ok(()) + } + fn close(&mut self) {} + fn terminate(&mut self) {} + fn join( + &mut self, + _timeout: Option, + ) -> Result, StrategyError> { + Ok(None) + } + } + + let mut strategy = RunTask::new(identity, Noop {}); + + let partition = Partition::new(Topic::new("test"), 0); + + strategy + .submit(Message { + inner_message: InnerMessage::BrokerMessage(BrokerMessage::new( + "Hello world".to_string(), + partition, + 0, + Utc::now(), + )), + }) + .unwrap(); + } +} diff --git a/rust-arroyo/src/processing/strategies/run_task_in_threads.rs b/rust-arroyo/src/processing/strategies/run_task_in_threads.rs new file mode 100644 index 00000000..5de3e763 --- /dev/null +++ b/rust-arroyo/src/processing/strategies/run_task_in_threads.rs @@ -0,0 +1,387 @@ +use std::collections::VecDeque; +use std::future::Future; +use std::pin::Pin; +use std::time::Duration; + +use tokio::runtime::{Handle, Runtime}; +use tokio::task::JoinHandle; + +use crate::gauge; +use crate::processing::strategies::{ + merge_commit_request, CommitRequest, InvalidMessage, MessageRejected, ProcessingStrategy, + SubmitError, +}; +use crate::types::Message; +use crate::utils::timing::Deadline; + +use super::StrategyError; + +#[derive(Clone, Debug)] +pub enum RunTaskError { + RetryableError, + InvalidMessage(InvalidMessage), + Other(TError), +} + +pub type RunTaskFunc = + Pin, RunTaskError>> + Send>>; + +pub trait TaskRunner: Send + Sync { + fn get_task(&self, message: Message) -> RunTaskFunc; +} + +/// This is configuration for the [`RunTaskInThreads`] strategy. +/// +/// It defines the runtime on which tasks are being spawned, and the number of +/// concurrently running tasks. +pub struct ConcurrencyConfig { + /// The configured number of concurrently running tasks. + pub concurrency: usize, + runtime: RuntimeOrHandle, +} + +impl ConcurrencyConfig { + /// Creates a new [`ConcurrencyConfig`], spawning a new [`Runtime`]. + /// + /// The runtime will use the number of worker threads given by the `concurrency`, + /// and also limit the number of concurrently running tasks as well. + pub fn new(concurrency: usize) -> Self { + let runtime = tokio::runtime::Builder::new_multi_thread() + .worker_threads(concurrency) + .enable_all() + .build() + .unwrap(); + Self { + concurrency, + runtime: RuntimeOrHandle::Runtime(runtime), + } + } + + /// Creates a new [`ConcurrencyConfig`], reusing an existing [`Runtime`] via + /// its [`Handle`]. + pub fn with_runtime(concurrency: usize, runtime: Handle) -> Self { + Self { + concurrency, + runtime: RuntimeOrHandle::Handle(runtime), + } + } + + /// Returns a [`Handle`] to the underlying runtime. + pub fn handle(&self) -> Handle { + match &self.runtime { + RuntimeOrHandle::Handle(handle) => handle.clone(), + RuntimeOrHandle::Runtime(runtime) => runtime.handle().to_owned(), + } + } +} + +enum RuntimeOrHandle { + Handle(Handle), + Runtime(Runtime), +} + +pub struct RunTaskInThreads { + next_step: Box>, + task_runner: Box>, + concurrency: usize, + runtime: Handle, + handles: VecDeque, RunTaskError>>>, + message_carried_over: Option>, + commit_request_carried_over: Option, + metric_strategy_name: &'static str, +} + +impl RunTaskInThreads { + pub fn new( + next_step: N, + task_runner: Box>, + concurrency: &ConcurrencyConfig, + // If provided, this name is used for metrics + custom_strategy_name: Option<&'static str>, + ) -> Self + where + N: ProcessingStrategy + 'static, + { + let strategy_name = custom_strategy_name.unwrap_or("run_task_in_threads"); + + RunTaskInThreads { + next_step: Box::new(next_step), + task_runner, + concurrency: concurrency.concurrency, + runtime: concurrency.handle(), + handles: VecDeque::new(), + message_carried_over: None, + commit_request_carried_over: None, + metric_strategy_name: strategy_name, + } + } +} + +impl ProcessingStrategy + for RunTaskInThreads +where + TTransformed: Send + Sync + 'static, + TError: Into> + Send + Sync + 'static, +{ + fn poll(&mut self) -> Result, StrategyError> { + let commit_request = self.next_step.poll()?; + self.commit_request_carried_over = + merge_commit_request(self.commit_request_carried_over.take(), commit_request); + + gauge!("arroyo.strategies.run_task_in_threads.threads", + self.handles.len() as u64, + "strategy_name" => self.metric_strategy_name + ); + gauge!("arroyo.strategies.run_task_in_threads.concurrency", + self.concurrency as u64, + "strategy_name" => self.metric_strategy_name + ); + + if let Some(message) = self.message_carried_over.take() { + match self.next_step.submit(message) { + Err(SubmitError::MessageRejected(MessageRejected { + message: transformed_message, + })) => { + self.message_carried_over = Some(transformed_message); + } + Err(SubmitError::InvalidMessage(invalid_message)) => { + return Err(invalid_message.into()); + } + Ok(_) => {} + } + } + + while !self.handles.is_empty() { + if let Some(front) = self.handles.front() { + if !front.is_finished() { + break; + } + let handle = self.handles.pop_front().unwrap(); + match self.runtime.block_on(handle) { + Ok(Ok(message)) => match self.next_step.submit(message) { + Err(SubmitError::MessageRejected(MessageRejected { + message: transformed_message, + })) => { + self.message_carried_over = Some(transformed_message); + } + Err(SubmitError::InvalidMessage(invalid_message)) => { + return Err(invalid_message.into()); + } + Ok(_) => {} + }, + Ok(Err(RunTaskError::InvalidMessage(e))) => { + return Err(e.into()); + } + Ok(Err(RunTaskError::RetryableError)) => { + tracing::error!("retryable error"); + } + Ok(Err(RunTaskError::Other(error))) => { + return Err(StrategyError::Other(error.into())); + } + Err(e) => { + return Err(StrategyError::Other(e.into())); + } + } + } + } + + Ok(self.commit_request_carried_over.take()) + } + + fn submit(&mut self, message: Message) -> Result<(), SubmitError> { + if self.message_carried_over.is_some() { + return Err(SubmitError::MessageRejected(MessageRejected { message })); + } + + if self.handles.len() > self.concurrency { + return Err(SubmitError::MessageRejected(MessageRejected { message })); + } + + let task = self.task_runner.get_task(message); + let handle = self.runtime.spawn(task); + self.handles.push_back(handle); + + Ok(()) + } + + fn close(&mut self) { + self.next_step.close(); + } + + fn terminate(&mut self) { + for handle in &self.handles { + handle.abort(); + } + self.handles.clear(); + self.next_step.terminate(); + } + + fn join(&mut self, timeout: Option) -> Result, StrategyError> { + let deadline = timeout.map(Deadline::new); + + // Poll until there are no more messages or timeout is hit + while self.message_carried_over.is_some() || !self.handles.is_empty() { + if deadline.map_or(false, |d| d.has_elapsed()) { + tracing::warn!( + %self.metric_strategy_name, + "Timeout reached while waiting for tasks to finish", + ); + break; + } + + let commit_request = self.poll()?; + self.commit_request_carried_over = + merge_commit_request(self.commit_request_carried_over.take(), commit_request); + } + + // Cancel remaining tasks if any + for handle in &self.handles { + handle.abort(); + } + self.handles.clear(); + + let next_commit = self.next_step.join(deadline.map(|d| d.remaining()))?; + + Ok(merge_commit_request( + self.commit_request_carried_over.take(), + next_commit, + )) + } +} + +#[cfg(test)] +mod tests { + use crate::types::{Partition, Topic}; + + use super::*; + use std::collections::BTreeMap; + use std::sync::{Arc, Mutex}; + + struct IdentityTaskRunner {} + + impl TaskRunner for IdentityTaskRunner { + fn get_task(&self, message: Message) -> RunTaskFunc { + Box::pin(async move { Ok(message) }) + } + } + + #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] + struct Counts { + submit: u8, + polled: bool, + } + + struct Mock(Arc>); + + impl Mock { + fn new() -> Self { + Self(Arc::new(Mutex::new(Default::default()))) + } + + fn counts(&self) -> Arc> { + self.0.clone() + } + } + + impl ProcessingStrategy for Mock { + fn poll(&mut self) -> Result, StrategyError> { + self.0.lock().unwrap().polled = true; + Ok(None) + } + fn submit(&mut self, _message: Message) -> Result<(), SubmitError> { + self.0.lock().unwrap().submit += 1; + Ok(()) + } + fn close(&mut self) {} + fn terminate(&mut self) {} + fn join( + &mut self, + _timeout: Option, + ) -> Result, StrategyError> { + Ok(None) + } + } + + #[test] + fn test() { + let concurrency = ConcurrencyConfig::new(1); + let mut strategy = RunTaskInThreads::new( + Mock::new(), + Box::new(IdentityTaskRunner {}), + &concurrency, + None, + ); + + let message = Message::new_any_message("hello_world".to_string(), BTreeMap::new()); + + strategy.submit(message).unwrap(); + let _ = strategy.poll(); + let _ = strategy.join(None); + } + + #[test] + fn test_run_task_in_threads() { + for poll_after_msg in [false, true] { + for poll_before_join in [false, true] { + let next_step = Mock::new(); + let counts = next_step.counts(); + let concurrency = ConcurrencyConfig::new(2); + let mut strategy = RunTaskInThreads::new( + next_step, + Box::new(IdentityTaskRunner {}), + &concurrency, + None, + ); + + let partition = Partition::new(Topic::new("topic"), 0); + + strategy + .submit(Message::new_broker_message( + "hello".to_string(), + partition, + 0, + chrono::Utc::now(), + )) + .unwrap(); + + if poll_after_msg { + strategy.poll().unwrap(); + } + + strategy + .submit(Message::new_broker_message( + "world".to_string(), + partition, + 1, + chrono::Utc::now(), + )) + .unwrap(); + + if poll_after_msg { + strategy.poll().unwrap(); + } + + if poll_before_join { + for _ in 0..10 { + if counts.lock().unwrap().submit < 2 { + strategy.poll().unwrap(); + std::thread::sleep(Duration::from_millis(100)); + } else { + break; + } + } + + let counts = counts.lock().unwrap(); + assert_eq!(counts.submit, 2); + assert!(counts.polled); + } + + strategy.join(None).unwrap(); + + let counts = counts.lock().unwrap(); + assert_eq!(counts.submit, 2); + assert!(counts.polled); + } + } + } +} diff --git a/rust-arroyo/src/testutils.rs b/rust-arroyo/src/testutils.rs new file mode 100644 index 00000000..1c41203f --- /dev/null +++ b/rust-arroyo/src/testutils.rs @@ -0,0 +1,121 @@ +use std::sync::{Arc, Mutex}; +use std::time::Duration; + +use rdkafka::admin::{AdminClient, AdminOptions, NewTopic, TopicReplication}; +use rdkafka::client::DefaultClientContext; +use rdkafka::ClientConfig; +use tokio::runtime::Runtime; + +use crate::backends::kafka::config::KafkaConfig; +use crate::backends::kafka::producer::KafkaProducer; +use crate::backends::kafka::types::KafkaPayload; +use crate::backends::Producer; +use crate::processing::strategies::{ + CommitRequest, ProcessingStrategy, StrategyError, SubmitError, +}; +use crate::types::Message; +use crate::types::Topic; + +#[derive(Clone)] +pub struct TestStrategy { + pub messages: Arc>>>, +} + +impl Default for TestStrategy { + fn default() -> Self { + TestStrategy { + messages: Arc::new(Mutex::new(Vec::new())), + } + } +} + +impl TestStrategy { + pub fn new() -> Self { + TestStrategy::default() + } +} + +impl ProcessingStrategy for TestStrategy { + fn poll(&mut self) -> Result, StrategyError> { + Ok(None) + } + + fn submit(&mut self, message: Message) -> Result<(), SubmitError> { + self.messages.lock().unwrap().push(message); + Ok(()) + } + + fn close(&mut self) {} + fn terminate(&mut self) {} + fn join(&mut self, _timeout: Option) -> Result, StrategyError> { + Ok(None) + } +} + +pub fn get_default_broker() -> String { + std::env::var("DEFAULT_BROKERS").unwrap_or("127.0.0.1:9092".to_string()) +} + +fn get_admin_client() -> AdminClient { + let mut config = ClientConfig::new(); + config.set("bootstrap.servers".to_string(), get_default_broker()); + + config.create().unwrap() +} + +async fn create_topic(topic_name: &str, partition_count: i32) { + let client = get_admin_client(); + let topics = [NewTopic::new( + topic_name, + partition_count, + TopicReplication::Fixed(1), + )]; + client + .create_topics(&topics, &AdminOptions::new()) + .await + .unwrap(); +} + +async fn delete_topic(topic_name: &str) { + let client = get_admin_client(); + client + .delete_topics(&[topic_name], &AdminOptions::new()) + .await + .unwrap(); +} + +pub struct TestTopic { + runtime: Runtime, + pub topic: Topic, +} + +impl TestTopic { + pub fn create(name: &str) -> Self { + let runtime = Runtime::new().unwrap(); + let name = format!("rust-arroyo-{}-{}", name, uuid::Uuid::new_v4()); + runtime.block_on(create_topic(&name, 1)); + Self { + runtime, + topic: Topic::new(&name), + } + } + + pub fn produce(&self, payload: KafkaPayload) { + let producer_configuration = + KafkaConfig::new_producer_config(vec![get_default_broker()], None); + + let producer = KafkaProducer::new(producer_configuration); + + producer + .produce(&crate::types::TopicOrPartition::Topic(self.topic), payload) + .expect("Message produced"); + } +} + +impl Drop for TestTopic { + fn drop(&mut self) { + let name = self.topic.as_str(); + // i really wish i had async drop now + self.runtime.block_on(delete_topic(name)); + } +} diff --git a/rust-arroyo/src/types/mod.rs b/rust-arroyo/src/types/mod.rs new file mode 100755 index 00000000..33cebdb4 --- /dev/null +++ b/rust-arroyo/src/types/mod.rs @@ -0,0 +1,396 @@ +use std::any::type_name; +use std::cmp::Eq; +use std::collections::{BTreeMap, HashSet}; +use std::fmt; +use std::hash::Hash; +use std::sync::Mutex; + +use chrono::{DateTime, Utc}; +use once_cell::sync::Lazy; + +#[derive(Clone, Copy, Eq, Hash, PartialEq, PartialOrd, Ord)] +pub struct Topic(&'static str); + +impl Topic { + pub fn new(name: &str) -> Self { + static INTERNED_TOPICS: Lazy>> = Lazy::new(Default::default); + let mut interner = INTERNED_TOPICS.lock().unwrap(); + interner.insert(name.into()); + let interned_name = interner.get(name).unwrap(); + + // SAFETY: + // - The interner is static, append-only, and only defined within this function. + // - We insert heap-allocated `String`s that do not move. + let interned_name = unsafe { std::mem::transmute::<&str, &'static str>(interned_name) }; + Self(interned_name) + } + + pub fn as_str(&self) -> &str { + self.0 + } +} + +impl fmt::Debug for Topic { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let s = self.as_str(); + f.debug_tuple("Topic").field(&s).finish() + } +} + +impl fmt::Display for Topic { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.as_str()) + } +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, PartialOrd, Ord)] +pub struct Partition { + pub topic: Topic, + pub index: u16, +} + +impl Partition { + pub fn new(topic: Topic, index: u16) -> Self { + Self { topic, index } + } +} + +impl fmt::Display for Partition { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Partition({} topic={})", self.index, &self.topic) + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum TopicOrPartition { + Topic(Topic), + Partition(Partition), +} + +impl TopicOrPartition { + pub fn topic(&self) -> Topic { + match self { + TopicOrPartition::Topic(topic) => *topic, + TopicOrPartition::Partition(partition) => partition.topic, + } + } +} + +impl From for TopicOrPartition { + fn from(value: Topic) -> Self { + Self::Topic(value) + } +} + +impl From for TopicOrPartition { + fn from(value: Partition) -> Self { + Self::Partition(value) + } +} + +#[derive(Clone, Debug, PartialEq)] +pub struct BrokerMessage { + pub payload: T, + pub partition: Partition, + pub offset: u64, + pub timestamp: DateTime, +} + +impl BrokerMessage { + pub fn new(payload: T, partition: Partition, offset: u64, timestamp: DateTime) -> Self { + Self { + payload, + partition, + offset, + timestamp, + } + } + + pub fn replace(self, replacement: TReplaced) -> BrokerMessage { + BrokerMessage { + payload: replacement, + partition: self.partition, + offset: self.offset, + timestamp: self.timestamp, + } + } + + /// Map a fallible function over this messages's payload. + pub fn try_map Result>( + self, + f: F, + ) -> Result, E> { + let Self { + payload, + partition, + offset, + timestamp, + } = self; + + let payload = f(payload)?; + + Ok(BrokerMessage { + payload, + partition, + offset, + timestamp, + }) + } +} + +impl fmt::Display for BrokerMessage { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "BrokerMessage(partition={} offset={})", + self.partition, self.offset + ) + } +} + +#[derive(Clone, Debug, PartialEq)] +pub struct AnyMessage { + pub payload: T, + pub committable: BTreeMap, +} + +impl AnyMessage { + pub fn new(payload: T, committable: BTreeMap) -> Self { + Self { + payload, + committable, + } + } + + pub fn replace(self, replacement: TReplaced) -> AnyMessage { + AnyMessage { + payload: replacement, + committable: self.committable, + } + } + + /// Map a fallible function over this messages's payload. + pub fn try_map Result>( + self, + f: F, + ) -> Result, E> { + let Self { + payload, + committable, + } = self; + + let payload = f(payload)?; + + Ok(AnyMessage { + payload, + committable, + }) + } +} + +impl fmt::Display for AnyMessage { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "AnyMessage(committable={:?})", self.committable) + } +} + +#[derive(Clone, Debug, PartialEq)] +pub enum InnerMessage { + BrokerMessage(BrokerMessage), + AnyMessage(AnyMessage), +} + +#[derive(Clone, Debug, PartialEq)] +pub struct Message { + pub inner_message: InnerMessage, +} + +impl Message { + pub fn new_broker_message( + payload: T, + partition: Partition, + offset: u64, + timestamp: DateTime, + ) -> Self { + Self { + inner_message: InnerMessage::BrokerMessage(BrokerMessage { + payload, + partition, + offset, + timestamp, + }), + } + } + + pub fn new_any_message(payload: T, committable: BTreeMap) -> Self { + Self { + inner_message: InnerMessage::AnyMessage(AnyMessage { + payload, + committable, + }), + } + } + + pub fn payload(&self) -> &T { + match &self.inner_message { + InnerMessage::BrokerMessage(BrokerMessage { payload, .. }) => payload, + InnerMessage::AnyMessage(AnyMessage { payload, .. }) => payload, + } + } + + /// Consumes the message and returns its payload. + pub fn into_payload(self) -> T { + match self.inner_message { + InnerMessage::BrokerMessage(BrokerMessage { payload, .. }) => payload, + InnerMessage::AnyMessage(AnyMessage { payload, .. }) => payload, + } + } + + /// Returns an iterator over this message's committable offsets. + pub fn committable(&self) -> Committable { + match &self.inner_message { + InnerMessage::BrokerMessage(BrokerMessage { + partition, offset, .. + }) => Committable(CommittableInner::Broker(std::iter::once(( + *partition, + offset + 1, + )))), + InnerMessage::AnyMessage(AnyMessage { committable, .. }) => { + Committable(CommittableInner::Any(committable.iter())) + } + } + } + + pub fn replace(self, replacement: TReplaced) -> Message { + match self.inner_message { + InnerMessage::BrokerMessage(inner) => Message { + inner_message: InnerMessage::BrokerMessage(inner.replace(replacement)), + }, + InnerMessage::AnyMessage(inner) => Message { + inner_message: InnerMessage::AnyMessage(inner.replace(replacement)), + }, + } + } + + /// Map a fallible function over this messages's payload. + pub fn try_map Result>( + self, + f: F, + ) -> Result, E> { + match self.inner_message { + InnerMessage::BrokerMessage(inner) => { + let inner = inner.try_map(f)?; + Ok(inner.into()) + } + InnerMessage::AnyMessage(inner) => { + let inner = inner.try_map(f)?; + Ok(inner.into()) + } + } + } + + // Returns this message's timestamp, if it has one. + pub fn timestamp(&self) -> Option> { + match &self.inner_message { + InnerMessage::BrokerMessage(m) => Some(m.timestamp), + InnerMessage::AnyMessage(_) => None, + } + } +} + +impl fmt::Display for Message { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.inner_message { + InnerMessage::BrokerMessage(BrokerMessage { + partition, offset, .. + }) => { + write!( + f, + "Message<{}>(partition={partition}), offset={offset}", + type_name::(), + ) + } + InnerMessage::AnyMessage(AnyMessage { committable, .. }) => { + write!( + f, + "Message<{}>(committable={committable:?})", + type_name::(), + ) + } + } + } +} + +impl From> for Message { + fn from(value: BrokerMessage) -> Self { + Self { + inner_message: InnerMessage::BrokerMessage(value), + } + } +} + +impl From> for Message { + fn from(value: AnyMessage) -> Self { + Self { + inner_message: InnerMessage::AnyMessage(value), + } + } +} + +#[derive(Debug, Clone)] +enum CommittableInner<'a> { + Any(std::collections::btree_map::Iter<'a, Partition, u64>), + Broker(std::iter::Once<(Partition, u64)>), +} + +/// An iterator over a `Message`'s committable offsets. +/// +/// This is produced by [`Message::committable`]. +#[derive(Debug, Clone)] +pub struct Committable<'a>(CommittableInner<'a>); + +impl<'a> Iterator for Committable<'a> { + type Item = (Partition, u64); + + fn next(&mut self) -> Option { + match self.0 { + CommittableInner::Any(ref mut inner) => inner.next().map(|(k, v)| (*k, *v)), + CommittableInner::Broker(ref mut inner) => inner.next(), + } + } +} + +#[cfg(test)] +mod tests { + use super::{BrokerMessage, Partition, Topic}; + use chrono::Utc; + + #[test] + fn message() { + let now = Utc::now(); + let topic = Topic::new("test"); + let part = Partition { topic, index: 10 }; + let message = BrokerMessage::new("payload".to_string(), part, 10, now); + + assert_eq!(message.partition.topic.as_str(), "test"); + assert_eq!(message.partition.index, 10); + assert_eq!(message.offset, 10); + assert_eq!(message.payload, "payload"); + assert_eq!(message.timestamp, now); + } + + #[test] + fn fmt_display() { + let now = Utc::now(); + let part = Partition { + topic: Topic::new("test"), + index: 10, + }; + let message = BrokerMessage::new("payload".to_string(), part, 10, now); + + assert_eq!( + message.to_string(), + "BrokerMessage(partition=Partition(10 topic=test) offset=10)" + ) + } +} diff --git a/rust-arroyo/src/utils/clock.rs b/rust-arroyo/src/utils/clock.rs new file mode 100644 index 00000000..b6b40a03 --- /dev/null +++ b/rust-arroyo/src/utils/clock.rs @@ -0,0 +1,39 @@ +use std::thread::sleep; +use std::time::{Duration, SystemTime}; + +pub trait Clock: Send { + fn time(&self) -> SystemTime; + + fn sleep(&mut self, duration: Duration); +} + +pub struct SystemClock {} +impl Clock for SystemClock { + fn time(&self) -> SystemTime { + SystemTime::now() + } + + fn sleep(&mut self, duration: Duration) { + sleep(duration) + } +} + +pub struct TestingClock { + time: SystemTime, +} + +impl TestingClock { + pub fn new(now: SystemTime) -> Self { + Self { time: now } + } +} + +impl Clock for TestingClock { + fn time(&self) -> SystemTime { + self.time + } + + fn sleep(&mut self, duration: Duration) { + self.time += duration; + } +} diff --git a/rust-arroyo/src/utils/mod.rs b/rust-arroyo/src/utils/mod.rs new file mode 100644 index 00000000..da550c9d --- /dev/null +++ b/rust-arroyo/src/utils/mod.rs @@ -0,0 +1,2 @@ +pub mod clock; +pub mod timing; diff --git a/rust-arroyo/src/utils/timing.rs b/rust-arroyo/src/utils/timing.rs new file mode 100644 index 00000000..d1d75e64 --- /dev/null +++ b/rust-arroyo/src/utils/timing.rs @@ -0,0 +1,44 @@ +use std::time::Duration; + +/// Represents a Deadline to be reached. +#[derive(Clone, Copy, Debug)] +pub struct Deadline { + start: coarsetime::Instant, + duration: Duration, +} + +fn now() -> coarsetime::Instant { + coarsetime::Instant::now_without_cache_update() +} + +impl Deadline { + /// Creates a new [`Deadline`]. + pub fn new(duration: Duration) -> Self { + Self { + start: now(), + duration, + } + } + + /// Returns the start since creation. + pub fn elapsed(&self) -> Duration { + now().duration_since(self.start).into() + } + + /// Checks whether the deadline has elapsed. + pub fn has_elapsed(&self) -> bool { + self.elapsed() >= self.duration + } + + /// Returns the remaining [`Duration`]. + /// + /// This will be [`Duration::ZERO`] if the deadline has elapsed + pub fn remaining(&self) -> Duration { + self.duration.saturating_sub(self.elapsed()) + } + + /// Restarts the deadline with the initial [`Duration`]. + pub fn restart(&mut self) { + self.start = now(); + } +}