Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/chat_completions_status #20

Merged
merged 10 commits into from
Feb 5, 2024
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions crates/edgen_server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ hyper = { workspace = true }
hyper-util = { workspace = true }
once_cell = { workspace = true }
pin-project = { workspace = true }
reqwest = { workspace = true }
rubato = "0.14.1"
serde = { workspace = true }
serde_derive = { workspace = true }
serde_json = { workspace = true }
Expand Down
11 changes: 11 additions & 0 deletions crates/edgen_server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ pub mod graceful_shutdown;
mod llm;
mod model;
pub mod openai_shim;
pub mod status;
pub mod util;
mod whisper;

Expand Down Expand Up @@ -168,14 +169,24 @@ async fn start_server(args: &cli::Serve) -> EdgenResult {

async fn run_server(args: &cli::Serve) -> bool {
let http_app = Router::new()
// -- AI endpoints -----------------------------------------------------
// ---- Chat -----------------------------------------------------------
.route(
"/v1/chat/completions",
axum::routing::post(openai_shim::chat_completions),
)
// ---- Audio ----------------------------------------------------------
.route(
"/v1/audio/transcriptions",
axum::routing::post(openai_shim::create_transcription),
)
// -- AI status endpoints ----------------------------------------------
// ---- Chat -----------------------------------------------------------
.route(
"/v1/chat/completions/status",
axum::routing::get(status::chat_completions_status),
)
// -- Miscellaneous services -------------------------------------------
.route("/v1/misc/version", axum::routing::get(misc::edgen_version))
.layer(CorsLayer::permissive());

Expand Down
2 changes: 0 additions & 2 deletions crates/edgen_server/src/misc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ pub struct Version {

/// GET `/v1/version`: returns the current version of edgend.
///
/// [openai]: https://platform.edgen.io/docs/api-reference/version
///
/// The version is returned as json value with major, minor and patch as integer
/// and build as string (which may be empty).
/// For any error, the version endpoint returns "internal server error".
Expand Down
48 changes: 45 additions & 3 deletions crates/edgen_server/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ use serde_derive::Serialize;
use thiserror::Error;
use utoipa::ToSchema;

use crate::status;

#[derive(Serialize, Error, ToSchema, Debug)]
pub enum ModelError {
#[error("the provided model file name does does not exist, or isn't a file: ({0})")]
Expand Down Expand Up @@ -66,6 +68,7 @@ impl Model {
preloaded: false,
}
}

/// Checks if a file of the model is already present locally, and if not, downloads it.
pub async fn preload(&mut self) -> Result<(), ModelError> {
if self.path.is_file() {
Expand All @@ -82,16 +85,55 @@ impl Model {
.build()
.map_err(move |e| ModelError::API(e.to_string()))?;
let api = api.model(self.repo.to_string());
let path = api

// progress observer
let download = hf_hub::Cache::new(self.dir.clone())
.model(self.repo.to_string())
.get(&self.name)
.map_err(move |e| ModelError::API(e.to_string()))?;
.is_none();
let size = self.get_size(&api).await;
let progress_handle =
status::observe_chat_completions_progress(&self.dir, size, download).await;

let name = self.name.clone();
let download_handle = tokio::spawn(async move {
if download {
status::set_chat_completions_download(true).await;
}

let path = api
.get(&name)
.map_err(move |e| ModelError::API(e.to_string()));

self.path = path;
if download {
status::set_chat_completions_progress(100).await;
status::set_chat_completions_download(false).await;
}

return path;
});

let _ = progress_handle.await.unwrap();
let path = download_handle.await.unwrap();

self.path = path?;
self.preloaded = true;

Ok(())
}

// get size of the remote file when we download.
async fn get_size(&self, api: &hf_hub::api::sync::ApiRepo) -> Option<u64> {
let metadata = reqwest::Client::new()
.get(api.url(&self.name))
.header("Content-Range", "bytes 0-0")
.header("Range", "bytes 0-0")
.send()
.await
.unwrap();
return metadata.content_length();
}

/// Returns a [`PathBuf`] pointing to the local model file.
pub fn file_path(&self) -> Result<PathBuf, ModelError> {
if self.preloaded {
Expand Down
Loading
Loading