diff --git a/Cargo.toml b/Cargo.toml index 75ee4f5..b77c345 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members = [ "async-openai", "examples/*" ] +members = [ "async-openai", "async-openai-macros", "examples/*" ] # Only check / build main crates by default (check all with `--workspace`) -default-members = ["async-openai"] +default-members = ["async-openai", "async-openai-macros"] resolver = "2" diff --git a/async-openai-macros/Cargo.toml b/async-openai-macros/Cargo.toml new file mode 100644 index 0000000..756e325 --- /dev/null +++ b/async-openai-macros/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "async-openai-macros" +version = "0.1.0" +authors = ["Jean-Sébastien Bour "] +categories = ["api-bindings", "web-programming", "asynchronous"] +keywords = ["openai", "async", "openapi", "ai"] +description = "Procedural macros for async-openai" +edition = "2021" +rust-version = "1.70" +license = "MIT" +readme = "README.md" +homepage = "https://github.com/64bit/async-openai" +repository = "https://github.com/64bit/async-openai" + +[lib] +proc-macro = true + +[dependencies] +darling = "0.20" +itertools = "0.13" +proc-macro2 = "1" +quote = "1" +syn = "2" diff --git a/async-openai-macros/README.md b/async-openai-macros/README.md new file mode 100644 index 0000000..9a981f4 --- /dev/null +++ b/async-openai-macros/README.md @@ -0,0 +1,26 @@ +
+ + + +
+

async-openai-macros

+

Procedural macros for async-openai

