Skip to content

Commit

Permalink
In memory files (#154)
Browse files Browse the repository at this point in the history
* Added ability to use in-memory files (Bytes, vec[u8])

* Removed unnecessary trait impls

* Polished example
  • Loading branch information
prosammer authored Nov 25, 2023
1 parent 923d03a commit 136a463
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 39 deletions.
46 changes: 33 additions & 13 deletions async-openai/src/types/impls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,12 @@ use crate::{
download::{download_url, save_b64},
error::OpenAIError,
util::{create_all_dir, create_file_part},
types::InputSource,
};

use bytes::Bytes;
use reqwest::Body;

use super::{
AudioInput, AudioResponseFormat, ChatCompletionFunctionCall, ChatCompletionFunctions,
ChatCompletionNamedToolChoice, ChatCompletionRequestAssistantMessage,
Expand Down Expand Up @@ -91,29 +95,45 @@ impl_default!(Prompt);
impl_default!(ModerationInput);
impl_default!(EmbeddingInput);

macro_rules! file_path_input {

impl Default for InputSource {
fn default() -> Self {
InputSource::Path {
path: PathBuf::new(),
}
}
}

macro_rules! impl_input {
($for_typ:ty) => {
impl $for_typ {
pub fn new<P: AsRef<Path>>(path: P) -> Self {
pub fn from_bytes(filename: String, bytes: Bytes) -> Self {
Self {
source: InputSource::Bytes { filename, bytes },
}
}

pub fn from_vec_u8(filename: String, vec: Vec<u8>) -> Self {
Self {
path: PathBuf::from(path.as_ref()),
source: InputSource::VecU8 { filename, vec },
}
}
}

impl<P: AsRef<Path>> From<P> for $for_typ {
fn from(path: P) -> Self {
let path_buf = path.as_ref().to_path_buf();
Self {
path: PathBuf::from(path.as_ref()),
source: InputSource::Path { path: path_buf },
}
}
}
};
}

file_path_input!(ImageInput);
file_path_input!(FileInput);
file_path_input!(AudioInput);
impl_input!(AudioInput);
impl_input!(FileInput);
impl_input!(ImageInput);

impl Display for ImageSize {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Expand Down Expand Up @@ -574,7 +594,7 @@ impl async_convert::TryFrom<CreateTranscriptionRequest> for reqwest::multipart::
type Error = OpenAIError;

async fn try_from(request: CreateTranscriptionRequest) -> Result<Self, Self::Error> {
let audio_part = create_file_part(&request.file.path).await?;
let audio_part = create_file_part(request.file.source).await?;

let mut form = reqwest::multipart::Form::new()
.part("file", audio_part)
Expand All @@ -600,7 +620,7 @@ impl async_convert::TryFrom<CreateTranslationRequest> for reqwest::multipart::Fo
type Error = OpenAIError;

async fn try_from(request: CreateTranslationRequest) -> Result<Self, Self::Error> {
let audio_part = create_file_part(&request.file.path).await?;
let audio_part = create_file_part(request.file.source).await?;

let mut form = reqwest::multipart::Form::new()
.part("file", audio_part)
Expand All @@ -626,14 +646,14 @@ impl async_convert::TryFrom<CreateImageEditRequest> for reqwest::multipart::Form
type Error = OpenAIError;

async fn try_from(request: CreateImageEditRequest) -> Result<Self, Self::Error> {
let image_part = create_file_part(&request.image.path).await?;
let image_part = create_file_part(request.image.source).await?;

let mut form = reqwest::multipart::Form::new()
.part("image", image_part)
.text("prompt", request.prompt);

if let Some(mask) = request.mask {
let mask_part = create_file_part(&mask.path).await?;
let mask_part = create_file_part(mask.source).await?;
form = form.part("mask", mask_part);
}

Expand Down Expand Up @@ -668,7 +688,7 @@ impl async_convert::TryFrom<CreateImageVariationRequest> for reqwest::multipart:
type Error = OpenAIError;

async fn try_from(request: CreateImageVariationRequest) -> Result<Self, Self::Error> {
let image_part = create_file_part(&request.image.path).await?;
let image_part = create_file_part(request.image.source).await?;

let mut form = reqwest::multipart::Form::new().part("image", image_part);

Expand Down Expand Up @@ -703,7 +723,7 @@ impl async_convert::TryFrom<CreateFileRequest> for reqwest::multipart::Form {
type Error = OpenAIError;

async fn try_from(request: CreateFileRequest) -> Result<Self, Self::Error> {
let file_part = create_file_part(&request.file.path).await?;
let file_part = create_file_part(request.file.source).await?;
let form = reqwest::multipart::Form::new()
.part("file", file_part)
.text("purpose", request.purpose);
Expand Down
14 changes: 11 additions & 3 deletions async-openai/src/types/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::{collections::HashMap, path::PathBuf, pin::Pin};
use bytes::Bytes;
use derive_builder::Builder;
use futures::Stream;
use reqwest::{Body};
use serde::{Deserialize, Serialize};

use crate::error::OpenAIError;
Expand Down Expand Up @@ -392,9 +393,16 @@ pub struct ImagesResponse {
pub data: Vec<std::sync::Arc<Image>>,
}

#[derive(Debug, Clone, PartialEq)]
pub enum InputSource {
Path { path: PathBuf },
Bytes { filename: String, bytes: Bytes },
VecU8 { filename: String, vec: Vec<u8> },
}

#[derive(Debug, Default, Clone, PartialEq)]
pub struct ImageInput {
pub path: PathBuf,
pub source: InputSource,
}

#[derive(Debug, Clone, Default, Builder, PartialEq)]
Expand Down Expand Up @@ -583,7 +591,7 @@ pub struct CreateModerationResponse {

#[derive(Debug, Default, Clone, PartialEq)]
pub struct FileInput {
pub path: PathBuf,
pub source: InputSource,
}

#[derive(Debug, Default, Clone, Builder, PartialEq)]
Expand Down Expand Up @@ -1597,7 +1605,7 @@ pub struct CreateChatCompletionStreamResponse {

#[derive(Debug, Default, Clone, PartialEq)]
pub struct AudioInput {
pub path: PathBuf,
pub source: InputSource,
}

#[derive(Debug, Serialize, Default, Clone, Copy, PartialEq)]
Expand Down
58 changes: 35 additions & 23 deletions async-openai/src/util.rs
Original file line number Diff line number Diff line change
@@ -1,37 +1,49 @@
use std::path::Path;

use reqwest::Body;
use tokio::fs::File;
use tokio_util::codec::{BytesCodec, FramedRead};

use crate::error::OpenAIError;
use crate::types::InputSource;

pub(crate) async fn file_stream_body<P: AsRef<Path>>(path: P) -> Result<Body, OpenAIError> {
let file = tokio::fs::File::open(path.as_ref())
.await
.map_err(|e| OpenAIError::FileReadError(e.to_string()))?;
let stream = FramedRead::new(file, BytesCodec::new());
let body = Body::wrap_stream(stream);
pub(crate) async fn file_stream_body(source: InputSource) -> Result<Body, OpenAIError> {
let body = match source {
InputSource::Path{ path } => {
let file = File::open(path)
.await
.map_err(|e| OpenAIError::FileReadError(e.to_string()))?;
let stream = FramedRead::new(file, BytesCodec::new());
Body::wrap_stream(stream)
}
_ => return Err(OpenAIError::FileReadError("Cannot create stream from non-file source".to_string())),
};
Ok(body)
}

/// Creates the part for the given image file for multipart upload.
pub(crate) async fn create_file_part<P: AsRef<Path>>(
path: P,
/// Creates the part for the given file for multipart upload.
pub(crate) async fn create_file_part(
source: InputSource,
) -> Result<reqwest::multipart::Part, OpenAIError> {
let file_name = path
.as_ref()
.file_name()
.ok_or_else(|| {
OpenAIError::FileReadError(format!(
"cannot extract file name from {}",
path.as_ref().display()
))
})?
.to_str()
.unwrap()
.to_string();

let file_part = reqwest::multipart::Part::stream(file_stream_body(path.as_ref()).await?)
let (stream, file_name) = match source {
InputSource::Path{ path } => {
let file_name = path.file_name()
.ok_or_else(|| OpenAIError::FileReadError(format!("cannot extract file name from {}", path.display())))?
.to_str()
.unwrap()
.to_string();

(file_stream_body(InputSource::Path{ path }).await?, file_name)
}
InputSource::Bytes{ filename, bytes } => {
(Body::from(bytes), filename)
}
InputSource::VecU8{ filename, vec } => {
(Body::from(vec), filename)
}
};

let file_part = reqwest::multipart::Part::stream(stream)
.file_name(file_name)
.mime_str("application/octet-stream")
.unwrap();
Expand Down
10 changes: 10 additions & 0 deletions examples/in-memory-file/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
[package]
name = "in-memory-file"
version = "0.1.0"
edition = "2021"
publish = false

[dependencies]
async-openai = {path = "../../async-openai"}
tokio = { version = "1.25.0", features = ["full"] }
bytes = "1.5.0"
3 changes: 3 additions & 0 deletions examples/in-memory-file/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
### Output

> Hello, I'm David Attenborough. I'm speaking to you from my home because, like many of you, I've spent much of the last year indoors, away from friends, family, and access to the natural world. It's been a challenging few months for many of us, but the reaction to these extraordinary times has proved that when we work together, there is no limit to what we can accomplish. Today, we are experiencing environmental change as never before. And the need to take action has never been more urgent. This year, the world will gather in Glasgow for the United Nations Climate Change Conference. It's a crucial moment in our history. This could be a year for positive change for ourselves, for our planet, and for the wonderful creatures with which we share it. A year the world could remember proudly and say, we made a difference. As we make our New Year's resolutions, let's think about what each of us could do. What positive changes could we make in our own lives? So here's to a brighter year ahead. Let's make 2021 a happy New Year for all the inhabitants of our perfect planet.
Binary file not shown.
28 changes: 28 additions & 0 deletions examples/in-memory-file/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
use async_openai::{types::CreateTranscriptionRequestArgs, Client};
use std::error::Error;
use std::fs;
use async_openai::types::AudioInput;

#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
let filename = "A Message From Sir David Attenborough A Perfect Planet BBC Earth_320kbps.mp3".to_string();
let file_contents = fs::read(format!("./audio/{}", filename))?;

let bytes = bytes::Bytes::from(file_contents);

// To pass in in-memory files, you can pass either bytes::Bytes or vec[u8] to AudioInputs, FileInputs, and ImageInputs.
let audio_input = AudioInput::from_bytes(filename, bytes);

let client = Client::new();
// Credits and Source for audio: https://www.youtube.com/watch?v=oQnDVqGIv4s
let request = CreateTranscriptionRequestArgs::default()
.file(audio_input)
.model("whisper-1")
.build()?;

let response = client.audio().transcribe(request).await?;

println!("{}", response.text);

Ok(())
}

0 comments on commit 136a463

Please sign in to comment.