Skip to content

Commit

Permalink
Merge pull request #20 from edgenai/feat/issue10
Browse files Browse the repository at this point in the history
Feat/chat_completions_status
  • Loading branch information
toschoo authored Feb 5, 2024
2 parents 4f2a901 + 1fa4d42 commit 0c5308f
Show file tree
Hide file tree
Showing 7 changed files with 504 additions and 5 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions crates/edgen_server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ hyper = { workspace = true }
hyper-util = { workspace = true }
once_cell = { workspace = true }
pin-project = { workspace = true }
reqwest = { workspace = true }
rubato = "0.14.1"
serde = { workspace = true }
serde_derive = { workspace = true }
serde_json = { workspace = true }
Expand Down
11 changes: 11 additions & 0 deletions crates/edgen_server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ pub mod graceful_shutdown;
mod llm;
mod model;
pub mod openai_shim;
pub mod status;
pub mod util;
mod whisper;

Expand Down Expand Up @@ -168,14 +169,24 @@ async fn start_server(args: &cli::Serve) -> EdgenResult {

async fn run_server(args: &cli::Serve) -> bool {
let http_app = Router::new()
// -- AI endpoints -----------------------------------------------------
// ---- Chat -----------------------------------------------------------
.route(
"/v1/chat/completions",
axum::routing::post(openai_shim::chat_completions),
)
// ---- Audio ----------------------------------------------------------
.route(
"/v1/audio/transcriptions",
axum::routing::post(openai_shim::create_transcription),
)
// -- AI status endpoints ----------------------------------------------
// ---- Chat -----------------------------------------------------------
.route(
"/v1/chat/completions/status",
axum::routing::get(status::chat_completions_status),
)
// -- Miscellaneous services -------------------------------------------
.route("/v1/misc/version", axum::routing::get(misc::edgen_version))
.layer(CorsLayer::permissive());

Expand Down
2 changes: 0 additions & 2 deletions crates/edgen_server/src/misc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ pub struct Version {

/// GET `/v1/version`: returns the current version of edgend.
///
/// [openai]: https://platform.edgen.io/docs/api-reference/version
///
/// The version is returned as json value with major, minor and patch as integer
/// and build as string (which may be empty).
/// For any error, the version endpoint returns "internal server error".
Expand Down
48 changes: 45 additions & 3 deletions crates/edgen_server/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ use serde_derive::Serialize;
use thiserror::Error;
use utoipa::ToSchema;

use crate::status;

#[derive(Serialize, Error, ToSchema, Debug)]
pub enum ModelError {
#[error("the provided model file name does does not exist, or isn't a file: ({0})")]
Expand Down Expand Up @@ -66,6 +68,7 @@ impl Model {
preloaded: false,
}
}

/// Checks if a file of the model is already present locally, and if not, downloads it.
pub async fn preload(&mut self) -> Result<(), ModelError> {
if self.path.is_file() {
Expand All @@ -82,16 +85,55 @@ impl Model {
.build()
.map_err(move |e| ModelError::API(e.to_string()))?;
let api = api.model(self.repo.to_string());
let path = api

// progress observer
let download = hf_hub::Cache::new(self.dir.clone())
.model(self.repo.to_string())
.get(&self.name)
.map_err(move |e| ModelError::API(e.to_string()))?;
.is_none();
let size = self.get_size(&api).await;
let progress_handle =
status::observe_chat_completions_progress(&self.dir, size, download).await;

let name = self.name.clone();
let download_handle = tokio::spawn(async move {
if download {
status::set_chat_completions_download(true).await;
}

let path = api
.get(&name)
.map_err(move |e| ModelError::API(e.to_string()));

self.path = path;
if download {
status::set_chat_completions_progress(100).await;
status::set_chat_completions_download(false).await;
}

return path;
});

let _ = progress_handle.await.unwrap();
let path = download_handle.await.unwrap();

self.path = path?;
self.preloaded = true;

Ok(())
}

// get size of the remote file when we download.
async fn get_size(&self, api: &hf_hub::api::sync::ApiRepo) -> Option<u64> {
let metadata = reqwest::Client::new()
.get(api.url(&self.name))
.header("Content-Range", "bytes 0-0")
.header("Range", "bytes 0-0")
.send()
.await
.unwrap();
return metadata.content_length();
}

/// Returns a [`PathBuf`] pointing to the local model file.
pub fn file_path(&self) -> Result<PathBuf, ModelError> {
if self.preloaded {
Expand Down
Loading

0 comments on commit 0c5308f

Please sign in to comment.