Skip to content

Commit

Permalink
Generate extensible versions of methods by allowing generic requests …
Browse files Browse the repository at this point in the history
…and response

Fixes 64bit#280
  • Loading branch information
Sufflope committed Nov 14, 2024
1 parent adaf26e commit de31d8c
Show file tree
Hide file tree
Showing 11 changed files with 354 additions and 9 deletions.
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[workspace]
members = [ "async-openai", "examples/*" ]
members = [ "async-openai", "async-openai-macros", "examples/*" ]
# Only check / build main crates by default (check all with `--workspace`)
default-members = ["async-openai"]
default-members = ["async-openai", "async-openai-macros"]
resolver = "2"
23 changes: 23 additions & 0 deletions async-openai-macros/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
[package]
name = "async-openai-macros"
version = "0.1.0"
authors = ["Jean-Sébastien Bour <[email protected]>"]
categories = ["api-bindings", "web-programming", "asynchronous"]
keywords = ["openai", "async", "openapi", "ai"]
description = "Procedural macros for async-openai"
edition = "2021"
rust-version = "1.70"
license = "MIT"
readme = "README.md"
homepage = "https://github.com/64bit/async-openai"
repository = "https://github.com/64bit/async-openai"

[lib]
proc-macro = true

[dependencies]
darling = "0.20"
itertools = "0.13"
proc-macro2 = "1"
quote = "1"
syn = "2"
26 changes: 26 additions & 0 deletions async-openai-macros/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
<div align="center">
<a href="https://docs.rs/async-openai-macros">
<img width="50px" src="https://raw.githubusercontent.com/64bit/async-openai/assets/create-image-b64-json/img-1.png" />
</a>
</div>
<h1 align="center"> async-openai-macros </h1>
<p align="center"> Procedural macros for async-openai </p>
<div align="center">
<a href="https://crates.io/crates/async-openai-macros">
<img src="https://img.shields.io/crates/v/async-openai-macros.svg" />
</a>
<a href="https://docs.rs/async-openai-macros">
<img src="https://docs.rs/async-openai-macros/badge.svg" />
</a>
</div>
<div align="center">
<sub>Logo created by this <a href="https://github.com/64bit/async-openai/tree/main/examples/create-image-b64-json">repo itself</a></sub>
</div>

## Overview

This crate contains the procedural macros for `async-openai`. It is not meant to be used directly.

## License

