From 3d4c74935ba9fccb04d16bfd1a1d078ffac5e0ed Mon Sep 17 00:00:00 2001 From: Tobias Schoofs Date: Tue, 26 Mar 2024 13:16:47 +0000 Subject: [PATCH] [chore/issue127] embeddings status --- crates/edgen_server/src/lib.rs | 3 + crates/edgen_server/src/routes.rs | 3 + crates/edgen_server/src/status.rs | 102 ++++++++++++++++++++++++++++++ 3 files changed, 108 insertions(+) diff --git a/crates/edgen_server/src/lib.rs b/crates/edgen_server/src/lib.rs index 3301446..5d9ac9a 100644 --- a/crates/edgen_server/src/lib.rs +++ b/crates/edgen_server/src/lib.rs @@ -209,6 +209,9 @@ async fn run_server(args: &cli::Serve) -> Result { ) .await; + status::set_embeddings_active_model(&SETTINGS.read().await.read().await.embeddings_model_name) + .await; + let http_app = routes::routes() .layer(CorsLayer::permissive()) .layer(DefaultBodyLimit::max( diff --git a/crates/edgen_server/src/routes.rs b/crates/edgen_server/src/routes.rs index f9f1fe5..ceb432e 100644 --- a/crates/edgen_server/src/routes.rs +++ b/crates/edgen_server/src/routes.rs @@ -49,6 +49,9 @@ pub fn routes() -> Router { "/v1/audio/transcriptions/status", get(status::audio_transcriptions_status), ) + // ---- Embeddings ----------------------------------------------------- + .route("/v1/embeddings/status", get(status::embeddings_status)) + // -- Model Manager ---------------------------------------------------- // -- Model Manager ---------------------------------------------------- .route("/v1/models", get(model_man::list_models)) .route("/v1/models/:model", get(model_man::retrieve_model)) diff --git a/crates/edgen_server/src/status.rs b/crates/edgen_server/src/status.rs index 9e00e3b..1ac15b8 100644 --- a/crates/edgen_server/src/status.rs +++ b/crates/edgen_server/src/status.rs @@ -42,6 +42,15 @@ pub async fn audio_transcriptions_status() -> Response { Json(state.clone()).into_response() } +/// GET `/v1/embeddings`: returns the current status of the /embeddings endpoint. +/// +/// The status is returned as json value AIStatus. +/// For any error, the version endpoint returns "internal server error". +pub async fn embeddings_status() -> Response { + let state = get_embeddings_status().read().await; + Json(state.clone()).into_response() +} + /// Current Endpoint status. #[derive(ToSchema, Deserialize, Serialize, Clone, Debug, PartialEq, Eq)] pub struct AIStatus { @@ -72,6 +81,7 @@ static AISTATES: Lazy = Lazy::new(Default::default); const EP_CHAT_COMPLETIONS: usize = 0; const EP_AUDIO_TRANSCRIPTIONS: usize = 1; +const EP_EMBEDDINGS: usize = 2; const MAX_ERRORS: usize = 32; @@ -87,6 +97,12 @@ pub fn get_audio_transcriptions_status() -> &'static RwLock { get_status(EP_AUDIO_TRANSCRIPTIONS) } +/// Get a protected embeddings status. +/// Call read() or write() on the returned value to get either read or write access. +pub fn get_embeddings_status() -> &'static RwLock { + get_status(EP_EMBEDDINGS) +} + fn get_status(idx: usize) -> &'static RwLock { &AISTATES.endpoints[idx] } @@ -101,6 +117,11 @@ pub async fn reset_audio_transcriptions_status() { reset_status(EP_AUDIO_TRANSCRIPTIONS).await; } +/// Reset the embeddings status to its defaults +pub async fn reset_embeddings_status() { + reset_status(EP_EMBEDDINGS).await; +} + async fn reset_status(idx: usize) { let mut status = get_status(idx).write().await; *status = AIStatus::default(); @@ -116,6 +137,11 @@ pub async fn set_audio_transcriptions_active_model(model: &str) { set_active_model(EP_AUDIO_TRANSCRIPTIONS, model).await; } +/// Set embeddings active model +pub async fn set_embeddings_active_model(model: &str) { + set_active_model(EP_EMBEDDINGS, model).await; +} + async fn set_active_model(idx: usize, model: &str) { let mut state = get_status(idx).write().await; state.active_model = model.to_string(); @@ -141,6 +167,16 @@ pub async fn set_audio_transcriptions_download(ongoing: bool) { set_download(EP_AUDIO_TRANSCRIPTIONS, ongoing).await; } +/// Set embeddings download ongoing +pub async fn set_embeddings_download(ongoing: bool) { + if ongoing { + info!("starting embeddings model download"); + } else { + info!("embeddings model download finished"); + }; + set_download(EP_EMBEDDINGS, ongoing).await; +} + async fn set_download(idx: usize, ongoing: bool) { let mut state = get_status(idx).write().await; state.download_ongoing = ongoing; @@ -156,6 +192,11 @@ pub async fn set_audio_transcriptions_progress(progress: u64) { set_progress(EP_AUDIO_TRANSCRIPTIONS, progress).await; } +/// Set embeddings download progress +pub async fn set_embeddings_progress(progress: u64) { + set_progress(EP_EMBEDDINGS, progress).await; +} + async fn set_progress(idx: usize, progress: u64) { let mut state = get_status(idx).write().await; state.download_progress = progress; @@ -179,6 +220,15 @@ pub async fn observe_audio_transcriptions_progress( observe_progress(EP_AUDIO_TRANSCRIPTIONS, datadir, size, download).await } +/// Observe embeddings download progress +pub async fn observe_embeddings_progress( + datadir: &PathBuf, + size: Option, + download: bool, +) -> tokio::task::JoinHandle<()> { + observe_progress(EP_EMBEDDINGS, datadir, size, download).await +} + /// Add an error to the last errors in chat completions pub async fn add_chat_completions_error(e: E) where @@ -217,6 +267,7 @@ impl Default for AIStates { endpoints: vec![ RwLock::new(Default::default()), RwLock::new(Default::default()), + RwLock::new(Default::default()), ], } } @@ -603,4 +654,55 @@ mod tests { assert!(response.text().len() > 0); assert_eq!(response.json::().active_model, model); } + + #[tokio::test] + async fn test_embeddings_status() { + reset_embeddings_status().await; + + // default + let mut expected = AIStatus::default(); + + { + let status = get_embeddings_status().read().await; + assert_eq!(*status, AIStatus::default()); + } + + // download ongoing + expected.download_ongoing = true; + set_embeddings_download(true).await; + + { + let status = get_embeddings_status().read().await; + assert_eq!(*status, expected); + } + + // download progress + expected.download_progress = 42; + set_embeddings_progress(42).await; + + { + let status = get_embeddings_status().read().await; + assert_eq!(*status, expected); + } + + // axum router + let router = Router::new().route("/v1/embeddings/status", get(embeddings_status)); + + let server = TestServer::new(router).expect("cannot instantiate TestServer"); + + let response = server.get("/v1/embeddings/status").await; + + response.assert_status_ok(); + assert!(response.text().len() > 0); + assert_eq!(response.json::().active_model, "unknown"); + + let model = "shes-a-model-and-shes-looking-good".to_string(); + set_embeddings_active_model(&model).await; + + let response = server.get("/v1/embeddings/status").await; + + response.assert_status_ok(); + assert!(response.text().len() > 0); + assert_eq!(response.json::().active_model, model); + } }