From de31d8c95ed507f15a40c0a170827d8539ea12ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jean-S=C3=A9bastien=20Bour?= Date: Mon, 21 Oct 2024 20:23:40 +0200 Subject: [PATCH] Generate extensible versions of methods by allowing generic requests and response Fixes #280 --- Cargo.toml | 4 +- async-openai-macros/Cargo.toml | 23 +++++++ async-openai-macros/README.md | 26 ++++++++ async-openai-macros/src/lib.rs | 107 +++++++++++++++++++++++++++++++ async-openai/Cargo.toml | 4 +- async-openai/README.md | 50 +++++++++++++++ async-openai/src/chat.rs | 15 +++-- async-openai/src/lib.rs | 46 +++++++++++++ async-openai/src/types/chat.rs | 12 +++- async-openai/src/types/mod.rs | 5 ++ async-openai/tests/extensible.rs | 71 ++++++++++++++++++++ 11 files changed, 354 insertions(+), 9 deletions(-) create mode 100644 async-openai-macros/Cargo.toml create mode 100644 async-openai-macros/README.md create mode 100644 async-openai-macros/src/lib.rs create mode 100644 async-openai/tests/extensible.rs diff --git a/Cargo.toml b/Cargo.toml index 75ee4f59..b77c3454 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members = [ "async-openai", "examples/*" ] +members = [ "async-openai", "async-openai-macros", "examples/*" ] # Only check / build main crates by default (check all with `--workspace`) -default-members = ["async-openai"] +default-members = ["async-openai", "async-openai-macros"] resolver = "2" diff --git a/async-openai-macros/Cargo.toml b/async-openai-macros/Cargo.toml new file mode 100644 index 00000000..756e3258 --- /dev/null +++ b/async-openai-macros/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "async-openai-macros" +version = "0.1.0" +authors = ["Jean-Sébastien Bour "] +categories = ["api-bindings", "web-programming", "asynchronous"] +keywords = ["openai", "async", "openapi", "ai"] +description = "Procedural macros for async-openai" +edition = "2021" +rust-version = "1.70" +license = "MIT" +readme = "README.md" +homepage = "https://github.com/64bit/async-openai" +repository = "https://github.com/64bit/async-openai" + +[lib] +proc-macro = true + +[dependencies] +darling = "0.20" +itertools = "0.13" +proc-macro2 = "1" +quote = "1" +syn = "2" diff --git a/async-openai-macros/README.md b/async-openai-macros/README.md new file mode 100644 index 00000000..9a981f4d --- /dev/null +++ b/async-openai-macros/README.md @@ -0,0 +1,26 @@ +
+ + + +
+

async-openai-macros

+

Procedural macros for async-openai

