Skip to content

Commit

Permalink
generate_request_id fixes (#27)
Browse files Browse the repository at this point in the history
* Fix generate_request_id with thread_safe

The previous implementation dropped the first guard
before taking another, which could race.

It would be better to change to NonZeroU32 and checked_add as well
(wrapping issues).

* Change generate_request_id to return NonZeroU32

Zero is for broadcast request_ids, change types to make it impossible
to return zero.  i32 wrapping does take a lot of requests, but this is neater.
  • Loading branch information
g2p authored Mar 17, 2024
1 parent 8d57d22 commit 62ab9cd
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 52 deletions.
32 changes: 16 additions & 16 deletions src/cast/proxies.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pub mod media {
#[derive(Serialize, Debug)]
pub struct GetStatusRequest {
#[serde(rename = "requestId")]
pub request_id: i32,
pub request_id: u32,

#[serde(rename = "type")]
pub typ: String,
Expand All @@ -41,7 +41,7 @@ pub mod media {
#[derive(Serialize, Debug)]
pub struct MediaRequest {
#[serde(rename = "requestId")]
pub request_id: i32,
pub request_id: u32,

#[serde(rename = "sessionId")]
pub session_id: String,
Expand All @@ -63,7 +63,7 @@ pub mod media {
#[derive(Serialize, Debug)]
pub struct PlaybackGenericRequest {
#[serde(rename = "requestId")]
pub request_id: i32,
pub request_id: u32,

#[serde(rename = "mediaSessionId")]
pub media_session_id: i32,
Expand All @@ -78,7 +78,7 @@ pub mod media {
#[derive(Serialize, Debug)]
pub struct PlaybackSeekRequest {
#[serde(rename = "requestId")]
pub request_id: i32,
pub request_id: u32,

#[serde(rename = "mediaSessionId")]
pub media_session_id: i32,
Expand Down Expand Up @@ -246,7 +246,7 @@ pub mod media {
#[derive(Deserialize, Debug)]
pub struct StatusReply {
#[serde(rename = "requestId", default)]
pub request_id: i32,
pub request_id: u32,

#[serde(rename = "type")]
pub typ: String,
Expand All @@ -257,25 +257,25 @@ pub mod media {
#[derive(Deserialize, Debug)]
pub struct LoadCancelledReply {
#[serde(rename = "requestId")]
pub request_id: i32,
pub request_id: u32,
}

#[derive(Deserialize, Debug)]
pub struct LoadFailedReply {
#[serde(rename = "requestId")]
pub request_id: i32,
pub request_id: u32,
}

#[derive(Deserialize, Debug)]
pub struct InvalidPlayerStateReply {
#[serde(rename = "requestId")]
pub request_id: i32,
pub request_id: u32,
}

#[derive(Deserialize, Debug)]
pub struct InvalidRequestReply {
#[serde(rename = "requestId")]
pub request_id: i32,
pub request_id: u32,

#[serde(rename = "type")]
pub typ: String,
Expand All @@ -293,7 +293,7 @@ pub mod receiver {
#[derive(Serialize, Debug)]
pub struct AppLaunchRequest {
#[serde(rename = "requestId")]
pub request_id: i32,
pub request_id: u32,

#[serde(rename = "type")]
pub typ: String,
Expand All @@ -305,7 +305,7 @@ pub mod receiver {
#[derive(Serialize, Debug)]
pub struct AppStopRequest<'a> {
#[serde(rename = "requestId")]
pub request_id: i32,
pub request_id: u32,

#[serde(rename = "type")]
pub typ: String,
Expand All @@ -317,7 +317,7 @@ pub mod receiver {
#[derive(Serialize, Debug)]
pub struct GetStatusRequest {
#[serde(rename = "requestId")]
pub request_id: i32,
pub request_id: u32,

#[serde(rename = "type")]
pub typ: String,
Expand All @@ -326,7 +326,7 @@ pub mod receiver {
#[derive(Serialize, Debug)]
pub struct SetVolumeRequest {
#[serde(rename = "requestId")]
pub request_id: i32,
pub request_id: u32,

#[serde(rename = "type")]
pub typ: String,
Expand All @@ -337,7 +337,7 @@ pub mod receiver {
#[derive(Deserialize, Debug)]
pub struct StatusReply {
#[serde(rename = "requestId")]
pub request_id: i32,
pub request_id: u32,

#[serde(rename = "type")]
pub typ: String,
Expand Down Expand Up @@ -398,7 +398,7 @@ pub mod receiver {
#[derive(Deserialize, Debug)]
pub struct LaunchErrorReply {
#[serde(rename = "requestId")]
pub request_id: i32,
pub request_id: u32,

#[serde(rename = "type")]
pub typ: String,
Expand All @@ -409,7 +409,7 @@ pub mod receiver {
#[derive(Deserialize, Debug)]
pub struct InvalidRequestReply {
#[serde(rename = "requestId")]
pub request_id: i32,
pub request_id: u32,

#[serde(rename = "type")]
pub typ: String,
Expand Down
24 changes: 12 additions & 12 deletions src/channels/media.rs
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ pub struct Media {
#[derive(Clone, Debug)]
pub struct Status {
/// Unique id of the request that requested the status.
pub request_id: i32,
pub request_id: u32,
/// Detailed status of every media status entry.
pub entries: Vec<StatusEntry>,
}
Expand Down Expand Up @@ -369,28 +369,28 @@ pub struct StatusEntry {
#[derive(Copy, Clone, Debug)]
pub struct LoadCancelled {
/// Unique id of the request that caused this error.
pub request_id: i32,
pub request_id: u32,
}

/// Describes the load failed error.
#[derive(Copy, Clone, Debug)]
pub struct LoadFailed {
/// Unique id of the request that caused this error.
pub request_id: i32,
pub request_id: u32,
}

/// Describes the invalid player state error.
#[derive(Copy, Clone, Debug)]
pub struct InvalidPlayerState {
/// Unique id of the request that caused this error.
pub request_id: i32,
pub request_id: u32,
}

/// Describes the invalid request error.
#[derive(Clone, Debug)]
pub struct InvalidRequest {
/// Unique id of the invalid request.
pub request_id: i32,
pub request_id: u32,
/// Description of the invalid request reason if available.
pub reason: Option<String>,
}
Expand Down Expand Up @@ -455,7 +455,7 @@ where
where
S: Into<Cow<'a, str>>,
{
let request_id = self.message_manager.generate_request_id();
let request_id = self.message_manager.generate_request_id().get();

let payload = serde_json::to_string(&proxies::media::GetStatusRequest {
typ: MESSAGE_TYPE_GET_STATUS.to_string(),
Expand Down Expand Up @@ -510,7 +510,7 @@ where
where
S: Into<Cow<'a, str>>,
{
let request_id = self.message_manager.generate_request_id();
let request_id = self.message_manager.generate_request_id().get();

let metadata = media.metadata.as_ref().map(|m| match *m {
Metadata::Generic(ref x) => proxies::media::Metadata {
Expand Down Expand Up @@ -666,7 +666,7 @@ where
where
S: Into<Cow<'a, str>>,
{
let request_id = self.message_manager.generate_request_id();
let request_id = self.message_manager.generate_request_id().get();

let payload = serde_json::to_string(&proxies::media::PlaybackGenericRequest {
request_id,
Expand Down Expand Up @@ -700,7 +700,7 @@ where
where
S: Into<Cow<'a, str>>,
{
let request_id = self.message_manager.generate_request_id();
let request_id = self.message_manager.generate_request_id().get();

let payload = serde_json::to_string(&proxies::media::PlaybackGenericRequest {
request_id,
Expand Down Expand Up @@ -735,7 +735,7 @@ where
where
S: Into<Cow<'a, str>>,
{
let request_id = self.message_manager.generate_request_id();
let request_id = self.message_manager.generate_request_id().get();

let payload = serde_json::to_string(&proxies::media::PlaybackGenericRequest {
request_id,
Expand Down Expand Up @@ -778,7 +778,7 @@ where
where
S: Into<Cow<'a, str>>,
{
let request_id = self.message_manager.generate_request_id();
let request_id = self.message_manager.generate_request_id().get();

let payload = serde_json::to_string(&proxies::media::PlaybackSeekRequest {
request_id,
Expand Down Expand Up @@ -905,7 +905,7 @@ where
/// Returned `Result` should consist of either `Status` instance or an `Error`.
fn receive_status_entry(
&self,
request_id: i32,
request_id: u32,
media_session_id: i32,
) -> Result<StatusEntry, Error> {
self.message_manager.receive_find_map(|message| {
Expand Down
14 changes: 7 additions & 7 deletions src/channels/receiver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ pub struct Application {
#[derive(Clone, Debug)]
pub struct Status {
/// Unique id of the request that requested the status.
pub request_id: i32,
pub request_id: u32,
/// Contains the list of applications that are currently run.
pub applications: Vec<Application>,
/// Determines whether the Cast device is the active input or not.
Expand All @@ -106,7 +106,7 @@ pub struct Status {
#[derive(Clone, Debug)]
pub struct LaunchError {
/// Unique id of the request that tried to launch application.
pub request_id: i32,
pub request_id: u32,
/// Description of the launch error reason if available.
pub reason: Option<String>,
}
Expand All @@ -115,7 +115,7 @@ pub struct LaunchError {
#[derive(Clone, Debug)]
pub struct InvalidRequest {
/// Unique id of the invalid request.
pub request_id: i32,
pub request_id: u32,
/// Description of the invalid request reason if available.
pub reason: Option<String>,
}
Expand Down Expand Up @@ -212,7 +212,7 @@ where
///
/// * `app` - `CastDeviceApp` instance reference to run.
pub fn launch_app(&self, app: &CastDeviceApp) -> Result<Application, Error> {
let request_id = self.message_manager.generate_request_id();
let request_id = self.message_manager.generate_request_id().get();

let payload = serde_json::to_string(&proxies::receiver::AppLaunchRequest {
typ: MESSAGE_TYPE_LAUNCH.to_string(),
Expand Down Expand Up @@ -301,7 +301,7 @@ where
where
S: Into<Cow<'a, str>>,
{
let request_id = self.message_manager.generate_request_id();
let request_id = self.message_manager.generate_request_id().get();

let payload = serde_json::to_string(&proxies::receiver::AppStopRequest {
typ: MESSAGE_TYPE_STOP.to_string(),
Expand Down Expand Up @@ -350,7 +350,7 @@ where
///
/// Returned `Result` should consist of either `Status` instance or an `Error`.
pub fn get_status(&self) -> Result<Status, Error> {
let request_id = self.message_manager.generate_request_id();
let request_id = self.message_manager.generate_request_id().get();

let payload = serde_json::to_string(&proxies::receiver::GetStatusRequest {
typ: MESSAGE_TYPE_GET_STATUS.to_string(),
Expand Down Expand Up @@ -398,7 +398,7 @@ where
where
T: Into<Volume>,
{
let request_id = self.message_manager.generate_request_id();
let request_id = self.message_manager.generate_request_id().get();
let volume = volume.into();

let payload = serde_json::to_string(&proxies::receiver::SetVolumeRequest {
Expand Down
24 changes: 7 additions & 17 deletions src/message_manager.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{
io::{Read, Write},
num::NonZeroU32,
ops::{Deref, DerefMut},
};

Expand Down Expand Up @@ -60,16 +61,6 @@ impl<T> Lock<T> {
})
}

fn borrow(&self) -> LockGuard<'_, T> {
LockGuard({
#[cfg(feature = "thread_safe")]
let guard = self.0.lock().unwrap();
#[cfg(not(feature = "thread_safe"))]
let guard = self.0.borrow();
guard
})
}

fn borrow_mut(&self) -> LockGuardMut<'_, T> {
LockGuardMut({
#[cfg(feature = "thread_safe")]
Expand Down Expand Up @@ -112,7 +103,7 @@ where
{
message_buffer: Lock<Vec<CastMessage>>,
stream: Lock<S>,
request_counter: Lock<i32>,
request_counter: Lock<NonZeroU32>,
}

impl<S> MessageManager<S>
Expand All @@ -123,7 +114,7 @@ where
MessageManager {
stream: Lock::new(stream),
message_buffer: Lock::new(vec![]),
request_counter: Lock::new(1),
request_counter: Lock::new(NonZeroU32::MIN),
}
}

Expand Down Expand Up @@ -248,11 +239,10 @@ where
/// # Return value
///
/// Unique (in the scope of this particular `MessageManager` instance) integer number.
pub fn generate_request_id(&self) -> i32 {
let request_id = *self.request_counter.borrow() + 1;

*self.request_counter.borrow_mut() = request_id;

pub fn generate_request_id(&self) -> NonZeroU32 {
let mut counter = self.request_counter.borrow_mut();
let request_id = *counter;
*counter = counter.checked_add(1).unwrap();
request_id
}

Expand Down

0 comments on commit 62ab9cd

Please sign in to comment.