From 2d883c93df3a9c0bcdf6f98d77bc174c9841c552 Mon Sep 17 00:00:00 2001 From: francis2tm Date: Thu, 1 Feb 2024 12:29:35 +0000 Subject: [PATCH 01/24] docs: fixed URLs in readme --- README.md | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 94ebd7d..d0d35e5 100644 --- a/README.md +++ b/README.md @@ -11,10 +11,10 @@

| - Documentation | - Blog | - Discord | - Roadmap | + Documentation | + Blog | + Discord | + Roadmap |

@@ -81,10 +81,10 @@ Then open your favorite IDE from the shell, and you're ready to go! ## Communication Channels -- [Edgen Discord server](https://discord.gg/MMUcgBtV): Real time discussions with the ⚡Edgen team and other users. -- [GitHub issues](https://github.com/binedge/edgen/issues): Feature requests, bugs. -- [GitHub discussions](https://github.com/binedge/edgen/discussions/): Q&A. -- [Blog](https://binedge.ai): Big announcements. +- [Edgen Discord server](https://discord.gg/QUXbwqdMRs): Real time discussions with the ⚡Edgen team and other users. +- [GitHub issues](https://github.com/edgenai/edgen/issues): Feature requests, bugs. +- [GitHub discussions](https://github.com/edgenai/edgen/discussions/): Q&A. +- [Blog](https://blog.edgen.co): Big announcements. ## Special Thanks From a6d6cd00e0c79d4e096fd1ca82e71d8d46ee75c1 Mon Sep 17 00:00:00 2001 From: Pedro Valente Date: Thu, 1 Feb 2024 12:35:46 +0000 Subject: [PATCH 02/24] updated llama and whisper dependencies --- Cargo.lock | 111 +++++++++++++++++++++++++++++++---------------------- Cargo.toml | 8 +--- 2 files changed, 67 insertions(+), 52 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5c797a5..a70a483 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1129,9 +1129,9 @@ dependencies = [ [[package]] name = "darling" -version = "0.20.3" +version = "0.20.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0209d94da627ab5605dcccf08bb18afa5009cfbef48d8a8b7d7bdbc79be25c5e" +checksum = "fc5d6b04b3fd0ba9926f945895de7d806260a2d7431ba82e7edaecb043c4c6b8" dependencies = [ "darling_core", "darling_macro", @@ -1139,9 +1139,9 @@ dependencies = [ [[package]] name = "darling_core" -version = "0.20.3" +version = "0.20.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "177e3443818124b357d8e76f53be906d60937f0d3a90773a664fa63fa253e621" +checksum = "04e48a959bcd5c761246f5d090ebc2fbf7b9cd527a492b07a67510c108f1e7e3" dependencies = [ "fnv", "ident_case", @@ -1153,9 +1153,9 @@ dependencies = [ [[package]] name = "darling_macro" -version = "0.20.3" +version = "0.20.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "836a9bbc7ad63342d6d6e7b815ccab164bc77a2d95d84bc3117a8c0d5c98e2d5" +checksum = "1d1545d67a2149e1d93b7e5c7752dce5a7426eb5d1357ddcfd89336b94444f77" dependencies = [ "darling_core", "quote", @@ -1416,7 +1416,7 @@ dependencies = [ "tokio", "tokio-stream", "tokio-util", - "toml_edit 0.21.0", + "toml_edit 0.21.1", "tower-http", "tracing", "utoipa", @@ -1441,7 +1441,7 @@ dependencies = [ "cc", "memchr", "rustc_version", - "toml 0.8.8", + "toml 0.8.9", "vswhom", "winreg 0.51.0", ] @@ -2118,7 +2118,7 @@ dependencies = [ "futures-sink", "futures-util", "http 0.2.11", - "indexmap 2.2.1", + "indexmap 2.2.2", "slab", "tokio", "tokio-util", @@ -2137,7 +2137,7 @@ dependencies = [ "futures-sink", "futures-util", "http 1.0.0", - "indexmap 2.2.1", + "indexmap 2.2.2", "slab", "tokio", "tokio-util", @@ -2387,9 +2387,9 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.2" +version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bdea9aac0dbe5a9240d68cfd9501e2db94222c6dc06843e06640b9e07f0fdc67" +checksum = "ca38ef113da30126bbff9cd1705f9273e15d45498615d138b0c20279ac7a76aa" dependencies = [ "bytes", "futures-channel", @@ -2495,9 +2495,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.2.1" +version = "2.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "433de089bd45971eecf4668ee0ee8f4cec17db4f8bd8f7bc3197a6ce37aa7d9b" +checksum = "824b2ae422412366ba479e8111fd301f7b5faece8149317bb81925979a53f520" dependencies = [ "equivalent", "hashbrown 0.14.3", @@ -2741,9 +2741,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.152" +version = "0.2.153" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13e3bf6590cbc649f4d1a3eefc9d5d6eb746f5200ffb04e5e142700b8faa56e7" +checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" [[package]] name = "libloading" @@ -2803,7 +2803,7 @@ checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c" [[package]] name = "llama_cpp" version = "0.3.0" -source = "git+https://github.com/edgenai/llama_cpp-rs?branch=build-overhaul#f59e4cdd27b2ca628df3283d986e83d199a6dbc8" +source = "git+https://github.com/edgenai/llama_cpp-rs?branch=main#d7c895d051780aea216eb81f40b9e849cdb61e3e" dependencies = [ "ctor", "derive_more", @@ -2818,7 +2818,7 @@ dependencies = [ [[package]] name = "llama_cpp_sys" version = "0.3.0" -source = "git+https://github.com/edgenai/llama_cpp-rs?branch=build-overhaul#f59e4cdd27b2ca628df3283d986e83d199a6dbc8" +source = "git+https://github.com/edgenai/llama_cpp-rs?branch=main#d7c895d051780aea216eb81f40b9e849cdb61e3e" dependencies = [ "bindgen", "cc", @@ -3156,6 +3156,12 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-conv" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" + [[package]] name = "num-integer" version = "0.1.45" @@ -3611,7 +3617,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5699cc8a63d1aa2b1ee8e12b9ad70ac790d65788cd36101fa37f87ea46c4cef" dependencies = [ "base64 0.21.7", - "indexmap 2.2.1", + "indexmap 2.2.2", "line-wrap", "quick-xml", "serde", @@ -3959,9 +3965,9 @@ checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" [[package]] name = "reqwest" -version = "0.11.23" +version = "0.11.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37b1ae8d9ac08420c66222fb9096fc5de435c3c48542bc5336c51892cffafb41" +checksum = "c6920094eb85afde5e4a138be3f2de8bbdf28000f0029e72c45025a56b042251" dependencies = [ "base64 0.21.7", "bytes", @@ -3981,9 +3987,11 @@ dependencies = [ "once_cell", "percent-encoding", "pin-project-lite", + "rustls-pemfile", "serde", "serde_json", "serde_urlencoded", + "sync_wrapper", "system-configuration", "tokio", "tokio-native-tls", @@ -4149,6 +4157,15 @@ dependencies = [ "sct", ] +[[package]] +name = "rustls-pemfile" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1c74cae0a4cf6ccbbf5f359f08efdf8ee7e1dc532573bf0db71968cb56b1448c" +dependencies = [ + "base64 0.21.7", +] + [[package]] name = "rustls-webpki" version = "0.101.7" @@ -4344,15 +4361,15 @@ dependencies = [ [[package]] name = "serde_with" -version = "3.5.1" +version = "3.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f5c9fdb6b00a489875b22efd4b78fe2b363b72265cc5f6eb2e2b9ee270e6140c" +checksum = "1b0ed1662c5a68664f45b76d18deb0e234aff37207086803165c961eb695e981" dependencies = [ "base64 0.21.7", "chrono", "hex", "indexmap 1.9.3", - "indexmap 2.2.1", + "indexmap 2.2.2", "serde", "serde_json", "serde_with_macros", @@ -4361,9 +4378,9 @@ dependencies = [ [[package]] name = "serde_with_macros" -version = "3.5.1" +version = "3.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dbff351eb4b33600a2e138dfa0b10b65a238ea8ff8fb2387c422c5022a3e8298" +checksum = "568577ff0ef47b879f736cd66740e022f3672788cdf002a05a4e609ea5a6fb15" dependencies = [ "darling", "proc-macro2", @@ -4377,7 +4394,7 @@ version = "0.9.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "adf8a49373e98a4c5f0ceb5d05aa7c648d75f63774981ed95b7c7443bbd50c6e" dependencies = [ - "indexmap 2.2.1", + "indexmap 2.2.2", "itoa 1.0.10", "ryu", "serde", @@ -4881,7 +4898,7 @@ dependencies = [ "cfg-expr 0.15.6", "heck 0.4.1", "pkg-config", - "toml 0.8.8", + "toml 0.8.9", "version-compare 0.1.1", ] @@ -5218,12 +5235,13 @@ dependencies = [ [[package]] name = "time" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f657ba42c3f86e7680e53c8cd3af8abbe56b5491790b46e22e19c0d57463583e" +checksum = "fe80ced77cbfb4cb91a94bf72b378b4b6791a0d9b7f09d0be747d1bdff4e68bd" dependencies = [ "deranged", "itoa 1.0.10", + "num-conv", "powerfmt", "serde", "time-core", @@ -5238,10 +5256,11 @@ checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" [[package]] name = "time-macros" -version = "0.2.16" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26197e33420244aeb70c3e8c78376ca46571bc4e701e4791c2cd9f57dcb3a43f" +checksum = "7ba3a3ef41e6672a2f0f001392bb5dcd3ff0a9992d618ca761a11c3121547774" dependencies = [ + "num-conv", "time-core", ] @@ -5360,14 +5379,14 @@ dependencies = [ [[package]] name = "toml" -version = "0.8.8" +version = "0.8.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1a195ec8c9da26928f773888e0742ca3ca1040c6cd859c919c9f59c1954ab35" +checksum = "c6a4b9e8023eb94392d3dca65d717c53abc5dad49c07cb65bb8fcd87115fa325" dependencies = [ "serde", "serde_spanned", "toml_datetime", - "toml_edit 0.21.0", + "toml_edit 0.21.1", ] [[package]] @@ -5385,7 +5404,7 @@ version = "0.19.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b5bb770da30e5cbfde35a2d7b9b8a2c4b8ef89548a7a6aeab5c9a576e3e7421" dependencies = [ - "indexmap 2.2.1", + "indexmap 2.2.2", "serde", "serde_spanned", "toml_datetime", @@ -5394,11 +5413,11 @@ dependencies = [ [[package]] name = "toml_edit" -version = "0.21.0" +version = "0.21.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d34d383cd00a163b4a5b85053df514d45bc330f6de7737edfe0a93311d1eaa03" +checksum = "6a8534fd7f78b5405e860340ad6575217ce99f38d4d5c8f2442cb5ecb50090e1" dependencies = [ - "indexmap 2.2.1", + "indexmap 2.2.2", "serde", "serde_spanned", "toml_datetime", @@ -5676,7 +5695,7 @@ version = "4.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "272ebdfbc99111033031d2f10e018836056e4d2c8e2acda76450ec7974269fa7" dependencies = [ - "indexmap 2.2.1", + "indexmap 2.2.2", "serde", "serde_json", "serde_yaml", @@ -5859,9 +5878,9 @@ checksum = "4d91413b1c31d7539ba5ef2451af3f0b833a005eb27a631cec32bc0635a8602b" [[package]] name = "wasm-streams" -version = "0.3.0" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4609d447824375f43e1ffbc051b50ad8f4b3ae8219680c94452ea05eb240ac7" +checksum = "b65dc4c90b63b118468cf747d8bf3566c1913ef60be765b5730ead9e0a3ba129" dependencies = [ "futures-util", "js-sys", @@ -5995,7 +6014,7 @@ dependencies = [ [[package]] name = "whisper_cpp" version = "0.2.0" -source = "git+https://github.com/edgenai/whisper_cpp-rs?branch=build-overhaul#2809975f809f1ac1dc0cba92a83dea7e2da6e093" +source = "git+https://github.com/edgenai/whisper_cpp-rs?branch=main#7257315a3ab7f462d6f45e0df252dc7d8462fe84" dependencies = [ "derive_more", "thiserror", @@ -6006,7 +6025,7 @@ dependencies = [ [[package]] name = "whisper_cpp_sys" version = "0.2.0" -source = "git+https://github.com/edgenai/whisper_cpp-rs?branch=build-overhaul#2809975f809f1ac1dc0cba92a83dea7e2da6e093" +source = "git+https://github.com/edgenai/whisper_cpp-rs?branch=main#7257315a3ab7f462d6f45e0df252dc7d8462fe84" dependencies = [ "bindgen", "cmake", @@ -6381,9 +6400,9 @@ checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04" [[package]] name = "winnow" -version = "0.5.35" +version = "0.5.36" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1931d78a9c73861da0134f453bb1f790ce49b2e30eba8410b4b79bac72b46a2d" +checksum = "818ce546a11a9986bc24f93d0cdf38a8a1a400f1473ea8c82e59f6e0ffab9249" dependencies = [ "memchr", ] diff --git a/Cargo.toml b/Cargo.toml index 51e9c15..dd6d3f8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -56,10 +56,6 @@ toml_edit = "0.21.0" tracing = "0.1.40" uuid = "1.6.1" utoipa = { version = "4", features = ["yaml"] } -llama_cpp = { git = "https://github.com/edgenai/llama_cpp-rs", branch = "build-overhaul" } -whisper_cpp = { git = "https://github.com/edgenai/whisper_cpp-rs", branch = "build-overhaul" } +llama_cpp = { git = "https://github.com/edgenai/llama_cpp-rs", branch = "main" } +whisper_cpp = { git = "https://github.com/edgenai/whisper_cpp-rs", branch = "main" } tauri = { version = "1.5.4", features = [] } - -[patch.crates-io] -llama_cpp = { git = "https://github.com/edgenai/llama_cpp-rs", branch = "build-overhaul" } -llama_cpp_sys = { git = "https://github.com/edgenai/llama_cpp-rs", branch = "build-overhaul" } From 4777197e1a3ec692cf62a7f7334c74f6154b3315 Mon Sep 17 00:00:00 2001 From: Pedro Valente Date: Thu, 1 Feb 2024 12:41:52 +0000 Subject: [PATCH 03/24] updated release GitHub action --- .github/workflows/release.yml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 9e70290..d7c7114 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -1,8 +1,10 @@ name: "release" on: push: - branches: - - main + branches: [ "*" ] + pull_request: + branches: [ "*" ] + # This is the example from the readme. # On each push to the `main` branch it will create or update a GitHub release, build your app, and upload the artifacts to the release. From 535d6aa80a3630609e0d431feebe89bbb9c63efd Mon Sep 17 00:00:00 2001 From: francis2tm Date: Thu, 1 Feb 2024 14:20:25 +0000 Subject: [PATCH 04/24] docs: updated readme --- README.md | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index d0d35e5..189371f 100644 --- a/README.md +++ b/README.md @@ -25,9 +25,14 @@ - [x] **Optimized Inference**: You don't need to take a PhD in AI optimization. ⚡Edgen abstracts the complexity of optimizing inference for different hardware, platforms and models. - [x] **Modular**: ⚡Edgen is **model** and **runtime** agnostic. New models can be added easily and ⚡Edgen can select the best runtime for the user's hardware: you don't need to keep up about the latest models and ML runtimes - **⚡Edgen will do that for you**. - [x] **Model Caching**: ⚡Edgen caches foundational models locally, so 1 model can power hundreds of different apps - users don't need to download the same model multiple times. +- [x] **Native**: ⚡Edgen is build in 🦀Rust and is natively compiled to all popular platforms. No docker required. - [x] **OpenAI Compliant API**: ⚡Edgen is a drop-in replacement for OpenAI. -⚡Edgen lets you use GenAI in your app, completely **locally** on your user's devices, for **free** and with **data-privacy**. It's a drop-in replacement for OpenAI (it uses the a compatible API), supports various functions like text and image generation, speech-to-text, and text-to-speech, and works on Windows, Linux, and MacOS. +⚡Edgen lets you use GenAI in your app, completely **locally** on your user's devices, for **free** and with **data-privacy**. It's a drop-in replacement for OpenAI (it uses the a compatible API), supports various functions like text generation, speech-to-text and works on Windows, Linux, and MacOS. + +### Features + +- [x] Session Caching: ⚡Edgen maintains top performance with big contexts (big chat histories), by caching sessions. Sessions are auto-detected in function of the chat history. ### Endpoints From dd6cda2eb80f296c7f8ce2ab5831afdb2eb8e535 Mon Sep 17 00:00:00 2001 From: francis2tm Date: Thu, 1 Feb 2024 19:59:51 +0000 Subject: [PATCH 05/24] chore: update docs base path --- docs/next.config.mjs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/next.config.mjs b/docs/next.config.mjs index 85ada68..90797e3 100644 --- a/docs/next.config.mjs +++ b/docs/next.config.mjs @@ -16,7 +16,7 @@ const withMDX = nextMDX({ /** @type {import('next').NextConfig} */ const nextConfig = { output: 'export', - basePath: '', + basePath: 'edgen/', pageExtensions: ['js', 'jsx', 'ts', 'tsx', 'mdx'], } From c2757acc8b434dbdd1ff2beed6aba133bccee4ac Mon Sep 17 00:00:00 2001 From: Pedro Valente Date: Fri, 2 Feb 2024 11:24:15 +0000 Subject: [PATCH 06/24] feat: non-streaming chat completions --- crates/edgen_core/src/llm.rs | 2 +- crates/edgen_server/src/openai_shim.rs | 87 +++++++++++++++---- .../edgen_server/src/util/stopping_stream.rs | 2 +- 3 files changed, 73 insertions(+), 18 deletions(-) diff --git a/crates/edgen_core/src/llm.rs b/crates/edgen_core/src/llm.rs index 4a2973a..3b47be7 100644 --- a/crates/edgen_core/src/llm.rs +++ b/crates/edgen_core/src/llm.rs @@ -44,7 +44,7 @@ pub struct CompletionArgs { pub frequency_penalty: f32, } -/// A large language language model endpoint, that is, an object that provides various ways to interact with a large +/// A large language model endpoint, that is, an object that provides various ways to interact with a large /// language model. pub trait LLMEndpoint { /// Given a prompt with several arguments, return a [`Box`]ed [`Future`] which may eventually contain the prompt diff --git a/crates/edgen_server/src/openai_shim.rs b/crates/edgen_server/src/openai_shim.rs index 7fbf3eb..9d01a60 100644 --- a/crates/edgen_server/src/openai_shim.rs +++ b/crates/edgen_server/src/openai_shim.rs @@ -17,7 +17,6 @@ use std::borrow::Cow; use std::collections::HashMap; -use std::convert::Infallible; use std::fmt::{Display, Formatter}; use axum::http::StatusCode; @@ -26,10 +25,8 @@ use axum::response::{IntoResponse, Response, Sse}; use axum::Json; use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart}; use derive_more::{Deref, DerefMut, From}; -use edgen_core::settings::SETTINGS; -use edgen_core::settings::{get_audio_transcriptions_model_dir, get_chat_completions_model_dir}; use either::Either; -use futures::StreamExt; +use futures::{Stream, StreamExt, TryStream}; use serde_derive::{Deserialize, Serialize}; use thiserror::Error; use time::OffsetDateTime; @@ -38,6 +35,9 @@ use tracing::error; use utoipa::ToSchema; use uuid::Uuid; +use edgen_core::settings::SETTINGS; +use edgen_core::settings::{get_audio_transcriptions_model_dir, get_chat_completions_model_dir}; + use crate::model::{Model, ModelKind}; use crate::whisper::WhisperEndpointError; @@ -521,6 +521,29 @@ impl IntoResponse for ChatCompletionError { } } +/// The return type of [`chat_completions`]. Contains either a [`Stream`] of [`Event`]s or a [`Json`] +/// of a [`ChatCompletion`]. +enum ChatCompletionResponse<'a, S> +where + S: TryStream + Send + 'static, +{ + Stream(Sse), + Full(Json>), +} + +impl<'a, S, E> IntoResponse for ChatCompletionResponse<'a, S> +where + S: Stream> + Send + 'static, + E: Into, +{ + fn into_response(self) -> Response { + match self { + ChatCompletionResponse::Stream(stream) => stream.into_response(), + ChatCompletionResponse::Full(full) => full.into_response(), + } + } +} + /// POST `/v1/chat/completions`: generate chat completions for the provided context, optionally /// streaming those completions in real-time. /// @@ -590,12 +613,18 @@ pub async fn chat_completions( let untokenized_context = format!("{}<|ASSISTANT|>", req.messages); - let completions_stream = crate::llm::chat_completion_stream(model, untokenized_context) - .await? - .map(|chunk| { - let fp = format!("edgen-{}", cargo_crate_version!()); - Event::default() - .json_data(ChatCompletionChunk { + let stream_response = if let Some(stream) = req.stream { + stream + } else { + false + }; + + let fp = format!("edgen-{}", cargo_crate_version!()); + let response = if stream_response { + let completions_stream = crate::llm::chat_completion_stream(model, untokenized_context) + .await? + .map(move |chunk| { + Event::default().json_data(ChatCompletionChunk { id: Uuid::new_v4().to_string().into(), choices: tiny_vec![ChatCompletionChunkChoice { index: 0, @@ -610,11 +639,37 @@ pub async fn chat_completions( system_fingerprint: Cow::Borrowed(&fp), // use macro for version object: Cow::Borrowed("text_completion"), }) - .expect("Could not serialize JSON; this should never happen") - }) - .map(Ok::); - - Ok(Sse::new(completions_stream)) + }); + + ChatCompletionResponse::Stream(Sse::new(completions_stream)) + } else { + let content_str = crate::llm::chat_completion(model, untokenized_context).await?; + let response = ChatCompletion { + id: Uuid::new_v4().to_string().into(), + choices: vec![ChatCompletionChoice { + message: ChatMessage::Assistant { + content: Some(Cow::Owned(content_str)), + name: None, + tool_calls: None, + }, + finish_reason: None, + index: 0, + }], + created: OffsetDateTime::now_utc().unix_timestamp(), + model: Cow::Borrowed("main"), + object: Cow::Borrowed("text_completion"), + system_fingerprint: Cow::Owned(fp), // use macro for version + usage: ChatCompletionUsage { + completion_tokens: 0, + prompt_tokens: 0, + total_tokens: 0, + }, + }; + + ChatCompletionResponse::Full(Json(response)) + }; + + Ok(response) } /// A request to transcribe an audio file into text in either the specified language, or whichever @@ -662,7 +717,7 @@ pub struct CreateTranscriptionRequest { /// /// See [the original OpenAI API specification][openai], which this endpoint is compatible with. /// -/// [openai]: https://platform.openai.com/docs/api-reference/auddio/createTranscription +/// [openai]: https://platform.openai.com/docs/api-reference/audio/createTranscription /// /// On failure, may raise a `500 Internal Server Error` with a JSON-encoded [`WhisperEndpointError`] /// to the peer. diff --git a/crates/edgen_server/src/util/stopping_stream.rs b/crates/edgen_server/src/util/stopping_stream.rs index de1b4f6..a493794 100644 --- a/crates/edgen_server/src/util/stopping_stream.rs +++ b/crates/edgen_server/src/util/stopping_stream.rs @@ -28,7 +28,7 @@ pub struct StoppingStream { /// The stop words (phrases) that this stream should stop at. /// /// These are never emitted downstream, and the stream will yield with `Pending` until it - /// is is impossible for any stop word to be generated. + /// is impossible for any stop word to be generated. stop_words: Vec, /// If this stream is uncertain whether it's collecting a stop word, this buffer contains From 65e4685cd147ff1b7f7da99e4ca9167fa48ec5d2 Mon Sep 17 00:00:00 2001 From: Pedro Valente Date: Fri, 2 Feb 2024 12:58:26 +0000 Subject: [PATCH 07/24] temp: re-enabled Windows terminal for release builds --- edgen/src-tauri/src/main.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/edgen/src-tauri/src/main.rs b/edgen/src-tauri/src/main.rs index d2ad9f5..e64c6fa 100644 --- a/edgen/src-tauri/src/main.rs +++ b/edgen/src-tauri/src/main.rs @@ -11,7 +11,7 @@ */ // Tauri - prevents additional console window on Windows in release, DO NOT REMOVE!! -#![cfg_attr(not(debug_assertions), windows_subsystem = "windows")] +//#![cfg_attr(not(debug_assertions), windows_subsystem = "windows")] #[cfg(not(feature = "no_gui"))] mod gui; From 1b14252e4d587fcdca0778deab4e76eae22b6773 Mon Sep 17 00:00:00 2001 From: Pedro Valente Date: Fri, 2 Feb 2024 13:40:32 +0000 Subject: [PATCH 08/24] wip: added config.toml to tauri src --- edgen/src-tauri/.cargo/config.toml | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 edgen/src-tauri/.cargo/config.toml diff --git a/edgen/src-tauri/.cargo/config.toml b/edgen/src-tauri/.cargo/config.toml new file mode 100644 index 0000000..0640a29 --- /dev/null +++ b/edgen/src-tauri/.cargo/config.toml @@ -0,0 +1,4 @@ +[build] +rustflags = [ + "--cfg", "tokio_unstable", +] From fad0b0ea16b85b7c93039ecefeb61a6f7890da71 Mon Sep 17 00:00:00 2001 From: Pedro Valente Date: Fri, 2 Feb 2024 14:42:16 +0000 Subject: [PATCH 09/24] fix: don't use console_subscriber in release builds --- crates/edgen_server/src/lib.rs | 6 +++++- edgen/src-tauri/.cargo/config.toml | 4 ---- 2 files changed, 5 insertions(+), 5 deletions(-) delete mode 100644 edgen/src-tauri/.cargo/config.toml diff --git a/crates/edgen_server/src/lib.rs b/crates/edgen_server/src/lib.rs index 870d167..3773c1b 100644 --- a/crates/edgen_server/src/lib.rs +++ b/crates/edgen_server/src/lib.rs @@ -155,7 +155,11 @@ fn serve(args: &cli::Serve) -> EdgenResult { #[tokio::main] async fn start_server(args: &cli::Serve) -> EdgenResult { - console_subscriber::init(); + // The console is disabled in release builds, so there is no need to initialise this + #[cfg(debug_assertions)] + { + console_subscriber::init(); + } SETTINGS .write() diff --git a/edgen/src-tauri/.cargo/config.toml b/edgen/src-tauri/.cargo/config.toml deleted file mode 100644 index 0640a29..0000000 --- a/edgen/src-tauri/.cargo/config.toml +++ /dev/null @@ -1,4 +0,0 @@ -[build] -rustflags = [ - "--cfg", "tokio_unstable", -] From 60f93762d465c44035176c521065104be471fffc Mon Sep 17 00:00:00 2001 From: Pedro Valente Date: Fri, 2 Feb 2024 14:55:31 +0000 Subject: [PATCH 10/24] re-disabled terminal in release builds --- edgen/src-tauri/src/main.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/edgen/src-tauri/src/main.rs b/edgen/src-tauri/src/main.rs index e64c6fa..d2ad9f5 100644 --- a/edgen/src-tauri/src/main.rs +++ b/edgen/src-tauri/src/main.rs @@ -11,7 +11,7 @@ */ // Tauri - prevents additional console window on Windows in release, DO NOT REMOVE!! -//#![cfg_attr(not(debug_assertions), windows_subsystem = "windows")] +#![cfg_attr(not(debug_assertions), windows_subsystem = "windows")] #[cfg(not(feature = "no_gui"))] mod gui; From cd171b4936cdfa5486b13d83eaa3c5c710e79bfb Mon Sep 17 00:00:00 2001 From: francis2tm Date: Fri, 2 Feb 2024 17:32:32 +0000 Subject: [PATCH 11/24] chore: changed organization to EdgenAI in ProjectDirs --- crates/edgen_core/src/settings.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/edgen_core/src/settings.rs b/crates/edgen_core/src/settings.rs index d2c3b02..6d953f2 100644 --- a/crates/edgen_core/src/settings.rs +++ b/crates/edgen_core/src/settings.rs @@ -33,7 +33,7 @@ pub static SETTINGS: Lazy> = Lazy::new(Default::default); /// The configuration, and data directories for Edgen. pub static PROJECT_DIRS: Lazy = - Lazy::new(|| ProjectDirs::from("com", "Binedge", "Edgen").unwrap()); + Lazy::new(|| ProjectDirs::from("com", "EdgenAI", "Edgen").unwrap()); pub static CONFIG_FILE: Lazy = Lazy::new(|| build_config_file_path()); pub static CHAT_COMPLETIONS_MODEL_DIR: Lazy = Lazy::new(|| build_chat_completions_model_dir()); From 41016f60557c7ed13e9277339b295f0492474d4c Mon Sep 17 00:00:00 2001 From: Pedro Valente Date: Fri, 2 Feb 2024 18:33:52 +0000 Subject: [PATCH 12/24] chore: small corrections and cleanup --- Cargo.lock | 53 +---------- crates/edgen_rt_whisper_cpp/Cargo.toml | 9 +- crates/edgen_rt_whisper_cpp/src/lib.rs | 123 ------------------------- crates/edgen_server/src/lib.rs | 6 +- crates/edgen_server/src/openai_shim.rs | 7 +- crates/edgen_server/src/whisper.rs | 4 + 6 files changed, 12 insertions(+), 190 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a70a483..2ef251b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -65,28 +65,6 @@ dependencies = [ "alloc-no-stdlib", ] -[[package]] -name = "alsa" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ce34de545ad29bcc00cb1b87a94c132256dcf83aa7eeb9674482568405a6ff0a" -dependencies = [ - "alsa-sys", - "bitflags 2.4.2", - "libc", - "nix 0.26.4", -] - -[[package]] -name = "alsa-sys" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db8fee663d06c4e303404ef5f40488a53e062f89ba8bfed81f42325aafad1527" -dependencies = [ - "libc", - "pkg-config", -] - [[package]] name = "android-tzdata" version = "0.1.1" @@ -1370,8 +1348,6 @@ dependencies = [ name = "edgen_rt_whisper_cpp" version = "0.1.0" dependencies = [ - "alsa", - "axum 0.7.4", "dashmap", "derive_more", "edgen_core", @@ -1379,7 +1355,6 @@ dependencies = [ "once_cell", "thiserror", "tokio", - "wav", "whisper_cpp", ] @@ -2027,7 +2002,7 @@ dependencies = [ "libc", "lockfree", "log", - "nix 0.23.2", + "nix", "pin-project-lite", "rlimit", "scoped-tls", @@ -3082,17 +3057,6 @@ dependencies = [ "memoffset 0.6.5", ] -[[package]] -name = "nix" -version = "0.26.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "598beaf3cc6fdd9a5dfb1630c2800c7acd31df7aaf0f565796fba2b53ca1af1b" -dependencies = [ - "bitflags 1.3.2", - "cfg-if", - "libc", -] - [[package]] name = "nodrop" version = "0.1.14" @@ -4039,12 +4003,6 @@ dependencies = [ "windows 0.37.0", ] -[[package]] -name = "riff" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9b1a3d5f46d53f4a3478e2be4a5a5ce5108ea58b100dcd139830eae7f79a3a1" - [[package]] name = "ring" version = "0.17.7" @@ -5889,15 +5847,6 @@ dependencies = [ "web-sys", ] -[[package]] -name = "wav" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a65e199c799848b4f997072aa4d673c034f80f40191f97fe2f0a23f410be1609" -dependencies = [ - "riff", -] - [[package]] name = "web-sys" version = "0.3.67" diff --git a/crates/edgen_rt_whisper_cpp/Cargo.toml b/crates/edgen_rt_whisper_cpp/Cargo.toml index 0922ddd..71c3e2a 100644 --- a/crates/edgen_rt_whisper_cpp/Cargo.toml +++ b/crates/edgen_rt_whisper_cpp/Cargo.toml @@ -7,16 +7,11 @@ edition = "2021" [dependencies] dashmap = { workspace = true } -edgen_core = { version = "^0.1.0", path = "../edgen_core" } +edgen_core = { path = "../edgen_core" } futures = { workspace = true } derive_more = { workspace = true } thiserror = { workspace = true } tokio = { workspace = true, features = ["sync"] } -once_cell = "1.18.0" +once_cell = { workspace = true } whisper_cpp = { workspace = true, features = ["native"] } -[dev-dependencies] -axum = { workspace = true, features = ["tokio", "http1"] } -alsa = "0.8.1" -tokio = { workspace = true, features = ["full"] } -wav = "1.0.0" diff --git a/crates/edgen_rt_whisper_cpp/src/lib.rs b/crates/edgen_rt_whisper_cpp/src/lib.rs index 92fa959..46ec890 100644 --- a/crates/edgen_rt_whisper_cpp/src/lib.rs +++ b/crates/edgen_rt_whisper_cpp/src/lib.rs @@ -132,126 +132,3 @@ impl WhisperRunner for WhisperCppRunner { Ok(res) } } - -#[cfg(test)] -mod tests { - /* - use alsa::pcm::{Access, Format, HwParams}; - use alsa::{Direction, ValueOr, PCM}; - use tokio::sync::mpsc::error::TryRecvError; - - use crate::*; - - #[derive(Error, Debug)] - enum TestError { - #[error("whisper error: {0}")] - Whisper(#[from] edgen_core::whisper::WhisperError), - #[error("decode session error: {0}")] - Session(#[from] edgen_core::whisper::DecodeSessionError), - #[error("whisper.cpp error: {0}")] - WhisperCpp(#[from] WhisperCppError), - #[error("alsa error: {0}")] - Alsa(#[from] alsa::Error), - #[error("failed to write test file: {0}")] - File(#[from] std::io::Error), - } - - #[tokio::test] - async fn live_transcription() -> Result<(), TestError> { - - //TODO change to use env variables - let model = WhisperCpp::load("/home/pedro/dev/models/ggml-base.en.bin")?; - - let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel(); - let mut session = model - .new_session(Box::new(move |str| { - println!("{str}"); - tx.send(str).unwrap(); - })) - .await?; - - let _rate; - let pcm = PCM::new("default", Direction::Capture, false)?; - { - // For this example, we assume 44100Hz, one channel, 16 bit audio. - let hwp = HwParams::any(&pcm)?; - hwp.set_channels(1)?; - hwp.set_rate(16000, ValueOr::Nearest)?; - hwp.set_format(Format::s16())?; - hwp.set_access(Access::RWInterleaved)?; - pcm.hw_params(&hwp)?; - - _rate = hwp.get_rate()?; // is there any side effect that justifies this assignment? - } - pcm.start()?; - - let io = pcm.io_i16()?; - let mut buf = [0i16; 8192 * 10]; - - loop { - let mut stop = false; - - tokio::time::sleep(tokio::time::Duration::from_millis(1)).await; - let _size_read = io.readi(&mut buf)?; - let samples: Vec<_> = buf[..].iter().map(|v| *v as f32 / 32768.).collect(); - - session.push(&samples)?; - println!("aaaaaa"); - - loop { - println!("huh?"); - let output = rx.try_recv(); - match output { - Ok(str) => { - println!("{str}"); - if str.contains("stop") | str.contains("Stop") | str.contains("STOP") { - session.end().await?; - stop = true; - break; - } - } - Err(TryRecvError::Empty) => break, - Err(_) => { - panic!() - } - } - } - - if stop { - break; - } - } - Ok(()) - } - */ - - /* - static ENDPOINT: OnceCell>>> = OnceCell::const_new(); - - #[tokio::test] - async fn server() -> Result<(), TestError> { - use axum::Router; - use axum::routing::post; - use axum::http::StatusCode; - use axum::Json; - - async fn creation_helper() -> Arc>> { - Arc::new(RwLock::new(WhisperEndpoint::default())) - } - - async fn request_helper(payload: Json) -> (StatusCode, String) { - let mut locked = ENDPOINT.get_or_init(creation_helper).await.write().await; - locked.standalone(payload).await - } - - let app = Router::new() - .route("/", post(request_helper)); - - let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap(); - axum::serve(listener, app).await.unwrap(); - - Ok(()) - } - - */ -} diff --git a/crates/edgen_server/src/lib.rs b/crates/edgen_server/src/lib.rs index 3773c1b..870d167 100644 --- a/crates/edgen_server/src/lib.rs +++ b/crates/edgen_server/src/lib.rs @@ -155,11 +155,7 @@ fn serve(args: &cli::Serve) -> EdgenResult { #[tokio::main] async fn start_server(args: &cli::Serve) -> EdgenResult { - // The console is disabled in release builds, so there is no need to initialise this - #[cfg(debug_assertions)] - { - console_subscriber::init(); - } + console_subscriber::init(); SETTINGS .write() diff --git a/crates/edgen_server/src/openai_shim.rs b/crates/edgen_server/src/openai_shim.rs index 9d01a60..ef5606e 100644 --- a/crates/edgen_server/src/openai_shim.rs +++ b/crates/edgen_server/src/openai_shim.rs @@ -521,8 +521,9 @@ impl IntoResponse for ChatCompletionError { } } -/// The return type of [`chat_completions`]. Contains either a [`Stream`] of [`Event`]s or a [`Json`] -/// of a [`ChatCompletion`]. +/// The return type of [`chat_completions`]. +/// +/// Contains either a [`Stream`] of [`Event`]s or the [`Json`] of a [`ChatCompletion`]. enum ChatCompletionResponse<'a, S> where S: TryStream + Send + 'static, @@ -686,7 +687,7 @@ pub async fn chat_completions( pub struct CreateTranscriptionRequest { /// The audio file object (not file name) to transcribe, in one of the following formats: /// **`aac`**, **`flac`**, **`mp3`**, **`m4a`**, **`m4b`**, **`ogg`**, **`oga`**, **`mogg`**, - /// **`wav`**, **`webm`**. TODO check working formats. + /// **`wav`**. TODO check working formats. webm #[form_data(limit = "unlimited")] #[schema(value_type = Vec < u8 >)] pub file: FieldData, diff --git a/crates/edgen_server/src/whisper.rs b/crates/edgen_server/src/whisper.rs index 9cae89f..239a71c 100644 --- a/crates/edgen_server/src/whisper.rs +++ b/crates/edgen_server/src/whisper.rs @@ -18,6 +18,7 @@ use rubato::Resampler; use serde_derive::Serialize; use thiserror::Error; use time::Duration; +use tracing::info; use utoipa::ToSchema; use uuid::Uuid; @@ -201,6 +202,8 @@ fn to_pcm(audio_file: &[u8]) -> Result, AudioError> { /// The optimal sample rate for whisper models. const OPTIMAL_SAMPLE_RATE: u32 = 16000; + info!("Parsing audio file ({} bytes)", audio_file.len()); + // Initialisation. let cursor = std::io::Cursor::new(audio_file.to_vec()); let stream = MediaSourceStream::new(Box::new(cursor), Default::default()); @@ -210,6 +213,7 @@ fn to_pcm(audio_file: &[u8]) -> Result, AudioError> { let meta_opts: MetadataOptions = Default::default(); let fmt_opts: FormatOptions = Default::default(); + // TODO this gets stuck in a loop for some invalid files let probed = symphonia::default::get_probe() .format(&hint, stream, &fmt_opts, &meta_opts) .map_err(move |e| AudioError::Parse(format!("failed to probe audio data: {e}")))?; From 3d0bf5977f4eff1788b51228e2c032304188aa83 Mon Sep 17 00:00:00 2001 From: francis2tm Date: Fri, 2 Feb 2024 21:00:31 +0000 Subject: [PATCH 13/24] chore: publish docs --- docs/next.config.mjs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/next.config.mjs b/docs/next.config.mjs index 90797e3..8ae6b90 100644 --- a/docs/next.config.mjs +++ b/docs/next.config.mjs @@ -16,7 +16,7 @@ const withMDX = nextMDX({ /** @type {import('next').NextConfig} */ const nextConfig = { output: 'export', - basePath: 'edgen/', + // basePath: 'edgen/', pageExtensions: ['js', 'jsx', 'ts', 'tsx', 'mdx'], } From ca3c5eb662f6907ae2ea15dae186fec77d56856a Mon Sep 17 00:00:00 2001 From: francis2tm Date: Sat, 3 Feb 2024 14:29:45 +0000 Subject: [PATCH 14/24] chore: updated README --- README.md | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 189371f..972ee59 100644 --- a/README.md +++ b/README.md @@ -62,9 +62,27 @@ ## Quickstart -1. [Download](https://edgen.co/download) ⚡Edgen +1. [Download](https://edgen.co/download) and start ⚡Edgen 2. Chat with ⚡[EdgenChat](https://chat.edgen.co) +⚡Edgen usage: + +``` +Usage: edgen [] [] + +Toplevel CLI commands and options. Subcommands are optional. If no command is provided "serve" will be invoked with default options. + +Options: + --help display usage information + +Commands: + serve Starts the edgen server. This is the default command when no + command is provided. + config Configuration-related subcommands. + version Prints the edgen version to stdout. + oasgen Generates the Edgen OpenAPI specification. +``` + # Developers The following sections are for people looking to contribute to ⚡Edgen. From cba17798ad4bd487b3261c8633bc0cde426a8e41 Mon Sep 17 00:00:00 2001 From: francis2tm Date: Sat, 3 Feb 2024 15:54:55 +0000 Subject: [PATCH 15/24] chore: renamed release -> build worflow --- .github/workflows/{release.yml => build.yml} | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) rename .github/workflows/{release.yml => build.yml} (93%) diff --git a/.github/workflows/release.yml b/.github/workflows/build.yml similarity index 93% rename from .github/workflows/release.yml rename to .github/workflows/build.yml index d7c7114..007062b 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/build.yml @@ -1,16 +1,15 @@ -name: "release" +name: "build" on: push: - branches: [ "*" ] + branches: ["*"] pull_request: - branches: [ "*" ] - + branches: ["*"] # This is the example from the readme. -# On each push to the `main` branch it will create or update a GitHub release, build your app, and upload the artifacts to the release. +# On each push to the `main` branch it will create or update a GitHub build, build your app, and upload the artifacts to the build. jobs: - release: + build: permissions: contents: write strategy: From 7eda9dca42b201842684b782d3a924327db20de2 Mon Sep 17 00:00:00 2001 From: francis2tm Date: Mon, 5 Feb 2024 11:41:46 +0000 Subject: [PATCH 16/24] chore: corrected typos --- crates/edgen_server/src/cli.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/edgen_server/src/cli.rs b/crates/edgen_server/src/cli.rs index 43a1b9d..02c1bbd 100644 --- a/crates/edgen_server/src/cli.rs +++ b/crates/edgen_server/src/cli.rs @@ -59,7 +59,7 @@ pub struct Serve { #[argh(option, short = 'b')] pub uri: Vec, /// if present, edgen will not start the GUI; - /// the default behaviour is to start the GUI. + /// the default behavior is to start the GUI. #[argh(switch, short = 'g')] pub nogui: bool, } @@ -110,7 +110,7 @@ pub struct Oasgen { #[argh(switch, short = 'y')] pub yaml: bool, /// if present, edgen will generate the OpenAPI spec in JSON format; - /// the default behaviour is to generate yaml output. + /// the default behavior is to generate yaml output. #[argh(switch, short = 'j')] pub json: bool, } From 6e24c861c5ac3b19c770107e9c76360f8025e412 Mon Sep 17 00:00:00 2001 From: francis2tm Date: Mon, 5 Feb 2024 11:42:39 +0000 Subject: [PATCH 17/24] updated README --- .vscode/settings.json | 2 +- README.md | 24 ++++++++++++++++++++++-- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 0152b53..5eeac91 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,3 +1,3 @@ { - "cSpell.files": ["docs/"] + "cSpell.files": ["docs/", "**/*.md"] } diff --git a/README.md b/README.md index 972ee59..89a92c1 100644 --- a/README.md +++ b/README.md @@ -22,11 +22,14 @@

