Skip to content

Commit

Permalink
remove wait_for_low function, add functionality to wait it in publish…
Browse files Browse the repository at this point in the history
…_data
  • Loading branch information
typester committed Jan 13, 2025
1 parent 7a4be4f commit ee64923
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 84 deletions.
6 changes: 0 additions & 6 deletions livekit/src/room/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -734,12 +734,6 @@ impl RoomSession {
EngineEvent::DataStreamChunk { chunk, participant_identity } => {
self.handle_data_stream_chunk(chunk, participant_identity);
}
EngineEvent::DataChannelBufferedAmountChanged { buffered_amount } => {
let local_participant = self.local_participant.clone();
livekit_runtime::spawn(async move {
local_participant.handle_dc_buffered_amount_changed(buffered_amount).await;
});
},
_ => {}
}

Expand Down
45 changes: 5 additions & 40 deletions livekit/src/room/participant/local_participant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -481,48 +481,13 @@ impl LocalParticipant {
self.inner.rtc_engine.publish_data(&data, kind).await.map_err(Into::into)
}

pub async fn wait_for_dc_buffer_low(&self) -> RoomResult<()> {
let threshold =
self.local.dc_buffered_amount_low_threshold.load(std::sync::atomic::Ordering::Relaxed);
let amount =
self.inner.rtc_engine.session().data_channel_buffered_amount(DataPacketKind::Reliable);

if amount <= threshold {
return Ok(());
}

// wait buffered amount becam low
let rx = {
let (tx, rx) = oneshot::channel();

let mut low_tx = self.local.dc_buffered_amount_low_tx.lock();
if low_tx.is_some() {
return Err(RoomError::Request {
reason: Reason::NotAllowed,
message: "Another wait request is already in progress.".into(),
});
}
*low_tx = Some(tx);
rx
};

match rx.await {
Ok(()) => Ok(()),
Err(err) => Err(RoomError::Internal(format!("failed to wait: {}", err))),
}
pub async fn set_data_channel_buffered_amount_low_threshold(&self, threshold: u64) -> RoomResult<()> {
self.inner.rtc_engine.session().set_data_channel_buffered_amount_low_threshold(threshold);
Ok(())
}

pub(crate) async fn handle_dc_buffered_amount_changed(&self, buffered_amount: u64) {
let threshold =
self.local.dc_buffered_amount_low_threshold.load(std::sync::atomic::Ordering::Relaxed);
if buffered_amount > threshold {
return;
}
let Some(tx) = self.local.dc_buffered_amount_low_tx.lock().take() else {
return;
};
log::debug!("return wait_for_dc_buffer_low: buffered={}, threshold={}", buffered_amount, threshold);
let _ = tx.send(());
pub async fn data_channel_buffered_amount_low_threshold(&self) -> RoomResult<u64> {
Ok(self.inner.rtc_engine.session().data_channel_buffered_amount_low_threshold())
}

pub async fn publish_transcription(&self, packet: Transcription) -> RoomResult<()> {
Expand Down
8 changes: 0 additions & 8 deletions livekit/src/rtc_engine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,6 @@ pub enum EngineEvent {
chunk: proto::data_stream::Chunk,
participant_identity: String,
},
DataChannelBufferedAmountChanged {
buffered_amount: u64,
},
}

/// Represents a running RtcSession with the ability to close the session
Expand Down Expand Up @@ -545,11 +542,6 @@ impl EngineInner {
.engine_tx
.send(EngineEvent::DataStreamChunk { chunk, participant_identity });
}
SessionEvent::DataChannelBufferedAmountChanged { buffered_amount } => {
let _ = self
.engine_tx
.send(EngineEvent::DataChannelBufferedAmountChanged { buffered_amount });
}
}
Ok(())
}
Expand Down
10 changes: 0 additions & 10 deletions livekit/src/rtc_engine/rtc_events.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,6 @@ pub enum RtcEvent {
data: Vec<u8>,
binary: bool,
},
DataChannelStateChange {
state: DataChannelState,
},
DataChannelBufferedAmountChange {
sent: u64,
amount: u64,
Expand Down Expand Up @@ -149,12 +146,6 @@ fn on_message(emitter: RtcEmitter) -> rtc::data_channel::OnMessage {
})
}

fn on_state_change(emitter: RtcEmitter) -> rtc::data_channel::OnStateChange {
Box::new(move |state| {
let _ = emitter.send(RtcEvent::DataChannelStateChange { state });
})
}

