Skip to content

Commit

Permalink
Generate extensible versions of methods by allowing generic requests …
Browse files Browse the repository at this point in the history
…and response

Fixes 64bit#280
  • Loading branch information
Sufflope committed Nov 13, 2024
1 parent adaf26e commit d0cad81
Show file tree
Hide file tree
Showing 8 changed files with 168 additions and 6 deletions.
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[workspace]
members = [ "async-openai", "examples/*" ]
members = [ "async-openai", "async-openai-macros", "examples/*" ]
# Only check / build main crates by default (check all with `--workspace`)
default-members = ["async-openai"]
default-members = ["async-openai", "async-openai-macros"]
resolver = "2"
23 changes: 23 additions & 0 deletions async-openai-macros/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
[package]
name = "async-openai-macros"
version = "0.1.0"
authors = ["Jean-Sébastien Bour <[email protected]>"]
categories = ["api-bindings", "web-programming", "asynchronous"]
keywords = ["openai", "async", "openapi", "ai"]
description = "Procedural macros for async-openai"
edition = "2021"
rust-version = "1.70"
license = "MIT"
readme = "README.md"
homepage = "https://github.com/64bit/async-openai"
repository = "https://github.com/64bit/async-openai"

[lib]
proc-macro = true

[dependencies]
darling = "0.20"
itertools = "0.13"
proc-macro2 = "1"
quote = "1"
syn = "2"
26 changes: 26 additions & 0 deletions async-openai-macros/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
<div align="center">
<a href="https://docs.rs/async-openai-macros">
<img width="50px" src="https://raw.githubusercontent.com/64bit/async-openai/assets/create-image-b64-json/img-1.png" />
</a>
</div>
<h1 align="center"> async-openai-macros </h1>
<p align="center"> Procedural macros for async-openai </p>
<div align="center">
<a href="https://crates.io/crates/async-openai-macros">
<img src="https://img.shields.io/crates/v/async-openai-macros.svg" />
</a>
<a href="https://docs.rs/async-openai-macros">
<img src="https://docs.rs/async-openai-macros/badge.svg" />
</a>
</div>
<div align="center">
<sub>Logo created by this <a href="https://github.com/64bit/async-openai/tree/main/examples/create-image-b64-json">repo itself</a></sub>
</div>

## Overview

This crate contains the procedural macros for `async-openai`. It is not meant to be used directly.

## License