⚡Edgen architecture overview

+- [x] **OpenAI Compliant API**: ⚡Edgen implements the same API as OpenAI, making it a drop-in replacement. +- [x] **Multi-Endpoint Support**: ⚡Edgen exposes multiple AI endpoints such as chat completions (LLMs) and Speech-to-Text (Whisper) for audio transcriptions. +- [x] **Model Agnostic**: LLMs (Llama2, Mistral, Mixtral...), Speech-to-text (whisper) and many others. - [x] **Optimized Inference**: You don't need to take a PhD in AI optimization. ⚡Edgen abstracts the complexity of optimizing inference for different hardware, platforms and models. - [x] **Modular**: ⚡Edgen is **model** and **runtime** agnostic. New models can be added easily and ⚡Edgen can select the best runtime for the user's hardware: you don't need to keep up about the latest models and ML runtimes - **⚡Edgen will do that for you**. - [x] **Model Caching**: ⚡Edgen caches foundational models locally, so 1 model can power hundreds of different apps - users don't need to download the same model multiple times. -- [x] **Native**: ⚡Edgen is build in 🦀Rust and is natively compiled to all popular platforms. No docker required. -- [x] **OpenAI Compliant API**: ⚡Edgen is a drop-in replacement for OpenAI. +- [x] **Native**: ⚡Edgen is build in 🦀Rust and is natively compiled to all popular platforms: **Windows, MacOS and Linux**. No docker required. +- [ ] **Graphical Interface**: A graphical user interface to help users efficiently manage their models, endpoints and permissions. ⚡Edgen lets you use GenAI in your app, completely **locally** on your user's devices, for **free** and with **data-privacy**. It's a drop-in replacement for OpenAI (it uses the a compatible API), supports various functions like text generation, speech-to-text and works on Windows, Linux, and MacOS. @@ -83,6 +86,23 @@ Commands: oasgen Generates the Edgen OpenAPI specification. ``` +`edgen serve` usage: + +``` +Usage: edgen serve [-b ] [-g] + +Starts the edgen server. This is the default command when no command is provided. + +Options: + -b, --uri if present, one or more URIs/hosts to bind the server to. + `unix://` (on Linux), `http://`, and `ws://` are supported. + For use in scripts, it is recommended to explicitly add this + option to make your scripts future-proof. + -g, --nogui if present, edgen will not start the GUI; the default + behavior is to start the GUI. + --help display usage information +``` + # Developers The following sections are for people looking to contribute to ⚡Edgen. From 9d26b6f4d54b9b456840bf4323cace8e960381b4 Mon Sep 17 00:00:00 2001 From: francis2tm Date: Mon, 5 Feb 2024 11:42:55 +0000 Subject: [PATCH 18/24] chore: removed .envrc --- .envrc | 7 ------- flake.lock | 6 +++--- 2 files changed, 3 insertions(+), 10 deletions(-) delete mode 100644 .envrc diff --git a/.envrc b/.envrc deleted file mode 100644 index 5e62b23..0000000 --- a/.envrc +++ /dev/null @@ -1,7 +0,0 @@ -if ! has nix_direnv_version || ! nix_direnv_version 2.4.0; then - source_url "https://raw.githubusercontent.com/nix-community/nix-direnv/2.4.0/direnvrc" "sha256-17G+Mvt/JsyJrwsf7bqMr7ho7liHP+0Lo4RMIHgp0F8=" -fi - -# watching all nix-files and re-evaluate on change -watch_file $(find . -name "*.nix" -printf '"%p" ') -use flake diff --git a/flake.lock b/flake.lock index 85cb9e1..191a359 100644 --- a/flake.lock +++ b/flake.lock @@ -38,11 +38,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1705697961, - "narHash": "sha256-XepT3WS516evSFYkme3GrcI3+7uwXHqtHbip+t24J7E=", + "lastModified": 1706925685, + "narHash": "sha256-hVInjWMmgH4yZgA4ZtbgJM1qEAel72SYhP5nOWX4UIM=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "e5d1c87f5813afde2dda384ac807c57a105721cc", + "rev": "79a13f1437e149dc7be2d1290c74d378dad60814", "type": "github" }, "original": { From 2df0897871c2273889bb4a292b694426165872da Mon Sep 17 00:00:00 2001 From: Pedro Valente Date: Mon, 5 Feb 2024 11:49:24 +0000 Subject: [PATCH 19/24] chore: moved Windows terminal to feature --- .github/workflows/build.yml | 7 ++++--- crates/edgen_server/src/lib.rs | 6 +----- edgen/src-tauri/Cargo.toml | 5 +++-- edgen/src-tauri/src/main.rs | 6 ++++-- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 007062b..f9d1702 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -1,9 +1,9 @@ name: "build" on: push: - branches: ["*"] + branches: [ "*" ] pull_request: - branches: ["*"] + branches: [ "*" ] # This is the example from the readme. # On each push to the `main` branch it will create or update a GitHub build, build your app, and upload the artifacts to the build. @@ -15,7 +15,7 @@ jobs: strategy: fail-fast: false matrix: - platform: [macos-latest, ubuntu-20.04, windows-latest] + platform: [ macos-latest, ubuntu-20.04, windows-latest ] runs-on: ${{ matrix.platform }} steps: @@ -59,6 +59,7 @@ jobs: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} TAURI_PRIVATE_KEY: ${{ secrets.TAURI_KEY }} TAURI_KEY_PASSWORD: ${{ secrets.TAURI_KEY_PASSWORD }} + RUSTFLAGS: "--cfg tokio_unstable" with: tagName: v__VERSION__ # the action automatically replaces \_\_VERSION\_\_ with the app version releaseName: "v__VERSION__" diff --git a/crates/edgen_server/src/lib.rs b/crates/edgen_server/src/lib.rs index 3773c1b..870d167 100644 --- a/crates/edgen_server/src/lib.rs +++ b/crates/edgen_server/src/lib.rs @@ -155,11 +155,7 @@ fn serve(args: &cli::Serve) -> EdgenResult { #[tokio::main] async fn start_server(args: &cli::Serve) -> EdgenResult { - // The console is disabled in release builds, so there is no need to initialise this - #[cfg(debug_assertions)] - { - console_subscriber::init(); - } + console_subscriber::init(); SETTINGS .write() diff --git a/edgen/src-tauri/Cargo.toml b/edgen/src-tauri/Cargo.toml index 1b4ddd5..b6128fc 100644 --- a/edgen/src-tauri/Cargo.toml +++ b/edgen/src-tauri/Cargo.toml @@ -20,11 +20,12 @@ serde_json = "1.0" tokio = { workspace = true, features = ["full", "tracing"] } tracing = { workspace = true } opener = "0.6.1" -edgen_server = { path = "../../crates/edgen_server"} -edgen_core = { path = "../../crates/edgen_core"} +edgen_server = { path = "../../crates/edgen_server" } +edgen_core = { path = "../../crates/edgen_core" } [features] no_gui = [] # this feature is used for production builds or when `devPath` points to the filesystem # DO NOT REMOVE!! custom-protocol = ["tauri/custom-protocol"] +enable-windows-terminal = [] diff --git a/edgen/src-tauri/src/main.rs b/edgen/src-tauri/src/main.rs index d2ad9f5..7082412 100644 --- a/edgen/src-tauri/src/main.rs +++ b/edgen/src-tauri/src/main.rs @@ -10,8 +10,10 @@ * limitations under the License. */ -// Tauri - prevents additional console window on Windows in release, DO NOT REMOVE!! -#![cfg_attr(not(debug_assertions), windows_subsystem = "windows")] +#![cfg_attr( + not(feature = "enable-windows-terminal"), + windows_subsystem = "windows" +)] #[cfg(not(feature = "no_gui"))] mod gui; From 33008ad1f2972ac9d64683c71b4da5fa4fe89c07 Mon Sep 17 00:00:00 2001 From: Pedro Valente Date: Mon, 5 Feb 2024 12:18:36 +0000 Subject: [PATCH 20/24] reverted unintentional change --- crates/edgen_server/src/lib.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/crates/edgen_server/src/lib.rs b/crates/edgen_server/src/lib.rs index 870d167..3773c1b 100644 --- a/crates/edgen_server/src/lib.rs +++ b/crates/edgen_server/src/lib.rs @@ -155,7 +155,11 @@ fn serve(args: &cli::Serve) -> EdgenResult { #[tokio::main] async fn start_server(args: &cli::Serve) -> EdgenResult { - console_subscriber::init(); + // The console is disabled in release builds, so there is no need to initialise this + #[cfg(debug_assertions)] + { + console_subscriber::init(); + } SETTINGS .write() From 4db7c34a1c0f6c1f60d79e5f1e780a3ae88eaa2c Mon Sep 17 00:00:00 2001 From: Pedro Valente Date: Mon, 5 Feb 2024 16:51:41 +0000 Subject: [PATCH 21/24] refactored whisper endpoint --- Cargo.lock | 76 ++-- crates/edgen_core/Cargo.toml | 2 + crates/edgen_core/src/lib.rs | 191 +--------- crates/edgen_core/src/llm.rs | 28 +- crates/edgen_core/src/whisper.rs | 487 +++++++------------------ crates/edgen_rt_llama_cpp/src/lib.rs | 5 +- crates/edgen_rt_whisper_cpp/Cargo.toml | 2 + crates/edgen_rt_whisper_cpp/src/lib.rs | 327 ++++++++++++----- crates/edgen_server/src/lib.rs | 7 +- crates/edgen_server/src/llm.rs | 34 +- crates/edgen_server/src/openai_shim.rs | 64 +++- crates/edgen_server/src/whisper.rs | 292 +-------------- 12 files changed, 524 insertions(+), 991 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2ef251b..10d98ff 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -437,9 +437,9 @@ dependencies = [ [[package]] name = "axum-test" -version = "14.2.2" +version = "14.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2d15e9969313df61a64e25ce39cc8e586d42432696a0c8e0cfac1d377013d9c" +checksum = "fc431b62ab307c833af24700936485eb5f9a8ac18a19347fe37dd4f7ae3dffe9" dependencies = [ "anyhow", "async-trait", @@ -526,17 +526,17 @@ checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" [[package]] name = "bindgen" -version = "0.69.2" +version = "0.69.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4c69fae65a523209d34240b60abe0c42d33d1045d445c0839d8a4894a736e2d" +checksum = "a00dc851838a2120612785d195287475a3ac45514741da670b735818822129a0" dependencies = [ "bitflags 2.4.2", "cexpr", "clang-sys", + "itertools 0.12.1", "lazy_static", "lazycell", "log", - "peeking_take_while", "prettyplease", "proc-macro2", "quote", @@ -561,9 +561,9 @@ checksum = "ed570934406eb16438a4e976b1b4500774099c13b8cb96eec99f620f05090ddf" [[package]] name = "bitmaps" -version = "3.2.0" +version = "3.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "703642b98a00b3b90513279a8ede3fcfa479c126c5fb46e78f3051522f021403" +checksum = "a1d084b0137aaa901caf9f1e8b21daa6aa24d41cd806e111335541eff9683bd6" [[package]] name = "blake3" @@ -1313,9 +1313,11 @@ dependencies = [ "notify", "num_cpus", "once_cell", + "rubato", "serde", "serde_yaml", "smol", + "symphonia", "tempfile", "thiserror", "time", @@ -1355,6 +1357,8 @@ dependencies = [ "once_cell", "thiserror", "tokio", + "tracing", + "uuid", "whisper_cpp", ] @@ -2382,9 +2386,9 @@ dependencies = [ [[package]] name = "iana-time-zone" -version = "0.1.59" +version = "0.1.60" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6a67363e2aa4443928ce15e57ebae94fd8949958fd1223c4cfc0cd473ad7539" +checksum = "e7ffbb5a1b541ea2561f8c41c087286cc091e21e556a4f09a8f6cbf17b69b141" dependencies = [ "android_system_properties", "core-foundation-sys", @@ -2554,6 +2558,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "0.4.8" @@ -2778,7 +2791,7 @@ checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c" [[package]] name = "llama_cpp" version = "0.3.0" -source = "git+https://github.com/edgenai/llama_cpp-rs?branch=main#d7c895d051780aea216eb81f40b9e849cdb61e3e" +source = "git+https://github.com/edgenai/llama_cpp-rs?branch=main#f4f56beac06c6e35993b1ca94a28dc3786409bda" dependencies = [ "ctor", "derive_more", @@ -2793,7 +2806,7 @@ dependencies = [ [[package]] name = "llama_cpp_sys" version = "0.3.0" -source = "git+https://github.com/edgenai/llama_cpp-rs?branch=main#d7c895d051780aea216eb81f40b9e849cdb61e3e" +source = "git+https://github.com/edgenai/llama_cpp-rs?branch=main#f4f56beac06c6e35993b1ca94a28dc3786409bda" dependencies = [ "bindgen", "cc", @@ -2945,9 +2958,9 @@ checksum = "933dca44d65cdd53b355d0b73d380a2ff5da71f87f036053188bf1eab6a19881" [[package]] name = "miniz_oxide" -version = "0.7.1" +version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7810e0be55b428ada41041c41f32c9f1a42817901b4ccf45fa3d4b6561e74c7" +checksum = "9d811f3e15f28568be3407c8e7fdb6514c1cda3cb30683f15b6a1a1dc4ea14a7" dependencies = [ "adler", "simd-adler32", @@ -3379,12 +3392,6 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8835116a5c179084a830efb3adc117ab007512b535bc1a21c991d3b32a6b44dd" -[[package]] -name = "peeking_take_while" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099" - [[package]] name = "percent-encoding" version = "2.3.1" @@ -3734,7 +3741,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "efb6c9a1dd1def8e2124d17e83a20af56f1570d6c2d2bd9e266ccb768df3840e" dependencies = [ "anyhow", - "itertools", + "itertools 0.11.0", "proc-macro2", "quote", "syn 2.0.48", @@ -4092,9 +4099,9 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.30" +version = "0.38.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "322394588aaf33c24007e8bb3238ee3e4c5c09c084ab32bc73890b99ff326bca" +checksum = "6ea3e1a662af26cd7a3ba09c0297a31af215563ecf42817c98df621387f4e949" dependencies = [ "bitflags 2.4.2", "errno", @@ -5193,9 +5200,9 @@ dependencies = [ [[package]] name = "time" -version = "0.3.32" +version = "0.3.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe80ced77cbfb4cb91a94bf72b378b4b6791a0d9b7f09d0be747d1bdff4e68bd" +checksum = "c8248b6521bb14bc45b4067159b9b6ad792e2d6d754d6c41fb50e29fefe38749" dependencies = [ "deranged", "itoa 1.0.10", @@ -5240,9 +5247,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.35.1" +version = "1.36.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c89b4efa943be685f629b149f53829423f8f5531ea21249408e8e2f8671ec104" +checksum = "61285f6515fa018fb2d1e46eb21223fff441ee8db5d0f1435e8ab4f5cdb80931" dependencies = [ "backtrace", "bytes", @@ -5531,9 +5538,9 @@ dependencies = [ [[package]] name = "treediff" -version = "4.0.2" +version = "4.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52984d277bdf2a751072b5df30ec0377febdb02f7696d64c2d7d54630bac4303" +checksum = "4d127780145176e2b5d16611cc25a900150e86e9fd79d3bde6ff3a37359c9cb5" dependencies = [ "serde_json", ] @@ -5906,9 +5913,9 @@ dependencies = [ [[package]] name = "webpki-roots" -version = "0.25.3" +version = "0.25.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1778a42e8b3b90bff8d0f5032bf22250792889a5cdc752aa0020c84abe3aaf10" +checksum = "5f20c57d8d7db6d3b86154206ae5d8fba62dd39573114de97c2cb0578251f8e1" [[package]] name = "webview2-com" @@ -5963,18 +5970,19 @@ dependencies = [ [[package]] name = "whisper_cpp" version = "0.2.0" -source = "git+https://github.com/edgenai/whisper_cpp-rs?branch=main#7257315a3ab7f462d6f45e0df252dc7d8462fe84" +source = "git+https://github.com/edgenai/whisper_cpp-rs?branch=main#904715e1c29fa01facd3a63bdc7265ea4090d4ed" dependencies = [ "derive_more", "thiserror", "tokio", + "tracing", "whisper_cpp_sys", ] [[package]] name = "whisper_cpp_sys" version = "0.2.0" -source = "git+https://github.com/edgenai/whisper_cpp-rs?branch=main#7257315a3ab7f462d6f45e0df252dc7d8462fe84" +source = "git+https://github.com/edgenai/whisper_cpp-rs?branch=main#904715e1c29fa01facd3a63bdc7265ea4090d4ed" dependencies = [ "bindgen", "cmake", @@ -6349,9 +6357,9 @@ checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04" [[package]] name = "winnow" -version = "0.5.36" +version = "0.5.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "818ce546a11a9986bc24f93d0cdf38a8a1a400f1473ea8c82e59f6e0ffab9249" +checksum = "a7cad8365489051ae9f054164e459304af2e7e9bb407c958076c8bf4aef52da5" dependencies = [ "memchr", ] diff --git a/crates/edgen_core/Cargo.toml b/crates/edgen_core/Cargo.toml index 130a67b..8a306ec 100644 --- a/crates/edgen_core/Cargo.toml +++ b/crates/edgen_core/Cargo.toml @@ -12,9 +12,11 @@ edgen_async_compat = { path = "../edgen_async_compat", features = ["runtime-toki notify = { workspace = true } num_cpus = { workspace = true } once_cell = { workspace = true } +rubato = "0.14.1" serde = { workspace = true, features = ["derive"] } serde_yaml = { workspace = true } smol = { workspace = true } +symphonia = { version = "0.5.3", features = ["all-codecs", "all-formats"] } time = { workspace = true } tracing = { workspace = true } futures = { workspace = true } diff --git a/crates/edgen_core/src/lib.rs b/crates/edgen_core/src/lib.rs index d5dca18..e5199e9 100644 --- a/crates/edgen_core/src/lib.rs +++ b/crates/edgen_core/src/lib.rs @@ -17,19 +17,9 @@ //! arbitrary [`EnvelopeHandler`]. extern crate alloc; -/* -use alloc::sync::Arc; -use dashmap::DashMap; -use smol::channel::{Receiver, Sender}; -use smol::future::Future; -use smol::lock::{OnceCell, Semaphore}; -use time::OffsetDateTime; -use tracing::info; -use uuid::Uuid; +use std::time::Duration; -use edgen_proto::Envelope; -*/ pub mod llm; pub mod whisper; @@ -37,178 +27,9 @@ pub mod settings; pub mod perishable; -/* - -/// A [sans-IO][sans-io], fully-asynchronous, task-oriented server for routing [`Envelope`]s to and -/// from [`Session`]s. -/// -/// [sans-io]: https://sans-io.readthedocs.io/ -#[derive(Clone)] -pub struct Server { - _inner: Arc, -} - -impl Server { - /// Opens a new session on this server, dispatching incoming envelopes to the given handler. - pub fn open_session_with_handler(&self, _handler: impl EnvelopeHandler) -> Session { - let session_id = Uuid::new_v4(); - - info!("New virtual session: {}", session_id); - todo!(); - /*let session = Session { - inner: Arc::new(SessionState { - uuid: session_id, - bound_server: self.clone(), - task_semaphore: Arc::new(Semaphore::new(Session::MAX_CONCURRENT_TASKS)), - }) - }; - - self.inner.open_sessions.insert(session_id, session.clone()); - - session*/ - } - - /// Removes a session from this server. - #[allow(unused)] - pub(crate) fn drop_session(&self, session_id: Uuid) { - self._inner._open_sessions.remove(&session_id); - } -} - -struct ServerInner { - _open_sessions: DashMap, -} - -struct SessionState { - uuid: Uuid, - - /// The server that contains this session. - _bound_server: Server, - - /// A semaphore limiting the number of concurrent asynchronous tasks that this session can - /// spawn to [`Session::MAX_CONCURRENT_TASKS`]. - _task_semaphore: Arc, - - /// When set, the timestamp at which the session was closed. When present, downstream tasks - /// should gracefully terminate. - dead_from: OnceCell, - - /// The channel from the main thread to the primary task thread for this session. - process_tx: Sender, - - /// The channel from the primary task thread for this session to the main thread. - outgoing_rx: Receiver, - - /// If present, bound callback handlers waiting for one or more messages on a given channel. - _bound_channels: DashMap, ahash::RandomState>, +/// Return the [`Duration`] that cleanup threads should wait before looking for and freeing unused +/// resources, after last doing so. +pub fn cleanup_interval() -> Duration { + // TODO this should come from the settings + Duration::from_secs(20) } - -/// A transport-agnostic session within a [`Server`]. -/// -/// Typically, you'll want to wire this into some other sender/receiver pair for -/// [`Envelope`][edgen_proto::Envelope]s, such as a WebSocket or HTTP connection, via -/// [`Session::receive_incoming`][Session::receive_incoming] and [`Session::next_outgoing`]. -/// -/// When this structure is `Drop`ped, all downstream tasks associated with it are cancelled; -/// **the server only maintains a weak reference to session channels, and this structure is the only -/// thing actually keeping the session alive.** -pub struct Session { - /// The shared inner state of this session. - inner: Arc, -} - -impl Session { - /// The maximum number of concurrent tasks this session can have outstanding on its [`Server`]. - /// - /// Beyond this limit, new requests may start to block as old ones finish. - pub const MAX_CONCURRENT_TASKS: usize = 32; - - /// Returns the UUID of this session. - pub fn id(&self) -> Uuid { - self.inner.uuid - } - - /// Resolves, eventually, to the next outgoing [`Envelope`] for this session. - /// - /// This envelope is outbound from this session, and is destined for another peer. - /// - /// Resolves to `None` if the session has been dropped. - pub async fn next_outgoing(&self) -> Option { - self.inner.outgoing_rx.recv().await.ok() - } - - /// Forwards an incoming envelope to this session's server, dispatching upstream routing - /// handlers on a new task thread. - /// - /// This function may yield if [`MAX_CONCURRENT_TASKS`] would be exceeded, until at least one - /// such task thread is available. - pub async fn receive_incoming(&self, envelope: Envelope) { - self.inner.process_tx.send(envelope).await.ok(); - } -} - -impl Drop for Session { - fn drop(&mut self) { - let _ = self.inner.dead_from.set(OffsetDateTime::now_utc()); - } -} - -/// Whether a handler has consumed an envelope, or wishes to pass it to the next handler. -pub enum SinkMode { - /// This handler has successfully processed the envelope, and no subsequent handlers should be - /// invoked. - Sink, - - /// This handler has indicated that this envelope should be passed to the next downstream - /// handler. - Pass(Envelope), -} - -/// An asynchronous handler for incoming [`Envelope`]s. -pub trait EnvelopeHandler { - fn handle(&self, envelope: Envelope) -> impl Future + Send; -} - -impl EnvelopeHandler for F - where - F: Fn(Envelope) -> Fut + Send + Sync + 'static, - Fut: Future + Send + 'static, -{ - #[inline] - fn handle(&self, envelope: Envelope) -> impl Future + Send { - (self)(envelope) - } -} - -pub struct OrElse { - a: A, - b: B, -} - -impl EnvelopeHandler for OrElse where A: EnvelopeHandler, B: EnvelopeHandler, OrElse: Send + Sync + 'static { - #[inline] - fn handle(&self, envelope: Envelope) -> impl Future + Send { - async move { - match self.a.handle(envelope).await { - SinkMode::Pass(envelope) => self.b.handle(envelope).await, - SinkMode::Sink => SinkMode::Sink, - } - } - } -} - -pub trait EnvelopeHandlerExt: EnvelopeHandler { - #[inline] - fn or_else(self, other: B) -> OrElse where Self: Sized, B: EnvelopeHandler { - OrElse { a: self, b: other } - } -} - -impl EnvelopeHandlerExt for T where T: EnvelopeHandler {} - -#[cfg(test)] -mod test { - use super::*; -} - -*/ diff --git a/crates/edgen_core/src/llm.rs b/crates/edgen_core/src/llm.rs index 3b47be7..ef56de6 100644 --- a/crates/edgen_core/src/llm.rs +++ b/crates/edgen_core/src/llm.rs @@ -44,19 +44,20 @@ pub struct CompletionArgs { pub frequency_penalty: f32, } -/// A large language model endpoint, that is, an object that provides various ways to interact with a large -/// language model. +/// A large language model endpoint, that is, an object that provides various ways to interact with +/// a large language model. pub trait LLMEndpoint { - /// Given a prompt with several arguments, return a [`Box`]ed [`Future`] which may eventually contain the prompt - /// completion in [`String`] form. + /// Given a prompt with several arguments, return a [`Box`]ed [`Future`] which may eventually + /// contain the prompt completion in [`String`] form. fn chat_completions<'a>( &'a self, model_path: impl AsRef + Send + 'a, args: CompletionArgs, ) -> Box> + Send + Unpin + 'a>; - /// Given a prompt with several arguments, return a [`Box`]ed [`Future`] which may eventually contain a [`Stream`] - /// of [`String`] chunks of the prompt completion, acquired as they get processed. + /// Given a prompt with several arguments, return a [`Box`]ed [`Future`] which may eventually + /// contain a [`Stream`] of [`String`] chunks of the prompt completion, acquired as they get + /// processed. fn stream_chat_completions<'a>( &'a self, model_path: impl AsRef + Send + 'a, @@ -72,23 +73,16 @@ pub trait LLMEndpoint { fn reset(&self); } -/// Return the [`Duration`] for which a large language model lives while not being used before being unloaded from -/// memory. +/// Return the [`Duration`] for which a large language model lives while not being used before +/// being unloaded from memory. pub fn inactive_llm_ttl() -> Duration { // TODO this should come from the settings Duration::from_secs(5 * 60) } -/// Return the [`Duration`] for which a large language model session lives while not being used before being unloaded -/// from memory. +/// Return the [`Duration`] for which a large language model session lives while not being used +/// before being unloaded from memory. pub fn inactive_llm_session_ttl() -> Duration { // TODO this should come from the settings Duration::from_secs(2 * 60) } - -/// Return the [`Duration`] that cleanup threads should wait before looking for and freeing unused resources, after -/// last doing so. -pub fn cleanup_interval() -> Duration { - // TODO this should come from the settings - Duration::from_secs(20) -} diff --git a/crates/edgen_core/src/whisper.rs b/crates/edgen_core/src/whisper.rs index e5de5b4..65571c6 100644 --- a/crates/edgen_core/src/whisper.rs +++ b/crates/edgen_core/src/whisper.rs @@ -10,394 +10,189 @@ * limitations under the License. */ -use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::Arc; +use std::path::Path; +use std::time::Duration; -use futures::executor::block_on; +use rubato::Resampler; use serde::Serialize; use smol::future::Future; use thiserror::Error; -use tokio::spawn; -use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender}; -use tokio::sync::{mpsc, RwLock}; -use tokio::task::JoinHandle; +use tracing::info; use utoipa::ToSchema; +use uuid::Uuid; + +#[derive(Serialize, Error, Debug)] +pub enum WhisperEndpointError { + #[error("failed to advance context: {0}")] + Advance(String), + #[error("failed to decode result: {0}")] + Decode(String), + #[error("failed to load the model: {0}")] + Load(String), + #[error("failed to create a session: {0}")] + Session(String), + #[error("failed to parse audio file data: {0}")] + Audio(#[from] AudioError), +} -#[derive(Serialize, Error, ToSchema, Debug)] -pub enum WhisperError { - #[error("{mime:?} mime is unsupported, this executor supports: {supported:?}")] - UnsupportedMime { mime: String, supported: String }, - #[error("audio has unsupported sample rate {value:?}, executor supports: {supported:?}")] - UnsupportedSampleRate { value: u32, supported: String }, - #[error("could not parse input data: {0}")] - Parsing(String), - #[error("failed to run the internal executor: {0}")] - Internal(String), - #[error("failed to load model: {0}")] - ModelInitialization(String), - #[error("failed to create a new session: {0}")] - SessionInitialization(String), - #[error("failed to prepare the execution: {0}")] - Other(String), +pub struct TranscriptionArgs { + pub file: Vec, + pub language: Option, + pub prompt: Option, + pub temperature: Option, + pub session: Option, } -pub trait Whisper { - fn decode<'a>( +pub trait WhisperEndpoint { + /// Given an audio segment with several arguments, return a [`Box`]ed [`Future`] which may + /// eventually contain its transcription in [`String`] form. + fn transcription<'a>( &'a self, - data: &'a [f32], - ) -> Box> + Send + Unpin + 'a>; + model_path: impl AsRef + Send + 'a, + args: TranscriptionArgs, + ) -> Box> + Send + Unpin + 'a>; - fn new_session<'a>( - &'a self, - callback: Box, - ) -> Box> + Send + Unpin + 'a>; + /// Unloads everything from memory. + fn reset(&self); } -pub trait WhisperRunner { - fn forward_decode( - &mut self, - data: &[f32], - ) -> impl Future> + Send; +/// Return the [`Duration`] for which a whisper model lives while not being used before being +/// unloaded from memory. +pub fn inactive_whisper_ttl() -> Duration { + // TODO this should come from the settings + Duration::from_secs(5 * 60) } -#[derive(Serialize, Error, ToSchema, Debug)] -pub enum DecodeSessionError { - #[error("the session has already been closed or aborted")] - SessionOver, - #[error("failed to send input to worker thread: {0}")] - Send(String), - #[error("error occurred in the session runner: {0}")] - SessionRunner(#[from] SessionRunnerError), - #[error("failed to join runner thread: {0}")] - Join(String), +/// Return the [`Duration`] for which a whisper model session lives while not being used before +/// being unloaded from memory. +pub fn inactive_whisper_session_ttl() -> Duration { + // TODO this should come from the settings + Duration::from_secs(2 * 60) } -pub struct WhisperSession { - current_result: Arc>, - tx: Option>>, - rx: Option>, - runner: Option>>, - closed: Arc, +#[derive(Serialize, Error, ToSchema, Debug)] +pub enum AudioError { + #[error("failed to parse mime data: {0}")] + Parse(String), + #[error("failed to initialise resampler: {0}")] + ResamplerInit(String), + #[error("failed to resample the input audio: {0}")] + Resample(String), } -// TODO turn session into a receiver-transmitter pair -impl WhisperSession { - /// Creates a new [`WhisperSession`]. - /// - /// ## Arguments - /// * `runner` - A *Whisper* backend instance that implements [`WhisperRunner`] - /// * `callback` - A callback function (of type [`Fn(String)`]) that is called for every item - /// processed, receiving the result of the item as input - pub fn new(runner: impl WhisperRunner + Send + 'static, callback: F) -> Self - where - F: Fn(String) + Send + 'static, - { - let (tx, rx) = mpsc::unbounded_channel(); - let (status_tx, status_rx) = mpsc::unbounded_channel(); - let mut session = Self { - current_result: Arc::new(RwLock::new("".to_string())), - tx: Some(tx), - rx: Some(status_rx), - runner: None, - closed: Arc::new(AtomicBool::new(false)), - }; - - session.runner = Some(spawn(run_session( - runner, - rx, - status_tx, - callback, - session.current_result.clone(), - session.closed.clone(), - ))); - - session - } - - /// Push an item to the session queue to be processed. - /// - /// Does nothing if the provided slice is empty. - /// - /// ## Arguments - /// * `data` - A [`f32`] normalized slice of *PCM* data to be pushed to the queue - pub fn push(&self, data: &[f32]) -> Result<(), DecodeSessionError> { - if data.is_empty() { - return Ok(()); - } - - self.unchecked_push(data) - } - - /// Push an item to the session queue to be processed. - /// - /// ## Arguments - /// * `data` -A [`f32`] normalized slice of *PCM* data to be pushed to the queue - fn unchecked_push(&self, data: &[f32]) -> Result<(), DecodeSessionError> { - if let Some(tx) = &self.tx { - tx.send(data.to_vec()) - .map_err(move |e| DecodeSessionError::Send(e.to_string()))?; - Ok(()) - } else { - Err(DecodeSessionError::SessionOver) - } - } - - /// Returns the current cumulative [`String`] result of the session. - pub async fn result(&self) -> String { - let locked = self.current_result.read().await; - locked.clone() - } - - /// Wait for every item submitted to the session queue up to this point to processed. - pub async fn sync(&mut self) -> Result<(), DecodeSessionError> { - if self.closed() { - return Err(DecodeSessionError::SessionOver); - } - - let empty: Vec = vec![]; - self.unchecked_push(&empty)?; +pub fn parse_pcm(audio_file: &[u8]) -> Result, AudioError> { + use rubato::{SincFixedIn, SincInterpolationParameters, SincInterpolationType, WindowFunction}; + use symphonia::core::audio::Signal; + use symphonia::core::codecs::{DecoderOptions, CODEC_TYPE_NULL}; + use symphonia::core::errors::Error; + use symphonia::core::formats::FormatOptions; + use symphonia::core::io::MediaSourceStream; + use symphonia::core::meta::MetadataOptions; + use symphonia::core::probe::Hint; - if let Some(rx) = &mut self.rx { - let _ = rx.recv().await; - } else { - // Should never happen, since we check if the session is already closed above. - return Err(DecodeSessionError::SessionOver); - } + /// The optimal sample rate for whisper models. + const OPTIMAL_SAMPLE_RATE: u32 = 16000; - Ok(()) - } - - /// Terminates the session and waits for the remaining items in the queue to be processed and - /// joining with the worker thread. - pub async fn end(&mut self) -> Result<(), DecodeSessionError> { - { - let _ = self.rx.take(); - } + info!("Parsing audio file ({} bytes)", audio_file.len()); - let empty: Vec = vec![]; - self.unchecked_push(&empty)?; + // Initialisation. + let cursor = std::io::Cursor::new(audio_file.to_vec()); + let stream = MediaSourceStream::new(Box::new(cursor), Default::default()); - if let Some(tx) = self.tx.take() { - tx.closed().await; - } + let hint = Hint::new(); - if let Some(runner) = self.runner.take() { - runner - .await - .map_err(move |e| DecodeSessionError::Join(e.to_string()))??; - } + let meta_opts: MetadataOptions = Default::default(); + let fmt_opts: FormatOptions = Default::default(); - Ok(()) - } + // TODO this gets stuck in a loop for some invalid files + let probed = symphonia::default::get_probe() + .format(&hint, stream, &fmt_opts, &meta_opts) + .map_err(move |e| AudioError::Parse(format!("failed to probe audio data: {e}")))?; - /// Abruptly ends the session, joining with the worker thread as soon as possible. Any remaining - /// items in the queue are ignored. - pub async fn abort(&mut self) -> Result<(), DecodeSessionError> { - if self.closed() { - return Err(DecodeSessionError::SessionOver); - } + let mut format = probed.format; + let track = format + .tracks() + .iter() + .find(|t| t.codec_params.codec != CODEC_TYPE_NULL) + .ok_or(AudioError::Parse("codec is null".to_string()))?; - self.closed.store(true, Ordering::Relaxed); + let dec_opts: DecoderOptions = Default::default(); - if let Some(tx) = self.tx.take() { - // Make sure that the runner iterates at least once more - tx.send(vec![]) - .map_err(move |e| DecodeSessionError::Send(e.to_string()))?; - } + let mut decoder = symphonia::default::get_codecs() + .make(&track.codec_params, &dec_opts) + .map_err(move |e| AudioError::Parse(format!("failed to initialize decoder: {e}")))?; - if let Some(runner) = self.runner.take() { - runner - .await - .map_err(move |e| DecodeSessionError::Join(e.to_string()))??; - } + let track_id = track.id; - Ok(()) - } + let sample_rate = track + .codec_params + .sample_rate + .ok_or_else(move || AudioError::Parse("could not get sample rate".to_string()))?; - /// Checks if the session has been terminated. - pub fn closed(&self) -> bool { - self.tx.is_none() - || self.rx.is_none() - || self.runner.is_none() - || self.closed.load(Ordering::Relaxed) - } -} + let mut samples = vec![]; -impl Drop for WhisperSession { - fn drop(&mut self) { - let e = block_on(self.abort()); - match e { - // Nothing to do, session got terminated previously - Err(DecodeSessionError::SessionOver) => { /* nothing */ } - Err(DecodeSessionError::Send(_)) => { - println!("Failed to send end signal: {e:?}"); + // Decoding loop. + loop { + let packet = match format.next_packet() { + Ok(packet) => packet, + Err(Error::ResetRequired) => { + break; } - Err(DecodeSessionError::SessionRunner(_)) => { - println!("Error occurred while session was running: {e:?}"); + Err(Error::IoError(e)) => { + // TODO this isnt ideal, but gonna have to wait for symphonia to be updated + // https://github.com/pdeljanov/Symphonia/issues/134#issuecomment-1146990539 + if e.kind() == std::io::ErrorKind::UnexpectedEof && e.to_string() == "end of stream" + { + break; + } else { + return Err(AudioError::Parse(format!("unexpected end of file: {e:#?}"))); + } } - Err(DecodeSessionError::Join(_)) => { - // todo should this panic? - println!("Failed to join runner thread: {e:?}"); + Err(e) => { + return Err(AudioError::Parse(format!( + "failed to acquire next packet: {e}" + ))); } - _ => { /* nothing */ } - } - } -} - -#[derive(Serialize, Error, ToSchema, Debug)] -pub enum SessionRunnerError { - #[error("could not run the executor: {0}")] - Executor(#[from] WhisperError), -} - -/// The runtime of the session runner thread. -/// -/// ## Arguments -/// * `runner` - -/// * `rx` - -/// * `tx` - -/// * `callback` - -/// * `current_result` - -/// * `closed` - -async fn run_session( - mut runner: R, - mut rx: UnboundedReceiver>, - tx: UnboundedSender, - callback: F, - current_result: Arc>, - closed: Arc, -) -> Result<(), SessionRunnerError> -where - R: WhisperRunner, - F: Fn(String), -{ - while let Some(data) = rx.recv().await { - if closed.load(Ordering::Relaxed) { - rx.close(); - break; - } - - if data.is_empty() { - if tx.send(0.0).is_err() { - // session.end() was called, probably - break; - } else { - continue; - } - } - - let segment = runner.forward_decode(&data).await?; - { - let mut locked = current_result.write().await; - let len = locked.len(); - locked.insert_str(len, &segment); - } - callback(segment); - } - - closed.store(true, Ordering::Relaxed); - rx.close(); - - Ok(()) -} - -#[cfg(test)] -mod tests { - use std::future::Future; - - use thiserror::Error; - - use crate::whisper::{ - DecodeSessionError, Whisper, WhisperError, WhisperRunner, WhisperSession, - }; - - #[derive(Error, Debug)] - enum TestError { - #[error("whisper error: {0}")] - Whisper(#[from] WhisperError), - #[error("decode session error: {0}")] - Session(#[from] DecodeSessionError), - } - - struct TestWhisper {} - - impl TestWhisper { - async fn async_decode(&self, _data: &[f32]) -> Result { - Ok("decode".to_string()) - } - - async fn async_new_session( - &self, - callback: Box, - ) -> Result { - Ok(WhisperSession::new(TestRunner {}, callback)) - } - } - - impl Whisper for TestWhisper { - fn decode<'a>( - &'a self, - data: &'a [f32], - ) -> Box> + Send + Unpin + 'a> { - let fut = Box::pin(self.async_decode(data)); - Box::new(fut) - } + }; - fn new_session<'a>( - &'a self, - callback: Box, - ) -> Box> + Send + Unpin + 'a> - { - let fut = Box::pin(self.async_new_session(callback)); - Box::new(fut) + if packet.track_id() != track_id { + continue; } - } - struct TestRunner {} + let decoded = decoder + .decode(&packet) + .map_err(move |e| AudioError::Parse(format!("failed to decode packet: {e}")))?; - impl WhisperRunner for TestRunner { - async fn forward_decode(&mut self, _data: &[f32]) -> Result { - Ok("forward".to_string()) - } + let mut sample_slice = decoded.make_equivalent::(); + decoded.convert(&mut sample_slice); + samples.extend_from_slice(sample_slice.chan(0)); } - #[tokio::test] - async fn decode() -> Result<(), TestError> { - let test = TestWhisper {}; - - let e: Vec = vec![]; - let res = test.decode(&e).await?; + // Resample the pcm data if necessary. + if sample_rate != OPTIMAL_SAMPLE_RATE { + let params = SincInterpolationParameters { + sinc_len: 256, + f_cutoff: 0.95, + interpolation: SincInterpolationType::Linear, + oversampling_factor: 256, + window: WindowFunction::BlackmanHarris2, + }; - assert_eq!(res, "decode".to_string()); + let mut resampler = SincFixedIn::::new( + OPTIMAL_SAMPLE_RATE as f64 / sample_rate as f64, + 2.0, + params, + samples.len(), + 1, + ) + .map_err(move |e| AudioError::ResamplerInit(e.to_string()))?; - Ok(()) + let pre: Vec<_> = samples.drain(..).map(move |x| x as f64).collect(); + let mut resampled = resampler + .process(&[pre], None) + .map_err(move |e| AudioError::Resample(e.to_string()))?; + samples = resampled[0].drain(..).map(move |x| x as f32).collect(); } - #[tokio::test] - async fn session() -> Result<(), TestError> { - let test = TestWhisper {}; - - let mut session = test - .new_session(Box::new(move |e| assert_eq!(e, "forward".to_string()))) - .await?; - - let e: Vec = vec![1.0]; - - for _ in 0..3 { - session.push(&e)?; - } - session.sync().await?; - assert_eq!(session.result().await, "forwardforwardforward".to_string()); - - for _ in 0..2 { - session.push(&e)?; - } - session.end().await?; - assert_eq!( - session.result().await, - "forwardforwardforwardforwardforward".to_string() - ); - - Ok(()) - } + Ok(samples) } diff --git a/crates/edgen_rt_llama_cpp/src/lib.rs b/crates/edgen_rt_llama_cpp/src/lib.rs index 8aa28c7..3a08932 100644 --- a/crates/edgen_rt_llama_cpp/src/lib.rs +++ b/crates/edgen_rt_llama_cpp/src/lib.rs @@ -28,9 +28,10 @@ use tokio::time::{interval, MissedTickBehavior}; use tokio::{select, spawn}; use tracing::{error, info}; +use edgen_core::cleanup_interval; use edgen_core::llm::{ - cleanup_interval, inactive_llm_session_ttl, inactive_llm_ttl, CompletionArgs, LLMEndpoint, - LLMEndpointError, ASSISTANT_TAG, SYSTEM_TAG, TOOL_TAG, USER_TAG, + inactive_llm_session_ttl, inactive_llm_ttl, CompletionArgs, LLMEndpoint, LLMEndpointError, + ASSISTANT_TAG, SYSTEM_TAG, TOOL_TAG, USER_TAG, }; use edgen_core::perishable::{ActiveSignal, Perishable, PerishableReadGuard, PerishableWriteGuard}; use edgen_core::settings::SETTINGS; diff --git a/crates/edgen_rt_whisper_cpp/Cargo.toml b/crates/edgen_rt_whisper_cpp/Cargo.toml index 71c3e2a..6e5bbec 100644 --- a/crates/edgen_rt_whisper_cpp/Cargo.toml +++ b/crates/edgen_rt_whisper_cpp/Cargo.toml @@ -13,5 +13,7 @@ derive_more = { workspace = true } thiserror = { workspace = true } tokio = { workspace = true, features = ["sync"] } once_cell = { workspace = true } +tracing = { workspace = true } +uuid = { workspace = true, features = ["v4"] } whisper_cpp = { workspace = true, features = ["native"] } diff --git a/crates/edgen_rt_whisper_cpp/src/lib.rs b/crates/edgen_rt_whisper_cpp/src/lib.rs index 46ec890..d9be530 100644 --- a/crates/edgen_rt_whisper_cpp/src/lib.rs +++ b/crates/edgen_rt_whisper_cpp/src/lib.rs @@ -11,124 +11,263 @@ */ use std::future::Future; +use std::path::{Path, PathBuf}; +use std::sync::Arc; -use edgen_core::settings::SETTINGS; -use thiserror::Error; -use whisper_cpp::{WhisperModel, WhisperParams, WhisperSampling}; +use dashmap::DashMap; +use futures::executor::block_on; +use tokio::sync::mpsc::{unbounded_channel, UnboundedSender}; +use tokio::task::JoinHandle; +use tokio::time::{interval, MissedTickBehavior}; +use tokio::{select, spawn}; +use tracing::info; +use uuid::Uuid; +use whisper_cpp::{WhisperModel, WhisperParams, WhisperSampling, WhisperSession}; -use edgen_core::whisper::{Whisper, WhisperError, WhisperRunner, WhisperSession}; +use edgen_core::cleanup_interval; +use edgen_core::perishable::{ActiveSignal, Perishable, PerishableReadGuard, PerishableWriteGuard}; +use edgen_core::settings::SETTINGS; +use edgen_core::whisper::{ + inactive_whisper_session_ttl, inactive_whisper_ttl, parse_pcm, TranscriptionArgs, + WhisperEndpoint, WhisperEndpointError, +}; -#[derive(Error, Debug)] -enum WhisperCppError { - #[error("failed initialize whisper.cpp model: {0}")] - Initialization(#[from] whisper_cpp::WhisperError), -} +/// A large language model endpoint, implementing [`WhisperEndpoint`] using a [`whisper_cpp`] backend. +pub struct WhisperCppEndpoint { + /// A map of the models currently loaded into memory, with their path as the key. + models: Arc>, -pub struct WhisperCpp { - model: WhisperModel, + /// A background thread that periodically removes models from the `models` collection, if they + /// are not loaded at the time. + cleanup_thread: JoinHandle<()>, } -impl WhisperCpp { - pub fn load

(model_path: P) -> Result - where - P: AsRef, - { - Ok(Self { - model: WhisperModel::new_from_file(model_path, false) - .map_err(move |e| WhisperError::ModelInitialization(e.to_string()))?, - }) - } - - async fn async_decode(&self, data: &[f32]) -> Result { - let mut session: whisper_cpp::WhisperSession = self - .model - .new_session() - .await - .map_err(move |e| WhisperError::SessionInitialization(e.to_string()))?; +impl WhisperCppEndpoint { + /// Gets the [`UnloadingModel`] loaded from the specified path. If the model isn't already + /// loaded, first initialise it and add it to the `models` collection. + async fn get( + &self, + model_path: impl AsRef, + ) -> dashmap::mapref::one::Ref { + let key = model_path.as_ref().to_string_lossy().to_string(); - let mut params = WhisperParams::new(WhisperSampling::default_greedy()); - params.thread_count = SETTINGS.read().await.read().await.auto_threads(false); - - session - .full(params, data) - .await - .map_err(move |e| WhisperError::Internal(e.to_string()))?; - - let mut res = "".to_string(); - for i in 0..session.segment_count() { - res += &*session - .segment_text(i) - .map_err(move |e| WhisperError::Internal(e.to_string()))?; + if !self.models.contains_key(&key) { + let model = UnloadingModel::new(model_path).await; + self.models.insert(key.clone(), model); } - Ok(res) + // PANIC SAFETY: Just inserted the element if it isn't already inside the map, so must be present in the map + self.models.get(&key).unwrap() } - async fn async_new_session( + async fn async_transcription( &self, - callback: Box, - ) -> Result { - Ok(WhisperSession::new( - WhisperCppRunner { - session: self - .model - .new_session() - .await - .map_err(move |e| WhisperError::SessionInitialization(e.to_string()))?, - }, - callback, - )) + model_path: impl AsRef, + args: TranscriptionArgs, + ) -> Result { + let pcm = parse_pcm(&args.file)?; + let model = self.get(model_path).await; + model.transcription(args.session, pcm).await } } -impl Whisper for WhisperCpp { - fn decode<'a>( +impl WhisperEndpoint for WhisperCppEndpoint { + fn transcription<'a>( &'a self, - data: &'a [f32], - ) -> Box> + Send + Unpin + 'a> { - // todo are the 2 boxes really needed? - let fut = Box::pin(self.async_decode(data)); - Box::new(fut) + model_path: impl AsRef + Send + 'a, + args: TranscriptionArgs, + ) -> Box> + Send + Unpin + 'a> { + let pinned = Box::pin(self.async_transcription(model_path, args)); + Box::new(pinned) } - fn new_session<'a>( - &'a self, - callback: Box, - ) -> Box> + Send + Unpin + 'a> { - // todo are the 2 boxes really needed? - let fut = Box::pin(self.async_new_session(callback)); - Box::new(fut) + fn reset(&self) { + self.models.clear(); + } +} + +impl Default for WhisperCppEndpoint { + fn default() -> Self { + let models: Arc> = Default::default(); + let models_clone = models.clone(); + let cleanup_thread = spawn(async move { + let mut interval = interval(cleanup_interval()); + interval.set_missed_tick_behavior(MissedTickBehavior::Delay); + + loop { + interval.tick().await; + models_clone.retain(move |_, model| block_on(model.loaded())); + } + }); + + Self { + models, + cleanup_thread, + } } } -struct WhisperCppRunner { - session: whisper_cpp::WhisperSession, +impl Drop for WhisperCppEndpoint { + fn drop(&mut self) { + self.cleanup_thread.abort() + } } -impl WhisperRunner for WhisperCppRunner { - async fn forward_decode(&mut self, data: &[f32]) -> Result { - let params = WhisperParams::new(WhisperSampling::default_greedy()); - - self.session - .full(params, data) - .await - .map_err(move |e| WhisperError::Internal(e.to_string()))?; - - let _segment_count = self.session.segment_count(); - /*self - .session - .segment_text(segment_count - 1) - .map_err(move |e| WhisperError::Internal(e.to_string()))*/ - - let mut res = "".to_string(); - for i in 0..self.session.segment_count() { - // we should review if this is correct! - res += &*self - .session - .segment_text(i) - .map_err(move |e| WhisperError::Internal(e.to_string()))?; +/// A [`WhisperModel`] (as well as its associated [`WhisperSession`]s) that unloads itself from +/// memory after not being used for a period of time. +struct UnloadingModel { + model: Perishable, + path: PathBuf, + sessions: Arc>>, + maintenance_thread: JoinHandle<()>, + finished_tx: UnboundedSender<(Uuid, Perishable)>, +} + +impl UnloadingModel { + /// Creates a new instance of this model, provided it's [`Path`]. + /// + /// This function is lazy and does not actually load the model into system memory, the model must be accessed in + /// order to be loaded. + async fn new(model_path: impl AsRef) -> Self { + let sessions: Arc>> = Default::default(); + let (tx, mut rx) = unbounded_channel(); + + let sessions_clone = sessions.clone(); + let maintenance_thread = spawn(async move { + let mut interval = interval(cleanup_interval()); + interval.set_missed_tick_behavior(MissedTickBehavior::Delay); + + loop { + select! { + _ = interval.tick() => sessions_clone.retain(move |_, session| block_on(session.is_alive())), + item = rx.recv() => { + if let Some((id, session)) = item { + sessions_clone.insert(id, session); + } + } + } + } + }); + + Self { + model: Perishable::with_ttl(inactive_whisper_ttl()), + path: model_path.as_ref().to_path_buf(), + sessions, + maintenance_thread, + finished_tx: tx, + } + } + + /// Returns **`true`** if this model is currently loaded in system memory, **`false`** otherwise. + async fn loaded(&self) -> bool { + self.model.is_alive().await + } + + /// Either takes an existing chat [`WhisperSession`] matching the provided [`Uuid`], or creates + /// a new one. + async fn take_session(&self, uuid: Uuid) -> Perishable { + let session_perishable = if let Some((_, session)) = self.sessions.remove(&uuid) { + info!("Matching session found, continuing"); + session + } else { + info!("No matching session found, creating new one"); + Perishable::with_ttl(inactive_whisper_session_ttl()) + }; + + session_perishable + } + + async fn transcription( + &self, + uuid: Option, + pcm: Vec, + ) -> Result { + let (_model_signal, model_guard) = get_or_init_model(&self.model, &self.path).await?; + + let mut params = WhisperParams::new(WhisperSampling::default_greedy()); + let threads = SETTINGS.read().await.read().await.auto_threads(false); + + params.thread_count = threads; + + if let Some(uuid) = uuid { + let session = self.take_session(uuid).await; + + let (_session_signal, mut session_guard) = { + let (session_signal, mut session_guard) = + get_or_init_session(&session, model_guard.clone()).await?; + + (session_signal, session_guard) + }; + + session_guard + .full(params, &pcm) + .await + .map_err(move |e| WhisperEndpointError::Advance(e.to_string()))?; + + let mut res = "".to_string(); + for i in 0..session_guard.segment_count() { + res += &*session_guard + .segment_text(i) + .map_err(move |e| WhisperEndpointError::Decode(e.to_string()))?; + } + + Ok(res) + } else { + let mut session = model_guard + .new_session() + .await + .map_err(move |e| WhisperEndpointError::Session(e.to_string()))?; + + session + .full(params, &pcm) + .await + .map_err(move |e| WhisperEndpointError::Advance(e.to_string()))?; + + let mut res = "".to_string(); + for i in 0..session.segment_count() { + res += &*session + .segment_text(i) + .map_err(move |e| WhisperEndpointError::Decode(e.to_string()))?; + } + + Ok(res) } + } +} - Ok(res) +impl Drop for UnloadingModel { + fn drop(&mut self) { + self.maintenance_thread.abort() } } + +/// Helper function to acquire a read guard to a [`WhisperModel`] (and its associated +/// [`ActiveSignal`]). +async fn get_or_init_model( + model: &Perishable, + path: impl AsRef, +) -> Result<(ActiveSignal, PerishableReadGuard), WhisperEndpointError> { + let path = path.as_ref().to_path_buf(); + model + .get_or_try_init(move || async move { + WhisperModel::new_from_file(path, false) + .map_err(move |e| WhisperEndpointError::Load(e.to_string())) + }) + .await +} + +/// Helper function to acquire a write guard to a [`WhisperSession`] (and its associated +/// [`ActiveSignal`]). +async fn get_or_init_session( + session: &Perishable, + model: WhisperModel, +) -> Result<(ActiveSignal, PerishableWriteGuard), WhisperEndpointError> { + session + .get_or_try_init_mut(move || async move { + model + .new_session() + .await + .map_err(move |e| WhisperEndpointError::Session(e.to_string())) + }) + .await +} diff --git a/crates/edgen_server/src/lib.rs b/crates/edgen_server/src/lib.rs index 870d167..743d6d6 100644 --- a/crates/edgen_server/src/lib.rs +++ b/crates/edgen_server/src/lib.rs @@ -32,7 +32,6 @@ use utoipa::OpenApi; use edgen_core::settings; use edgen_core::settings::SETTINGS; -use edgen_core::whisper::{DecodeSessionError, SessionRunnerError, WhisperError}; use openai_shim as chat; use openai_shim as audio; @@ -76,11 +75,7 @@ mod whisper; openai_shim::AssistantFunctionStub, openai_shim::AssistantToolCall, openai_shim::CreateTranscriptionRequest, - whisper::WhisperEndpointError, - whisper::AudioError, - WhisperError, - DecodeSessionError, - SessionRunnerError, + openai_shim::TranscriptionError, model::ModelError, model::ModelKind, )) diff --git a/crates/edgen_server/src/llm.rs b/crates/edgen_server/src/llm.rs index f6915a8..98c4ff7 100644 --- a/crates/edgen_server/src/llm.rs +++ b/crates/edgen_server/src/llm.rs @@ -12,10 +12,8 @@ use futures::Stream; use once_cell::sync::Lazy; -use serde_derive::Serialize; -use thiserror::Error; -use edgen_core::llm::{CompletionArgs, LLMEndpoint}; +use edgen_core::llm::{CompletionArgs, LLMEndpoint, LLMEndpointError}; use edgen_rt_llama_cpp::LlamaCppEndpoint; use crate::model::Model; @@ -23,20 +21,6 @@ use crate::util::StoppingStream; static ENDPOINT: Lazy = Lazy::new(Default::default); -#[derive(Serialize, Error, Debug)] -pub enum LLMEndpointError { - #[error("the provided model file name does does not exist, or isn't a file: ({0})")] - FileNotFound(String), - #[error("there is no session associated with the provided uuid ({0})")] - SessionNotFound(String), - #[error("failed to run inference: {0}")] - Inference(#[from] edgen_core::llm::LLMEndpointError), - #[error("failed to load model: {0}")] - Model(#[from] crate::model::ModelError), -} - -// TODO use this -#[allow(dead_code)] pub async fn chat_completion(model: Model, context: String) -> Result { let args = CompletionArgs { prompt: context, @@ -44,7 +28,14 @@ pub async fn chat_completion(model: Model, context: String) -> Result where S: TryStream + Send + 'static, @@ -566,7 +568,7 @@ post, path = "/chat/completions", request_body = CreateChatCompletionRequest, responses( -(status = 200, description = "OK", body = ChatCompletion), +(status = 200, description = "OK", body = ChatCompletionResponse), (status = 500, description = "unexpected internal server error", body = ChatCompletionError) ), )] @@ -728,12 +730,12 @@ path = "/audio/transcriptions", request_body = CreateTranscriptionRequest, responses( (status = 200, description = "OK", body = String), -(status = 500, description = "unexpected internal server error", body = WhisperEndpointError) +(status = 500, description = "unexpected internal server error", body = TranscriptionError) ), )] pub async fn create_transcription( req: TypedMultipart, -) -> Result { +) -> Result { // For MVP1, the model string in the request is *always* ignored. let model_name = SETTINGS .read() @@ -754,7 +756,10 @@ pub async fn create_transcription( // invalid if model_name.is_empty() { - return Err(WhisperEndpointError::FileNotFound(model_name)); + return Err(TranscriptionError::ProhibitedName { + model_name, + reason: Cow::Borrowed("Empty name"), + }); } let mut model = Model::new( @@ -766,11 +771,6 @@ pub async fn create_transcription( model.preload().await?; - model - .preload() - .await - .map_err(move |_| WhisperEndpointError::FileNotFound(model_name))?; - let res = crate::whisper::create_transcription( &req.file.contents, model, @@ -783,7 +783,45 @@ pub async fn create_transcription( Ok(res.into_boxed_str()) } -impl IntoResponse for WhisperEndpointError { +/// An error condition raised by the audio transcription API. +/// +/// This is **not normative** with OpenAI's specification, which does not document any specific +/// failure modes. +#[derive(Serialize, Error, ToSchema, Debug)] +#[serde(rename_all = "snake_case")] +#[serde(tag = "error")] +pub enum TranscriptionError { + /// The provided model could not be found on the local system. + #[error("no such model: {model_name}")] + NoSuchModel { + /// The name of the model. + model_name: String, + }, + + /// The provided model name contains prohibited characters. + #[error("model {model_name} could not be fetched from the system: {reason}")] + ProhibitedName { + /// The name of the model provided. + model_name: String, + + /// A human-readable error message. + reason: Cow<'static, str>, + }, + + /// The provided model could not be preloaded. + #[error("failed to preload the model: {0}")] + Preload(#[from] ModelError), + + /// An error occurred on the other side of an FFI boundary. + #[error("an error occurred on the other side of a C FFI boundary; check `tracing`")] + Ffi, + + /// An error occurred while processing the request to this endpoint. + #[error("an error occurred while processing the request: {0}")] + Endpoint(#[from] WhisperEndpointError), +} + +impl IntoResponse for TranscriptionError { fn into_response(self) -> Response { (StatusCode::INTERNAL_SERVER_ERROR, Json(self)).into_response() } diff --git a/crates/edgen_server/src/whisper.rs b/crates/edgen_server/src/whisper.rs index 239a71c..6428758 100644 --- a/crates/edgen_server/src/whisper.rs +++ b/crates/edgen_server/src/whisper.rs @@ -10,30 +10,14 @@ * limitations under the License. */ -use std::sync::Arc; - -use dashmap::DashMap; use once_cell::sync::Lazy; -use rubato::Resampler; -use serde_derive::Serialize; -use thiserror::Error; -use time::Duration; -use tracing::info; -use utoipa::ToSchema; -use uuid::Uuid; -use edgen_core::whisper::{DecodeSessionError, Whisper, WhisperError, WhisperSession}; -use edgen_rt_whisper_cpp::WhisperCpp; +use edgen_core::whisper::{TranscriptionArgs, WhisperEndpoint, WhisperEndpointError}; +use edgen_rt_whisper_cpp::WhisperCppEndpoint; use crate::model::Model; -use crate::model::ModelError; -use crate::util::{Perishable, PerishableReadGuard}; - -static ENDPOINT: Lazy = Lazy::new(Default::default); -/// The number of seconds that a `whisper` model will remain loaded for before being automatically -/// unloaded. -pub const WHISPER_INACTIVE_TTL: Duration = Duration::seconds(5 * 60); +static ENDPOINT: Lazy = Lazy::new(Default::default); pub async fn create_transcription( file: &[u8], @@ -42,266 +26,24 @@ pub async fn create_transcription( prompt: Option<&str>, temperature: Option, ) -> Result { + let args = TranscriptionArgs { + file: file.to_vec(), + language: language.map(move |s| s.to_string()), + prompt: prompt.map(move |s| s.to_string()), + temperature, + session: None, + }; + ENDPOINT - .standalone_decode(file, model, language, prompt, temperature) + .transcription( + model + .file_path() + .map_err(move |e| WhisperEndpointError::Load(e.to_string()))?, + args, + ) .await } pub async fn reset_environment() { ENDPOINT.reset() } - -#[derive(Serialize, Error, ToSchema, Debug)] -pub enum WhisperEndpointError { - #[error("the provided model file name does does not exist, or isn't a file: ({0})")] - FileNotFound(String), - #[error("there is no session associated with the provided uuid ({0})")] - SessionNotFound(String), - #[error("internal error: {0}")] - Internal(#[from] WhisperError), - #[error("error in decode session: {0}")] - Session(#[from] DecodeSessionError), - #[error("failed to parse audio data: {0}")] - Audio(#[from] AudioError), - #[error("failed to load model: {0}")] - Model(#[from] ModelError), -} - -struct WhisperInstance { - model: Box, - _sessions: DashMap, -} - -impl WhisperInstance { - fn new(model: impl Whisper + Send + Sync + 'static) -> Self { - Self { - model: Box::new(model), - _sessions: DashMap::new(), - } - } - - async fn decode(&self, data: &[f32]) -> Result { - Ok(self.model.decode(data).await?) - } - - #[allow(dead_code)] - async fn new_session( - &self, - callback: impl Fn(String) + Send + 'static, - ) -> Result { - let session = self.model.new_session(Box::new(callback)).await?; - - let uuid = Uuid::new_v4(); - - self._sessions.insert(uuid, session); - - Ok(uuid) - } - - #[allow(dead_code)] - async fn advance_decode(&self, uuid: Uuid, data: &[f32]) -> Result<(), WhisperEndpointError> { - let session = self - ._sessions - .get(&uuid) - .ok_or(WhisperEndpointError::SessionNotFound(uuid.to_string()))?; - - session.push(data)?; - - Ok(()) - } -} - -#[derive(Default)] -pub struct WhisperEndpoint { - instances: DashMap>>, -} - -impl WhisperEndpoint { - fn get_or_create( - &self, - model: &Model, - ) -> Result>, WhisperEndpointError> { - let path = model.file_path()?; - let key = path.to_string_lossy(); - - if !self.instances.contains_key(key.as_ref()) { - self.instances.insert( - key.to_string(), - Arc::new(Perishable::with_ttl(WHISPER_INACTIVE_TTL)), - ); - } - - Ok(self.instances.get(key.as_ref()).expect("Model instance not found. This should never happen as a new instance is added if there isn't one").value().clone()) - } - - async fn lock<'a>( - instance: &'a Arc>, - model: &Model, - ) -> Result, WhisperEndpointError> { - let path = model.file_path()?; - - instance - .get_or_try_init(move || async move { - //let model_path = PROJECT_DIRS.config_dir().join(&path); - - if !path.is_file() { - return Err(WhisperEndpointError::FileNotFound( - path.to_string_lossy().to_string(), - )); - } - - // TODO change this when more backends exist - let model = WhisperCpp::load(path)?; - - Ok(WhisperInstance::new(model)) - }) - .await - } - - pub async fn standalone_decode( - &self, - audio: &[u8], - model: Model, - _language: Option<&str>, // currently not used - _prompt: Option<&str>, // currently not used - _temperature: Option, // currently not used - ) -> Result { - let instance = self.get_or_create(&model)?; - let locked = Self::lock(&instance, &model).await?; - - let pcm = to_pcm(audio)?; - - locked.decode(&pcm).await - } - - fn reset(&self) { - self.instances.clear() - } -} - -#[derive(Serialize, Error, ToSchema, Debug)] -pub enum AudioError { - #[error("failed to parse mime data: {0}")] - Parse(String), - #[error("failed to initialise resampler: {0}")] - ResamplerInit(String), - #[error("failed to resample the input audio: {0}")] - Resample(String), -} - -fn to_pcm(audio_file: &[u8]) -> Result, AudioError> { - use rubato::{SincFixedIn, SincInterpolationParameters, SincInterpolationType, WindowFunction}; - use symphonia::core::audio::Signal; - use symphonia::core::codecs::{DecoderOptions, CODEC_TYPE_NULL}; - use symphonia::core::errors::Error; - use symphonia::core::formats::FormatOptions; - use symphonia::core::io::MediaSourceStream; - use symphonia::core::meta::MetadataOptions; - use symphonia::core::probe::Hint; - - /// The optimal sample rate for whisper models. - const OPTIMAL_SAMPLE_RATE: u32 = 16000; - - info!("Parsing audio file ({} bytes)", audio_file.len()); - - // Initialisation. - let cursor = std::io::Cursor::new(audio_file.to_vec()); - let stream = MediaSourceStream::new(Box::new(cursor), Default::default()); - - let hint = Hint::new(); - - let meta_opts: MetadataOptions = Default::default(); - let fmt_opts: FormatOptions = Default::default(); - - // TODO this gets stuck in a loop for some invalid files - let probed = symphonia::default::get_probe() - .format(&hint, stream, &fmt_opts, &meta_opts) - .map_err(move |e| AudioError::Parse(format!("failed to probe audio data: {e}")))?; - - let mut format = probed.format; - let track = format - .tracks() - .iter() - .find(|t| t.codec_params.codec != CODEC_TYPE_NULL) - .ok_or(AudioError::Parse("codec is null".to_string()))?; - - let dec_opts: DecoderOptions = Default::default(); - - let mut decoder = symphonia::default::get_codecs() - .make(&track.codec_params, &dec_opts) - .map_err(move |e| AudioError::Parse(format!("failed to initialize decoder: {e}")))?; - - let track_id = track.id; - - let sample_rate = track - .codec_params - .sample_rate - .ok_or_else(move || AudioError::Parse("could not get sample rate".to_string()))?; - - let mut samples = vec![]; - - // Decoding loop. - loop { - let packet = match format.next_packet() { - Ok(packet) => packet, - Err(Error::ResetRequired) => { - break; - } - Err(Error::IoError(e)) => { - // TODO this isnt ideal, but gonna have to wait for symphonia to be updated - // https://github.com/pdeljanov/Symphonia/issues/134#issuecomment-1146990539 - if e.kind() == std::io::ErrorKind::UnexpectedEof && e.to_string() == "end of stream" - { - break; - } else { - return Err(AudioError::Parse(format!("unexpected end of file: {e:#?}"))); - } - } - Err(e) => { - return Err(AudioError::Parse(format!( - "failed to acquire next packet: {e}" - ))); - } - }; - - if packet.track_id() != track_id { - continue; - } - - let decoded = decoder - .decode(&packet) - .map_err(move |e| AudioError::Parse(format!("failed to decode packet: {e}")))?; - - let mut sample_slice = decoded.make_equivalent::(); - decoded.convert(&mut sample_slice); - samples.extend_from_slice(sample_slice.chan(0)); - } - - // Resample the pcm data if necessary. - if sample_rate != OPTIMAL_SAMPLE_RATE { - let params = SincInterpolationParameters { - sinc_len: 256, - f_cutoff: 0.95, - interpolation: SincInterpolationType::Linear, - oversampling_factor: 256, - window: WindowFunction::BlackmanHarris2, - }; - - let mut resampler = SincFixedIn::::new( - OPTIMAL_SAMPLE_RATE as f64 / sample_rate as f64, - 2.0, - params, - samples.len(), - 1, - ) - .map_err(move |e| AudioError::ResamplerInit(e.to_string()))?; - - let pre: Vec<_> = samples.drain(..).map(move |x| x as f64).collect(); - let mut resampled = resampler - .process(&[pre], None) - .map_err(move |e| AudioError::Resample(e.to_string()))?; - samples = resampled[0].drain(..).map(move |x| x as f32).collect(); - } - - Ok(samples) -} From 8fa7eccda0e2ddc44b0bfc4bddd376022de729d9 Mon Sep 17 00:00:00 2001 From: Pedro Valente Date: Mon, 5 Feb 2024 16:59:28 +0000 Subject: [PATCH 22/24] chore: docs --- crates/edgen_core/src/whisper.rs | 2 ++ crates/edgen_rt_whisper_cpp/src/lib.rs | 3 +++ 2 files changed, 5 insertions(+) diff --git a/crates/edgen_core/src/whisper.rs b/crates/edgen_core/src/whisper.rs index 65571c6..d889d53 100644 --- a/crates/edgen_core/src/whisper.rs +++ b/crates/edgen_core/src/whisper.rs @@ -80,6 +80,8 @@ pub enum AudioError { Resample(String), } +/// Parse an audio file and convert it into a *PCM* audio segment, using the optimal sample rate +/// for whisper models. pub fn parse_pcm(audio_file: &[u8]) -> Result, AudioError> { use rubato::{SincFixedIn, SincInterpolationParameters, SincInterpolationType, WindowFunction}; use symphonia::core::audio::Signal; diff --git a/crates/edgen_rt_whisper_cpp/src/lib.rs b/crates/edgen_rt_whisper_cpp/src/lib.rs index d9be530..7f6d4b1 100644 --- a/crates/edgen_rt_whisper_cpp/src/lib.rs +++ b/crates/edgen_rt_whisper_cpp/src/lib.rs @@ -60,6 +60,8 @@ impl WhisperCppEndpoint { self.models.get(&key).unwrap() } + /// Helper `async` function that returns the transcription for the specified model and + /// [`TranscriptionArgs`] async fn async_transcription( &self, model_path: impl AsRef, @@ -177,6 +179,7 @@ impl UnloadingModel { session_perishable } + /// Computes the full transcription for the provided *PCM*; async fn transcription( &self, uuid: Option, From 40ae6844de66dd6a76eed83f66a507b32709176b Mon Sep 17 00:00:00 2001 From: Pedro Valente Date: Mon, 5 Feb 2024 17:02:14 +0000 Subject: [PATCH 23/24] chore: removed unused dependencies --- Cargo.lock | 2 -- crates/edgen_server/Cargo.toml | 2 -- 2 files changed, 4 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 10d98ff..27fd752 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1384,11 +1384,9 @@ dependencies = [ "levenshtein", "once_cell", "pin-project", - "rubato", "serde", "serde_derive", "serde_json", - "symphonia", "thiserror", "time", "tinyvec", diff --git a/crates/edgen_server/Cargo.toml b/crates/edgen_server/Cargo.toml index 3c50598..3c4c4b9 100644 --- a/crates/edgen_server/Cargo.toml +++ b/crates/edgen_server/Cargo.toml @@ -20,11 +20,9 @@ hyper = { workspace = true } hyper-util = { workspace = true } once_cell = { workspace = true } pin-project = { workspace = true } -rubato = "0.14.1" serde = { workspace = true } serde_derive = { workspace = true } serde_json = { workspace = true } -symphonia = { version = "0.5.3", features = ["all-codecs", "all-formats"] } time = { workspace = true } tinyvec = { workspace = true, features = ["serde"] } thiserror = { workspace = true } From e3760666c975a42da43c9336acbb96260018badc Mon Sep 17 00:00:00 2001 From: Pedro Valente Date: Mon, 5 Feb 2024 17:10:09 +0000 Subject: [PATCH 24/24] some corrections --- crates/edgen_rt_whisper_cpp/src/lib.rs | 42 ++++++++++++++++---------- 1 file changed, 26 insertions(+), 16 deletions(-) diff --git a/crates/edgen_rt_whisper_cpp/src/lib.rs b/crates/edgen_rt_whisper_cpp/src/lib.rs index 7f6d4b1..06fd265 100644 --- a/crates/edgen_rt_whisper_cpp/src/lib.rs +++ b/crates/edgen_rt_whisper_cpp/src/lib.rs @@ -20,7 +20,7 @@ use tokio::sync::mpsc::{unbounded_channel, UnboundedSender}; use tokio::task::JoinHandle; use tokio::time::{interval, MissedTickBehavior}; use tokio::{select, spawn}; -use tracing::info; +use tracing::{error, info}; use uuid::Uuid; use whisper_cpp::{WhisperModel, WhisperParams, WhisperSampling, WhisperSession}; @@ -195,24 +195,34 @@ impl UnloadingModel { if let Some(uuid) = uuid { let session = self.take_session(uuid).await; - let (_session_signal, mut session_guard) = { - let (session_signal, mut session_guard) = - get_or_init_session(&session, model_guard.clone()).await?; + let res = { + let (_session_signal, mut session_guard) = { + let (session_signal, session_guard) = + get_or_init_session(&session, model_guard.clone()).await?; - (session_signal, session_guard) - }; + (session_signal, session_guard) + }; - session_guard - .full(params, &pcm) - .await - .map_err(move |e| WhisperEndpointError::Advance(e.to_string()))?; + session_guard + .full(params, &pcm) + .await + .map_err(move |e| WhisperEndpointError::Advance(e.to_string()))?; - let mut res = "".to_string(); - for i in 0..session_guard.segment_count() { - res += &*session_guard - .segment_text(i) - .map_err(move |e| WhisperEndpointError::Decode(e.to_string()))?; - } + let mut res = "".to_string(); + for i in 0..session_guard.segment_count() { + res += &*session_guard + .segment_text(i) + .map_err(move |e| WhisperEndpointError::Decode(e.to_string()))?; + } + + res + }; + + self.finished_tx + .send((uuid, session)) + .unwrap_or_else(move |e| { + error!("Failed to send session to maintenance thread: {e}") + }); Ok(res) } else {