diff --git a/Cargo.lock b/Cargo.lock index 112399f14..75efa4a79 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,21 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "addr2line" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a30b2e23b9e17a9f90641c7ab1549cd9b44f296d3ccbf309d2863cfe398a0cb" +dependencies = [ + "gimli", +] + +[[package]] +name = "adler" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" + [[package]] name = "aho-corasick" version = "1.1.2" @@ -74,6 +89,28 @@ version = "1.0.80" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5ad32ce52e4161730f7098c077cd2ed6229b5804ccf99e5366be1ab72a98b4e1" +[[package]] +name = "async-stream" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd56dd203fef61ac097dd65721a419ddccb106b2d2b70ba60a6b529f03961a51" +dependencies = [ + "async-stream-impl", + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "autocfg" version = "1.1.0" @@ -106,6 +143,21 @@ dependencies = [ "paste", ] +[[package]] +name = "backtrace" +version = "0.3.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2089b7e3f35b9dd2d0ed921ead4f6d318c27680d4a5bd167b3ee120edb105837" +dependencies = [ + "addr2line", + "cc", + "cfg-if", + "libc", + "miniz_oxide", + "object", + "rustc-demangle", +] + [[package]] name = "base64" version = "0.21.7" @@ -148,10 +200,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ed570934406eb16438a4e976b1b4500774099c13b8cb96eec99f620f05090ddf" [[package]] -name = "byteorder" +name = "bytes" version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" +checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" [[package]] name = "cc" @@ -317,12 +369,44 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" +[[package]] +name = "futures-macro" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "futures-sink" version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5" +[[package]] +name = "futures-task" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" + +[[package]] +name = "futures-util" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" +dependencies = [ + "futures-core", + "futures-macro", + "futures-sink", + "futures-task", + "pin-project-lite", + "pin-utils", + "slab", +] + [[package]] name = "getrandom" version = "0.2.12" @@ -334,6 +418,12 @@ dependencies = [ "wasi", ] +[[package]] +name = "gimli" +version = "0.28.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" + [[package]] name = "glob" version = "0.3.1" @@ -346,6 +436,12 @@ version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" +[[package]] +name = "hermit-abi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" + [[package]] name = "home" version = "0.5.9" @@ -456,7 +552,7 @@ dependencies = [ "if-addrs", "log", "polling", - "socket2", + "socket2 0.4.10", ] [[package]] @@ -471,6 +567,26 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" +[[package]] +name = "miniz_oxide" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d811f3e15f28568be3407c8e7fdb6514c1cda3cb30683f15b6a1a1dc4ea14a7" +dependencies = [ + "adler", +] + +[[package]] +name = "mio" +version = "0.8.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4a650543ca06a924e8b371db273b2756685faae30f8487da1b56505a8f78b0c" +dependencies = [ + "libc", + "wasi", + "windows-sys 0.48.0", +] + [[package]] name = "mirai-annotations" version = "1.12.0" @@ -487,6 +603,25 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "num_cpus" +version = "1.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" +dependencies = [ + "hermit-abi", + "libc", +] + +[[package]] +name = "object" +version = "0.32.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6a622008b6e321afc04970976f62ee297fdbaa6f95318ca343e3eebb9648441" +dependencies = [ + "memchr", +] + [[package]] name = "once_cell" version = "1.19.0" @@ -511,6 +646,12 @@ version = "0.2.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58" +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + [[package]] name = "polling" version = "2.8.0" @@ -655,9 +796,10 @@ name = "rust_cast" version = "0.19.0" dependencies = [ "ansi_term", - "byteorder", + "bytes", "docopt", "env_logger", + "futures-util", "log", "mdns-sd", "protobuf", @@ -667,8 +809,18 @@ dependencies = [ "serde", "serde_derive", "serde_json", + "tokio", + "tokio-rustls", + "tokio-test", + "tokio-util", ] +[[package]] +name = "rustc-demangle" +version = "0.1.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" + [[package]] name = "rustc-hash" version = "1.1.0" @@ -825,6 +977,15 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" +[[package]] +name = "slab" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f92a496fb766b417c996b9c5e57daf2f7ad3b0bebe1ccfca4856390e3d3bb67" +dependencies = [ + "autocfg", +] + [[package]] name = "socket2" version = "0.4.10" @@ -835,6 +996,16 @@ dependencies = [ "winapi", ] +[[package]] +name = "socket2" +version = "0.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05ffd9c0a93b7543e062e759284fcf5f5e3b098501104bfbdde4d404db792871" +dependencies = [ + "libc", + "windows-sys 0.52.0", +] + [[package]] name = "spin" version = "0.9.8" @@ -899,6 +1070,101 @@ dependencies = [ "syn", ] +[[package]] +name = "tokio" +version = "1.36.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61285f6515fa018fb2d1e46eb21223fff441ee8db5d0f1435e8ab4f5cdb80931" +dependencies = [ + "backtrace", + "bytes", + "libc", + "mio", + "num_cpus", + "pin-project-lite", + "socket2 0.5.6", + "tokio-macros", + "windows-sys 0.48.0", +] + +[[package]] +name = "tokio-macros" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tokio-rustls" +version = "0.25.0" +source = "git+https://github.com/rustls/tokio-rustls#330d28788f54dc94664a316bd29b354341f0aa38" +dependencies = [ + "rustls", + "rustls-pki-types", + "tokio", +] + +[[package]] +name = "tokio-stream" +version = "0.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "397c988d37662c7dda6d2208364a706264bf3d6138b11d436cbac0ad38832842" +dependencies = [ + "futures-core", + "pin-project-lite", + "tokio", +] + +[[package]] +name = "tokio-test" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e89b3cbabd3ae862100094ae433e1def582cf86451b4e9bf83aa7ac1d8a7d719" +dependencies = [ + "async-stream", + "bytes", + "futures-core", + "tokio", + "tokio-stream", +] + +[[package]] +name = "tokio-util" +version = "0.7.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5419f34732d9eb6ee4c3578b7989078579b7f039cbbb9ca2c4da015749371e15" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", + "tracing", +] + +[[package]] +name = "tracing" +version = "0.1.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" +dependencies = [ + "pin-project-lite", + "tracing-core", +] + +[[package]] +name = "tracing-core" +version = "0.1.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" +dependencies = [ + "once_cell", +] + [[package]] name = "unicode-ident" version = "1.0.12" diff --git a/Cargo.toml b/Cargo.toml index 8084d4400..e58018fab 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,7 +18,8 @@ exclude = [ ] [dependencies] -byteorder = "1.5" +bytes = "1.5.0" +futures-util = { version = "0.3.30", features = ["sink"] } log = "0.4" protobuf = "=3.4.0" rustls = "0.23" @@ -26,12 +27,17 @@ rustls-native-certs = "0.7" serde = "1" serde_derive = "1" serde_json = "1" +tokio = { version = "1.36.0", features = ["net", "sync", "io-util"] } +tokio-rustls = { git = "https://github.com/rustls/tokio-rustls" } +tokio-util = { version = "0.7.10", features = ["codec"] } [dev-dependencies] ansi_term = "0.12.1" docopt = "1.1.1" env_logger = "0.11.2" mdns-sd = "0.10.4" +tokio = { version = "1.36.0", features = ["macros", "rt-multi-thread"] } +tokio-test = "0.4.3" [build-dependencies] protobuf-codegen = "=3.4.0" diff --git a/examples/rust_caster.rs b/examples/rust_caster.rs index ff24dfadd..612224a74 100644 --- a/examples/rust_caster.rs +++ b/examples/rust_caster.rs @@ -64,8 +64,8 @@ struct Args { flag_media_seek: Option, } -fn print_info(device: &CastDevice) { - let status = device.receiver.get_status().unwrap(); +async fn print_info(device: &CastDevice<'_>) { + let status = device.receiver.get_status().await.unwrap(); println!( "\n{} {}", @@ -104,8 +104,8 @@ fn print_info(device: &CastDevice) { } } -fn run_app(device: &CastDevice, app_to_run: &CastDeviceApp) { - let app = device.receiver.launch_app(app_to_run).unwrap(); +async fn run_app(device: &CastDevice<'_>, app_to_run: &CastDeviceApp) { + let app = device.receiver.launch_app(app_to_run).await.unwrap(); println!( "{}{}{}{}{}{}{}", @@ -119,8 +119,8 @@ fn run_app(device: &CastDevice, app_to_run: &CastDeviceApp) { ); } -fn stop_app(device: &CastDevice, app_to_run: &CastDeviceApp) { - let status = device.receiver.get_status().unwrap(); +async fn stop_app(device: &CastDevice<'_>, app_to_run: &CastDeviceApp) { + let status = device.receiver.get_status().await.unwrap(); let app = status .applications @@ -129,7 +129,11 @@ fn stop_app(device: &CastDevice, app_to_run: &CastDeviceApp) { match app { Some(app) => { - device.receiver.stop_app(app.session_id.as_str()).unwrap(); + device + .receiver + .stop_app(app.session_id.as_str()) + .await + .unwrap(); println!( "{}{}{}{}{}{}{}", @@ -153,11 +157,15 @@ fn stop_app(device: &CastDevice, app_to_run: &CastDeviceApp) { } } -fn stop_current_app(device: &CastDevice) { - let status = device.receiver.get_status().unwrap(); +async fn stop_current_app(device: &CastDevice<'_>) { + let status = device.receiver.get_status().await.unwrap(); match status.applications.first() { Some(app) => { - device.receiver.stop_app(app.session_id.as_str()).unwrap(); + device + .receiver + .stop_app(app.session_id.as_str()) + .await + .unwrap(); println!( "{}{}{}{}{}{}{}", @@ -174,18 +182,19 @@ fn stop_current_app(device: &CastDevice) { } } -fn play_media( - device: &CastDevice, +async fn play_media( + device: &CastDevice<'_>, app_to_run: &CastDeviceApp, media: String, media_type: String, media_stream_type: StreamType, ) { - let app = device.receiver.launch_app(app_to_run).unwrap(); + let app = device.receiver.launch_app(app_to_run).await.unwrap(); device .connection .connect(app.transport_id.as_str()) + .await .unwrap(); let status = device @@ -201,6 +210,7 @@ fn play_media( metadata: None, }, ) + .await .unwrap(); for i in 0..status.entries.len() { @@ -257,14 +267,14 @@ fn play_media( } } -fn discover() -> Option<(String, u16)> { +async fn discover() -> Option<(String, u16)> { let mdns = ServiceDaemon::new().expect("Failed to create mDNS daemon."); let receiver = mdns .browse(SERVICE_TYPE) .expect("Failed to browse mDNS services."); - while let Ok(event) = receiver.recv() { + while let Ok(event) = receiver.recv_async().await { match event { ServiceEvent::ServiceResolved(info) => { let mut addresses = info @@ -297,7 +307,8 @@ fn discover() -> Option<(String, u16)> { None } -fn main() { +#[tokio::main] +async fn main() { env_logger::init(); let args: Args = Docopt::new(USAGE) @@ -308,14 +319,14 @@ fn main() { Some(address) => (address, args.flag_port), None => { println!("Cast Device address is not specified, trying to discover..."); - discover().unwrap_or_else(|| { + discover().await.unwrap_or_else(|| { println!("No Cast device discovered, please specify device address explicitly."); std::process::exit(1); }) } }; - let cast_device = match CastDevice::connect_without_host_verification(address, port) { + let cast_device = match CastDevice::connect_without_host_verification(address, port).await { Ok(cast_device) => cast_device, Err(err) => panic!("Could not establish connection with Cast Device: {:?}", err), }; @@ -323,32 +334,33 @@ fn main() { cast_device .connection .connect(DEFAULT_DESTINATION_ID.to_string()) + .await .unwrap(); - cast_device.heartbeat.ping().unwrap(); + cast_device.heartbeat.ping().await.unwrap(); // Information about cast device. if args.flag_info.is_some() { - return print_info(&cast_device); + return print_info(&cast_device).await; } // Run specific application. if let Some(app) = args.flag_run { - return run_app(&cast_device, &CastDeviceApp::from_str(&app).unwrap()); + return run_app(&cast_device, &CastDeviceApp::from_str(&app).unwrap()).await; } // Stop specific application. if let Some(app) = args.flag_stop { - return stop_app(&cast_device, &CastDeviceApp::from_str(&app).unwrap()); + return stop_app(&cast_device, &CastDeviceApp::from_str(&app).unwrap()).await; } // Stop currently active application. if args.flag_stop_current { - return stop_current_app(&cast_device); + return stop_current_app(&cast_device).await; } // Adjust volume level. if let Some(level) = args.flag_media_volume { - let volume = cast_device.receiver.set_volume(level).unwrap(); + let volume = cast_device.receiver.set_volume(level).await.unwrap(); println!( "{}{}", Green.paint("Volume level has been set to: "), @@ -360,7 +372,11 @@ fn main() { // Mute/unmute cast device. if args.flag_media_mute || args.flag_media_unmute { let mute_or_unmute = args.flag_media_mute; - let volume = cast_device.receiver.set_volume(mute_or_unmute).unwrap(); + let volume = cast_device + .receiver + .set_volume(mute_or_unmute) + .await + .unwrap(); println!( "{}{}", Green.paint("Cast device is muted: "), @@ -376,7 +392,7 @@ fn main() { || args.flag_media_seek.is_some() { let app_to_manage = CastDeviceApp::from_str(args.flag_media_app.as_str()).unwrap(); - let status = cast_device.receiver.get_status().unwrap(); + let status = cast_device.receiver.get_status().await.unwrap(); let app = status .applications @@ -388,11 +404,13 @@ fn main() { cast_device .connection .connect(app.transport_id.as_str()) + .await .unwrap(); let status = cast_device .media .get_status(app.transport_id.as_str(), None) + .await .unwrap(); let status = status.entries.first().unwrap(); @@ -403,6 +421,7 @@ fn main() { cast_device .media .pause(app.transport_id.as_str(), status.media_session_id) + .await .unwrap(), ); } else if args.flag_media_play { @@ -410,6 +429,7 @@ fn main() { cast_device .media .play(app.transport_id.as_str(), status.media_session_id) + .await .unwrap(), ); } else if args.flag_media_stop { @@ -417,6 +437,7 @@ fn main() { cast_device .media .stop(app.transport_id.as_str(), status.media_session_id) + .await .unwrap(), ); } else if args.flag_media_seek.is_some() { @@ -429,6 +450,7 @@ fn main() { Some(args.flag_media_seek.unwrap()), None, ) + .await .unwrap(), ); } @@ -510,15 +532,16 @@ fn main() { media, media_type, media_stream_type, - ); + ) + .await; loop { - match cast_device.receive() { + match cast_device.receive().await { Ok(ChannelMessage::Heartbeat(response)) => { println!("[Heartbeat] {:?}", response); if let HeartbeatResponse::Ping = response { - cast_device.heartbeat.pong().unwrap(); + cast_device.heartbeat.pong().await.unwrap(); } } diff --git a/src/channels/connection.rs b/src/channels/connection.rs index ac6b273d9..ebf495d13 100644 --- a/src/channels/connection.rs +++ b/src/channels/connection.rs @@ -1,7 +1,6 @@ -use std::{ - borrow::Cow, - io::{Read, Write}, -}; +use std::borrow::Cow; + +use tokio::io::{AsyncRead, AsyncWrite}; use crate::{ cast::proxies, @@ -25,7 +24,7 @@ pub enum ConnectionResponse { pub struct ConnectionChannel<'a, W> where - W: Read + Write, + W: AsyncRead + AsyncWrite, { sender: Cow<'a, str>, message_manager: Lrc>, @@ -33,7 +32,7 @@ where impl<'a, W> ConnectionChannel<'a, W> where - W: Read + Write, + W: AsyncRead + AsyncWrite, { pub fn new(sender: S, message_manager: Lrc>) -> ConnectionChannel<'a, W> where @@ -45,7 +44,7 @@ where } } - pub fn connect(&self, destination: S) -> Result<(), Error> + pub async fn connect(&self, destination: S) -> Result<(), Error> where S: Into>, { @@ -54,15 +53,17 @@ where user_agent: CHANNEL_USER_AGENT.to_string(), })?; - self.message_manager.send(CastMessage { - namespace: CHANNEL_NAMESPACE.to_string(), - source: self.sender.to_string(), - destination: destination.into().to_string(), - payload: CastMessagePayload::String(payload), - }) + self.message_manager + .send(CastMessage { + namespace: CHANNEL_NAMESPACE.to_string(), + source: self.sender.to_string(), + destination: destination.into().to_string(), + payload: CastMessagePayload::String(payload), + }) + .await } - pub fn disconnect(&self, destination: S) -> Result<(), Error> + pub async fn disconnect(&self, destination: S) -> Result<(), Error> where S: Into>, { @@ -71,12 +72,14 @@ where user_agent: CHANNEL_USER_AGENT.to_string(), })?; - self.message_manager.send(CastMessage { - namespace: CHANNEL_NAMESPACE.to_string(), - source: self.sender.to_string(), - destination: destination.into().to_string(), - payload: CastMessagePayload::String(payload), - }) + self.message_manager + .send(CastMessage { + namespace: CHANNEL_NAMESPACE.to_string(), + source: self.sender.to_string(), + destination: destination.into().to_string(), + payload: CastMessagePayload::String(payload), + }) + .await } pub fn can_handle(&self, message: &CastMessage) -> bool { diff --git a/src/channels/heartbeat.rs b/src/channels/heartbeat.rs index 5db3d5ddf..bf6b81c36 100644 --- a/src/channels/heartbeat.rs +++ b/src/channels/heartbeat.rs @@ -1,7 +1,6 @@ -use std::{ - borrow::Cow, - io::{Read, Write}, -}; +use std::borrow::Cow; + +use tokio::io::{AsyncRead, AsyncWrite}; use crate::{ cast::proxies, @@ -24,7 +23,7 @@ pub enum HeartbeatResponse { pub struct HeartbeatChannel<'a, W> where - W: Read + Write, + W: AsyncRead + AsyncWrite, { sender: Cow<'a, str>, receiver: Cow<'a, str>, @@ -33,7 +32,7 @@ where impl<'a, W> HeartbeatChannel<'a, W> where - W: Read + Write, + W: AsyncRead + AsyncWrite, { pub fn new( sender: S, @@ -50,30 +49,34 @@ where } } - pub fn ping(&self) -> Result<(), Error> { + pub async fn ping(&self) -> Result<(), Error> { let payload = serde_json::to_string(&proxies::heartbeat::HeartBeatRequest { typ: MESSAGE_TYPE_PING.to_string(), })?; - self.message_manager.send(CastMessage { - namespace: CHANNEL_NAMESPACE.to_string(), - source: self.sender.to_string(), - destination: self.receiver.to_string(), - payload: CastMessagePayload::String(payload), - }) + self.message_manager + .send(CastMessage { + namespace: CHANNEL_NAMESPACE.to_string(), + source: self.sender.to_string(), + destination: self.receiver.to_string(), + payload: CastMessagePayload::String(payload), + }) + .await } - pub fn pong(&self) -> Result<(), Error> { + pub async fn pong(&self) -> Result<(), Error> { let payload = serde_json::to_string(&proxies::heartbeat::HeartBeatRequest { typ: MESSAGE_TYPE_PONG.to_string(), })?; - self.message_manager.send(CastMessage { - namespace: CHANNEL_NAMESPACE.to_string(), - source: self.sender.to_string(), - destination: self.receiver.to_string(), - payload: CastMessagePayload::String(payload), - }) + self.message_manager + .send(CastMessage { + namespace: CHANNEL_NAMESPACE.to_string(), + source: self.sender.to_string(), + destination: self.receiver.to_string(), + payload: CastMessagePayload::String(payload), + }) + .await } pub fn can_handle(&self, message: &CastMessage) -> bool { diff --git a/src/channels/media.rs b/src/channels/media.rs index 64ce7b374..8edaff7fd 100644 --- a/src/channels/media.rs +++ b/src/channels/media.rs @@ -1,9 +1,6 @@ -use std::{ - borrow::Cow, - io::{Read, Write}, - str::FromStr, - string::ToString, -}; +use std::{borrow::Cow, str::FromStr, string::ToString}; + +use tokio::io::{AsyncRead, AsyncWrite}; use crate::{ cast::proxies, @@ -772,7 +769,7 @@ pub enum MediaResponse { pub struct MediaChannel<'a, W> where - W: Read + Write, + W: AsyncRead + AsyncWrite, { sender: Cow<'a, str>, message_manager: Lrc>, @@ -780,7 +777,7 @@ where impl<'a, W> MediaChannel<'a, W> where - W: Read + Write, + W: AsyncRead + AsyncWrite, { pub fn new(sender: S, message_manager: Lrc>) -> MediaChannel<'a, W> where @@ -803,7 +800,7 @@ where /// # Return value /// /// Returned `Result` should consist of either `Status` instance or an `Error`. - pub fn get_status( + pub async fn get_status( &self, destination: S, media_session_id: Option, @@ -819,37 +816,41 @@ where media_session_id, })?; - self.message_manager.send(CastMessage { - namespace: CHANNEL_NAMESPACE.to_string(), - source: self.sender.to_string(), - destination: destination.into().to_string(), - payload: CastMessagePayload::String(payload), - })?; - - self.message_manager.receive_find_map(|message| { - if !self.can_handle(message) { - return Ok(None); - } + self.message_manager + .send(CastMessage { + namespace: CHANNEL_NAMESPACE.to_string(), + source: self.sender.to_string(), + destination: destination.into().to_string(), + payload: CastMessagePayload::String(payload), + }) + .await?; + + self.message_manager + .receive_find_map(|message| { + if !self.can_handle(message) { + return Ok(None); + } - match self.parse(message)? { - MediaResponse::Status(status) => { - if status.request_id == request_id { - return Ok(Some(status)); + match self.parse(message)? { + MediaResponse::Status(status) => { + if status.request_id == request_id { + return Ok(Some(status)); + } } - } - MediaResponse::InvalidRequest(error) => { - if error.request_id == request_id { - return Err(Error::Internal(format!( - "Invalid request ({}).", - error.reason.unwrap_or_else(|| "Unknown".to_string()) - ))); + MediaResponse::InvalidRequest(error) => { + if error.request_id == request_id { + return Err(Error::Internal(format!( + "Invalid request ({}).", + error.reason.unwrap_or_else(|| "Unknown".to_string()) + ))); + } } + _ => {} } - _ => {} - } - Ok(None) - }) + Ok(None) + }) + .await } /// Loads provided media to the application. @@ -862,11 +863,17 @@ where /// # Return value /// /// Returned `Result` should consist of either `Status` instance or an `Error`. - pub fn load(&self, destination: S, session_id: S, media: &Media) -> Result + pub async fn load( + &self, + destination: S, + session_id: S, + media: &Media, + ) -> Result where S: Into>, { self.load_with_queue(destination, session_id, media, None) + .await } /// Loads provided media to the application. @@ -879,7 +886,7 @@ where /// # Return value /// /// Returned `Result` should consist of either `Status` instance or an `Error`. - pub fn load_with_queue( + pub async fn load_with_queue( &self, destination: S, session_id: S, @@ -904,78 +911,82 @@ where queue_data: queue.map(|qd| qd.encode()), })?; - self.message_manager.send(CastMessage { - namespace: CHANNEL_NAMESPACE.to_string(), - source: self.sender.to_string(), - destination: destination.into().to_string(), - payload: CastMessagePayload::String(payload), - })?; + self.message_manager + .send(CastMessage { + namespace: CHANNEL_NAMESPACE.to_string(), + source: self.sender.to_string(), + destination: destination.into().to_string(), + payload: CastMessagePayload::String(payload), + }) + .await?; // Once media is loaded cast receiver device should emit status update event, or load failed // event if something went wrong. - self.message_manager.receive_find_map(|message| { - if !self.can_handle(message) { - return Ok(None); - } - - match self.parse(message)? { - MediaResponse::Status(status) => { - if status.request_id == request_id { - return Ok(Some(status)); - } + self.message_manager + .receive_find_map(|message| { + if !self.can_handle(message) { + return Ok(None); + } - // [WORKAROUND] In some cases we don't receive response (e.g. from YouTube app), - // so let's just wait for the response with the media we're interested in and - // return it. - let has_media = { - status.entries.iter().any(|entry| { - if let Some(ref loaded_media) = entry.media { - return loaded_media.content_id == media.content_id; - } - - false - }) - }; - - if has_media { - return Ok(Some(status)); + match self.parse(message)? { + MediaResponse::Status(status) => { + if status.request_id == request_id { + return Ok(Some(status)); + } + + // [WORKAROUND] In some cases we don't receive response (e.g. from YouTube app), + // so let's just wait for the response with the media we're interested in and + // return it. + let has_media = { + status.entries.iter().any(|entry| { + if let Some(ref loaded_media) = entry.media { + return loaded_media.content_id == media.content_id; + } + + false + }) + }; + + if has_media { + return Ok(Some(status)); + } } - } - MediaResponse::LoadFailed(error) => { - if error.request_id == request_id { - return Err(Error::Internal("Failed to load media.".to_string())); + MediaResponse::LoadFailed(error) => { + if error.request_id == request_id { + return Err(Error::Internal("Failed to load media.".to_string())); + } } - } - MediaResponse::LoadCancelled(error) => { - if error.request_id == request_id { - return Err(Error::Internal( - "Load cancelled by another request.".to_string(), - )); + MediaResponse::LoadCancelled(error) => { + if error.request_id == request_id { + return Err(Error::Internal( + "Load cancelled by another request.".to_string(), + )); + } } - } - MediaResponse::InvalidPlayerState(error) => { - if error.request_id == request_id { - return Err(Error::Internal( - "Load failed because of invalid player state.".to_string(), - )); + MediaResponse::InvalidPlayerState(error) => { + if error.request_id == request_id { + return Err(Error::Internal( + "Load failed because of invalid player state.".to_string(), + )); + } } - } - MediaResponse::InvalidRequest(error) => { - if error.request_id == request_id { - return Err(Error::Internal(format!( - "Load failed because of invalid media request (reason: {}).", - error.reason.unwrap_or_else(|| "UNKNOWN".to_string()) - ))); + MediaResponse::InvalidRequest(error) => { + if error.request_id == request_id { + return Err(Error::Internal(format!( + "Load failed because of invalid media request (reason: {}).", + error.reason.unwrap_or_else(|| "UNKNOWN".to_string()) + ))); + } } + _ => {} } - _ => {} - } - Ok(None) - }) + Ok(None) + }) + .await } - pub fn load_queue( + pub async fn load_queue( &self, destination: S, _session_id: S, @@ -996,58 +1007,62 @@ where start_index: queue.start_index, })?; - self.message_manager.send(CastMessage { - namespace: CHANNEL_NAMESPACE.to_string(), - source: self.sender.to_string(), - destination: destination.into().to_string(), - payload: CastMessagePayload::String(payload), - })?; + self.message_manager + .send(CastMessage { + namespace: CHANNEL_NAMESPACE.to_string(), + source: self.sender.to_string(), + destination: destination.into().to_string(), + payload: CastMessagePayload::String(payload), + }) + .await?; // Once media is loaded cast receiver device should emit status update event, or load failed // event if something went wrong. - self.message_manager.receive_find_map(|message| { - if !self.can_handle(message) { - return Ok(None); - } + self.message_manager + .receive_find_map(|message| { + if !self.can_handle(message) { + return Ok(None); + } - match self.parse(message)? { - MediaResponse::Status(status) => { - if status.request_id == request_id { - return Ok(Some(status)); + match self.parse(message)? { + MediaResponse::Status(status) => { + if status.request_id == request_id { + return Ok(Some(status)); + } } - } - MediaResponse::LoadFailed(error) => { - if error.request_id == request_id { - return Err(Error::Internal("Failed to load media.".to_string())); + MediaResponse::LoadFailed(error) => { + if error.request_id == request_id { + return Err(Error::Internal("Failed to load media.".to_string())); + } } - } - MediaResponse::LoadCancelled(error) => { - if error.request_id == request_id { - return Err(Error::Internal( - "Load cancelled by another request.".to_string(), - )); + MediaResponse::LoadCancelled(error) => { + if error.request_id == request_id { + return Err(Error::Internal( + "Load cancelled by another request.".to_string(), + )); + } } - } - MediaResponse::InvalidPlayerState(error) => { - if error.request_id == request_id { - return Err(Error::Internal( - "Load failed because of invalid player state.".to_string(), - )); + MediaResponse::InvalidPlayerState(error) => { + if error.request_id == request_id { + return Err(Error::Internal( + "Load failed because of invalid player state.".to_string(), + )); + } } - } - MediaResponse::InvalidRequest(error) => { - if error.request_id == request_id { - return Err(Error::Internal(format!( - "Load failed because of invalid media request (reason: {}).", - error.reason.unwrap_or_else(|| "UNKNOWN".to_string()) - ))); + MediaResponse::InvalidRequest(error) => { + if error.request_id == request_id { + return Err(Error::Internal(format!( + "Load failed because of invalid media request (reason: {}).", + error.reason.unwrap_or_else(|| "UNKNOWN".to_string()) + ))); + } } + _ => {} } - _ => {} - } - Ok(None) - }) + Ok(None) + }) + .await } /// Pauses playback of the current content. Triggers a STATUS event notification to all sender @@ -1061,7 +1076,11 @@ where /// # Return value /// /// Returned `Result` should consist of either `Status` instance or an `Error`. - pub fn pause(&self, destination: S, media_session_id: i32) -> Result + pub async fn pause( + &self, + destination: S, + media_session_id: i32, + ) -> Result where S: Into>, { @@ -1074,14 +1093,17 @@ where custom_data: proxies::media::CustomData::new(), })?; - self.message_manager.send(CastMessage { - namespace: CHANNEL_NAMESPACE.to_string(), - source: self.sender.to_string(), - destination: destination.into().to_string(), - payload: CastMessagePayload::String(payload), - })?; + self.message_manager + .send(CastMessage { + namespace: CHANNEL_NAMESPACE.to_string(), + source: self.sender.to_string(), + destination: destination.into().to_string(), + payload: CastMessagePayload::String(payload), + }) + .await?; self.receive_status_entry(request_id, media_session_id) + .await } /// Begins playback of the content that was loaded with the load call, playback is continued @@ -1095,7 +1117,7 @@ where /// # Return value /// /// Returned `Result` should consist of either `Status` instance or an `Error`. - pub fn play(&self, destination: S, media_session_id: i32) -> Result + pub async fn play(&self, destination: S, media_session_id: i32) -> Result where S: Into>, { @@ -1108,14 +1130,17 @@ where custom_data: proxies::media::CustomData::new(), })?; - self.message_manager.send(CastMessage { - namespace: CHANNEL_NAMESPACE.to_string(), - source: self.sender.to_string(), - destination: destination.into().to_string(), - payload: CastMessagePayload::String(payload), - })?; + self.message_manager + .send(CastMessage { + namespace: CHANNEL_NAMESPACE.to_string(), + source: self.sender.to_string(), + destination: destination.into().to_string(), + payload: CastMessagePayload::String(payload), + }) + .await?; self.receive_status_entry(request_id, media_session_id) + .await } /// Stops playback of the current content. Triggers a STATUS event notification to all sender @@ -1130,7 +1155,7 @@ where /// # Return value /// /// Returned `Result` should consist of either `Status` instance or an `Error`. - pub fn stop(&self, destination: S, media_session_id: i32) -> Result + pub async fn stop(&self, destination: S, media_session_id: i32) -> Result where S: Into>, { @@ -1143,14 +1168,17 @@ where custom_data: proxies::media::CustomData::new(), })?; - self.message_manager.send(CastMessage { - namespace: CHANNEL_NAMESPACE.to_string(), - source: self.sender.to_string(), - destination: destination.into().to_string(), - payload: CastMessagePayload::String(payload), - })?; + self.message_manager + .send(CastMessage { + namespace: CHANNEL_NAMESPACE.to_string(), + source: self.sender.to_string(), + destination: destination.into().to_string(), + payload: CastMessagePayload::String(payload), + }) + .await?; self.receive_status_entry(request_id, media_session_id) + .await } /// Sets the current position in the stream. Triggers a STATUS event notification to all sender @@ -1167,7 +1195,7 @@ where /// # Return value /// /// Returned `Result` should consist of either `Status` instance or an `Error`. - pub fn seek( + pub async fn seek( &self, destination: S, media_session_id: i32, @@ -1188,14 +1216,17 @@ where custom_data: proxies::media::CustomData::new(), })?; - self.message_manager.send(CastMessage { - namespace: CHANNEL_NAMESPACE.to_string(), - source: self.sender.to_string(), - destination: destination.into().to_string(), - payload: CastMessagePayload::String(payload), - })?; + self.message_manager + .send(CastMessage { + namespace: CHANNEL_NAMESPACE.to_string(), + source: self.sender.to_string(), + destination: destination.into().to_string(), + payload: CastMessagePayload::String(payload), + }) + .await?; self.receive_status_entry(request_id, media_session_id) + .await } pub fn can_handle(&self, message: &CastMessage) -> bool { @@ -1285,46 +1316,48 @@ where /// # Return value /// /// Returned `Result` should consist of either `Status` instance or an `Error`. - fn receive_status_entry( + async fn receive_status_entry( &self, request_id: u32, media_session_id: i32, ) -> Result { - self.message_manager.receive_find_map(|message| { - if !self.can_handle(message) { - return Ok(None); - } + self.message_manager + .receive_find_map(|message| { + if !self.can_handle(message) { + return Ok(None); + } - match self.parse(message)? { - MediaResponse::Status(mut status) => { - if status.request_id == request_id { - let position = status - .entries - .iter() - .position(|e| e.media_session_id == media_session_id); + match self.parse(message)? { + MediaResponse::Status(mut status) => { + if status.request_id == request_id { + let position = status + .entries + .iter() + .position(|e| e.media_session_id == media_session_id); - return Ok(position.map(|position| status.entries.remove(position))); + return Ok(position.map(|position| status.entries.remove(position))); + } } - } - MediaResponse::InvalidPlayerState(error) => { - if error.request_id == request_id { - return Err(Error::Internal( - "Request failed because of invalid player state.".to_string(), - )); + MediaResponse::InvalidPlayerState(error) => { + if error.request_id == request_id { + return Err(Error::Internal( + "Request failed because of invalid player state.".to_string(), + )); + } } - } - MediaResponse::InvalidRequest(error) => { - if error.request_id == request_id { - return Err(Error::Internal(format!( - "Invalid request ({}).", - error.reason.unwrap_or_else(|| "Unknown".to_string()) - ))); + MediaResponse::InvalidRequest(error) => { + if error.request_id == request_id { + return Err(Error::Internal(format!( + "Invalid request ({}).", + error.reason.unwrap_or_else(|| "Unknown".to_string()) + ))); + } } + _ => {} } - _ => {} - } - Ok(None) - }) + Ok(None) + }) + .await } } diff --git a/src/channels/receiver.rs b/src/channels/receiver.rs index 3984b317e..e6840f4f3 100644 --- a/src/channels/receiver.rs +++ b/src/channels/receiver.rs @@ -1,12 +1,7 @@ -use std::{ - borrow::Cow, - convert::Into, - io::{Read, Write}, - str::FromStr, - string::ToString, -}; +use std::{borrow::Cow, convert::Into, str::FromStr, string::ToString}; use serde::Serialize; +use tokio::io::{AsyncRead, AsyncWrite}; use crate::{ cast::proxies, @@ -170,7 +165,7 @@ impl ToString for CastDeviceApp { pub struct ReceiverChannel<'a, W> where - W: Write + Read, + W: AsyncRead + AsyncWrite, { sender: Cow<'a, str>, receiver: Cow<'a, str>, @@ -179,7 +174,7 @@ where impl<'a, W> ReceiverChannel<'a, W> where - W: Write + Read, + W: AsyncRead + AsyncWrite, { pub fn new( sender: S, @@ -204,14 +199,16 @@ where /// use std::str::FromStr; /// use rust_cast::{CastDevice, channels::receiver::CastDeviceApp}; /// - /// # let cast_device = CastDevice::connect_without_host_verification("host", 1234).unwrap(); - /// cast_device.receiver.launch_app(&CastDeviceApp::from_str("youtube").unwrap()); + /// # tokio_test::block_on(async { + /// # let cast_device = CastDevice::connect_without_host_verification("host", 1234).await.unwrap(); + /// cast_device.receiver.launch_app(&CastDeviceApp::from_str("youtube").unwrap()).await; + /// # }) /// ``` /// /// # Arguments /// /// * `app` - `CastDeviceApp` instance reference to run. - pub fn launch_app(&self, app: &CastDeviceApp) -> Result { + pub async fn launch_app(&self, app: &CastDeviceApp) -> Result { let request_id = self.message_manager.generate_request_id().get(); let payload = serde_json::to_string(&proxies::receiver::AppLaunchRequest { @@ -220,39 +217,43 @@ where app_id: app.to_string(), })?; - self.message_manager.send(CastMessage { - namespace: CHANNEL_NAMESPACE.to_string(), - source: self.sender.to_string(), - destination: self.receiver.to_string(), - payload: CastMessagePayload::String(payload), - })?; + self.message_manager + .send(CastMessage { + namespace: CHANNEL_NAMESPACE.to_string(), + source: self.sender.to_string(), + destination: self.receiver.to_string(), + payload: CastMessagePayload::String(payload), + }) + .await?; // Once application is run cast receiver device should emit status update event, or launch // error event if something went wrong. - self.message_manager.receive_find_map(|message| { - if !self.can_handle(message) { - return Ok(None); - } + self.message_manager + .receive_find_map(|message| { + if !self.can_handle(message) { + return Ok(None); + } - match self.parse(message)? { - ReceiverResponse::Status(mut status) => { - if status.request_id == request_id { - return Ok(Some(status.applications.remove(0))); + match self.parse(message)? { + ReceiverResponse::Status(mut status) => { + if status.request_id == request_id { + return Ok(Some(status.applications.remove(0))); + } } - } - ReceiverResponse::LaunchError(error) => { - if error.request_id == request_id { - return Err(Error::Internal(format!( - "Could not run application ({}).", - error.reason.unwrap_or_else(|| "Unknown".to_string()) - ))); + ReceiverResponse::LaunchError(error) => { + if error.request_id == request_id { + return Err(Error::Internal(format!( + "Could not run application ({}).", + error.reason.unwrap_or_else(|| "Unknown".to_string()) + ))); + } } + _ => {} } - _ => {} - } - Ok(None) - }) + Ok(None) + }) + .await } /// Broadcasts a message over a cast device's message bus. @@ -271,7 +272,7 @@ where /// /// * `namespace` - Message namespace that should start with `urn:x-cast:`. /// * `message` - Message instance to send. - pub fn broadcast_message( + pub async fn broadcast_message( &self, namespace: &str, message: &M, @@ -283,12 +284,14 @@ where ))); } let payload = serde_json::to_string(message)?; - self.message_manager.send(CastMessage { - namespace: namespace.to_string(), - source: self.sender.to_string(), - destination: "*".into(), - payload: CastMessagePayload::String(payload), - })?; + self.message_manager + .send(CastMessage { + namespace: namespace.to_string(), + source: self.sender.to_string(), + destination: "*".into(), + payload: CastMessagePayload::String(payload), + }) + .await?; Ok(()) } @@ -297,7 +300,7 @@ where /// /// # Arguments /// * `session_id` - identifier of the active application session from `Application` instance. - pub fn stop_app(&self, session_id: S) -> Result<(), Error> + pub async fn stop_app(&self, session_id: S) -> Result<(), Error> where S: Into>, { @@ -309,39 +312,43 @@ where session_id: session_id.into(), })?; - self.message_manager.send(CastMessage { - namespace: CHANNEL_NAMESPACE.to_string(), - source: self.sender.to_string(), - destination: self.receiver.to_string(), - payload: CastMessagePayload::String(payload), - })?; + self.message_manager + .send(CastMessage { + namespace: CHANNEL_NAMESPACE.to_string(), + source: self.sender.to_string(), + destination: self.receiver.to_string(), + payload: CastMessagePayload::String(payload), + }) + .await?; // Once application is stopped cast receiver device should emit status update event, or // invalid request event if provided session id is not valid. - self.message_manager.receive_find_map(|message| { - if !self.can_handle(message) { - return Ok(None); - } + self.message_manager + .receive_find_map(|message| { + if !self.can_handle(message) { + return Ok(None); + } - match self.parse(message)? { - ReceiverResponse::Status(status) => { - if status.request_id == request_id { - return Ok(Some(())); + match self.parse(message)? { + ReceiverResponse::Status(status) => { + if status.request_id == request_id { + return Ok(Some(())); + } } - } - ReceiverResponse::InvalidRequest(error) => { - if error.request_id == request_id { - return Err(Error::Internal(format!( - "Invalid request ({}).", - error.reason.unwrap_or_else(|| "Unknown".to_string()) - ))); + ReceiverResponse::InvalidRequest(error) => { + if error.request_id == request_id { + return Err(Error::Internal(format!( + "Invalid request ({}).", + error.reason.unwrap_or_else(|| "Unknown".to_string()) + ))); + } } + _ => {} } - _ => {} - } - Ok(None) - }) + Ok(None) + }) + .await } /// Retrieves status of the cast device receiver. @@ -349,7 +356,7 @@ where /// # Return value /// /// Returned `Result` should consist of either `Status` instance or an `Error`. - pub fn get_status(&self) -> Result { + pub async fn get_status(&self) -> Result { let request_id = self.message_manager.generate_request_id().get(); let payload = serde_json::to_string(&proxies::receiver::GetStatusRequest { @@ -357,27 +364,31 @@ where request_id, })?; - self.message_manager.send(CastMessage { - namespace: CHANNEL_NAMESPACE.to_string(), - source: self.sender.to_string(), - destination: self.receiver.to_string(), - payload: CastMessagePayload::String(payload), - })?; - - self.message_manager.receive_find_map(|message| { - if !self.can_handle(message) { - return Ok(None); - } + self.message_manager + .send(CastMessage { + namespace: CHANNEL_NAMESPACE.to_string(), + source: self.sender.to_string(), + destination: self.receiver.to_string(), + payload: CastMessagePayload::String(payload), + }) + .await?; + + self.message_manager + .receive_find_map(|message| { + if !self.can_handle(message) { + return Ok(None); + } - let message = self.parse(message)?; - if let ReceiverResponse::Status(status) = message { - if status.request_id == request_id { - return Ok(Some(status)); + let message = self.parse(message)?; + if let ReceiverResponse::Status(status) = message { + if status.request_id == request_id { + return Ok(Some(status)); + } } - } - Ok(None) - }) + Ok(None) + }) + .await } /// Sets volume for the active cast device. @@ -394,7 +405,7 @@ where /// # Errors /// /// Usually method can fail only if network connection with cast device is lost for some reason. - pub fn set_volume(&self, volume: T) -> Result + pub async fn set_volume(&self, volume: T) -> Result where T: Into, { @@ -410,27 +421,31 @@ where }, })?; - self.message_manager.send(CastMessage { - namespace: CHANNEL_NAMESPACE.to_string(), - source: self.sender.to_string(), - destination: self.receiver.to_string(), - payload: CastMessagePayload::String(payload), - })?; - - self.message_manager.receive_find_map(|message| { - if !self.can_handle(message) { - return Ok(None); - } + self.message_manager + .send(CastMessage { + namespace: CHANNEL_NAMESPACE.to_string(), + source: self.sender.to_string(), + destination: self.receiver.to_string(), + payload: CastMessagePayload::String(payload), + }) + .await?; + + self.message_manager + .receive_find_map(|message| { + if !self.can_handle(message) { + return Ok(None); + } - let message = self.parse(message)?; - if let ReceiverResponse::Status(status) = message { - if status.request_id == request_id { - return Ok(Some(status.volume)); + let message = self.parse(message)?; + if let ReceiverResponse::Status(status) = message { + if status.request_id == request_id { + return Ok(Some(status.volume)); + } } - } - Ok(None) - }) + Ok(None) + }) + .await } pub fn can_handle(&self, message: &CastMessage) -> bool { diff --git a/src/lib.rs b/src/lib.rs index 525dad119..1cc1f23d3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,9 +4,8 @@ mod cast; pub mod channels; pub mod errors; pub mod message_manager; -mod utils; -use std::{borrow::Cow, net::TcpStream, sync::Arc}; +use std::{borrow::Cow, sync::Arc}; use channels::{ connection::{ConnectionChannel, ConnectionResponse}, @@ -21,8 +20,10 @@ use rustls::{ client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}, crypto::{aws_lc_rs::default_provider, verify_tls12_signature, verify_tls13_signature}, pki_types::{CertificateDer, ServerName, UnixTime}, - ClientConfig, ClientConnection, DigitallySignedStruct, RootCertStore, StreamOwned, + ClientConfig, DigitallySignedStruct, RootCertStore, }; +use tokio::net::TcpStream; +use tokio_rustls::client::TlsStream; const DEFAULT_SENDER_ID: &str = "sender-0"; const DEFAULT_RECEIVER_ID: &str = "receiver-0"; @@ -50,19 +51,19 @@ pub enum ChannelMessage { /// Structure that manages connection to a cast device. pub struct CastDevice<'a> { - message_manager: Lrc>>, + message_manager: Lrc>>, /// Channel that manages connection responses/requests. - pub connection: ConnectionChannel<'a, StreamOwned>, + pub connection: ConnectionChannel<'a, TlsStream>, /// Channel that allows connection to stay alive (via ping-pong requests/responses). - pub heartbeat: HeartbeatChannel<'a, StreamOwned>, + pub heartbeat: HeartbeatChannel<'a, TlsStream>, /// Channel that manages various media stuff. - pub media: MediaChannel<'a, StreamOwned>, + pub media: MediaChannel<'a, TlsStream>, /// Channel that manages receiving platform (e.g. Chromecast). - pub receiver: ReceiverChannel<'a, StreamOwned>, + pub receiver: ReceiverChannel<'a, TlsStream>, } impl<'a> CastDevice<'a> { @@ -73,8 +74,10 @@ impl<'a> CastDevice<'a> { /// ```no_run /// use rust_cast::CastDevice; /// - /// let device = CastDevice::connect("192.168.1.2", 8009)?; + /// # tokio_test::block_on(async { + /// let device = CastDevice::connect("192.168.1.2", 8009).await?; /// # Ok::<(), rust_cast::errors::Error>(()) + /// # }); /// ``` /// /// # Arguments @@ -90,7 +93,7 @@ impl<'a> CastDevice<'a> { /// # Return value /// /// Instance of `CastDevice` that allows you to manage connection. - pub fn connect(host: S, port: u16) -> Result, Error> + pub async fn connect(host: S, port: u16) -> Result, Error> where S: Into>, { @@ -114,12 +117,14 @@ impl<'a> CastDevice<'a> { .with_root_certificates(root_store) .with_no_client_auth(); config.key_log = Arc::new(rustls::KeyLogFile::new()); + let connor = tokio_rustls::TlsConnector::from(Arc::new(config)); - let conn = ClientConnection::new( - config.into(), - ServerName::try_from(host.as_ref())?.to_owned(), - )?; - let stream = StreamOwned::new(conn, TcpStream::connect((host.as_ref(), port))?); + let stream = connor + .connect( + ServerName::try_from(host.as_ref())?.to_owned(), + TcpStream::connect((host.as_ref(), port)).await?, + ) + .await?; log::debug!("Connection with {host}:{port} successfully established."); @@ -134,8 +139,10 @@ impl<'a> CastDevice<'a> { /// ```no_run /// use rust_cast::CastDevice; /// - /// let device = CastDevice::connect_without_host_verification("192.168.1.2", 8009)?; + /// # tokio_test::block_on(async { + /// let device = CastDevice::connect_without_host_verification("192.168.1.2", 8009).await?; /// # Ok::<(), rust_cast::errors::Error>(()) + /// # }); /// ``` /// /// # Arguments @@ -151,7 +158,10 @@ impl<'a> CastDevice<'a> { /// # Return value /// /// Instance of `CastDevice` that allows you to manage connection. - pub fn connect_without_host_verification(host: S, port: u16) -> Result, Error> + pub async fn connect_without_host_verification( + host: S, + port: u16, + ) -> Result, Error> where S: Into>, { @@ -164,14 +174,14 @@ impl<'a> CastDevice<'a> { .with_custom_certificate_verifier(Arc::new(NoCertificateVerification {})) .with_no_client_auth(); config.key_log = Arc::new(rustls::KeyLogFile::new()); - let stream = StreamOwned::new( - ClientConnection::new( - Arc::new(config), - ServerName::try_from(host.as_ref())?.to_owned(), - )?, - TcpStream::connect((host.as_ref(), port))?, - ); + let connor = tokio_rustls::TlsConnector::from(Arc::new(config)); + let stream = connor + .connect( + ServerName::try_from(host.as_ref())?.to_owned(), + TcpStream::connect((host.as_ref(), port)).await?, + ) + .await?; log::debug!("Connection with {host}:{port} successfully established."); CastDevice::connect_to_device(stream) @@ -186,15 +196,17 @@ impl<'a> CastDevice<'a> { /// use rust_cast::ChannelMessage; /// /// # use rust_cast::CastDevice; - /// # let cast_device = CastDevice::connect_without_host_verification("192.168.1.2", 8009)?; + /// # tokio_test::block_on(async { + /// # let cast_device = CastDevice::connect_without_host_verification("192.168.1.2", 8009).await?; /// - /// match cast_device.receive() { + /// match cast_device.receive().await { /// Ok(ChannelMessage::Connection(res)) => log::debug!("Connection message: {:?}", res), - /// Ok(ChannelMessage::Heartbeat(_)) => cast_device.heartbeat.pong()?, + /// Ok(ChannelMessage::Heartbeat(_)) => cast_device.heartbeat.pong().await?, /// Ok(_) => {}, /// Err(err) => log::error!("Error occurred while receiving message {}", err) /// } /// # Ok::<(), rust_cast::errors::Error>(()) + /// # }); /// ``` /// /// # Errors @@ -204,8 +216,8 @@ impl<'a> CastDevice<'a> { /// # Returned values /// /// Parsed channel message. - pub fn receive(&self) -> Result { - let cast_message = self.message_manager.receive()?; + pub async fn receive(&self) -> Result { + let cast_message = self.message_manager.receive().await?; if self.connection.can_handle(&cast_message) { return Ok(ChannelMessage::Connection( @@ -241,9 +253,7 @@ impl<'a> CastDevice<'a> { /// # Return value /// /// Instance of `CastDevice` that allows you to manage connection. - fn connect_to_device( - ssl_stream: StreamOwned, - ) -> Result, Error> { + fn connect_to_device(ssl_stream: TlsStream) -> Result, Error> { let message_manager_rc = Lrc::new(MessageManager::new(ssl_stream)); let heartbeat = HeartbeatChannel::new( diff --git a/src/message_manager.rs b/src/message_manager.rs index f51f1d15d..a525ffd31 100644 --- a/src/message_manager.rs +++ b/src/message_manager.rs @@ -1,16 +1,20 @@ use std::{ - io::{Read, Write}, num::NonZeroU32, ops::{Deref, DerefMut}, }; +use bytes::{Buf as _, BufMut as _, BytesMut}; +use futures_util::{SinkExt as _, StreamExt as _}; +use protobuf::Message as _; +use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf}; +use tokio_util::codec; + use crate::{ cast::{ cast_channel, cast_channel::cast_message::{PayloadType, ProtocolVersion}, }, errors::Error, - utils, }; struct Lock( @@ -95,35 +99,66 @@ pub struct CastMessage { pub payload: CastMessagePayload, } -/// Static structure that is responsible for (de)serializing and sending/receiving Cast protocol -/// messages. -pub struct MessageManager -where - S: Write + Read, -{ - message_buffer: Lock>, - stream: Lock, - request_counter: Lock, -} +#[derive(Default)] +struct MessageCodec; -impl MessageManager -where - S: Write + Read, -{ - pub fn new(stream: S) -> Self { - MessageManager { - stream: Lock::new(stream), - message_buffer: Lock::new(vec![]), - request_counter: Lock::new(NonZeroU32::MIN), +// Limit message size to 8MiB +const MAX: u32 = 8 * 1024 * 1024; + +// Basically tokio_util's default LengthDelimitedCodec, with +// protobuf deserialisation on top +impl codec::Decoder for MessageCodec { + type Item = CastMessage; + type Error = Error; + + fn decode(&mut self, src: &mut BytesMut) -> Result, Error> { + let mut peek: &[u8] = &*src; + if peek.remaining() < 4 { + return Ok(None); + } + let length = peek.get_u32(); + if length > MAX { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("Frame of length {} is too large.", length), + ) + .into()); + } + let length: usize = length.try_into().unwrap(); + if peek.remaining() < length { + // Enough space for this message and the length of the next + src.reserve(length - peek.remaining() + 4); + return Ok(None); } + // We have our frame, start consuming the original buffer + // Skip length: + src.advance(4); + let frame = src.split_to(length); + // parse_from_tokio_bytes could save a few allocs, + // if combined with cogen changes: https://lib.rs/crates/protobuf + //let raw_message = cast_channel::CastMessage::parse_from_tokio_bytes(frame)?; + let raw_message = cast_channel::CastMessage::parse_from_bytes(&frame)?; + log::debug!("Message received: {:?}", raw_message); + Ok(Some(CastMessage { + namespace: raw_message.namespace().to_string(), + source: raw_message.source_id().to_string(), + destination: raw_message.destination_id().to_string(), + payload: match raw_message.payload_type() { + PayloadType::STRING => { + CastMessagePayload::String(raw_message.payload_utf8().to_string()) + } + PayloadType::BINARY => { + CastMessagePayload::Binary(raw_message.payload_binary().to_owned()) + } + }, + })) } +} - /// Sends `message` to the Cast Device. - /// - /// # Arguments - /// - /// * `message` - `CastMessage` instance to be sent to the Cast Device. - pub fn send(&self, message: CastMessage) -> Result<(), Error> { +impl codec::Encoder for MessageCodec { + type Error = Error; + + fn encode(&mut self, message: CastMessage, dst: &mut BytesMut) -> Result<(), Error> { let mut raw_message = cast_channel::CastMessage::new(); raw_message.set_protocol_version(ProtocolVersion::CASTV2_1_0); @@ -144,17 +179,57 @@ where } }; - let message_content_buffer = utils::to_vec(&raw_message)?; - let message_length_buffer = - utils::write_u32_to_buffer(message_content_buffer.len() as u32)?; + let message_content_buffer = raw_message.write_to_bytes()?; + dst.put_u32(message_content_buffer.len().try_into().unwrap()); + dst.put_slice(&message_content_buffer); + log::debug!("Message encoded: {:?}", raw_message); - let writer = &mut *self.stream.borrow_mut(); + Ok(()) + } +} - writer.write_all(&message_length_buffer)?; - writer.write_all(&message_content_buffer)?; +/// Static structure that is responsible for (de)serializing and sending/receiving Cast protocol +/// messages. +pub struct MessageManager +where + S: AsyncWrite + AsyncRead, +{ + message_buffer: Lock>, + // Using async mutexes to prevent interleaving; we keep + // the guard across await points until a frame is entirely handled + sender: tokio::sync::Mutex, MessageCodec>>, + // This is an independent mutex so reads and writes can be interleaved + receiver: tokio::sync::Mutex, MessageCodec>>, + request_counter: Lock, +} - log::debug!("Message sent: {:?}", raw_message); +impl MessageManager +where + S: AsyncWrite + AsyncRead, +{ + pub fn new(stream: S) -> Self { + // Would like to use BiLock for splitting, + // but https://github.com/rust-lang/futures-rs/pull/2384 + // was left unmerged. Also, it may not be as useful + // if poll_lock takes mut self. + let (read, write) = tokio::io::split(stream); + let receiver = codec::FramedRead::new(read, MessageCodec).into(); + let sender = codec::FramedWrite::new(write, MessageCodec).into(); + MessageManager { + sender, + receiver, + message_buffer: Lock::new(vec![]), + request_counter: Lock::new(NonZeroU32::MIN), + } + } + /// Sends `message` to the Cast Device. + /// + /// # Arguments + /// + /// * `message` - `CastMessage` instance to be sent to the Cast Device. + pub async fn send(&self, message: CastMessage) -> Result<(), Error> { + self.sender.lock().await.send(message).await?; Ok(()) } @@ -165,12 +240,12 @@ where /// # Return value /// /// `Result` containing parsed `CastMessage` or `Error`. - pub fn receive(&self) -> Result { + pub async fn receive(&self) -> Result { let mut message_buffer = self.message_buffer.borrow_mut(); // If we have messages in the buffer, let's return them from it. if message_buffer.is_empty() { - self.read() + self.read().await } else { Ok(message_buffer.remove(0)) } @@ -183,18 +258,21 @@ where /// # Example /// /// ```no_run - /// # use std::net::TcpStream; + /// # use std::sync::Arc; + /// # use tokio::net::TcpStream; /// # use rust_cast::message_manager::{CastMessage, MessageManager}; /// # use rustls::{ClientConfig, ClientConnection, RootCertStore, StreamOwned}; /// # use rustls::pki_types::ServerName; + /// # use tokio_rustls::TlsConnector; /// # let config = ClientConfig::builder() /// # .with_root_certificates(RootCertStore::empty()) /// # .with_no_client_auth(); + /// # tokio_test::block_on(async { /// # let server_name = ServerName::try_from("0")?.to_owned(); - /// # let conn = ClientConnection::new(config.into(), server_name)?; - /// # let tcp_stream = TcpStream::connect(("0", 8009)).unwrap(); - /// # let ssl_stream = StreamOwned::new(conn, tcp_stream); - /// # let message_manager = MessageManager::new(ssl_stream); + /// # let connor = TlsConnector::from(Arc::new(config)); + /// # let tcp_stream = TcpStream::connect(("0", 8009)).await?; + /// # let tls_stream = connor.connect(server_name, tcp_stream).await?; + /// # let message_manager = MessageManager::new(tls_stream); /// # fn can_handle(message: &CastMessage) -> bool { unimplemented!() } /// # fn parse(message: &CastMessage) { unimplemented!() } /// message_manager.receive_find_map(|message| { @@ -205,8 +283,9 @@ where /// parse(message); /// /// Ok(Some(())) - /// })?; + /// }).await?; /// # Ok::<(), rust_cast::errors::Error>(()) + /// # }); /// ``` /// /// # Arguments @@ -218,12 +297,12 @@ where /// # Return value /// /// `Result` containing parsed `CastMessage` or `Error`. - pub fn receive_find_map(&self, f: F) -> Result + pub async fn receive_find_map(&self, f: F) -> Result where F: Fn(&CastMessage) -> Result, Error>, { loop { - let message = self.read()?; + let message = self.read().await?; // If message is found, just return mapped result, otherwise keep unprocessed message // in the buffer, it can be later retrieved with `receive`. @@ -251,36 +330,14 @@ where /// # Return value /// /// `Result` containing parsed `CastMessage` or `Error`. - fn read(&self) -> Result { - let mut buffer: [u8; 4] = [0; 4]; - - let reader = &mut *self.stream.borrow_mut(); - - reader.read_exact(&mut buffer)?; - - let length = utils::read_u32_from_buffer(&buffer)?; - - let mut buffer: Vec = Vec::with_capacity(length as usize); - let mut limited_reader = reader.take(u64::from(length)); - - limited_reader.read_to_end(&mut buffer)?; - - let raw_message = utils::from_vec::(buffer.to_vec())?; - - log::debug!("Message received: {:?}", raw_message); - - Ok(CastMessage { - namespace: raw_message.namespace().to_string(), - source: raw_message.source_id().to_string(), - destination: raw_message.destination_id().to_string(), - payload: match raw_message.payload_type() { - PayloadType::STRING => { - CastMessagePayload::String(raw_message.payload_utf8().to_string()) - } - PayloadType::BINARY => { - CastMessagePayload::Binary(raw_message.payload_binary().to_owned()) - } - }, - }) + async fn read(&self) -> Result { + let Some(maybe_msg) = self.receiver.lock().await.next().await else { + // Stream has ended cleanly; Ok(None) would be better + // but requires updating users + return Err(Error::Io(std::io::Error::from( + std::io::ErrorKind::UnexpectedEof, + ))); + }; + maybe_msg } } diff --git a/src/utils.rs b/src/utils.rs deleted file mode 100644 index 852b860c2..000000000 --- a/src/utils.rs +++ /dev/null @@ -1,29 +0,0 @@ -use crate::errors::Error; -use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; -use std::io::Cursor; - -pub fn read_u32_from_buffer(buffer: &[u8]) -> Result { - Ok(Cursor::new(buffer).read_u32::()?) -} - -pub fn write_u32_to_buffer(number: u32) -> Result, Error> { - let mut buffer = vec![]; - - buffer.write_u32::(number)?; - - Ok(buffer) -} - -pub fn to_vec(message: &M) -> Result, Error> { - let mut buffer = vec![]; - - message.write_to_writer(&mut buffer)?; - - Ok(buffer) -} - -pub fn from_vec(buffer: Vec) -> Result { - let mut read_buffer = Cursor::new(buffer); - - Ok(M::parse_from_reader(&mut read_buffer)?) -}