This project is licensed under [MIT license](https://github.com/64bit/async-openai/blob/main/LICENSE).
98 changes: 98 additions & 0 deletions async-openai-macros/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
use darling::{ast::NestedMeta, FromMeta};
use itertools::{Either, Itertools};
use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use syn::{parse2, parse_macro_input, Expr, FnArg, ItemFn, Meta, MetaList};

#[proc_macro_attribute]
pub fn extensible(
_: proc_macro::TokenStream,
item: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
let item = parse_macro_input!(item as ItemFn);
extensible_impl(item)
.unwrap_or_else(syn::Error::into_compile_error)
.into()
}

fn extensible_impl(mut item: ItemFn) -> syn::Result<TokenStream> {
// Prepare a generic method with a different name
let mut extension = item.clone();
extension.sig.ident = format_ident!("{}_ext", extension.sig.ident);

// Remove our attributes from original method arguments
for input in &mut item.sig.inputs {
match input {
FnArg::Receiver(_) => (),
FnArg::Typed(arg) => arg.attrs.retain(|attr| match &attr.meta {
Meta::List(meta) => !attr_is_ours(meta),
_ => true,
}),
}
}

// Gather request parameters that must be replaced by generics and their optional bounds
let mut i = 0;
let generics = extension
.sig
.inputs
.iter_mut()
.filter_map(|input| match input {
FnArg::Receiver(_) => None,
FnArg::Typed(arg) => {
let (mine, other): (Vec<_>, Vec<_>) =
arg.attrs
.clone()
.into_iter()
.partition_map(|attr| match &attr.meta {
Meta::List(meta) if attr_is_ours(meta) => Either::Left(
Request::from_list(
&NestedMeta::parse_meta_list(meta.tokens.clone()).unwrap(),
)
.unwrap(),
),
_ => Either::Right(attr),
});
let bounds = mine.into_iter().next();
arg.attrs = other;
bounds.map(|b| {
let ident = format_ident!("__EXTENSIBLE_REQUEST_{i}");
arg.ty = Box::new(parse2(quote! { #ident }).unwrap());
i += 1;
(ident, b)
})
}
})
.collect::<Vec<_>>();

// Add generics and their optional bounds to our method's generics
for (ident, Request { bounds }) in generics {
let bounds = bounds.map(|b| quote! { + #b });
extension
.sig
.generics
.params
.push(parse2(quote! { #ident : ::serde::Serialize #bounds })?)
}

// Make the result type generic too
extension.sig.output = parse2(quote! { -> Result<__EXTENSIBLE_RESPONSE, OpenAIError> })?;
extension.sig.generics.params.push(parse2(
quote! { __EXTENSIBLE_RESPONSE: serde::de::DeserializeOwned },
)?);

Ok(quote! {
#item

#extension
})
}

#[derive(FromMeta)]
struct Request {
bounds: Option<Expr>,
}

fn attr_is_ours(meta: &MetaList) -> bool {
meta.path.get_ident().map(|ident| ident.to_string()) == Some("request".to_string())
}
3 changes: 2 additions & 1 deletion async-openai/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ categories = ["api-bindings", "web-programming", "asynchronous"]
keywords = ["openai", "async", "openapi", "ai"]
description = "Rust library for OpenAI"
edition = "2021"
rust-version = "1.65"
rust-version = "1.70"
license = "MIT"
readme = "README.md"
homepage = "https://github.com/64bit/async-openai"
Expand All @@ -25,6 +25,7 @@ native-tls-vendored = ["reqwest/native-tls-vendored"]
realtime = ["dep:tokio-tungstenite"]

[dependencies]
async-openai-macros = { version = "0.1", path = "../async-openai-macros" }
backoff = { version = "0.4.0", features = ["tokio"] }
base64 = "0.22.1"
futures = "0.3.30"
Expand Down
8 changes: 6 additions & 2 deletions async-openai/src/chat.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
use async_openai_macros::extensible;

use crate::{
config::Config,
error::OpenAIError,
types::{
ChatCompletionResponseStream, CreateChatCompletionRequest, CreateChatCompletionResponse,
Streamable,
},
Client,
};
Expand All @@ -20,11 +23,12 @@ impl<'c, C: Config> Chat<'c, C> {
}

/// Creates a model response for the given chat conversation.
#[extensible]
pub async fn create(
&self,
request: CreateChatCompletionRequest,
#[request(bounds = Streamable)] request: CreateChatCompletionRequest,
) -> Result<CreateChatCompletionResponse, OpenAIError> {
if request.stream.is_some() && request.stream.unwrap() {
if request.stream() {
return Err(OpenAIError::InvalidArgument(
"When stream is true, use Chat::create_stream".into(),
));
Expand Down
8 changes: 7 additions & 1 deletion async-openai/src/types/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use derive_builder::Builder;
use futures::Stream;
use serde::{Deserialize, Serialize};

use crate::error::OpenAIError;
use crate::{error::OpenAIError, types::Streamable};

#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
#[serde(untagged)]
Expand Down Expand Up @@ -635,6 +635,12 @@ pub struct CreateChatCompletionRequest {
pub functions: Option<Vec<ChatCompletionFunctions>>,
}

impl Streamable for CreateChatCompletionRequest {
fn stream(&self) -> bool {
self.stream == Some(true)
}
}

/// Options for streaming response. Only set this when you set `stream: true`.
#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq)]
pub struct ChatCompletionStreamOptions {
Expand Down
4 changes: 4 additions & 0 deletions async-openai/src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,7 @@ impl From<UninitializedFieldError> for OpenAIError {
OpenAIError::InvalidArgument(value.to_string())
}
}

pub trait Streamable {
fn stream(&self) -> bool;
}

0 comments on commit d0cad81

Please sign in to comment.