diff --git a/Cargo.lock b/Cargo.lock index 8c6e09c..85e1214 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3040,6 +3040,7 @@ dependencies = [ "futures", "futures-util", "http-body-util", + "humantime", "hyper 1.5.1", "hyper-util", "log", diff --git a/Cargo.toml b/Cargo.toml index 1453afd..1211ca3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,7 @@ env_logger = "0.11.1" futures = "0.3.30" futures-util = "0.3.30" http-body-util = "0.1.0" +humantime = "2.1.0" hyper = { version = "1.1.0", features = ["server", "http1"] } hyper-util = { version = "0.1.3", features = ["full"] } log = { version = "0.4.20" } diff --git a/sample.ron b/sample.ron index 5e2ce3e..127ec0f 100644 --- a/sample.ron +++ b/sample.ron @@ -2,6 +2,7 @@ automigrate: true, reset_state: false, log_level: "debug", + clock_cycle_interval: "100ms", metrics: Prometheus( Config( port: 9090, diff --git a/src/config.rs b/src/config.rs index 73bd0e5..f310180 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,3 +1,5 @@ +use core::time; + use serde::Deserialize; use crate::grpc; @@ -5,10 +7,11 @@ use crate::metrics; use crate::nats; use crate::postgres; -#[derive(Debug, Deserialize)] +#[derive(Debug, Clone)] pub struct Config { pub automigrate: bool, - pub log_level: String, + pub log_level: log::Level, + pub clock_cycle_interval: time::Duration, pub metrics: Metrics, pub repository: Repository, pub reset_state: bool, @@ -16,23 +19,89 @@ pub struct Config { pub transport: Transport, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Clone, Deserialize)] pub enum Transport { Grpc(grpc::Config), } -#[derive(Debug, Deserialize)] +#[derive(Debug, Clone, Deserialize)] pub enum Transmitter { Nats(nats::Config), } -#[derive(Debug, Deserialize)] +#[derive(Debug, Clone, Deserialize)] pub enum Repository { Postgres(postgres::Config), InMemory, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Clone, Deserialize)] pub enum Metrics { Prometheus(metrics::Config), } + +pub fn validate(config: &Config) -> Result<(), Box> { + if config.clock_cycle_interval.is_zero() { + return Err("clock cycle interval cannot be zero".into()); + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn config() -> Config { + Config { + automigrate: true, + log_level: log::Level::Info, + clock_cycle_interval: time::Duration::from_millis(100), + metrics: Metrics::Prometheus(metrics::Config { + port: 3000, + endpoint: String::from("/metrics"), + }), + repository: Repository::InMemory, + reset_state: true, + transmitter: Transmitter::Nats(nats::Config { + port: 3001, + host: String::from("127.0.0.1"), + }), + transport: Transport::Grpc(grpc::Config { port: 3002 }), + } + } + + #[test] + fn test_validate_config() { + struct TestCase { + name: String, + config: Config, + expected_valid: bool, + } + + let test_cases = vec![ + TestCase { + name: String::from("valid"), + config: config(), + expected_valid: true, + }, + TestCase { + name: String::from("zero duration"), + config: Config { + clock_cycle_interval: time::Duration::from_secs(0), + ..config() + }, + expected_valid: false, + }, + ]; + + for test_case in test_cases { + let valid = validate(&test_case.config).is_ok(); + assert_eq!( + valid, test_case.expected_valid, + "test case failed: {}", + test_case.name + ); + } + } +} diff --git a/src/integration_test.rs b/src/integration_test.rs index 24d98f1..3a4fe2a 100644 --- a/src/integration_test.rs +++ b/src/integration_test.rs @@ -451,6 +451,7 @@ mod tests { nats_connection: &async_nats::Client, ) -> Arc { let scheduler = scheduler::TransmissionScheduler::new( + time::Duration::from_micros(10), Arc::new(postgres_repository().await), Arc::new(nats_publisher(&nats_connection)), Arc::new(now), diff --git a/src/load_config.rs b/src/load_config.rs index ae1611f..f723629 100644 --- a/src/load_config.rs +++ b/src/load_config.rs @@ -1,7 +1,9 @@ +use humantime; use std::env; use std::error::Error; use std::fs::File; use std::io::Read; +use std::str::FromStr; use serde::Deserialize; @@ -14,6 +16,7 @@ const ENV_POSTGRES_PASSWORD: &'static str = "POSTGRES_PASSWORD"; struct FileConfig { automigrate: bool, log_level: String, + clock_cycle_interval: String, metrics: config::Metrics, repository: Repository, reset_state: bool, @@ -60,13 +63,27 @@ pub fn load_config(file_path: &str) -> Result> { let secrets = load_secrets_from_env()?; - Ok(derive_config(config, secrets)) + let config = derive_config(config, secrets)?; + + config::validate(&config)?; + + Ok(config) } -fn derive_config(config: FileConfig, secrets: EnvConfig) -> config::Config { - config::Config { +fn derive_config(config: FileConfig, secrets: EnvConfig) -> Result> { + let log_level = match log::Level::from_str(&config.log_level) { + Ok(log_level) => log_level, + Err(err) => { + return Err(format!("invalid log level '{}': {err}", &config.log_level).into()); + } + }; + + let clock_cycle_interval = humantime::parse_duration(&config.clock_cycle_interval)?; + + Ok(config::Config { automigrate: config.automigrate, - log_level: config.log_level, + log_level, + clock_cycle_interval, metrics: config.metrics, repository: match config.repository { Repository::Postgres(postgres_config) => { @@ -84,12 +101,13 @@ fn derive_config(config: FileConfig, secrets: EnvConfig) -> config::Config { transmitter: config.transmitter, transport: config.transport, reset_state: config.reset_state, - } + }) } #[cfg(test)] mod tests { use super::*; + use std::time; #[test] fn test_load_config() { @@ -104,7 +122,12 @@ mod tests { let configuration = load_config(config_file).expect("could not load configuration"); // Merely asserting the log level is enough to assert the structure of the file contents. - assert_eq!(configuration.log_level, "debug".to_string()); + assert_eq!(configuration.log_level, log::Level::Debug); + + assert_eq!( + configuration.clock_cycle_interval, + time::Duration::from_millis(100) + ); match configuration.repository { config::Repository::Postgres(postgres_config) => { diff --git a/src/main.rs b/src/main.rs index d6b030a..4808a9d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -23,12 +23,14 @@ const DEFAULT_CONFIG_FILE_PATH: &'static str = "config.ron"; #[tokio::main] async fn main() -> Result<(), Box> { + env_logger::init(); + let args: Vec = std::env::args().collect(); let config_file_path = match args.len() { 1 => DEFAULT_CONFIG_FILE_PATH, 2 => &args[1], _ => { - println!("Please specify the path to the configuration file as the only argument."); + error!("Please specify the path to the configuration file as the only argument."); process::exit(1); } }; @@ -44,8 +46,7 @@ async fn main() -> Result<(), Box> { // Initialise logger. let rust_log = "RUST_LOG"; - env::set_var(rust_log, config.log_level); - env_logger::init(); + env::set_var(rust_log, config.log_level.as_str()); info!("Starting application."); // Construct transmitter. @@ -138,6 +139,7 @@ async fn main() -> Result<(), Box> { // Construct scheduler. let scheduler = Arc::new(scheduler::TransmissionScheduler::new( + config.clock_cycle_interval, repository, transmitter, now_provider, diff --git a/src/metrics.rs b/src/metrics.rs index 4f7e68e..06eff34 100644 --- a/src/metrics.rs +++ b/src/metrics.rs @@ -19,7 +19,7 @@ use tokio::net::TcpListener; const METRIC_NAME: &str = "procedure"; const METRIC_HELP_TEXT: &str = "Number of procedure calls"; -#[derive(Debug, Deserialize)] +#[derive(Debug, Clone, Deserialize)] pub struct Config { pub port: u16, pub endpoint: String, diff --git a/src/nats.rs b/src/nats.rs index 90599c9..3bdb91a 100644 --- a/src/nats.rs +++ b/src/nats.rs @@ -1,7 +1,7 @@ use log::info; use serde::Deserialize; -#[derive(Debug, Deserialize)] +#[derive(Debug, Clone, Deserialize)] pub struct Config { pub port: u16, pub host: String, diff --git a/src/postgres.rs b/src/postgres.rs index c2972b9..3e46dab 100644 --- a/src/postgres.rs +++ b/src/postgres.rs @@ -2,7 +2,7 @@ use log::info; use serde::Deserialize; use sqlx::postgres::{PgPool, PgPoolOptions}; -#[derive(Debug, Deserialize)] +#[derive(Debug, Clone, Deserialize)] pub struct Config { pub name: String, pub host: String, diff --git a/src/scheduler.rs b/src/scheduler.rs index b77f88f..4820f8e 100644 --- a/src/scheduler.rs +++ b/src/scheduler.rs @@ -1,12 +1,12 @@ use std::error::Error; use std::sync::Arc; +use std::time; use async_trait::async_trait; use chrono::prelude::*; use log::{error, info, trace, warn}; #[cfg(test)] use mockall::predicate::*; -use std::time; use tokio_util::sync::CancellationToken; use uuid::Uuid; @@ -15,11 +15,12 @@ use crate::model::{Message, MetricEvent, Schedule, ScheduleError, Transmission}; static BATCH_SIZE: u32 = 100; static MAX_DELAYED_AGE: chrono::Duration = chrono::Duration::seconds(1); -static MIN_INTERVAL: time::Duration = time::Duration::from_millis(10); static MAX_NATS_SUBJECT_LENGTH: u32 = 256; #[derive(Clone)] pub struct TransmissionScheduler { + // clock_cycle_interval is the duration between each transmission batch. + clock_cycle_interval: time::Duration, // repository keeps the program stateless, by providing a storage interface to store and // retrieve transmissions. repository: Arc, @@ -34,7 +35,7 @@ pub struct TransmissionScheduler { #[async_trait] impl Scheduler for TransmissionScheduler { async fn schedule(&self, when: Schedule, what: Message) -> Result { - validate_schedule(self.now.now(), &when)?; + validate_schedule(self.now.now(), &when, self.clock_cycle_interval)?; validate_message(&what)?; let transmission = Transmission::new(when, what); @@ -51,7 +52,11 @@ impl Scheduler for TransmissionScheduler { } } -fn validate_schedule(now: DateTime, schedule: &Schedule) -> Result<(), ScheduleError> { +fn validate_schedule( + now: DateTime, + schedule: &Schedule, + clock_cycle_interval: time::Duration, +) -> Result<(), ScheduleError> { match schedule { Schedule::Delayed(delayed) => { if delayed.transmit_at - now < -MAX_DELAYED_AGE { @@ -64,7 +69,7 @@ fn validate_schedule(now: DateTime, schedule: &Schedule) -> Result<(), Sche if interval.first_transmission - now < -MAX_DELAYED_AGE { return Err(ScheduleError::AgedSchedule); } - if interval.interval < MIN_INTERVAL { + if interval.interval < clock_cycle_interval { return Err(ScheduleError::TooShortInterval); } @@ -113,12 +118,14 @@ fn validate_message(message: &Message) -> Result<(), ScheduleError> { impl TransmissionScheduler { pub fn new( + clock_cycle_interval: time::Duration, repository: Arc, transmitter: Arc, now: Arc, metrics: Arc, ) -> TransmissionScheduler { TransmissionScheduler { + clock_cycle_interval, repository, transmitter, now, @@ -142,7 +149,7 @@ impl TransmissionScheduler { break; } - _ = tokio::time::sleep(std::time::Duration::from_millis(100)) => {} + _ = tokio::time::sleep(self.clock_cycle_interval) => {} } } @@ -257,6 +264,8 @@ mod tests { use crate::contract::*; use crate::model::*; + const DEFAULT_CLOCK_CYCLE_INTERVAL: time::Duration = time::Duration::from_micros(10); + #[tokio::test] async fn test_poll_non_ready_schedule() { let timestamp_now = @@ -317,6 +326,7 @@ mod tests { .times(1); let scheduler = TransmissionScheduler::new( + DEFAULT_CLOCK_CYCLE_INTERVAL, Arc::new(repository), Arc::new(transmitter), Arc::new(now), @@ -388,6 +398,7 @@ mod tests { .times(1); let scheduler = TransmissionScheduler::new( + DEFAULT_CLOCK_CYCLE_INTERVAL, Arc::new(repository), Arc::new(transmitter), Arc::new(now), // First call to now will be too early for the given time, the next will @@ -477,6 +488,7 @@ mod tests { .times(3); let scheduler = TransmissionScheduler::new( + DEFAULT_CLOCK_CYCLE_INTERVAL, Arc::new(repository), Arc::new(transmitter), Arc::new(Utc::now), @@ -561,6 +573,7 @@ mod tests { .times(3); let scheduler = TransmissionScheduler::new( + DEFAULT_CLOCK_CYCLE_INTERVAL, Arc::new(repository), Arc::new(transmitter), Arc::new(Utc::now), @@ -594,6 +607,7 @@ mod tests { .times(1); let scheduler = TransmissionScheduler::new( + DEFAULT_CLOCK_CYCLE_INTERVAL, Arc::new(repository), Arc::new(transmitter), Arc::new(Utc::now), @@ -624,6 +638,7 @@ mod tests { .times(1); let scheduler = TransmissionScheduler::new( + DEFAULT_CLOCK_CYCLE_INTERVAL, Arc::new(repository), Arc::new(transmitter), Arc::new(Utc::now), @@ -837,6 +852,7 @@ mod tests { }; let scheduler = TransmissionScheduler::new( + DEFAULT_CLOCK_CYCLE_INTERVAL, Arc::new(repository), Arc::new(transmitter), Arc::new(Utc::now), @@ -906,6 +922,7 @@ mod tests { .times(amount_transmissions); let scheduler = TransmissionScheduler::new( + DEFAULT_CLOCK_CYCLE_INTERVAL, Arc::new(repository), Arc::new(transmitter), Arc::new(Utc::now), @@ -960,6 +977,7 @@ mod tests { .times(1); let scheduler = TransmissionScheduler::new( + DEFAULT_CLOCK_CYCLE_INTERVAL, Arc::new(repository), Arc::new(transmitter), Arc::new(Utc::now), @@ -1072,6 +1090,7 @@ mod tests { .times(1); let scheduler = TransmissionScheduler::new( + DEFAULT_CLOCK_CYCLE_INTERVAL, Arc::new(repository), Arc::new(transmitter), Arc::new(Utc::now), @@ -1181,6 +1200,7 @@ mod tests { .times(1); let scheduler = TransmissionScheduler::new( + DEFAULT_CLOCK_CYCLE_INTERVAL, Arc::new(repository), Arc::new(transmitter), Arc::new(now), @@ -1336,7 +1356,7 @@ mod tests { ]; for test_case in test_cases { - let valid = validate_schedule(now, &test_case.schedule); + let valid = validate_schedule(now, &test_case.schedule, DEFAULT_CLOCK_CYCLE_INTERVAL); match test_case.expected_result { Ok(()) => assert!( valid.is_ok(), diff --git a/tests/e2e.ron b/tests/e2e.ron index 2b8f4da..66422ac 100644 --- a/tests/e2e.ron +++ b/tests/e2e.ron @@ -2,6 +2,7 @@ automigrate: true, reset_state: true, log_level: "debug", + clock_cycle_interval: "100ms", metrics: Prometheus( Config( port: 9090,