From a8a34d8edb3063cb3ffec9a17193425be8886373 Mon Sep 17 00:00:00 2001 From: Lukas Pukenis Date: Wed, 25 Sep 2024 11:37:36 +0300 Subject: [PATCH] Add network-activity trigger ability for batcher Add triggering ability to batcher so it could evaluate deadlines and thresholds on demand. The approach is simple - any activity on sent or received data on any peer will trigger the batcher. This is built on assumption that triggering on incoming data would sync the batchers between two devices. However triggering only between two devices would leave other devices unsynced, thus simplified approach works even better to "sync the clocks" across all the nodes. Side effects: consider having 3 interconnected peers: A, B and C. Peer C is idling and A streams data to B. Now A and B is each triggered on every packet and in turn send premature keepalives to C at T_new=T_orig-threshold. Triggering will happen mostly on wg_consolidate() which happens every second. Signed-off-by: Lukas Pukenis --- .../LLT-5026_trigger_batching_incoming | 0 crates/telio-batcher/src/batcher.rs | 178 +++++++++++++++--- crates/telio-traversal/src/session_keeper.rs | 44 ++++- crates/telio-wg/src/wg.rs | 44 ++++- nat-lab/tests/test_batching.py | 75 ++++---- nat-lab/tests/utils/batching.py | 26 ++- src/device.rs | 8 +- 7 files changed, 290 insertions(+), 85 deletions(-) create mode 100644 .unreleased/LLT-5026_trigger_batching_incoming diff --git a/.unreleased/LLT-5026_trigger_batching_incoming b/.unreleased/LLT-5026_trigger_batching_incoming new file mode 100644 index 000000000..e69de29bb diff --git a/crates/telio-batcher/src/batcher.rs b/crates/telio-batcher/src/batcher.rs index 7af1eb22a..c970890ec 100644 --- a/crates/telio-batcher/src/batcher.rs +++ b/crates/telio-batcher/src/batcher.rs @@ -3,15 +3,28 @@ use std::fmt::Debug; use std::hash::Hash; use std::{collections::HashMap, sync::Arc}; -use telio_utils::telio_log_debug; +use telio_utils::{telio_log_debug, telio_log_warn}; +use tokio; use tokio::time::sleep_until; type Action> = Arc Fn(&'a mut V) -> BoxFuture<'a, R> + Sync + Send>; +/// Guards against triggers that happened long time ago +const TRIGGER_EFFECTIVE_DURATION: tokio::time::Duration = tokio::time::Duration::from_secs(5); + +/// Batcher holds (actions, interval, threshold). When polled, batcher will +/// return a list of actions that should be executed now. pub struct Batcher { actions: HashMap)>, + + /// Adding new action must be immediately returned from `get_actions` if polled right away. + /// In case we're already polling, we need to notify the tokio::select! about such an event. notify_add: tokio::sync::Notify, + + /// Triggering batcher must be handled inside of tokio::select! + notify_trigger_timestamp: Option, + notify_trigger: tokio::sync::Notify, } struct BatchEntry { @@ -37,13 +50,15 @@ where Self { actions: HashMap::new(), notify_add: tokio::sync::Notify::new(), + notify_trigger: tokio::sync::Notify::new(), + notify_trigger_timestamp: None, } } - /// Batching works by sleeping until the nearest future and then trying to batch more actions - /// based on the threshold value. Higher delay before calling the function will increase the chances of batching - /// because the deadlines will _probably_ be in the past already. - /// Adding a new action wakes up the batcher due to immediate trigger of the action. + /// Batcher works in a polling manner, meaning the call site must invoke the actions. + /// When polled, batcher will await until the nearest future, new action addition or a trigger. + /// Once resolved, batcher will try to batch all the actions that are within their respective + /// thresholds. pub async fn get_actions(&mut self) -> Vec<(K, Action)> { let mut batched_actions: Vec<(K, Action)> = vec![]; @@ -51,42 +66,61 @@ where if !self.actions.is_empty() { let actions = &mut self.actions; - // TODO: This can be optimized by early breaking and precollecting items beforehand - if let Some(closest_entry) = actions.values().min_by_key(|entry| entry.0.deadline) { - tokio::select! { - _ = self.notify_add.notified() => { - // Item was added, we need to immediately emit it - } - _ = sleep_until(closest_entry.0.deadline) => { - // Closest action should now be emitted + let active_trigger = self + .notify_trigger_timestamp + .take() + .map_or(false, |ts| ts.elapsed() < TRIGGER_EFFECTIVE_DURATION); + + if !active_trigger { + if let Some(closest_entry) = + actions.values().min_by_key(|entry| entry.0.deadline) + { + tokio::select! { + _ = self.notify_add.notified() => { + telio_log_debug!("New item added"); + } + _ = sleep_until(closest_entry.0.deadline) => { + telio_log_debug!("Action deadline reached"); + } + _ = self.notify_trigger.notified() => { + telio_log_debug!("Trigger received"); + } } } + } - let now = tokio::time::Instant::now(); - // at this point in time we know we're at the earliest spot for batching, thus we can check if we have more actions to add - for (key, action) in actions.iter_mut() { - let adjusted_action_deadline = now + action.0.threshold; + let now = tokio::time::Instant::now(); + for (key, action) in actions.iter_mut() { + let adjusted_deadline = now + action.0.threshold; - if action.0.deadline <= adjusted_action_deadline { - action.0.deadline = now + action.0.interval; - batched_actions.push((key.clone(), action.1.clone())); - } + if action.0.deadline <= adjusted_deadline { + action.0.deadline = now + action.0.interval; + batched_actions.push((key.clone(), action.1.clone())); } } return batched_actions; } else { - let _ = self.notify_add.notified().await; + _ = self.notify_add.notified().await; } } } - /// Remove batcher action. Action is no longer eligible for batching pub fn remove(&mut self, key: &K) { - telio_log_debug!("removing item from batcher with key({:?})", key); + telio_log_debug!("Removing item from batcher with key({:?})", key); self.actions.remove(key); } + /// Due to async nature of batcher code it will await until an action becomes available. + /// This function allows for premature evaluation of actions. + /// Calling this function in a tight loop with result in actions + /// being returned at T-threshold time. + pub fn trigger(&mut self) { + telio_log_debug!("Triggering batcher"); + self.notify_trigger_timestamp = Some(tokio::time::Instant::now()); + self.notify_trigger.notify_waiters(); + } + /// Add batcher action. Batcher itself doesn't run the tasks and depends /// on actions being manually invoked. Adding an action immediately triggers it /// thus if the call site awaits for the future then it will resolve immediately after this @@ -99,11 +133,25 @@ where action: Action, ) { telio_log_debug!( - "adding item to batcher with key({:?}), interval({:?}), threshold({:?})", + "Adding item to batcher with key({:?}), interval({:?}), threshold({:?})", key, interval, threshold, ); + + let threshold = { + if threshold >= interval { + let capped_threshold = interval / 2; + telio_log_warn!( + "Threshold should not be bigger than the interval. Overriding to ({:?})", + capped_threshold + ); + capped_threshold + } else { + threshold + } + }; + let entry = BatchEntry { deadline: tokio::time::Instant::now(), interval, @@ -133,6 +181,86 @@ mod tests { use crate::batcher::Batcher; + #[tokio::test(start_paused = true)] + async fn batch_and_trigger() { + let start_time = tokio::time::Instant::now(); + let mut batcher = Batcher::::new(); + + batcher.add( + "key0".to_owned(), + Duration::from_secs(100), + Duration::from_secs(50), + Arc::new(|s: _| { + Box::pin(async move { + s.values + .push(("key0".to_owned(), tokio::time::Instant::now())); + Ok(()) + }) + }), + ); + + let mut test_checker = TestChecker { values: Vec::new() }; + + // pick up the immediate fire + for ac in batcher.get_actions().await { + ac.1(&mut test_checker).await.unwrap(); + } + assert!(test_checker.values.len() == 1); + + let create_time_checkpoint = + |add: u64| tokio::time::Instant::now() + tokio::time::Duration::from_secs(add); + + let mut trigger_timepoints = vec![ + create_time_checkpoint(10), + create_time_checkpoint(20), + create_time_checkpoint(60), + create_time_checkpoint(90), + create_time_checkpoint(200), + create_time_checkpoint(270), + create_time_checkpoint(280), + create_time_checkpoint(730), + create_time_checkpoint(1000), + ]; + + use tokio::time::sleep_until; + loop { + tokio::select! { + _ = sleep_until(trigger_timepoints[0]) => { + batcher.trigger(); + trigger_timepoints.remove(0); + if trigger_timepoints.len() == 0 { + break + } + } + + actions = batcher.get_actions() => { + for ac in &actions { + ac.1(&mut test_checker).await.unwrap(); + } + } + } + } + + let key0_entries: Vec = test_checker + .values + .iter() + .filter(|e| e.0 == "key0") + .map(|e| e.1.duration_since(start_time)) + .collect(); + + let expected_diff_values: Vec = + vec![0, 60, 160, 260, 360, 460, 560, 660, 730, 830, 930] + .iter() + .map(|v| tokio::time::Duration::from_secs(*v)) + .collect(); + assert!( + key0_entries == expected_diff_values, + "expected: {:?}, got: {:?}", + expected_diff_values, + key0_entries + ); + } + #[tokio::test(start_paused = true)] async fn batch_one() { let start_time = tokio::time::Instant::now(); diff --git a/crates/telio-traversal/src/session_keeper.rs b/crates/telio-traversal/src/session_keeper.rs index 3b4130256..7ea587079 100644 --- a/crates/telio-traversal/src/session_keeper.rs +++ b/crates/telio-traversal/src/session_keeper.rs @@ -15,7 +15,7 @@ use telio_task::{task_exec, BoxAction, Runtime, Task}; use telio_utils::{ dual_target, repeated_actions, telio_log_debug, telio_log_warn, DualTarget, RepeatedActions, }; - +use telio_wg::NetworkActivityGetter; const PING_PAYLOAD_SIZE: usize = 56; /// Possible [SessionKeeper] errors. @@ -62,7 +62,10 @@ pub struct SessionKeeper { } impl SessionKeeper { - pub fn start(sock_pool: Arc) -> Result { + pub fn start( + sock_pool: Arc, + network_activity_getter: Option>, + ) -> Result { let (client_v4, client_v6) = ( PingerClient::new(&Self::make_builder(ICMP::V4).build()) .map_err(|e| Error::PingerCreationError(ICMP::V4, e))?, @@ -81,6 +84,9 @@ impl SessionKeeper { }, batched_actions: Batcher::new(), nonbatched_actions: RepeatedActions::default(), + network_activity_getter, + last_tx_ts: None, + last_rx_ts: None, }), }) } @@ -128,14 +134,19 @@ async fn ping(pingers: &Pingers, targets: (&PublicKey, &DualTarget)) -> Result<( let (primary, secondary) = targets.1.get_targets()?; let public_key = targets.0; - telio_log_debug!("Pinging primary target {:?} on {:?}", public_key, primary); - let primary_client = match primary { IpAddr::V4(_) => &pingers.pinger_client_v4, IpAddr::V6(_) => &pingers.pinger_client_v6, }; let ping_id = PingIdentifier(rand::random()); + + telio_log_debug!( + "Pinging primary target {:?} on {:?} with ping_id: {:?}", + public_key, + primary, + ping_id + ); if let Err(e) = primary_client .pinger(primary, ping_id) .await @@ -264,7 +275,11 @@ struct State { pingers: Pingers, batched_actions: Batcher, nonbatched_actions: RepeatedActions>, + network_activity_getter: Option>, + last_tx_ts: Option, + last_rx_ts: Option, } + #[async_trait] impl Runtime for State { const NAME: &'static str = "SessionKeeper"; @@ -274,6 +289,25 @@ impl Runtime for State { where F: Future>> + Send, { + let mut tx_has_changed = false; + let mut rx_has_changed = false; + + // We just care about any network activity, thus no per-peer filtering. + if let Some(wg) = self.network_activity_getter.as_ref() { + if let Ok(Some(timestamps)) = wg.get_ts().await { + tx_has_changed = self.last_tx_ts.map_or(false, |ts| timestamps.tx_ts > ts); + self.last_tx_ts = Some(timestamps.tx_ts); + + rx_has_changed = self.last_rx_ts.map_or(false, |ts| timestamps.rx_ts > ts); + self.last_rx_ts = Some(timestamps.rx_ts); + } + } + + if tx_has_changed || rx_has_changed { + telio_log_debug!("Triggering batcher based on network activity"); + self.batched_actions.trigger(); + } + tokio::select! { Ok((pk, action)) = self.nonbatched_actions.select_action() => { let pk = *pk; @@ -324,7 +358,7 @@ mod tests { ) .unwrap(), )); - let sess_keep = SessionKeeper::start(socket_pool).unwrap(); + let sess_keep = SessionKeeper::start(socket_pool, None).unwrap(); let pk = "REjdn4zY2TFx2AMujoNGPffo9vDiRDXpGG4jHPtx2AY=" .parse::() diff --git a/crates/telio-wg/src/wg.rs b/crates/telio-wg/src/wg.rs index 772583d63..e635d2426 100644 --- a/crates/telio-wg/src/wg.rs +++ b/crates/telio-wg/src/wg.rs @@ -44,6 +44,22 @@ use std::{ time::Duration, }; +/// Interface for retrieving stats about network activity +#[async_trait] +pub trait NetworkActivityGetter: Sync + Send { + /// Get network activity timestamps + async fn get_ts(&self) -> Result, Error>; +} + +#[async_trait] +impl NetworkActivityGetter for DynamicWg { + /// Retrieves latest tx/rx change accross all the nodes. Essentially showing the time of last + /// egress or ingress activity + async fn get_ts(&self) -> Result, Error> { + Ok(task_exec!(&self.task, async move |s| Ok(s.network_activity_ts)).await?) + } +} + /// WireGuard adapter interface #[cfg_attr(any(test, feature = "mockall"), mockall::automock)] #[async_trait] @@ -208,6 +224,16 @@ impl BytesAndTimestamps { } } +/// Timestamp pair for egress and ingress activity +#[derive(Copy, Clone, Debug)] +pub struct TxRxTimestampPair { + /// Egress activity timestamp + pub tx_ts: Instant, + + /// Ingress activity timestamp + pub rx_ts: Instant, +} + struct State { #[cfg(unix)] cfg: Config, @@ -227,7 +253,7 @@ struct State { libtelio_event: Option>>, stats: HashMap>>, - + network_activity_ts: Option, ip_stack: Option, } @@ -362,6 +388,7 @@ impl DynamicWg { libtelio_event: io.libtelio_wide_event_publisher, stats: HashMap::new(), ip_stack: None, + network_activity_ts: Default::default(), }), } } @@ -786,10 +813,17 @@ impl State { if let Some(stats) = self.stats.get_mut(key) { match stats.lock().as_mut() { - Ok(s) => s.update( - new.rx_bytes.unwrap_or_default(), - new.tx_bytes.unwrap_or_default(), - ), + Ok(s) => { + s.update( + new.rx_bytes.unwrap_or_default(), + new.tx_bytes.unwrap_or_default(), + ); + + if let (Some(tx_ts), Some(rx_ts)) = (s.get_tx_ts(), s.get_rx_ts()) { + self.network_activity_ts = Some(TxRxTimestampPair { tx_ts, rx_ts }); + } + } + Err(e) => { telio_log_error!("poisoned lock - {}", e); } diff --git a/nat-lab/tests/test_batching.py b/nat-lab/tests/test_batching.py index 444f14d5d..a7c542034 100644 --- a/nat-lab/tests/test_batching.py +++ b/nat-lab/tests/test_batching.py @@ -6,7 +6,7 @@ from itertools import zip_longest from scapy.layers.inet import TCP, UDP # type: ignore from timeouts import TEST_BATCHING_TIMEOUT -from typing import List, Tuple +from typing import List, Tuple, Any from utils.batching import ( capture_traffic, print_histogram, @@ -23,9 +23,10 @@ ) from utils.connection import DockerConnection from utils.connection_util import DOCKER_GW_MAP, ConnectionTag, container_id +from utils.ping import ping -BATCHING_MISALIGN_RANGE = (0, 5) # Seconds to sleep for peers before starting -BATCHING_CAPTURE_TIME = 240 # Tied to TEST_BATCHING_TIMEOUT +BATCHING_MISALIGN_RANGE = (0, 3) # Seconds to sleep for peers before starting +BATCHING_CAPTURE_TIME = 30 # Tied to TEST_BATCHING_TIMEOUT def _generate_setup_parameters( @@ -49,9 +50,7 @@ def _generate_setup_parameters( ) return SetupParameters( - connection_tag=conn_tag, - adapter_type_override=adapter, - features=features, + connection_tag=conn_tag, adapter_type_override=adapter, features=features ) @@ -167,20 +166,14 @@ async def test_batching( capture_duration: int, ) -> None: async with AsyncExitStack() as exit_stack: - env = await exit_stack.enter_async_context( - setup_environment(exit_stack, setup_params) - ) - - await asyncio.gather(*[ - client.wait_for_state_on_any_derp([RelayState.CONNECTED]) - for client, instance in zip_longest(env.clients, setup_params) - if instance.derp_servers != [] - ]) - # We capture the traffic from all nodes and gateways. # On gateways we are sure the traffic has left the machine, however no easy way to # inspect the packets(encrypted by wireguard). For packet inspection # client traffic can be inspected. + env = await exit_stack.enter_async_context( + setup_environment(exit_stack, setup_params) + ) + gateways = [DOCKER_GW_MAP[param.connection_tag] for param in setup_params] gateway_container_names = [container_id(conn_tag) for conn_tag in gateways] conns = [client.get_connection() for client in env.clients] @@ -191,22 +184,39 @@ async def test_batching( ] container_names = gateway_container_names + node_container_names - print("Will capture batching on containers: ", container_names) - cnodes = zip(env.clients, env.nodes) + print("Will capture traffic on containers: ", container_names) - # Misalign the peers by first stopping all of them and then restarting after various delays. - # This will have an effect of forcing neighboring libtelio node to add the peer to internal lists - # for keepalives at various points in time thus allowing us to observe better - # if the local batching is in action. + pcap_capture_tasks: List[Any] = [] + for name in container_names: + pcap_task = asyncio.create_task( + capture_traffic( + name, + capture_duration, + ) + ) + pcap_capture_tasks.append(pcap_task) + + # at this point packet captures are running + await asyncio.gather(*[ + client.wait_for_state_on_any_derp([RelayState.CONNECTED]) + for client, instance in zip_longest(env.clients, setup_params) + if instance.derp_servers != [] + ]) + + # At this stage all peers have been started and connected to DERP server meaning they are ready. + # It's a good time to misalign the peers by stopping all of them and then sleeping for random amounts + # of time in parallel before starting again. This gives a more realistic view as when peer comes online, + # it's added to other peers meshmaps and misalignment occurs naturally since peers already were online. + # In NatLab all peers start at more or less the same time normally, preventing batching to do anything useful. for client in env.clients: await client.stop_device() - # misalign the peers by sleeping some before starting each node again async def start_node_manually(client, node, sleep_min: int, sleep_max: int): await asyncio.sleep(random.randint(sleep_min, sleep_max)) await client.simple_start() await client.set_meshnet_config(env.api.get_meshnet_config(node.id)) + cnodes = zip(env.clients, env.nodes) await asyncio.gather(*[ start_node_manually( client, node, misalign_sleep_range[0], misalign_sleep_range[1] @@ -229,21 +239,20 @@ async def start_node_manually(client, node, sleep_min: int, sleep_max: int): ), ] - pcap_capture_tasks = [] - for name in container_names: - pcap_task = asyncio.create_task( - capture_traffic( - name, - capture_duration, - ) - ) - pcap_capture_tasks.append(pcap_task) + await asyncio.gather(*[ + client.wait_for_state_on_any_derp([RelayState.CONNECTED]) + for client, instance in zip_longest(env.clients, setup_params) + if instance.derp_servers != [] + ]) - pcap_paths = await asyncio.gather(*pcap_capture_tasks) + pcap_paths: list[str] = await asyncio.gather(*pcap_capture_tasks) + # Once capture tasks end, we reached the end of the test for container, pcap_path in zip(container_names, pcap_paths): for filt in allow_pcap_filters: filter_name = filt[0] hs = generate_histogram_from_pcap(pcap_path, capture_duration, filt[1]) title = f"{container}-filter({filter_name})" print_histogram(title, hs, max_height=12) + + # In the end, the histograms are captured that can now be used for observing the results diff --git a/nat-lab/tests/utils/batching.py b/nat-lab/tests/utils/batching.py index 8c84e1c37..5f5a65bd9 100644 --- a/nat-lab/tests/utils/batching.py +++ b/nat-lab/tests/utils/batching.py @@ -31,9 +31,6 @@ def _generate_histogram( async def capture_traffic(container_name: str, duration_s: int) -> str: - cmd_rm = f"docker exec --privileged {container_name} rm /home/capture.pcap" - os.system(cmd_rm) - iface = "any" capture_path = "/home/capture.pcap" @@ -44,20 +41,19 @@ async def capture_traffic(container_name: str, duration_s: int) -> str: await asyncio.sleep(duration_s) - with tempfile.NamedTemporaryFile() as tmpfile: - local_path = f"{tmpfile.name}.pcap" - print(f"Copying pcap to {local_path}") - subprocess.run([ - "docker", - "cp", - container_name + ":" + "/home/capture.pcap", - local_path, - ]) + local_path = f"{tempfile.mkstemp(suffix='.pcap')[1]}" + print(f"Copying pcap to {local_path}") + subprocess.run([ + "docker", + "cp", + container_name + ":" + "/home/capture.pcap", + local_path, + ]) - cmd_rm = f"docker exec --privileged {container_name} pkill tcpdump" - os.system(cmd_rm) + cmd_rm = f"docker exec --privileged {container_name} pkill tcpdump" + os.system(cmd_rm) - return local_path + return local_path # Render ASCII histogram drawing for visual inspection diff --git a/src/device.rs b/src/device.rs index c2ec82aed..be7bfc74b 100644 --- a/src/device.rs +++ b/src/device.rs @@ -1430,7 +1430,12 @@ impl Runtime { self.requested_state.device_config.private_key.public(), )?); - match SessionKeeper::start(self.entities.socket_pool.clone()).map(Arc::new) { + match SessionKeeper::start( + self.entities.socket_pool.clone(), + Some(self.entities.wireguard_interface.clone()), + ) + .map(Arc::new) + { Ok(session_keeper) => Some(DirectEntities { local_interfaces_endpoint_provider, stun_endpoint_provider, @@ -1835,7 +1840,6 @@ impl Runtime { // Update configuration for DERP client meshnet_entities.derp.configure(Some(derp_config)).await; - // Refresh the lists of servers for STUN endpoint provider if let Some(direct) = meshnet_entities.direct.as_ref() { if let Some(stun_ep) = direct.stun_endpoint_provider.as_ref() {