+
+ + + + + + +
+
+Logo created by this repo itself +
+ +## Overview + +This crate contains the procedural macros for `async-openai`. It is not meant to be used directly. + +## License + +This project is licensed under [MIT license](https://github.com/64bit/async-openai/blob/main/LICENSE). diff --git a/async-openai-macros/src/lib.rs b/async-openai-macros/src/lib.rs new file mode 100644 index 0000000..b314fb0 --- /dev/null +++ b/async-openai-macros/src/lib.rs @@ -0,0 +1,98 @@ +use darling::{ast::NestedMeta, FromMeta}; +use itertools::{Either, Itertools}; +use proc_macro2::TokenStream; +use quote::{format_ident, quote}; +use syn::{parse2, parse_macro_input, Expr, FnArg, ItemFn, Meta, MetaList}; + +#[proc_macro_attribute] +pub fn extensible( + _: proc_macro::TokenStream, + item: proc_macro::TokenStream, +) -> proc_macro::TokenStream { + let item = parse_macro_input!(item as ItemFn); + extensible_impl(item) + .unwrap_or_else(syn::Error::into_compile_error) + .into() +} + +fn extensible_impl(mut item: ItemFn) -> syn::Result { + // Prepare a generic method with a different name + let mut extension = item.clone(); + extension.sig.ident = format_ident!("{}_ext", extension.sig.ident); + + // Remove our attributes from original method arguments + for input in &mut item.sig.inputs { + match input { + FnArg::Receiver(_) => (), + FnArg::Typed(arg) => arg.attrs.retain(|attr| match &attr.meta { + Meta::List(meta) => !attr_is_ours(meta), + _ => true, + }), + } + } + + // Gather request parameters that must be replaced by generics and their optional bounds + let mut i = 0; + let generics = extension + .sig + .inputs + .iter_mut() + .filter_map(|input| match input { + FnArg::Receiver(_) => None, + FnArg::Typed(arg) => { + let (mine, other): (Vec<_>, Vec<_>) = + arg.attrs + .clone() + .into_iter() + .partition_map(|attr| match &attr.meta { + Meta::List(meta) if attr_is_ours(meta) => Either::Left( + Request::from_list( + &NestedMeta::parse_meta_list(meta.tokens.clone()).unwrap(), + ) + .unwrap(), + ), + _ => Either::Right(attr), + }); + let bounds = mine.into_iter().next(); + arg.attrs = other; + bounds.map(|b| { + let ident = format_ident!("__EXTENSIBLE_REQUEST_{i}"); + arg.ty = Box::new(parse2(quote! { #ident }).unwrap()); + i += 1; + (ident, b) + }) + } + }) + .collect::>(); + + // Add generics and their optional bounds to our method's generics + for (ident, Request { bounds }) in generics { + let bounds = bounds.map(|b| quote! { + #b }); + extension + .sig + .generics + .params + .push(parse2(quote! { #ident : ::serde::Serialize #bounds })?) + } + + // Make the result type generic too + extension.sig.output = parse2(quote! { -> Result<__EXTENSIBLE_RESPONSE, OpenAIError> })?; + extension.sig.generics.params.push(parse2( + quote! { __EXTENSIBLE_RESPONSE: serde::de::DeserializeOwned }, + )?); + + Ok(quote! { + #item + + #extension + }) +} + +#[derive(FromMeta)] +struct Request { + bounds: Option, +} + +fn attr_is_ours(meta: &MetaList) -> bool { + meta.path.get_ident().map(|ident| ident.to_string()) == Some("request".to_string()) +} diff --git a/async-openai/Cargo.toml b/async-openai/Cargo.toml index 5c4cb94..720d625 100644 --- a/async-openai/Cargo.toml +++ b/async-openai/Cargo.toml @@ -6,7 +6,7 @@ categories = ["api-bindings", "web-programming", "asynchronous"] keywords = ["openai", "async", "openapi", "ai"] description = "Rust library for OpenAI" edition = "2021" -rust-version = "1.65" +rust-version = "1.70" license = "MIT" readme = "README.md" homepage = "https://github.com/64bit/async-openai" @@ -25,6 +25,7 @@ native-tls-vendored = ["reqwest/native-tls-vendored"] realtime = ["dep:tokio-tungstenite"] [dependencies] +async-openai-macros = { version = "0.1", path = "../async-openai-macros" } backoff = { version = "0.4.0", features = ["tokio"] } base64 = "0.22.1" futures = "0.3.30" diff --git a/async-openai/src/chat.rs b/async-openai/src/chat.rs index c7f9b96..d78c883 100644 --- a/async-openai/src/chat.rs +++ b/async-openai/src/chat.rs @@ -1,8 +1,11 @@ +use async_openai_macros::extensible; + use crate::{ config::Config, error::OpenAIError, types::{ ChatCompletionResponseStream, CreateChatCompletionRequest, CreateChatCompletionResponse, + Streamable, }, Client, }; @@ -20,11 +23,12 @@ impl<'c, C: Config> Chat<'c, C> { } /// Creates a model response for the given chat conversation. + #[extensible] pub async fn create( &self, - request: CreateChatCompletionRequest, + #[request(bounds = Streamable)] request: CreateChatCompletionRequest, ) -> Result { - if request.stream.is_some() && request.stream.unwrap() { + if request.stream() { return Err(OpenAIError::InvalidArgument( "When stream is true, use Chat::create_stream".into(), )); diff --git a/async-openai/src/types/chat.rs b/async-openai/src/types/chat.rs index 138d685..e2e70ac 100644 --- a/async-openai/src/types/chat.rs +++ b/async-openai/src/types/chat.rs @@ -4,7 +4,7 @@ use derive_builder::Builder; use futures::Stream; use serde::{Deserialize, Serialize}; -use crate::error::OpenAIError; +use crate::{error::OpenAIError, types::Streamable}; #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] #[serde(untagged)] @@ -635,6 +635,12 @@ pub struct CreateChatCompletionRequest { pub functions: Option>, } +impl Streamable for CreateChatCompletionRequest { + fn stream(&self) -> bool { + self.stream == Some(true) + } +} + /// Options for streaming response. Only set this when you set `stream: true`. #[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq)] pub struct ChatCompletionStreamOptions { diff --git a/async-openai/src/types/mod.rs b/async-openai/src/types/mod.rs index fdc0f51..659e9fa 100644 --- a/async-openai/src/types/mod.rs +++ b/async-openai/src/types/mod.rs @@ -72,3 +72,7 @@ impl From for OpenAIError { OpenAIError::InvalidArgument(value.to_string()) } } + +pub trait Streamable { + fn stream(&self) -> bool; +}