This project is licensed under [MIT license](https://github.com/64bit/async-openai/blob/main/LICENSE).
107 changes: 107 additions & 0 deletions async-openai-macros/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
use darling::{ast::NestedMeta, FromMeta};
use itertools::{Either, Itertools};
use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use syn::{parse2, parse_macro_input, Expr, FnArg, ItemFn, Meta, MetaList};

#[proc_macro_attribute]
pub fn extensible(
_: proc_macro::TokenStream,
item: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
let item = parse_macro_input!(item as ItemFn);
extensible_impl(item)
.unwrap_or_else(syn::Error::into_compile_error)
.into()
}

fn extensible_impl(mut item: ItemFn) -> syn::Result<TokenStream> {
// Stream variants use a special result type
let is_stream = item.sig.ident.to_string().ends_with("_stream");

// Prepare a generic method with a different name
let mut extension = item.clone();
extension.sig.ident = format_ident!("{}_ext", extension.sig.ident);

// Remove our attributes from original method arguments
for input in &mut item.sig.inputs {
match input {
FnArg::Receiver(_) => (),
FnArg::Typed(arg) => arg.attrs.retain(|attr| match &attr.meta {
Meta::List(meta) => !attr_is_ours(meta),
_ => true,
}),
}
}

// Gather request parameters that must be replaced by generics and their optional bounds
let mut i = 0;
let generics = extension
.sig
.inputs
.iter_mut()
.filter_map(|input| match input {
FnArg::Receiver(_) => None,
FnArg::Typed(arg) => {
let (mine, other): (Vec<_>, Vec<_>) =
arg.attrs
.clone()
.into_iter()
.partition_map(|attr| match &attr.meta {
Meta::List(meta) if attr_is_ours(meta) => Either::Left(
Request::from_list(
&NestedMeta::parse_meta_list(meta.tokens.clone()).unwrap(),
)
.unwrap(),
),
_ => Either::Right(attr),
});
let bounds = mine.into_iter().next();
arg.attrs = other;
bounds.map(|b| {
let ident = format_ident!("__EXTENSIBLE_REQUEST_{i}");
arg.ty = Box::new(parse2(quote! { #ident }).unwrap());
i += 1;
(ident, b)
})
}
})
.collect::<Vec<_>>();

// Add generics and their optional bounds to our method's generics
for (ident, Request { bounds }) in generics {
let bounds = bounds.map(|b| quote! { + #b });
extension
.sig
.generics
.params
.push(parse2(quote! { #ident : ::serde::Serialize #bounds })?)
}

// Make the result type generic too
let result = if is_stream {
quote! { std::pin::Pin<Box<dyn futures::Stream<Item = Result<__EXTENSIBLE_RESPONSE, OpenAIError>> + Send>>}
} else {
quote! { __EXTENSIBLE_RESPONSE }
};
extension.sig.output = parse2(quote! { -> Result<#result, OpenAIError> })?;
let send_and_static = is_stream.then_some(quote! { + Send + 'static });
extension.sig.generics.params.push(parse2(
quote! { __EXTENSIBLE_RESPONSE: serde::de::DeserializeOwned #send_and_static },
)?);

Ok(quote! {
#item

#extension
})
}

#[derive(FromMeta)]
struct Request {
bounds: Option<Expr>,
}

fn attr_is_ours(meta: &MetaList) -> bool {
meta.path.get_ident().map(|ident| ident.to_string()) == Some("request".to_string())
}
4 changes: 3 additions & 1 deletion async-openai/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ categories = ["api-bindings", "web-programming", "asynchronous"]
keywords = ["openai", "async", "openapi", "ai"]
description = "Rust library for OpenAI"
edition = "2021"
rust-version = "1.65"
rust-version = "1.70"
license = "MIT"
readme = "README.md"
homepage = "https://github.com/64bit/async-openai"
Expand All @@ -25,6 +25,7 @@ native-tls-vendored = ["reqwest/native-tls-vendored"]
realtime = ["dep:tokio-tungstenite"]

[dependencies]
async-openai-macros = { version = "0.1", path = "../async-openai-macros" }
backoff = { version = "0.4.0", features = ["tokio"] }
base64 = "0.22.1"
futures = "0.3.30"
Expand All @@ -50,6 +51,7 @@ eventsource-stream = "0.2.3"
tokio-tungstenite = { version = "0.24.0", optional = true, default-features = false }

[dev-dependencies]
actix-web = "4"
tokio-test = "0.4.4"

[package.metadata.docs.rs]
Expand Down
50 changes: 50 additions & 0 deletions async-openai/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
- Requests (except SSE streaming) including form submissions are retried with exponential backoff when [rate limited](https://platform.openai.com/docs/guides/rate-limits).
- Ergonomic builder pattern for all request objects.
- Microsoft Azure OpenAI Service (only for APIs matching OpenAI spec)
- Extensible to other providers via extension methods where you provide the request/response types

## Usage

Expand Down Expand Up @@ -108,6 +109,55 @@ async fn main() -> Result<(), Box<dyn Error>> {
<sub>Scaled up for README, actual size 256x256</sub>
</div>

## Other providers

Alternative OpenAI providers that provide the same endpoints, but with different requests/responses (e.g. Azure OpenAPI allows applying content filters on chat completion, and get the results as an additional field in the response), are supported through `_ext` methods (currently, only for the chat completion API) which allow you to "bring your own types":

```rust
use std::error::Error;

use async_openai::{config::AzureConfig, types::RequestForStream, Client};
use serde::{Deserialize, Serialize};

#[derive(Default, Serialize)]
struct AzureRequest {
specific_azure_field: String,
stream: Option<bool>,
}

impl RequestForStream for AzureRequest {
fn is_request_for_stream(&self) -> bool {
self.stream == Some(true)
}

fn set_request_for_stream(&mut self, stream: bool) {
self.stream = Some(stream)
}
}

#[derive(Deserialize)]
struct AzureResponse {
specific_azure_result: String,
}

#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
let client = Client::with_config(AzureConfig::new());

let request = AzureRequest::default();

let response: AzureResponse = client
.chat()
// Use the extensible method which allows you to bring your own types
.create_ext(request)
.await?;

println!("Specific azure result: {}", response.specific_azure_result);

Ok(())
}
```

## Contributing

Thank you for taking the time to contribute and improve the project. I'd be happy to have you!
Expand Down
15 changes: 10 additions & 5 deletions async-openai/src/chat.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
use async_openai_macros::extensible;

use crate::{
config::Config,
error::OpenAIError,
types::{
ChatCompletionResponseStream, CreateChatCompletionRequest, CreateChatCompletionResponse,
RequestForStream,
},
Client,
};
Expand All @@ -20,11 +23,12 @@ impl<'c, C: Config> Chat<'c, C> {
}

/// Creates a model response for the given chat conversation.
#[extensible]
pub async fn create(
&self,
request: CreateChatCompletionRequest,
#[request(bounds = RequestForStream)] request: CreateChatCompletionRequest,
) -> Result<CreateChatCompletionResponse, OpenAIError> {
if request.stream.is_some() && request.stream.unwrap() {
if request.is_request_for_stream() {
return Err(OpenAIError::InvalidArgument(
"When stream is true, use Chat::create_stream".into(),
));
Expand All @@ -37,17 +41,18 @@ impl<'c, C: Config> Chat<'c, C> {
/// partial message deltas will be sent, like in ChatGPT. Tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream terminated by a `data: [DONE]` message.
///
/// [ChatCompletionResponseStream] is a parsed SSE stream until a \[DONE\] is received from server.
#[extensible]
pub async fn create_stream(
&self,
mut request: CreateChatCompletionRequest,
#[request(bounds = RequestForStream)] mut request: CreateChatCompletionRequest,
) -> Result<ChatCompletionResponseStream, OpenAIError> {
if request.stream.is_some() && !request.stream.unwrap() {
if !request.is_request_for_stream() {
return Err(OpenAIError::InvalidArgument(
"When stream is false, use Chat::create".into(),
));
}

request.stream = Some(true);
request.set_request_for_stream(true);

Ok(self.client.post_stream("/chat/completions", request).await)
}
Expand Down
46 changes: 46 additions & 0 deletions async-openai/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,52 @@
//! # });
//!```
//!
//! ## Other providers
//!
//! You can use alternative providers that extend OpenAPI specs by using `_ext` extensible methods to bring your own types:
//!
//! ```
//! # tokio_test::block_on(async {
//! use std::error::Error;
//!
//! use async_openai::{config::AzureConfig, types::RequestForStream, Client};
//! use serde::{Deserialize, Serialize};
//!
//! #[derive(Default, Serialize)]
//! struct AzureRequest {
//! specific_azure_field: String,
//! stream: Option<bool>,
//! }
//!
//! impl RequestForStream for AzureRequest {
//! fn is_request_for_stream(&self) -> bool {
//! self.stream == Some(true)
//! }
//!
//! fn set_request_for_stream(&mut self, stream: bool) {
//! self.stream = Some(stream)
//! }
//! }
//!
//! #[derive(Deserialize)]
//! struct AzureResponse {
//! specific_azure_result: String,
//! }
//!
//! let client = Client::with_config(AzureConfig::new());
//!
//! let request = AzureRequest::default();
//!
//! let response: AzureResponse = client
//! .chat()
//! // Use the extensible method which allows you to bring your own types
//! .create_ext(request)
//! .await?;
//!
//! println!("Specific azure result: {}", response.specific_azure_result);
//! # });
//! ```
//!
//! ## Examples
//! For full working examples for all supported features see [examples](https://github.com/64bit/async-openai/tree/main/examples) directory in the repository.
//!
Expand Down
12 changes: 11 additions & 1 deletion async-openai/src/types/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use derive_builder::Builder;
use futures::Stream;
use serde::{Deserialize, Serialize};

use crate::error::OpenAIError;
use crate::{error::OpenAIError, types::RequestForStream};

#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
#[serde(untagged)]
Expand Down Expand Up @@ -635,6 +635,16 @@ pub struct CreateChatCompletionRequest {
pub functions: Option<Vec<ChatCompletionFunctions>>,
}

impl RequestForStream for CreateChatCompletionRequest {
fn is_request_for_stream(&self) -> bool {
self.stream == Some(true)
}

fn set_request_for_stream(&mut self, stream: bool) {
self.stream = Some(stream)
}
}

/// Options for streaming response. Only set this when you set `stream: true`.
#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq)]
pub struct ChatCompletionStreamOptions {
Expand Down
5 changes: 5 additions & 0 deletions async-openai/src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,8 @@ impl From<UninitializedFieldError> for OpenAIError {
OpenAIError::InvalidArgument(value.to_string())
}
}

pub trait RequestForStream {
fn is_request_for_stream(&self) -> bool;
fn set_request_for_stream(&mut self, stream: bool);
}
Loading

0 comments on commit de31d8c

Please sign in to comment.