Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding ClientExt and ClientProvider traits #299

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions async-openai/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ native-tls-vendored = ["reqwest/native-tls-vendored"]
realtime = ["dep:tokio-tungstenite"]

[dependencies]
async-trait = "0.1"
backoff = { version = "0.4.0", features = ["tokio"] }
base64 = "0.22.1"
futures = "0.3.30"
Expand Down
62 changes: 62 additions & 0 deletions async-openai/src/chat.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::{
client::{ClientExt, ClientProvider},
config::Config,
error::OpenAIError,
types::{
Expand Down Expand Up @@ -52,3 +53,64 @@ impl<'c, C: Config> Chat<'c, C> {
Ok(self.client.post_stream("/chat/completions", request).await)
}
}

impl<'c, C: Config + Send> ClientProvider<'c, C> for Chat<'c, C> {
fn client(&self) -> &'c Client<C> {
self.client
}
}

#[cfg(test)]
mod tests {

use crate::types::{ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequestArgs};

use super::*;

#[async_trait::async_trait]
trait ChatExt {
async fn create_annotated_stream(
&self,
mut request: CreateChatCompletionRequest,
) -> Result<ChatCompletionResponseStream, OpenAIError>;
}

#[async_trait::async_trait]
impl<'c, C: Config> ChatExt for Chat<'c, C> {
async fn create_annotated_stream(
&self,
mut request: CreateChatCompletionRequest,
) -> Result<ChatCompletionResponseStream, OpenAIError> {
if request.stream.is_some() && !request.stream.unwrap() {
return Err(OpenAIError::InvalidArgument(
"When stream is false, use Chat::create".into(),
));
}

request.stream = Some(true);

Ok(self.client.post_stream("/chat/completions", request).await)
}
}

#[tokio::test]
async fn test() {
let client = Client::new();
let chat = client.chat();

let request = CreateChatCompletionRequestArgs::default()
.model("gpt-3.5-turbo")
.max_tokens(512u32)
.messages([ChatCompletionRequestUserMessageArgs::default()
.content(
"Write a marketing blog praising and introducing Rust library async-openai",
)
.build()
.unwrap()
.into()])
.build()
.unwrap();

let _stream = chat.create_annotated_stream(request).await.unwrap();
}
}
121 changes: 77 additions & 44 deletions async-openai/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -377,50 +377,6 @@ impl<C: Config> Client<C> {
Ok(response)
}

/// Make HTTP POST request to receive SSE
pub(crate) async fn post_stream<I, O>(
&self,
path: &str,
request: I,
) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
where
I: Serialize,
O: DeserializeOwned + std::marker::Send + 'static,
{
let event_source = self
.http_client
.post(self.config.url(path))
.query(&self.config.query())
.headers(self.config.headers())
.json(&request)
.eventsource()
.unwrap();

stream(event_source).await
}

pub(crate) async fn post_stream_mapped_raw_events<I, O>(
&self,
path: &str,
request: I,
event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
where
I: Serialize,
O: DeserializeOwned + std::marker::Send + 'static,
{
let event_source = self
.http_client
.post(self.config.url(path))
.query(&self.config.query())
.headers(self.config.headers())
.json(&request)
.eventsource()
.unwrap();

stream_mapped_raw_events(event_source, event_mapper).await
}

/// Make HTTP GET request to receive SSE
pub(crate) async fn _get_stream<Q, O>(
&self,
Expand Down Expand Up @@ -537,3 +493,80 @@ where

Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
}

pub trait ClientProvider<'c, C: Config + Send> {
fn client(&self) -> &'c Client<C>;
}

#[async_trait::async_trait]
pub trait ClientExt<C: Config + Send>: Send {
/// Make HTTP POST request to receive SSE
async fn post_stream<I, O>(
&self,
path: &str,
request: I,
) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
where
I: Serialize + Send,
O: DeserializeOwned + std::marker::Send + 'static;

/// Make HTTP POST request to receive SSE with a custom event source handler
async fn post_stream_mapped_raw_events<I, O>(
&self,
path: &str,
request: I,
event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
where
I: Serialize + Send,
O: DeserializeOwned + std::marker::Send + 'static;
}

#[async_trait::async_trait]
impl<C: Config> ClientExt<C> for Client<C>
where
C: Send,
{
async fn post_stream<I, O>(
&self,
path: &str,
request: I,
) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
where
I: Serialize + Send,
O: DeserializeOwned + std::marker::Send + 'static,
{
let event_source = self
.http_client
.post(self.config.url(path))
.query(&self.config.query())
.headers(self.config.headers())
.json(&request)
.eventsource()
.unwrap();

stream(event_source).await
}

async fn post_stream_mapped_raw_events<I, O>(
&self,
path: &str,
request: I,
event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>> + Send>>
where
I: Serialize + Send,
O: DeserializeOwned + std::marker::Send + 'static,
{
let event_source = self
.http_client
.post(self.config.url(path))
.query(&self.config.query())
.headers(self.config.headers())
.json(&request)
.eventsource()
.unwrap();

stream_mapped_raw_events(event_source, event_mapper).await
}
}
2 changes: 1 addition & 1 deletion async-openai/src/completion.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::{
client::Client,
client::{Client, ClientExt},
config::Config,
error::OpenAIError,
types::{CompletionResponseStream, CreateCompletionRequest, CreateCompletionResponse},
Expand Down
2 changes: 1 addition & 1 deletion async-openai/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ pub const OPENAI_BETA_HEADER: &str = "OpenAI-Beta";

/// [crate::Client] relies on this for every API call on OpenAI
/// or Azure OpenAI service
pub trait Config: Clone {
pub trait Config: Clone + Send + Sync {
fn headers(&self) -> HeaderMap;
fn url(&self, path: &str) -> String;
fn query(&self) -> Vec<(&str, &str)>;
Expand Down
2 changes: 1 addition & 1 deletion async-openai/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ mod audio;
mod audit_logs;
mod batches;
mod chat;
mod client;
pub mod client;
mod completion;
pub mod config;
mod download;
Expand Down
1 change: 1 addition & 0 deletions async-openai/src/runs.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use serde::Serialize;

use crate::{
client::ClientExt,
config::Config,
error::OpenAIError,
steps::Steps,
Expand Down
1 change: 1 addition & 0 deletions async-openai/src/threads.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::{
client::ClientExt,
config::Config,
error::OpenAIError,
types::{
Expand Down