Skip to content

Commit

Permalink
Finish porting to async_std
Browse files Browse the repository at this point in the history
  • Loading branch information
mikayla-maki committed Feb 26, 2024
1 parent 87f60ab commit 1746ddc
Show file tree
Hide file tree
Showing 13 changed files with 408 additions and 75 deletions.
233 changes: 196 additions & 37 deletions Cargo.lock

Large diffs are not rendered by default.

35 changes: 21 additions & 14 deletions livekit-api/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,30 +8,36 @@ repository = "https://github.com/livekit/rust-sdks"

[features]
# By default ws TLS is not enabled
default = ["services", "access-token", "webhooks"]
default = ["services-tokio", "access-token", "webhooks"]

tokio = ["dep:reqwest"]
# TODO: Pick an async_std http client
async = ["dep:reqwest"]

# TODO: This is already a backend crate, livekit crate can select the feature

# TODO: Can we split the signal client out so we can more easily compile
# tokio-tungstenite and async-tungstenite separately?
signal-client = [
signal-client-tokio = [
"dep:tokio-tungstenite",
"dep:tokio",
"dep:futures-util",
"dep:reqwest",
"dep:livekit-runtime"
]
services = ["dep:reqwest"]

signal-client-async = [
"dep:async-tungstenite",
"dep:tokio", # For macros and sync
"dep:futures-util",
"dep:isahc",
"dep:livekit-runtime"
]

services-tokio = ["dep:reqwest"]
services-async = ["dep:isahc"]
access-token = ["dep:jsonwebtoken"]
webhooks = ["access-token", "dep:serde_json", "dep:base64"]

# Note that the following features only change the behavior of tokio-tungstenite.
# It doesn't change the behavior of libwebrtc/webrtc-sys
native-tls = ["tokio-tungstenite?/native-tls", "reqwest?/native-tls"]
native-tls = [
"tokio-tungstenite?/native-tls",
"async-tungstenite?/async-native-tls",
"reqwest?/native-tls"
]
native-tls-vendored = [
"tokio-tungstenite?/native-tls-vendored",
"reqwest?/native-tls-vendored",
Expand Down Expand Up @@ -66,12 +72,13 @@ jsonwebtoken = { version = "9", default-features = false, optional = true }
# signal_client
livekit-runtime = { path = "../livekit-runtime", version = "0.3.0", optional = true}
tokio-tungstenite = { version = "0.20", optional = true }
# async-tungstenite = { version = "0.25.0", optional = true }
async-tungstenite = { version = "0.25.0", features = [ "async-std-runtime", "async-native-tls"], optional = true }
tokio = { version = "1", default-features = false, features = ["sync", "macros"], optional = true }
futures-util = { version = "0.3", default-features = false, features = [ "sink" ], optional = true }

# This dependency must be kept in sync with reqwest's version
http = "0.2"
http = "0.2.1"
reqwest = { version = "0.11", default-features = false, features = [ "json" ], optional = true }
isahc = { version = "1.7.2", default-features = false, features = [ "json", "text-decoding" ], optional = true }

scopeguard = "1.2.0"
145 changes: 145 additions & 0 deletions livekit-api/src/http_client.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
#[cfg(any(feature = "services-tokio", feature = "signal-client-tokio"))]
mod tokio {
#[cfg(feature = "signal-client-tokio")]
pub use reqwest::get;

#[cfg(feature = "services-tokio")]
pub use reqwest::Client;
}

#[cfg(any(feature = "services-tokio", feature = "signal-client-tokio"))]
pub use tokio::*;

#[cfg(any(feature = "signal-client-async", feature = "services-async"))]
mod async_std {

// #[cfg(any(feature = "native-tls-vendored", feature = "rustls-tls-native-roots", feature = "rustls-tls-webpki-roots", feature = "__rustls-tls"))]
// compile_error!("the async std compatible libraries do not support these features");


#[cfg(any(feature = "signal-client-async", feature = "services-async"))]
pub struct Response(http::Response<isahc::AsyncBody>);

#[cfg(feature = "signal-client-async")]
mod signal_client {
use std::io;

use isahc::AsyncReadResponseExt;

use super::Response;

impl Response {
pub fn status(&self) -> http::StatusCode {
self.0.status()
}

// TODO: Sub in correct error types
pub async fn text(mut self) -> io::Result<String> {
self.0.text().await
}
}

pub async fn get(url: &str) -> io::Result<Response> {
let response = isahc::get_async(url).await?;
Ok(Response(response))
}
}

#[cfg(feature = "signal-client-async")]
pub use signal_client::*;


#[cfg(feature = "services-async")]
mod services {
use std::io;

use isahc::AsyncReadResponseExt;
use prost::bytes::Bytes;

use super::Response;

use http::header::{Entry, OccupiedEntry};
use url::Url;

impl Response {
pub async fn bytes(self) -> io::Result<Bytes> {
Ok(self.0.bytes().await?.into())
}

pub async fn json<T: serde::de::DeserializeOwned + Unpin>(&mut self) -> io::Result<T> {
self.0.json().await.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
}
}

#[derive(Debug)]
pub struct Client(isahc::HttpClient);

impl Client {
pub fn new() -> Self {
Self(isahc::HttpClient::new().unwrap())
}
}

impl Client {
pub fn post(&self, url: Url) -> RequestBuilder {
RequestBuilder {
body: Vec::new(),
builder: isahc::http::Request::post(url.as_str()),
client: self.0.clone(),
}
}
}

pub struct RequestBuilder {
builder: isahc::http::request::Builder,
body: Vec<u8>,
client: isahc::HttpClient,
}

impl RequestBuilder {
pub fn headers(mut self, headers: http::HeaderMap) -> Self {
// Copied from: https://docs.rs/reqwest/0.11.24/src/reqwest/util.rs.html#62-89
let self_headers = self.builder.headers_mut().unwrap();
let mut prev_entry: Option<OccupiedEntry<_>> = None;
for (key, value) in headers {
match key {
Some(key) => match self_headers.entry(key) {
Entry::Occupied(mut e) => {
e.insert(value);
prev_entry = Some(e);
}
Entry::Vacant(e) => {
let e = e.insert_entry(value);
prev_entry = Some(e);
}
},
None => match prev_entry {
Some(ref mut entry) => {
entry.append(value);
}
None => unreachable!("HeaderMap::into_iter yielded None first"),
},
}
}
self
}

pub fn body(mut self, body: Vec<u8>) -> Self {
self.body = body;
self
}

pub async fn send(self) -> io::Result<Response> {
let request = self.builder.body(self.body).unwrap();
let response = self.client.send_async(request).await?;
Ok(Response(response))
}
}
}

#[cfg(feature = "services-async")]
pub use services::*;
}

#[cfg(any(feature = "signal-client-async", feature = "services-async"))]
pub use async_std::*;
8 changes: 6 additions & 2 deletions livekit-api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,19 @@
// See the License for the specific language governing permissions and
// limitations under the License.


#[cfg(feature = "access-token")]
pub mod access_token;

#[cfg(feature = "services")]
#[cfg(any(feature = "services-tokio", feature = "services-async"))]
pub mod services;

#[cfg(feature = "signal-client")]
#[cfg(any(feature = "signal-client-tokio", feature = "signal-client-async"))]
pub mod signal_client;

#[cfg(any(feature = "signal-client-tokio", feature = "signal-client-async", feature = "services-tokio", feature = "services-async"))]
mod http_client;

#[cfg(feature = "webhooks")]
pub mod webhooks;

Expand Down
10 changes: 8 additions & 2 deletions livekit-api/src/services/twirp_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,18 @@ use http::{
use serde::Deserialize;
use thiserror::Error;

use crate::http_client;

pub const DEFAULT_PREFIX: &str = "/twirp";

#[derive(Debug, Error)]
pub enum TwirpError {
#[cfg(feature = "services-tokio")]
#[error("failed to execute the request: {0}")]
Request(#[from] reqwest::Error),
#[cfg(feature = "services-async")]
#[error("failed to execute the request: {0}")]
Request(#[from] std::io::Error),
#[error("twirp error: {0}")]
Twirp(TwirpErrorCode),
#[error("url error: {0}")]
Expand Down Expand Up @@ -75,7 +81,7 @@ pub struct TwirpClient {
host: String,
pkg: String,
prefix: String,
client: reqwest::Client,
client: http_client::Client,
}

impl TwirpClient {
Expand All @@ -84,7 +90,7 @@ impl TwirpClient {
host: host.to_owned(),
pkg: pkg.to_owned(),
prefix: prefix.unwrap_or(DEFAULT_PREFIX).to_owned(),
client: reqwest::Client::new(),
client: http_client::Client::new(),
}
}

Expand Down
9 changes: 7 additions & 2 deletions livekit-api/src/signal_client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,14 @@ use http::StatusCode;
use thiserror::Error;
use livekit_runtime::{JoinHandle, interval, sleep, Instant};
use tokio::sync::{mpsc, Mutex as AsyncMutex, RwLock as AsyncRwLock};

#[cfg(feature = "signal-client-tokio")]
use tokio_tungstenite::tungstenite::Error as WsError;

use crate::signal_client::signal_stream::SignalStream;
#[cfg(feature = "signal-client-async")]
use async_tungstenite::tungstenite::Error as WsError;

use crate::{http_client, signal_client::signal_stream::SignalStream};

mod signal_stream;

Expand Down Expand Up @@ -217,7 +222,7 @@ impl SignalInner {
segs.extend(&["rtc", "validate"]);
}

if let Ok(res) = reqwest::get(ws_url.as_str()).await {
if let Ok(res) = http_client::get(ws_url.as_str()).await {
let status = res.status();
let body = res.text().await.ok().unwrap_or_default();

Expand Down
14 changes: 7 additions & 7 deletions livekit-api/src/signal_client/signal_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,18 @@ use futures_util::{
SinkExt, StreamExt,
};
use livekit_protocol as proto;
use livekit_runtime::JoinHandle;
use livekit_runtime::{JoinHandle, TcpStream};
use prost::Message as ProtoMessage;

// TODO: TCP Stream
use tokio::{
net::TcpStream,
sync::{mpsc, oneshot},
};
use tokio::sync::{mpsc, oneshot};

// TODO: Web sockets
#[cfg(feature = "signal-client-tokio")]
use tokio_tungstenite::{connect_async, tungstenite::Message, MaybeTlsStream, WebSocketStream};

#[cfg(feature = "signal-client-async")]
use async_tungstenite::{tungstenite::Message, async_std::ClientStream as MaybeTlsStream, WebSocketStream,
async_std::connect_async as connect_async};

use super::{SignalError, SignalResult};

type WebSocket = WebSocketStream<MaybeTlsStream<TcpStream>>;
Expand Down
2 changes: 1 addition & 1 deletion livekit-ffi/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ __rustls-tls = ["livekit/__rustls-tls"]
tracing = ["tokio/tracing", "console-subscriber"]

[dependencies]
livekit = { path = "../livekit", version = "0.3.0" }
livekit = { path = "../livekit", features = ["async"], default-features = false, version = "0.3.0" }
livekit-protocol = { path = "../livekit-protocol", version = "0.3.0" }
tokio = { version = "1", features = ["full"] }
futures-util = { version = "0.3", default-features = false, features = ["sink"] }
Expand Down
1 change: 0 additions & 1 deletion livekit-runtime/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ edition = "2021"
repository = "https://github.com/livekit/rust-sdks"

[features]
default = ["async"]
tokio = ["dep:tokio"]
async = ["dep:async-std", "dep:futures", "dep:async-io"]

Expand Down
2 changes: 1 addition & 1 deletion livekit-runtime/src/async_std.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use std::time::Duration;

pub type JoinHandle<T> = async_std::task::JoinHandle<T>;
pub use std::time::Instant;
pub use async_std::future::timeout;
pub use async_std::task::spawn;
pub use async_std::net::TcpStream;
use futures::{Future, FutureExt, StreamExt};

pub struct Interval {
Expand Down
2 changes: 1 addition & 1 deletion livekit-runtime/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ pub use tokio::*;
mod async_std;

#[cfg(feature = "async")]
pub use async_std::*;
pub use async_std::*;
13 changes: 9 additions & 4 deletions livekit-runtime/src/tokio.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
pub use tokio::task::spawn;
use std::future::Future;
use std::pin::Pin;
use std::time::Duration;

pub use tokio::time::Instant;
pub use tokio::time::sleep;
pub use tokio::time::timeout;
pub use tokio::net::TcpStream;

pub type JoinHandle<T> = TokioJoinHandle<T>;
pub type Interval = tokio::time::Interval;

struct TokioJoinHandle<T> {
handle: JoinHandle<T>,
#[derive(Debug)]
pub struct TokioJoinHandle<T> {
handle: tokio::task::JoinHandle<T>,
}

pub fn spawn<F>(future: F) -> JoinHandle<F::Output>
Expand Down Expand Up @@ -37,7 +42,7 @@ impl<T> Future for TokioJoinHandle<T> {
// TODO: Is this ok? Or should we have some kind of seperate compatibility layer?
// TODO: Confirm that this matches the async-io implementation
pub fn interval(duration: Duration) -> Interval {
let timer = tokio::time::interval(duration);
let mut timer = tokio::time::interval(duration);
timer.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
timer
}
Loading

0 comments on commit 1746ddc

Please sign in to comment.