diff --git a/.nanpa/dc-buffered-amount-low-threshold.kdl b/.nanpa/dc-buffered-amount-low-threshold.kdl new file mode 100644 index 00000000..e5a46b93 --- /dev/null +++ b/.nanpa/dc-buffered-amount-low-threshold.kdl @@ -0,0 +1,6 @@ +patch type="added" package="libwebrtc" "Expose DataChannel.bufferedAmount property" +patch type="fixed" package="livekit" "Wait for the buffered amount to become low before sending data during publish_data for Reliable Data Channel" +patch type="added" package="livekit" "Add an API to set buffer_amount_low_threshold for DataChannel" +patch type="added" package="livekit" "Update RoomInfo to contain buffer_amount_low_threshold for DataChannel" +patch type="added" package="livekit-ffi" "Add an API to set buffer_amount_low_threshold for DataChannel" +patch type="added" package="livekit-ffi" "Update RoomInfo to contain buffer_amount_low_threshold for DataChannel" diff --git a/libwebrtc/src/data_channel.rs b/libwebrtc/src/data_channel.rs index a8ac1f0e..3b1629ef 100644 --- a/libwebrtc/src/data_channel.rs +++ b/libwebrtc/src/data_channel.rs @@ -97,6 +97,10 @@ impl DataChannel { self.handle.close() } + pub fn buffered_amount(&self) -> u64 { + self.handle.buffered_amount() + } + pub fn on_state_change(&self, callback: Option) { self.handle.on_state_change(callback) } diff --git a/libwebrtc/src/native/data_channel.rs b/libwebrtc/src/native/data_channel.rs index ba36f041..fa5e6b75 100644 --- a/libwebrtc/src/native/data_channel.rs +++ b/libwebrtc/src/native/data_channel.rs @@ -94,6 +94,10 @@ impl DataChannel { self.sys_handle.close(); } + pub fn buffered_amount(&self) -> u64 { + self.sys_handle.buffered_amount() + } + pub fn on_state_change(&self, handler: Option) { *self.observer.state_change_handler.lock() = handler; } diff --git a/livekit-ffi/protocol/ffi.proto b/livekit-ffi/protocol/ffi.proto index daf69f83..8ccd2d17 100644 --- a/livekit-ffi/protocol/ffi.proto +++ b/livekit-ffi/protocol/ffi.proto @@ -116,6 +116,9 @@ message FfiRequest { SendStreamHeaderRequest send_stream_header = 44; SendStreamChunkRequest send_stream_chunk = 45; SendStreamTrailerRequest send_stream_trailer = 46; + + // Data Channel + SetDataChannelBufferedAmountLowThresholdRequest set_data_channel_buffered_amount_low_threshold = 47; } } @@ -178,6 +181,9 @@ message FfiResponse { SendStreamHeaderResponse send_stream_header = 43; SendStreamChunkResponse send_stream_chunk = 44; SendStreamTrailerResponse send_stream_trailer = 45; + + // Data Channel + SetDataChannelBufferedAmountLowThresholdResponse set_data_channel_buffered_amount_low_threshold = 46; } } diff --git a/livekit-ffi/protocol/room.proto b/livekit-ffi/protocol/room.proto index aa10986d..87ad8fb5 100644 --- a/livekit-ffi/protocol/room.proto +++ b/livekit-ffi/protocol/room.proto @@ -370,6 +370,7 @@ message RoomEvent { DataStreamHeaderReceived stream_header_received = 30; DataStreamChunkReceived stream_chunk_received = 31; DataStreamTrailerReceived stream_trailer_received = 32; + DataChannelBufferedAmountLowThresholdChanged data_channel_low_threshold_changed = 33; } } @@ -377,6 +378,8 @@ message RoomInfo { optional string sid = 1; required string name = 2; required string metadata = 3; + required uint64 lossy_dc_buffered_amount_low_threshold = 4; + required uint64 reliable_dc_buffered_amount_low_threshold = 5; } message OwnedRoom { @@ -647,3 +650,17 @@ message SendStreamTrailerCallback { required uint64 async_id = 1; optional string error = 2; } + +message SetDataChannelBufferedAmountLowThresholdRequest { + required uint64 local_participant_handle = 1; + required uint64 threshold = 2; + required DataPacketKind kind = 3; +} + +message SetDataChannelBufferedAmountLowThresholdResponse { +} + +message DataChannelBufferedAmountLowThresholdChanged { + required DataPacketKind kind = 1; + required uint64 threshold = 2; +} diff --git a/livekit-ffi/src/conversion/room.rs b/livekit-ffi/src/conversion/room.rs index 260c280a..adf1bfce 100644 --- a/livekit-ffi/src/conversion/room.rs +++ b/livekit-ffi/src/conversion/room.rs @@ -252,6 +252,12 @@ impl From<&FfiRoom> for proto::RoomInfo { sid: room.maybe_sid().map(|x| x.to_string()), name: room.name(), metadata: room.metadata(), + lossy_dc_buffered_amount_low_threshold: room + .data_channel_options(DataPacketKind::Lossy) + .buffered_amount_low_threshold, + reliable_dc_buffered_amount_low_threshold: room + .data_channel_options(DataPacketKind::Reliable) + .buffered_amount_low_threshold, } } } diff --git a/livekit-ffi/src/livekit.proto.rs b/livekit-ffi/src/livekit.proto.rs index 45e0cd05..a10af08d 100644 --- a/livekit-ffi/src/livekit.proto.rs +++ b/livekit-ffi/src/livekit.proto.rs @@ -1,5 +1,4 @@ // @generated -// This file is @generated by prost-build. #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct FrameCryptor { @@ -2635,7 +2634,7 @@ pub struct OwnedBuffer { pub struct RoomEvent { #[prost(uint64, required, tag="1")] pub room_handle: u64, - #[prost(oneof="room_event::Message", tags="2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32")] + #[prost(oneof="room_event::Message", tags="2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33")] pub message: ::core::option::Option, } /// Nested message and enum types in `RoomEvent`. @@ -2707,6 +2706,8 @@ pub mod room_event { StreamChunkReceived(super::DataStreamChunkReceived), #[prost(message, tag="32")] StreamTrailerReceived(super::DataStreamTrailerReceived), + #[prost(message, tag="33")] + DataChannelLowThresholdChanged(super::DataChannelBufferedAmountLowThresholdChanged), } } #[allow(clippy::derive_partial_eq_without_eq)] @@ -2718,6 +2719,10 @@ pub struct RoomInfo { pub name: ::prost::alloc::string::String, #[prost(string, required, tag="3")] pub metadata: ::prost::alloc::string::String, + #[prost(uint64, required, tag="4")] + pub lossy_dc_buffered_amount_low_threshold: u64, + #[prost(uint64, required, tag="5")] + pub reliable_dc_buffered_amount_low_threshold: u64, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] @@ -3217,6 +3222,28 @@ pub struct SendStreamTrailerCallback { #[prost(string, optional, tag="2")] pub error: ::core::option::Option<::prost::alloc::string::String>, } +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct SetDataChannelBufferedAmountLowThresholdRequest { + #[prost(uint64, required, tag="1")] + pub local_participant_handle: u64, + #[prost(uint64, required, tag="2")] + pub threshold: u64, + #[prost(enumeration="DataPacketKind", required, tag="3")] + pub kind: i32, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct SetDataChannelBufferedAmountLowThresholdResponse { +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct DataChannelBufferedAmountLowThresholdChanged { + #[prost(enumeration="DataPacketKind", required, tag="1")] + pub kind: i32, + #[prost(uint64, required, tag="2")] + pub threshold: u64, +} #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum IceTransportType { @@ -3989,7 +4016,7 @@ pub struct RpcMethodInvocationEvent { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct FfiRequest { - #[prost(oneof="ffi_request::Message", tags="2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46")] + #[prost(oneof="ffi_request::Message", tags="2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47")] pub message: ::core::option::Option, } /// Nested message and enum types in `FfiRequest`. @@ -4094,13 +4121,16 @@ pub mod ffi_request { SendStreamChunk(super::SendStreamChunkRequest), #[prost(message, tag="46")] SendStreamTrailer(super::SendStreamTrailerRequest), + /// Data Channel + #[prost(message, tag="47")] + SetDataChannelBufferedAmountLowThreshold(super::SetDataChannelBufferedAmountLowThresholdRequest), } } /// This is the output of livekit_ffi_request function. #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct FfiResponse { - #[prost(oneof="ffi_response::Message", tags="2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45")] + #[prost(oneof="ffi_response::Message", tags="2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46")] pub message: ::core::option::Option, } /// Nested message and enum types in `FfiResponse`. @@ -4203,6 +4233,9 @@ pub mod ffi_response { SendStreamChunk(super::SendStreamChunkResponse), #[prost(message, tag="45")] SendStreamTrailer(super::SendStreamTrailerResponse), + /// Data Channel + #[prost(message, tag="46")] + SetDataChannelBufferedAmountLowThreshold(super::SetDataChannelBufferedAmountLowThresholdResponse), } } /// To minimize complexity, participant events are not included in the protocol. diff --git a/livekit-ffi/src/server/requests.rs b/livekit-ffi/src/server/requests.rs index a002a67b..8a8b5be8 100644 --- a/livekit-ffi/src/server/requests.rs +++ b/livekit-ffi/src/server/requests.rs @@ -905,6 +905,20 @@ fn on_rpc_method_invocation_response( Ok(proto::RpcMethodInvocationResponseResponse { error }) } +fn on_set_data_channel_buffered_amount_low_threshold( + server: &'static FfiServer, + set_data_channel_buffered_amount_low_threshold: proto::SetDataChannelBufferedAmountLowThresholdRequest, +) -> FfiResult { + let ffi_participant = server + .retrieve_handle::( + set_data_channel_buffered_amount_low_threshold.local_participant_handle, + )? + .clone(); + Ok(ffi_participant.room.set_data_channel_buffered_amount_low_threshold( + set_data_channel_buffered_amount_low_threshold, + )) +} + #[allow(clippy::field_reassign_with_default)] // Avoid uggly format pub fn handle_request( server: &'static FfiServer, @@ -1078,6 +1092,11 @@ pub fn handle_request( server, request, )?) } + proto::ffi_request::Message::SetDataChannelBufferedAmountLowThreshold(request) => { + proto::ffi_response::Message::SetDataChannelBufferedAmountLowThreshold( + on_set_data_channel_buffered_amount_low_threshold(server, request)?, + ) + } }); Ok(res) diff --git a/livekit-ffi/src/server/room.rs b/livekit-ffi/src/server/room.rs index fcfeee26..e3eea0d9 100644 --- a/livekit-ffi/src/server/room.rs +++ b/livekit-ffi/src/server/room.rs @@ -774,6 +774,17 @@ impl RoomInner { ) -> Option>> { return self.rpc_method_invocation_waiters.lock().remove(&invocation_id); } + + pub fn set_data_channel_buffered_amount_low_threshold( + &self, + request: proto::SetDataChannelBufferedAmountLowThresholdRequest, + ) -> proto::SetDataChannelBufferedAmountLowThresholdResponse { + let _ = self.room.local_participant().set_data_channel_buffered_amount_low_threshold( + request.threshold, + request.kind().into(), + ); + proto::SetDataChannelBufferedAmountLowThresholdResponse {} + } } // Task used to publish data without blocking the client thread @@ -1246,6 +1257,14 @@ async fn forward_event( proto::DataStreamTrailerReceived { trailer: trailer.into(), participant_identity }, )); } + RoomEvent::DataChannelBufferedAmountLowThresholdChanged { kind, threshold } => { + let _ = send_event(proto::room_event::Message::DataChannelLowThresholdChanged( + proto::DataChannelBufferedAmountLowThresholdChanged { + kind: proto::DataPacketKind::from(kind).into(), + threshold, + }, + )); + } _ => { log::warn!("unhandled room event: {:?}", event); } diff --git a/livekit/src/room/mod.rs b/livekit/src/room/mod.rs index 9437e5ff..762bec1f 100644 --- a/livekit/src/room/mod.rs +++ b/livekit/src/room/mod.rs @@ -43,7 +43,7 @@ use crate::{ prelude::*, rtc_engine::{ EngineError, EngineEvent, EngineEvents, EngineOptions, EngineResult, RtcEngine, - SessionStats, + SessionStats, INITIAL_BUFFERED_AMOUNT_LOW_THRESHOLD, }, }; @@ -196,6 +196,10 @@ pub enum RoomEvent { }, Reconnecting, Reconnected, + DataChannelBufferedAmountLowThresholdChanged { + kind: DataPacketKind, + threshold: u64, + }, } #[derive(Debug, Clone, Copy, Eq, PartialEq)] @@ -360,6 +364,19 @@ impl Debug for Room { struct RoomInfo { metadata: String, state: ConnectionState, + lossy_dc_options: DataChannelOptions, + reliable_dc_options: DataChannelOptions, +} + +#[derive(Clone)] +pub struct DataChannelOptions { + pub buffered_amount_low_threshold: u64, +} + +impl Default for DataChannelOptions { + fn default() -> Self { + Self { buffered_amount_low_threshold: INITIAL_BUFFERED_AMOUNT_LOW_THRESHOLD } + } } pub(crate) struct RoomSession { @@ -506,6 +523,8 @@ impl Room { info: RwLock::new(RoomInfo { state: ConnectionState::Disconnected, metadata: room_info.metadata, + lossy_dc_options: Default::default(), + reliable_dc_options: Default::default(), }), remote_participants: Default::default(), active_speakers: Default::default(), @@ -623,6 +642,13 @@ impl Room { pub fn e2ee_manager(&self) -> &E2eeManager { &self.inner.e2ee_manager } + + pub fn data_channel_options(&self, kind: DataPacketKind) -> DataChannelOptions { + match kind { + DataPacketKind::Lossy => self.inner.info.read().lossy_dc_options.clone(), + DataPacketKind::Reliable => self.inner.info.read().reliable_dc_options.clone(), + } + } } impl RoomSession { @@ -741,6 +767,9 @@ impl RoomSession { EngineEvent::DataStreamTrailer { trailer, participant_identity } => { self.handle_data_stream_trailer(trailer, participant_identity); } + EngineEvent::DataChannelBufferedAmountLowThresholdChanged { kind, threshold } => { + self.handle_data_channel_buffered_low_threshold_change(kind, threshold); + } _ => {} } @@ -1278,6 +1307,24 @@ impl RoomSession { self.dispatcher.dispatch(&event); } + fn handle_data_channel_buffered_low_threshold_change( + &self, + kind: DataPacketKind, + threshold: u64, + ) { + let mut info = self.info.write(); + match kind { + DataPacketKind::Lossy => { + info.lossy_dc_options.buffered_amount_low_threshold = threshold; + } + DataPacketKind::Reliable => { + info.reliable_dc_options.buffered_amount_low_threshold = threshold; + } + } + let event = RoomEvent::DataChannelBufferedAmountLowThresholdChanged { kind, threshold }; + self.dispatcher.dispatch(&event); + } + /// Create a new participant /// Also add it to the participants list fn create_participant( diff --git a/livekit/src/room/participant/local_participant.rs b/livekit/src/room/participant/local_participant.rs index bbf269c5..ad0db760 100644 --- a/livekit/src/room/participant/local_participant.rs +++ b/livekit/src/room/participant/local_participant.rs @@ -377,7 +377,7 @@ impl LocalParticipant { ..Default::default() }; - match self.inner.rtc_engine.publish_data(&data, DataPacketKind::Reliable).await { + match self.inner.rtc_engine.publish_data(data, DataPacketKind::Reliable).await { Ok(_) => Ok(ChatMessage::from(chat_message)), Err(e) => Err(Into::into(e)), } @@ -403,7 +403,7 @@ impl LocalParticipant { ..Default::default() }; - match self.inner.rtc_engine.publish_data(&data, DataPacketKind::Reliable).await { + match self.inner.rtc_engine.publish_data(data, DataPacketKind::Reliable).await { Ok(_) => Ok(ChatMessage::from(proto_msg)), Err(e) => Err(Into::into(e)), } @@ -447,7 +447,7 @@ impl LocalParticipant { true => DataPacketKind::Reliable, false => DataPacketKind::Lossy, }; - self.inner.rtc_engine.publish_data(&packet, kind).await.map_err(Into::into) + self.inner.rtc_engine.publish_data(packet, kind).await.map_err(Into::into) } pub async fn publish_data(&self, packet: DataPacket) -> RoomResult<()> { @@ -468,7 +468,26 @@ impl LocalParticipant { ..Default::default() }; - self.inner.rtc_engine.publish_data(&data, kind).await.map_err(Into::into) + self.inner.rtc_engine.publish_data(data, kind).await.map_err(Into::into) + } + + pub fn set_data_channel_buffered_amount_low_threshold( + &self, + threshold: u64, + kind: DataPacketKind, + ) -> RoomResult<()> { + self.inner + .rtc_engine + .session() + .set_data_channel_buffered_amount_low_threshold(threshold, kind); + Ok(()) + } + + pub fn data_channel_buffered_amount_low_threshold( + &self, + kind: DataPacketKind, + ) -> RoomResult { + Ok(self.inner.rtc_engine.session().data_channel_buffered_amount_low_threshold(kind)) } pub async fn publish_transcription(&self, packet: Transcription) -> RoomResult<()> { @@ -493,11 +512,7 @@ impl LocalParticipant { value: Some(proto::data_packet::Value::Transcription(transcription_packet)), ..Default::default() }; - self.inner - .rtc_engine - .publish_data(&data, DataPacketKind::Reliable) - .await - .map_err(Into::into) + self.inner.rtc_engine.publish_data(data, DataPacketKind::Reliable).await.map_err(Into::into) } pub async fn publish_dtmf(&self, dtmf: SipDTMF) -> RoomResult<()> { @@ -511,11 +526,7 @@ impl LocalParticipant { ..Default::default() }; - self.inner - .rtc_engine - .publish_data(&data, DataPacketKind::Reliable) - .await - .map_err(Into::into) + self.inner.rtc_engine.publish_data(data, DataPacketKind::Reliable).await.map_err(Into::into) } async fn publish_rpc_request(&self, rpc_request: RpcRequest) -> RoomResult<()> { @@ -535,11 +546,7 @@ impl LocalParticipant { ..Default::default() }; - self.inner - .rtc_engine - .publish_data(&data, DataPacketKind::Reliable) - .await - .map_err(Into::into) + self.inner.rtc_engine.publish_data(data, DataPacketKind::Reliable).await.map_err(Into::into) } async fn publish_rpc_response(&self, rpc_response: RpcResponse) -> RoomResult<()> { @@ -563,11 +570,7 @@ impl LocalParticipant { ..Default::default() }; - self.inner - .rtc_engine - .publish_data(&data, DataPacketKind::Reliable) - .await - .map_err(Into::into) + self.inner.rtc_engine.publish_data(data, DataPacketKind::Reliable).await.map_err(Into::into) } async fn publish_rpc_ack(&self, rpc_ack: RpcAck) -> RoomResult<()> { @@ -581,11 +584,7 @@ impl LocalParticipant { ..Default::default() }; - self.inner - .rtc_engine - .publish_data(&data, DataPacketKind::Reliable) - .await - .map_err(Into::into) + self.inner.rtc_engine.publish_data(data, DataPacketKind::Reliable).await.map_err(Into::into) } pub fn get_track_publication(&self, sid: &TrackSid) -> Option { diff --git a/livekit/src/rtc_engine/mod.rs b/livekit/src/rtc_engine/mod.rs index 0b205909..5599861e 100644 --- a/livekit/src/rtc_engine/mod.rs +++ b/livekit/src/rtc_engine/mod.rs @@ -25,7 +25,7 @@ use tokio::sync::{ RwLockReadGuard as AsyncRwLockReadGuard, }; -pub use self::rtc_session::SessionStats; +pub use self::rtc_session::{SessionStats, INITIAL_BUFFERED_AMOUNT_LOW_THRESHOLD}; use crate::prelude::ParticipantIdentity; use crate::{ id::ParticipantSid, @@ -171,6 +171,10 @@ pub enum EngineEvent { trailer: proto::data_stream::Trailer, participant_identity: String, }, + DataChannelBufferedAmountLowThresholdChanged { + kind: DataPacketKind, + threshold: u64, + }, } /// Represents a running RtcSession with the ability to close the session @@ -233,7 +237,7 @@ impl RtcEngine { pub async fn publish_data( &self, - data: &proto::DataPacket, + data: proto::DataPacket, kind: DataPacketKind, ) -> EngineResult<()> { let (session, _r_lock) = { @@ -551,6 +555,11 @@ impl EngineInner { .engine_tx .send(EngineEvent::DataStreamTrailer { trailer, participant_identity }); } + SessionEvent::DataChannelBufferedAmountLowThresholdChanged { kind, threshold } => { + let _ = self.engine_tx.send( + EngineEvent::DataChannelBufferedAmountLowThresholdChanged { kind, threshold }, + ); + } } Ok(()) } diff --git a/livekit/src/rtc_engine/rtc_events.rs b/livekit/src/rtc_engine/rtc_events.rs index 576b867d..004fb229 100644 --- a/livekit/src/rtc_engine/rtc_events.rs +++ b/livekit/src/rtc_engine/rtc_events.rs @@ -17,7 +17,7 @@ use livekit_protocol as proto; use tokio::sync::mpsc; use super::peer_transport::PeerTransport; -use crate::rtc_engine::peer_transport::OnOfferCreated; +use crate::{rtc_engine::peer_transport::OnOfferCreated, DataPacketKind}; pub type RtcEmitter = mpsc::UnboundedSender; pub type RtcEvents = mpsc::UnboundedReceiver; @@ -51,6 +51,11 @@ pub enum RtcEvent { data: Vec, binary: bool, }, + DataChannelBufferedAmountChange { + sent: u64, + amount: u64, + kind: DataPacketKind, + }, } /// Handlers used to forward events to a channel @@ -141,6 +146,18 @@ fn on_message(emitter: RtcEmitter) -> rtc::data_channel::OnMessage { }) } -pub fn forward_dc_events(dc: &mut DataChannel, rtc_emitter: RtcEmitter) { - dc.on_message(Some(on_message(rtc_emitter))); +fn on_buffered_amount_change( + emitter: RtcEmitter, + dc: DataChannel, + kind: DataPacketKind, +) -> rtc::data_channel::OnBufferedAmountChange { + Box::new(move |sent| { + let amount = dc.buffered_amount(); + let _ = emitter.send(RtcEvent::DataChannelBufferedAmountChange { sent, amount, kind }); + }) +} + +pub fn forward_dc_events(dc: &mut DataChannel, kind: DataPacketKind, rtc_emitter: RtcEmitter) { + dc.on_message(Some(on_message(rtc_emitter.clone()))); + dc.on_buffered_amount_change(Some(on_buffered_amount_change(rtc_emitter, dc.clone(), kind))); } diff --git a/livekit/src/rtc_engine/rtc_session.rs b/livekit/src/rtc_engine/rtc_session.rs index 8de9d959..01e57cd8 100644 --- a/livekit/src/rtc_engine/rtc_session.rs +++ b/livekit/src/rtc_engine/rtc_session.rs @@ -13,12 +13,12 @@ // limitations under the License. use std::{ - collections::HashMap, + collections::{HashMap, VecDeque}, convert::TryInto, fmt::Debug, ops::Not, sync::{ - atomic::{AtomicBool, Ordering}, + atomic::{AtomicBool, AtomicU64, Ordering}, Arc, }, time::Duration, @@ -34,7 +34,7 @@ use proto::{ debouncer::{self, Debouncer}, SignalTarget, }; -use serde::{de::IntoDeserializer, Deserialize, Serialize}; +use serde::{Deserialize, Serialize}; use tokio::sync::{mpsc, oneshot, watch}; use super::{rtc_events, EngineError, EngineOptions, EngineResult, SimulateScenario}; @@ -58,6 +58,7 @@ pub const TRACK_PUBLISH_TIMEOUT: Duration = Duration::from_secs(10); pub const LOSSY_DC_LABEL: &str = "_lossy"; pub const RELIABLE_DC_LABEL: &str = "_reliable"; pub const PUBLISHER_NEGOTIATION_FREQUENCY: Duration = Duration::from_millis(150); +pub const INITIAL_BUFFERED_AMOUNT_LOW_THRESHOLD: u64 = 2 * 1024 * 1024; pub type SessionEmitter = mpsc::UnboundedSender; pub type SessionEvents = mpsc::UnboundedReceiver; @@ -147,6 +148,16 @@ pub enum SessionEvent { trailer: proto::data_stream::Trailer, participant_identity: String, }, + DataChannelBufferedAmountLowThresholdChanged { + kind: DataPacketKind, + threshold: u64, + }, +} + +#[derive(Debug)] +enum DataChannelEvent { + PublishData(proto::DataPacket, DataPacketKind, oneshot::Sender>), + BufferedAmountChange(u64, DataPacketKind), } #[derive(Serialize, Deserialize)] @@ -170,7 +181,10 @@ struct SessionInner { // Publisher data channels // used to send data to other participants (The SFU forwards the messages) lossy_dc: DataChannel, + lossy_dc_buffered_amount_low_threshold: AtomicU64, reliable_dc: DataChannel, + reliable_dc_buffered_amount_low_threshold: AtomicU64, + dc_emitter: mpsc::UnboundedSender, // Keep a strong reference to the subscriber datachannels, // so we can receive data from other participants @@ -205,6 +219,7 @@ struct SessionHandle { close_tx: watch::Sender, // false = is_running signal_task: JoinHandle<()>, rtc_task: JoinHandle<()>, + dc_task: JoinHandle<()>, } impl RtcSession { @@ -223,6 +238,8 @@ impl RtcSession { let (rtc_emitter, rtc_events) = mpsc::unbounded_channel(); let rtc_config = make_rtc_config_join(join_response.clone(), options.rtc_config.clone()); + let (dc_emitter, dc_events) = mpsc::unbounded_channel(); + let lk_runtime = LkRuntime::instance(); let mut publisher_pc = PeerTransport::new( lk_runtime.pc_factory().create_peer_connection(rtc_config.clone())?, @@ -251,8 +268,8 @@ impl RtcSession { // Forward events received inside the signaling thread to our rtc channel rtc_events::forward_pc_events(&mut publisher_pc, rtc_emitter.clone()); rtc_events::forward_pc_events(&mut subscriber_pc, rtc_emitter.clone()); - rtc_events::forward_dc_events(&mut lossy_dc, rtc_emitter.clone()); - rtc_events::forward_dc_events(&mut reliable_dc, rtc_emitter); + rtc_events::forward_dc_events(&mut lossy_dc, DataPacketKind::Lossy, rtc_emitter.clone()); + rtc_events::forward_dc_events(&mut reliable_dc, DataPacketKind::Reliable, rtc_emitter); let (close_tx, close_rx) = watch::channel(false); let inner = Arc::new(SessionInner { @@ -262,7 +279,14 @@ impl RtcSession { subscriber_pc, pending_tracks: Default::default(), lossy_dc, + lossy_dc_buffered_amount_low_threshold: AtomicU64::new( + INITIAL_BUFFERED_AMOUNT_LOW_THRESHOLD, + ), reliable_dc, + reliable_dc_buffered_amount_low_threshold: AtomicU64::new( + INITIAL_BUFFERED_AMOUNT_LOW_THRESHOLD, + ), + dc_emitter, sub_lossy_dc: Mutex::new(None), sub_reliable_dc: Mutex::new(None), closed: Default::default(), @@ -275,9 +299,11 @@ impl RtcSession { // Start session tasks let signal_task = livekit_runtime::spawn(inner.clone().signal_task(signal_events, close_rx.clone())); - let rtc_task = livekit_runtime::spawn(inner.clone().rtc_session_task(rtc_events, close_rx)); + let rtc_task = + livekit_runtime::spawn(inner.clone().rtc_session_task(rtc_events, close_rx.clone())); + let dc_task = livekit_runtime::spawn(inner.clone().data_channel_task(dc_events, close_rx)); - let handle = Mutex::new(Some(SessionHandle { close_tx, signal_task, rtc_task })); + let handle = Mutex::new(Some(SessionHandle { close_tx, signal_task, rtc_task, dc_task })); Ok((Self { inner, handle }, join_response, session_events)) } @@ -319,6 +345,7 @@ impl RtcSession { let _ = handle.close_tx.send(true); let _ = handle.rtc_task.await; let _ = handle.signal_task.await; + let _ = handle.dc_task.await; } // Close the PeerConnections after the task @@ -328,7 +355,7 @@ impl RtcSession { pub async fn publish_data( &self, - data: &proto::DataPacket, + data: proto::DataPacket, kind: DataPacketKind, ) -> Result<(), EngineError> { self.inner.publish_data(data, kind).await @@ -374,6 +401,38 @@ impl RtcSession { self.inner.data_channel(target, kind) } + pub fn data_channel_buffered_amount_low_threshold(&self, kind: DataPacketKind) -> u64 { + match kind { + DataPacketKind::Lossy => { + self.inner.lossy_dc_buffered_amount_low_threshold.load(Ordering::Relaxed) + } + DataPacketKind::Reliable => { + self.inner.reliable_dc_buffered_amount_low_threshold.load(Ordering::Relaxed) + } + } + } + + pub fn set_data_channel_buffered_amount_low_threshold( + &self, + threshold: u64, + kind: DataPacketKind, + ) { + match kind { + DataPacketKind::Lossy => self + .inner + .lossy_dc_buffered_amount_low_threshold + .store(threshold, Ordering::Relaxed), + DataPacketKind::Reliable => self + .inner + .reliable_dc_buffered_amount_low_threshold + .store(threshold, Ordering::Relaxed), + } + let _ = self + .inner + .emitter + .send(SessionEvent::DataChannelBufferedAmountLowThresholdChanged { kind, threshold }); + } + pub async fn get_response(&self, request_id: u32) -> proto::RequestResponse { self.inner.get_response(request_id).await } @@ -469,6 +528,101 @@ impl SessionInner { log::debug!("closing signal_task"); } + async fn data_channel_task( + self: Arc, + mut dc_events: mpsc::UnboundedReceiver, + mut close_rx: watch::Receiver, + ) { + let mut lossy_buffered_amount = 0; + let mut reliable_buffered_amount = 0; + let mut lossy_queue = VecDeque::new(); + let mut reliable_queue = VecDeque::new(); + + loop { + tokio::select! { + event = dc_events.recv() => { + let Some(event) = event else { + // tx closed + break; + }; + + match event { + DataChannelEvent::PublishData(packet, kind, tx) => { + let data = packet.encode_to_vec(); + match kind { + DataPacketKind::Lossy => { + lossy_queue.push_back((data, kind, tx)); + let threshold = self.lossy_dc_buffered_amount_low_threshold.load(Ordering::Relaxed); + self._send_until_threshold(threshold, &mut lossy_buffered_amount, &mut lossy_queue); + } + DataPacketKind::Reliable => { + reliable_queue.push_back((data, kind, tx)); + let threshold = self.reliable_dc_buffered_amount_low_threshold.load(Ordering::Relaxed); + self._send_until_threshold(threshold, &mut reliable_buffered_amount, &mut reliable_queue); + } + } + } + DataChannelEvent::BufferedAmountChange(sent, kind) => { + match kind { + DataPacketKind::Lossy => { + if lossy_buffered_amount < sent { + // I believe never reach here but adding logs just in case + log::error!("unexpected buffer size detected: lossy_buffered_amount={}, sent={}", lossy_buffered_amount, sent); + lossy_buffered_amount = 0; + } else { + lossy_buffered_amount -= sent; + } + let threshold = self.lossy_dc_buffered_amount_low_threshold.load(Ordering::Relaxed); + self._send_until_threshold(threshold, &mut lossy_buffered_amount, &mut lossy_queue); + } + DataPacketKind::Reliable => { + if reliable_buffered_amount < sent { + log::error!("unexpected buffer size detected: reliable_buffered_amount={}, sent={}", reliable_buffered_amount, sent); + reliable_buffered_amount = 0; + } else { + reliable_buffered_amount -= sent; + } + let threshold = self.reliable_dc_buffered_amount_low_threshold.load(Ordering::Relaxed); + self._send_until_threshold(threshold, &mut reliable_buffered_amount, &mut reliable_queue); + } + } + } + } + }, + + _ = close_rx.changed() => { + break; + }, + } + } + + log::debug!("closing data_channel_task"); + } + + fn _send_until_threshold( + self: &Arc, + threshold: u64, + buffered_amount: &mut u64, + queue: &mut VecDeque<(Vec, DataPacketKind, oneshot::Sender>)>, + ) { + while *buffered_amount <= threshold { + let Some((data, kind, tx)) = queue.pop_front() else { + break; + }; + + *buffered_amount += data.len() as u64; + let result = self + .data_channel(SignalTarget::Publisher, kind) + .unwrap() + .send(&data, true) + .map_err(|err| { + EngineError::Internal(format!("failed to send data packet: {:?}", err).into()) + }); + + let _ = tx.send(result); + } + } + async fn on_signal_event(&self, event: proto::signal_response::Message) -> EngineResult<()> { match event { proto::signal_response::Message::Answer(answer) => { @@ -757,6 +911,13 @@ impl SessionInner { } } } + RtcEvent::DataChannelBufferedAmountChange { sent, amount: _, kind } => { + if let Err(err) = + self.dc_emitter.send(DataChannelEvent::BufferedAmountChange(sent, kind)) + { + log::error!("failed to send dc_event buffer_amount_change: {:?}", err); + } + } } Ok(()) @@ -964,16 +1125,20 @@ impl SessionInner { async fn publish_data( self: &Arc, - data: &proto::DataPacket, + data: proto::DataPacket, kind: DataPacketKind, ) -> Result<(), EngineError> { self.ensure_publisher_connected(kind).await?; - self.data_channel(SignalTarget::Publisher, kind) - .unwrap() - .send(&data.encode_to_vec(), true) - .map_err(|err| { - EngineError::Internal(format!("failed to send data packet {:?}", err).into()) - }) + + let (tx, rx) = oneshot::channel(); + if let Err(err) = self.dc_emitter.send(DataChannelEvent::PublishData(data, kind, tx)) { + return Err(EngineError::Internal( + format!("failed to push data into queue: {:?}", err).into(), + )); + }; + rx.await.map_err(|e| { + EngineError::Internal(format!("failed to receive data from dc_task: {:?}", e).into()) + })? } /// This reconnection if more seemless compared to the full reconnection implemented in diff --git a/webrtc-sys/include/livekit/data_channel.h b/webrtc-sys/include/livekit/data_channel.h index f08e93ea..33fea51b 100644 --- a/webrtc-sys/include/livekit/data_channel.h +++ b/webrtc-sys/include/livekit/data_channel.h @@ -49,6 +49,7 @@ class DataChannel { rust::String label() const; DataState state() const; void close() const; + uint64_t buffered_amount() const; private: mutable webrtc::Mutex mutex_; diff --git a/webrtc-sys/src/data_channel.cpp b/webrtc-sys/src/data_channel.cpp index 9f2005f6..bec5cef4 100644 --- a/webrtc-sys/src/data_channel.cpp +++ b/webrtc-sys/src/data_channel.cpp @@ -92,6 +92,10 @@ void DataChannel::close() const { return data_channel_->Close(); } +uint64_t DataChannel::buffered_amount() const { + return data_channel_->buffered_amount(); +} + NativeDataChannelObserver::NativeDataChannelObserver( rust::Box observer, const DataChannel* dc) diff --git a/webrtc-sys/src/data_channel.rs b/webrtc-sys/src/data_channel.rs index 7073194d..1ae85bf7 100644 --- a/webrtc-sys/src/data_channel.rs +++ b/webrtc-sys/src/data_channel.rs @@ -70,6 +70,7 @@ pub mod ffi { fn label(self: &DataChannel) -> String; fn state(self: &DataChannel) -> DataState; fn close(self: &DataChannel); + fn buffered_amount(self: &DataChannel) -> u64; fn _shared_data_channel() -> SharedPtr; // Ignore }