fn on_buffered_amount_change(
emitter: RtcEmitter,
dc: DataChannel,
Expand All @@ -168,6 +159,5 @@ fn on_buffered_amount_change(

pub fn forward_dc_events(dc: &mut DataChannel, kind: DataPacketKind, rtc_emitter: RtcEmitter) {
dc.on_message(Some(on_message(rtc_emitter.clone())));
dc.on_state_change(Some(on_state_change(rtc_emitter.clone())));
dc.on_buffered_amount_change(Some(on_buffered_amount_change(rtc_emitter, dc.clone(), kind)));
}
55 changes: 35 additions & 20 deletions livekit/src/rtc_engine/rtc_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ use proto::{
SignalTarget,
};
use serde::{de::IntoDeserializer, Deserialize, Serialize};
use tokio::sync::{mpsc, oneshot, watch};
use tokio::sync::{mpsc, oneshot, watch, Notify};

use super::{rtc_events, EngineError, EngineOptions, EngineResult, SimulateScenario};
use crate::{id::ParticipantIdentity, ChatMessage, TranscriptionSegment};
Expand Down Expand Up @@ -143,9 +143,6 @@ pub enum SessionEvent {
chunk: proto::data_stream::Chunk,
participant_identity: String,
},
DataChannelBufferedAmountChanged {
buffered_amount: u64,
},
}

#[derive(Serialize, Deserialize)]
Expand All @@ -169,9 +166,10 @@ struct SessionInner {
// Publisher data channels
// used to send data to other participants (The SFU forwards the messages)
lossy_dc: DataChannel,
lossy_dc_buffered_amount: AtomicU64,
reliable_dc: DataChannel,
reliable_dc_buffered_amount: AtomicU64,
reliable_dc_buffered_amount_low_threshold: AtomicU64,
reliable_dc_buffered_amount_low_notify: Notify,

// Keep a strong reference to the subscriber datachannels,
// so we can receive data from other participants
Expand Down Expand Up @@ -263,9 +261,10 @@ impl RtcSession {
subscriber_pc,
pending_tracks: Default::default(),
lossy_dc,
lossy_dc_buffered_amount: Default::default(),
reliable_dc,
reliable_dc_buffered_amount: Default::default(),
reliable_dc_buffered_amount_low_threshold: Default::default(),
reliable_dc_buffered_amount_low_notify: Default::default(),
sub_lossy_dc: Mutex::new(None),
sub_reliable_dc: Mutex::new(None),
closed: Default::default(),
Expand Down Expand Up @@ -377,18 +376,23 @@ impl RtcSession {
self.inner.data_channel(target, kind)
}

pub fn data_channel_buffered_amount(&self, kind: DataPacketKind) -> u64 {
match kind {
DataPacketKind::Lossy => self.inner.lossy_dc_buffered_amount.load(Ordering::Relaxed),
DataPacketKind::Reliable => {
self.inner.reliable_dc_buffered_amount.load(Ordering::Relaxed)
}
}
pub fn data_channel_buffered_amount(&self) -> u64 {
self.inner.reliable_dc_buffered_amount.load(Ordering::Relaxed)
}

pub fn data_channel_buffered_amount_low_threshold(&self) -> u64 {
self.inner.reliable_dc_buffered_amount_low_threshold.load(Ordering::Relaxed)
}

pub fn set_data_channel_buffered_amount_low_threshold(&self, threshold: u64) {
self.inner.reliable_dc_buffered_amount_low_threshold.store(threshold, Ordering::Relaxed)
}

pub async fn get_response(&self, request_id: u32) -> proto::RequestResponse {
self.inner.get_response(request_id).await
}


}

impl SessionInner {
Expand Down Expand Up @@ -763,19 +767,17 @@ impl SessionInner {
}
}
}
RtcEvent::DataChannelStateChange { state } => {
}
RtcEvent::DataChannelBufferedAmountChange { sent: _, amount, kind } => {
match kind {
DataPacketKind::Lossy => {
self.lossy_dc_buffered_amount.store(amount, Ordering::Relaxed)
// Do nothing at this moment
}
DataPacketKind::Reliable => {
self.reliable_dc_buffered_amount.store(amount, Ordering::Relaxed);
// Only reliable dc is needed this event at this time
let _ = self.emitter.send(SessionEvent::DataChannelBufferedAmountChanged {
buffered_amount: amount,
});
let threshold = self.reliable_dc_buffered_amount_low_threshold.load(Ordering::Relaxed);
if amount <= threshold {
self.reliable_dc_buffered_amount_low_notify.notify_one();
}
}
}
}
Expand Down Expand Up @@ -990,6 +992,9 @@ impl SessionInner {
kind: DataPacketKind,
) -> Result<(), EngineError> {
self.ensure_publisher_connected(kind).await?;
if kind == DataPacketKind::Reliable {
self.wait_buffer_low().await?;
}
self.data_channel(SignalTarget::Publisher, kind)
.unwrap()
.send(&data.encode_to_vec(), true)
Expand All @@ -998,6 +1003,16 @@ impl SessionInner {
})
}

async fn wait_buffer_low(&self) -> Result<(), EngineError> {
let amount = self.reliable_dc_buffered_amount.load(Ordering::Relaxed);
let threshold = self.reliable_dc_buffered_amount_low_threshold.load(Ordering::Relaxed);
if amount <= threshold {
return Ok(())
}
self.reliable_dc_buffered_amount_low_notify.notified().await;
Ok(())
}

/// This reconnection if more seemless compared to the full reconnection implemented in
/// ['RTCEngine']
async fn restart(&self) -> EngineResult<proto::ReconnectResponse> {
Expand Down

0 comments on commit ee64923

Please sign in to comment.