+
+ + + + + + +
+
+Logo created by this repo itself +
+ +## Overview + +This crate contains the procedural macros for `async-openai`. It is not meant to be used directly. + +## License + +This project is licensed under [MIT license](https://github.com/64bit/async-openai/blob/main/LICENSE). diff --git a/async-openai-macros/src/lib.rs b/async-openai-macros/src/lib.rs new file mode 100644 index 00000000..4c798ae7 --- /dev/null +++ b/async-openai-macros/src/lib.rs @@ -0,0 +1,107 @@ +use darling::{ast::NestedMeta, FromMeta}; +use itertools::{Either, Itertools}; +use proc_macro2::TokenStream; +use quote::{format_ident, quote}; +use syn::{parse2, parse_macro_input, Expr, FnArg, ItemFn, Meta, MetaList}; + +#[proc_macro_attribute] +pub fn extensible( + _: proc_macro::TokenStream, + item: proc_macro::TokenStream, +) -> proc_macro::TokenStream { + let item = parse_macro_input!(item as ItemFn); + extensible_impl(item) + .unwrap_or_else(syn::Error::into_compile_error) + .into() +} + +fn extensible_impl(mut item: ItemFn) -> syn::Result { + // Stream variants use a special result type + let is_stream = item.sig.ident.to_string().ends_with("_stream"); + + // Prepare a generic method with a different name + let mut extension = item.clone(); + extension.sig.ident = format_ident!("{}_ext", extension.sig.ident); + + // Remove our attributes from original method arguments + for input in &mut item.sig.inputs { + match input { + FnArg::Receiver(_) => (), + FnArg::Typed(arg) => arg.attrs.retain(|attr| match &attr.meta { + Meta::List(meta) => !attr_is_ours(meta), + _ => true, + }), + } + } + + // Gather request parameters that must be replaced by generics and their optional bounds + let mut i = 0; + let generics = extension + .sig + .inputs + .iter_mut() + .filter_map(|input| match input { + FnArg::Receiver(_) => None, + FnArg::Typed(arg) => { + let (mine, other): (Vec<_>, Vec<_>) = + arg.attrs + .clone() + .into_iter() + .partition_map(|attr| match &attr.meta { + Meta::List(meta) if attr_is_ours(meta) => Either::Left( + Request::from_list( + &NestedMeta::parse_meta_list(meta.tokens.clone()).unwrap(), + ) + .unwrap(), + ), + _ => Either::Right(attr), + }); + let bounds = mine.into_iter().next(); + arg.attrs = other; + bounds.map(|b| { + let ident = format_ident!("__EXTENSIBLE_REQUEST_{i}"); + arg.ty = Box::new(parse2(quote! { #ident }).unwrap()); + i += 1; + (ident, b) + }) + } + }) + .collect::>(); + + // Add generics and their optional bounds to our method's generics + for (ident, Request { bounds }) in generics { + let bounds = bounds.map(|b| quote! { + #b }); + extension + .sig + .generics + .params + .push(parse2(quote! { #ident : ::serde::Serialize #bounds })?) + } + + // Make the result type generic too + let result = if is_stream { + quote! { std::pin::Pin> + Send>>} + } else { + quote! { __EXTENSIBLE_RESPONSE } + }; + extension.sig.output = parse2(quote! { -> Result<#result, OpenAIError> })?; + let send_and_static = is_stream.then_some(quote! { + Send + 'static }); + extension.sig.generics.params.push(parse2( + quote! { __EXTENSIBLE_RESPONSE: serde::de::DeserializeOwned #send_and_static }, + )?); + + Ok(quote! { + #item + + #extension + }) +} + +#[derive(FromMeta)] +struct Request { + bounds: Option, +} + +fn attr_is_ours(meta: &MetaList) -> bool { + meta.path.get_ident().map(|ident| ident.to_string()) == Some("request".to_string()) +} diff --git a/async-openai/Cargo.toml b/async-openai/Cargo.toml index 5c4cb94d..088d58f5 100644 --- a/async-openai/Cargo.toml +++ b/async-openai/Cargo.toml @@ -6,7 +6,7 @@ categories = ["api-bindings", "web-programming", "asynchronous"] keywords = ["openai", "async", "openapi", "ai"] description = "Rust library for OpenAI" edition = "2021" -rust-version = "1.65" +rust-version = "1.70" license = "MIT" readme = "README.md" homepage = "https://github.com/64bit/async-openai" @@ -25,6 +25,7 @@ native-tls-vendored = ["reqwest/native-tls-vendored"] realtime = ["dep:tokio-tungstenite"] [dependencies] +async-openai-macros = { version = "0.1", path = "../async-openai-macros" } backoff = { version = "0.4.0", features = ["tokio"] } base64 = "0.22.1" futures = "0.3.30" @@ -50,6 +51,7 @@ eventsource-stream = "0.2.3" tokio-tungstenite = { version = "0.24.0", optional = true, default-features = false } [dev-dependencies] +actix-web = "4" tokio-test = "0.4.4" [package.metadata.docs.rs] diff --git a/async-openai/README.md b/async-openai/README.md index c3a9af18..6127c6dc 100644 --- a/async-openai/README.md +++ b/async-openai/README.md @@ -41,6 +41,7 @@ - Requests (except SSE streaming) including form submissions are retried with exponential backoff when [rate limited](https://platform.openai.com/docs/guides/rate-limits). - Ergonomic builder pattern for all request objects. - Microsoft Azure OpenAI Service (only for APIs matching OpenAI spec) +- Extensible to other providers via extension methods where you provide the request/response types ## Usage @@ -108,6 +109,55 @@ async fn main() -> Result<(), Box> { Scaled up for README, actual size 256x256 +## Other providers + +Alternative OpenAI providers that provide the same endpoints, but with different requests/responses (e.g. Azure OpenAPI allows applying content filters on chat completion, and get the results as an additional field in the response), are supported through `_ext` methods (currently, only for the chat completion API) which allow you to "bring your own types": + +```rust +use std::error::Error; + +use async_openai::{config::AzureConfig, types::RequestForStream, Client}; +use serde::{Deserialize, Serialize}; + +#[derive(Default, Serialize)] +struct AzureRequest { + specific_azure_field: String, + stream: Option, +} + +impl RequestForStream for AzureRequest { + fn is_request_for_stream(&self) -> bool { + self.stream == Some(true) + } + + fn set_request_for_stream(&mut self, stream: bool) { + self.stream = Some(stream) + } +} + +#[derive(Deserialize)] +struct AzureResponse { + specific_azure_result: String, +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let client = Client::with_config(AzureConfig::new()); + + let request = AzureRequest::default(); + + let response: AzureResponse = client + .chat() + // Use the extensible method which allows you to bring your own types + .create_ext(request) + .await?; + + println!("Specific azure result: {}", response.specific_azure_result); + + Ok(()) +} +``` + ## Contributing Thank you for taking the time to contribute and improve the project. I'd be happy to have you! diff --git a/async-openai/src/chat.rs b/async-openai/src/chat.rs index c7f9b962..b8763660 100644 --- a/async-openai/src/chat.rs +++ b/async-openai/src/chat.rs @@ -1,8 +1,11 @@ +use async_openai_macros::extensible; + use crate::{ config::Config, error::OpenAIError, types::{ ChatCompletionResponseStream, CreateChatCompletionRequest, CreateChatCompletionResponse, + RequestForStream, }, Client, }; @@ -20,11 +23,12 @@ impl<'c, C: Config> Chat<'c, C> { } /// Creates a model response for the given chat conversation. + #[extensible] pub async fn create( &self, - request: CreateChatCompletionRequest, + #[request(bounds = RequestForStream)] request: CreateChatCompletionRequest, ) -> Result { - if request.stream.is_some() && request.stream.unwrap() { + if request.is_request_for_stream() { return Err(OpenAIError::InvalidArgument( "When stream is true, use Chat::create_stream".into(), )); @@ -37,17 +41,18 @@ impl<'c, C: Config> Chat<'c, C> { /// partial message deltas will be sent, like in ChatGPT. Tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream terminated by a `data: [DONE]` message. /// /// [ChatCompletionResponseStream] is a parsed SSE stream until a \[DONE\] is received from server. + #[extensible] pub async fn create_stream( &self, - mut request: CreateChatCompletionRequest, + #[request(bounds = RequestForStream)] mut request: CreateChatCompletionRequest, ) -> Result { - if request.stream.is_some() && !request.stream.unwrap() { + if !request.is_request_for_stream() { return Err(OpenAIError::InvalidArgument( "When stream is false, use Chat::create".into(), )); } - request.stream = Some(true); + request.set_request_for_stream(true); Ok(self.client.post_stream("/chat/completions", request).await) } diff --git a/async-openai/src/lib.rs b/async-openai/src/lib.rs index c8a06edd..2a132e72 100644 --- a/async-openai/src/lib.rs +++ b/async-openai/src/lib.rs @@ -73,6 +73,52 @@ //! # }); //!``` //! +//! ## Other providers +//! +//! You can use alternative providers that extend OpenAPI specs by using `_ext` extensible methods to bring your own types: +//! +//! ``` +//! # tokio_test::block_on(async { +//! use std::error::Error; +//! +//! use async_openai::{config::AzureConfig, types::RequestForStream, Client}; +//! use serde::{Deserialize, Serialize}; +//! +//! #[derive(Default, Serialize)] +//! struct AzureRequest { +//! specific_azure_field: String, +//! stream: Option, +//! } +//! +//! impl RequestForStream for AzureRequest { +//! fn is_request_for_stream(&self) -> bool { +//! self.stream == Some(true) +//! } +//! +//! fn set_request_for_stream(&mut self, stream: bool) { +//! self.stream = Some(stream) +//! } +//! } +//! +//! #[derive(Deserialize)] +//! struct AzureResponse { +//! specific_azure_result: String, +//! } +//! +//! let client = Client::with_config(AzureConfig::new()); +//! +//! let request = AzureRequest::default(); +//! +//! let response: AzureResponse = client +//! .chat() +//! // Use the extensible method which allows you to bring your own types +//! .create_ext(request) +//! .await?; +//! +//! println!("Specific azure result: {}", response.specific_azure_result); +//! # }); +//! ``` +//! //! ## Examples //! For full working examples for all supported features see [examples](https://github.com/64bit/async-openai/tree/main/examples) directory in the repository. //! diff --git a/async-openai/src/types/chat.rs b/async-openai/src/types/chat.rs index 138d6852..0110fc1e 100644 --- a/async-openai/src/types/chat.rs +++ b/async-openai/src/types/chat.rs @@ -4,7 +4,7 @@ use derive_builder::Builder; use futures::Stream; use serde::{Deserialize, Serialize}; -use crate::error::OpenAIError; +use crate::{error::OpenAIError, types::RequestForStream}; #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] #[serde(untagged)] @@ -635,6 +635,16 @@ pub struct CreateChatCompletionRequest { pub functions: Option>, } +impl RequestForStream for CreateChatCompletionRequest { + fn is_request_for_stream(&self) -> bool { + self.stream == Some(true) + } + + fn set_request_for_stream(&mut self, stream: bool) { + self.stream = Some(stream) + } +} + /// Options for streaming response. Only set this when you set `stream: true`. #[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq)] pub struct ChatCompletionStreamOptions { diff --git a/async-openai/src/types/mod.rs b/async-openai/src/types/mod.rs index fdc0f51f..d47cc28b 100644 --- a/async-openai/src/types/mod.rs +++ b/async-openai/src/types/mod.rs @@ -72,3 +72,8 @@ impl From for OpenAIError { OpenAIError::InvalidArgument(value.to_string()) } } + +pub trait RequestForStream { + fn is_request_for_stream(&self) -> bool; + fn set_request_for_stream(&mut self, stream: bool); +} diff --git a/async-openai/tests/extensible.rs b/async-openai/tests/extensible.rs new file mode 100644 index 00000000..3d8e3ddb --- /dev/null +++ b/async-openai/tests/extensible.rs @@ -0,0 +1,71 @@ +use actix_web::{web, App, HttpServer}; +use async_openai::{config::OpenAIConfig, types::RequestForStream, Client}; +use futures::StreamExt; +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Deserialize, Serialize)] +struct Request { + param: String, + stream: Option, +} + +impl RequestForStream for Request { + fn is_request_for_stream(&self) -> bool { + self.stream.unwrap_or(false) + } + + fn set_request_for_stream(&mut self, stream: bool) { + self.stream = Some(stream) + } +} + +#[derive(Debug, Deserialize, PartialEq, Serialize)] +struct Response { + len: usize, +} + +#[tokio::test] +async fn extensible() { + async fn handle(request: web::Json) -> web::Json { + web::Json(Response { + len: request.param.len(), + }) + } + + let server = HttpServer::new(move || { + App::new().configure(|cfg| { + cfg.service(web::resource("/chat/completions").route(web::post().to(handle))); + }) + }) + .disable_signals() + .bind("127.0.0.1:8080") + .unwrap() + .run(); + + tokio::spawn(server); + + let client = Client::with_config(OpenAIConfig::new().with_api_base("http://127.0.0.1:8080")); + + let mut request = Request { + param: "foo".to_string(), + stream: None, + }; + + let response: Response = client.chat().create_ext(request.clone()).await.unwrap(); + assert_eq!(response, Response { len: 3 }); + + request.stream = Some(true); + let response = client + .chat() + .create_stream_ext::<_, Response>(request) + .await + .unwrap() + .next() + .await + .unwrap() + .unwrap_err(); + assert_eq!( + response.to_string(), + "stream failed: Invalid header value: \"application/json\"" + ) +}