From ee649235ef87e48667d897aab7f6074109cfcd4b Mon Sep 17 00:00:00 2001 From: Daisuke Murase Date: Mon, 13 Jan 2025 12:21:50 -0800 Subject: [PATCH] remove wait_for_low function, add functionality to wait it in publish_data --- livekit/src/room/mod.rs | 6 -- .../src/room/participant/local_participant.rs | 45 ++------------- livekit/src/rtc_engine/mod.rs | 8 --- livekit/src/rtc_engine/rtc_events.rs | 10 ---- livekit/src/rtc_engine/rtc_session.rs | 55 ++++++++++++------- 5 files changed, 40 insertions(+), 84 deletions(-) diff --git a/livekit/src/room/mod.rs b/livekit/src/room/mod.rs index 3cb8b260..051b56f0 100644 --- a/livekit/src/room/mod.rs +++ b/livekit/src/room/mod.rs @@ -734,12 +734,6 @@ impl RoomSession { EngineEvent::DataStreamChunk { chunk, participant_identity } => { self.handle_data_stream_chunk(chunk, participant_identity); } - EngineEvent::DataChannelBufferedAmountChanged { buffered_amount } => { - let local_participant = self.local_participant.clone(); - livekit_runtime::spawn(async move { - local_participant.handle_dc_buffered_amount_changed(buffered_amount).await; - }); - }, _ => {} } diff --git a/livekit/src/room/participant/local_participant.rs b/livekit/src/room/participant/local_participant.rs index 9359da26..861b19b5 100644 --- a/livekit/src/room/participant/local_participant.rs +++ b/livekit/src/room/participant/local_participant.rs @@ -481,48 +481,13 @@ impl LocalParticipant { self.inner.rtc_engine.publish_data(&data, kind).await.map_err(Into::into) } - pub async fn wait_for_dc_buffer_low(&self) -> RoomResult<()> { - let threshold = - self.local.dc_buffered_amount_low_threshold.load(std::sync::atomic::Ordering::Relaxed); - let amount = - self.inner.rtc_engine.session().data_channel_buffered_amount(DataPacketKind::Reliable); - - if amount <= threshold { - return Ok(()); - } - - // wait buffered amount becam low - let rx = { - let (tx, rx) = oneshot::channel(); - - let mut low_tx = self.local.dc_buffered_amount_low_tx.lock(); - if low_tx.is_some() { - return Err(RoomError::Request { - reason: Reason::NotAllowed, - message: "Another wait request is already in progress.".into(), - }); - } - *low_tx = Some(tx); - rx - }; - - match rx.await { - Ok(()) => Ok(()), - Err(err) => Err(RoomError::Internal(format!("failed to wait: {}", err))), - } + pub async fn set_data_channel_buffered_amount_low_threshold(&self, threshold: u64) -> RoomResult<()> { + self.inner.rtc_engine.session().set_data_channel_buffered_amount_low_threshold(threshold); + Ok(()) } - pub(crate) async fn handle_dc_buffered_amount_changed(&self, buffered_amount: u64) { - let threshold = - self.local.dc_buffered_amount_low_threshold.load(std::sync::atomic::Ordering::Relaxed); - if buffered_amount > threshold { - return; - } - let Some(tx) = self.local.dc_buffered_amount_low_tx.lock().take() else { - return; - }; - log::debug!("return wait_for_dc_buffer_low: buffered={}, threshold={}", buffered_amount, threshold); - let _ = tx.send(()); + pub async fn data_channel_buffered_amount_low_threshold(&self) -> RoomResult { + Ok(self.inner.rtc_engine.session().data_channel_buffered_amount_low_threshold()) } pub async fn publish_transcription(&self, packet: Transcription) -> RoomResult<()> { diff --git a/livekit/src/rtc_engine/mod.rs b/livekit/src/rtc_engine/mod.rs index b3c6aa13..2cd54521 100644 --- a/livekit/src/rtc_engine/mod.rs +++ b/livekit/src/rtc_engine/mod.rs @@ -167,9 +167,6 @@ pub enum EngineEvent { chunk: proto::data_stream::Chunk, participant_identity: String, }, - DataChannelBufferedAmountChanged { - buffered_amount: u64, - }, } /// Represents a running RtcSession with the ability to close the session @@ -545,11 +542,6 @@ impl EngineInner { .engine_tx .send(EngineEvent::DataStreamChunk { chunk, participant_identity }); } - SessionEvent::DataChannelBufferedAmountChanged { buffered_amount } => { - let _ = self - .engine_tx - .send(EngineEvent::DataChannelBufferedAmountChanged { buffered_amount }); - } } Ok(()) } diff --git a/livekit/src/rtc_engine/rtc_events.rs b/livekit/src/rtc_engine/rtc_events.rs index 82c69ef0..004fb229 100644 --- a/livekit/src/rtc_engine/rtc_events.rs +++ b/livekit/src/rtc_engine/rtc_events.rs @@ -51,9 +51,6 @@ pub enum RtcEvent { data: Vec, binary: bool, }, - DataChannelStateChange { - state: DataChannelState, - }, DataChannelBufferedAmountChange { sent: u64, amount: u64, @@ -149,12 +146,6 @@ fn on_message(emitter: RtcEmitter) -> rtc::data_channel::OnMessage { }) } -fn on_state_change(emitter: RtcEmitter) -> rtc::data_channel::OnStateChange { - Box::new(move |state| { - let _ = emitter.send(RtcEvent::DataChannelStateChange { state }); - }) -} - fn on_buffered_amount_change( emitter: RtcEmitter, dc: DataChannel, @@ -168,6 +159,5 @@ fn on_buffered_amount_change( pub fn forward_dc_events(dc: &mut DataChannel, kind: DataPacketKind, rtc_emitter: RtcEmitter) { dc.on_message(Some(on_message(rtc_emitter.clone()))); - dc.on_state_change(Some(on_state_change(rtc_emitter.clone()))); dc.on_buffered_amount_change(Some(on_buffered_amount_change(rtc_emitter, dc.clone(), kind))); } diff --git a/livekit/src/rtc_engine/rtc_session.rs b/livekit/src/rtc_engine/rtc_session.rs index c4a12534..5b888d6b 100644 --- a/livekit/src/rtc_engine/rtc_session.rs +++ b/livekit/src/rtc_engine/rtc_session.rs @@ -35,7 +35,7 @@ use proto::{ SignalTarget, }; use serde::{de::IntoDeserializer, Deserialize, Serialize}; -use tokio::sync::{mpsc, oneshot, watch}; +use tokio::sync::{mpsc, oneshot, watch, Notify}; use super::{rtc_events, EngineError, EngineOptions, EngineResult, SimulateScenario}; use crate::{id::ParticipantIdentity, ChatMessage, TranscriptionSegment}; @@ -143,9 +143,6 @@ pub enum SessionEvent { chunk: proto::data_stream::Chunk, participant_identity: String, }, - DataChannelBufferedAmountChanged { - buffered_amount: u64, - }, } #[derive(Serialize, Deserialize)] @@ -169,9 +166,10 @@ struct SessionInner { // Publisher data channels // used to send data to other participants (The SFU forwards the messages) lossy_dc: DataChannel, - lossy_dc_buffered_amount: AtomicU64, reliable_dc: DataChannel, reliable_dc_buffered_amount: AtomicU64, + reliable_dc_buffered_amount_low_threshold: AtomicU64, + reliable_dc_buffered_amount_low_notify: Notify, // Keep a strong reference to the subscriber datachannels, // so we can receive data from other participants @@ -263,9 +261,10 @@ impl RtcSession { subscriber_pc, pending_tracks: Default::default(), lossy_dc, - lossy_dc_buffered_amount: Default::default(), reliable_dc, reliable_dc_buffered_amount: Default::default(), + reliable_dc_buffered_amount_low_threshold: Default::default(), + reliable_dc_buffered_amount_low_notify: Default::default(), sub_lossy_dc: Mutex::new(None), sub_reliable_dc: Mutex::new(None), closed: Default::default(), @@ -377,18 +376,23 @@ impl RtcSession { self.inner.data_channel(target, kind) } - pub fn data_channel_buffered_amount(&self, kind: DataPacketKind) -> u64 { - match kind { - DataPacketKind::Lossy => self.inner.lossy_dc_buffered_amount.load(Ordering::Relaxed), - DataPacketKind::Reliable => { - self.inner.reliable_dc_buffered_amount.load(Ordering::Relaxed) - } - } + pub fn data_channel_buffered_amount(&self) -> u64 { + self.inner.reliable_dc_buffered_amount.load(Ordering::Relaxed) + } + + pub fn data_channel_buffered_amount_low_threshold(&self) -> u64 { + self.inner.reliable_dc_buffered_amount_low_threshold.load(Ordering::Relaxed) + } + + pub fn set_data_channel_buffered_amount_low_threshold(&self, threshold: u64) { + self.inner.reliable_dc_buffered_amount_low_threshold.store(threshold, Ordering::Relaxed) } pub async fn get_response(&self, request_id: u32) -> proto::RequestResponse { self.inner.get_response(request_id).await } + + } impl SessionInner { @@ -763,19 +767,17 @@ impl SessionInner { } } } - RtcEvent::DataChannelStateChange { state } => { - } RtcEvent::DataChannelBufferedAmountChange { sent: _, amount, kind } => { match kind { DataPacketKind::Lossy => { - self.lossy_dc_buffered_amount.store(amount, Ordering::Relaxed) + // Do nothing at this moment } DataPacketKind::Reliable => { self.reliable_dc_buffered_amount.store(amount, Ordering::Relaxed); - // Only reliable dc is needed this event at this time - let _ = self.emitter.send(SessionEvent::DataChannelBufferedAmountChanged { - buffered_amount: amount, - }); + let threshold = self.reliable_dc_buffered_amount_low_threshold.load(Ordering::Relaxed); + if amount <= threshold { + self.reliable_dc_buffered_amount_low_notify.notify_one(); + } } } } @@ -990,6 +992,9 @@ impl SessionInner { kind: DataPacketKind, ) -> Result<(), EngineError> { self.ensure_publisher_connected(kind).await?; + if kind == DataPacketKind::Reliable { + self.wait_buffer_low().await?; + } self.data_channel(SignalTarget::Publisher, kind) .unwrap() .send(&data.encode_to_vec(), true) @@ -998,6 +1003,16 @@ impl SessionInner { }) } + async fn wait_buffer_low(&self) -> Result<(), EngineError> { + let amount = self.reliable_dc_buffered_amount.load(Ordering::Relaxed); + let threshold = self.reliable_dc_buffered_amount_low_threshold.load(Ordering::Relaxed); + if amount <= threshold { + return Ok(()) + } + self.reliable_dc_buffered_amount_low_notify.notified().await; + Ok(()) + } + /// This reconnection if more seemless compared to the full reconnection implemented in /// ['RTCEngine'] async fn restart(&self) -> EngineResult {