Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chore/issue122: add embeddings to settings integration tests #134

Merged
merged 2 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions crates/edgen_core/src/settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,16 @@ pub async fn create_project_dirs() -> Result<(), std::io::Error> {

let audio_transcriptions_dir = PathBuf::from(&audio_transcriptions_str);

let embeddings_str = SETTINGS
.read()
.await
.read()
.await
.embeddings_models_dir
.to_string();

let embeddings_dir = PathBuf::from(&embeddings_str);

if !config_dir.is_dir() {
std::fs::create_dir_all(config_dir)?;
}
Expand All @@ -74,6 +84,10 @@ pub async fn create_project_dirs() -> Result<(), std::io::Error> {
std::fs::create_dir_all(&audio_transcriptions_dir)?;
}

if !embeddings_dir.is_dir() {
std::fs::create_dir_all(&embeddings_str)?;
}

Ok(())
}

Expand Down
4 changes: 2 additions & 2 deletions crates/edgen_server/src/status.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,8 +316,8 @@ async fn observe_progress(
let mut m = tokio::fs::metadata(&f.path()).await;
let mut last_size = 0;
let mut timestamp = Instant::now();
while m.is_ok() {
let s = m.unwrap().len() as u64;
while let Ok(d) = m {
let s = d.len() as u64;
let p = (s * 100) / size;

if s > last_size {
Expand Down
54 changes: 50 additions & 4 deletions crates/edgen_server/tests/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,15 @@ pub const SMALL_LLM_REPO: &str = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF";
pub const SMALL_WHISPER_NAME: &str = "ggml-distil-small.en.bin";
pub const SMALL_WHISPER_REPO: &str = "distil-whisper/distil-small.en";

pub const SMALL_EMBEDDINGS_NAME: &str = "tinyllama-1.1b-chat-v1.0.Q2_K.gguf";
pub const SMALL_EMBEDDINGS_REPO: &str = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF";

pub const BASE_URL: &str = "http://localhost:33322/v1";
pub const CHAT_URL: &str = "/chat";
pub const COMPLETIONS_URL: &str = "/completions";
pub const AUDIO_URL: &str = "/audio";
pub const TRANSCRIPTIONS_URL: &str = "/transcriptions";
pub const EMBEDDINGS_URL: &str = "/embeddings";
pub const STATUS_URL: &str = "/status";
pub const MISC_URL: &str = "/misc";
pub const VERSION_URL: &str = "/version";
Expand Down Expand Up @@ -241,6 +245,10 @@ pub fn data_exists() {
let transcriptions = audio.join("transcriptions");
println!("exists: {:?}", transcriptions);
assert!(transcriptions.exists());

let embeddings = models.join("embeddings");
println!("exists: {:?}", embeddings);
assert!(embeddings.exists());
}

/// Edit the config file: set another model dir for the indicated endpoint.
Expand All @@ -256,7 +264,9 @@ pub fn set_model_dir(ep: Endpoint, model_dir: &str) {
Endpoint::AudioTranscriptions => {
config.audio_transcriptions_models_dir = model_dir.to_string();
}
Endpoint::Embeddings => todo!(),
Endpoint::Embeddings => {
config.embeddings_models_dir = model_dir.to_string();
}
}
write_config(&config).unwrap();

Expand All @@ -280,7 +290,10 @@ pub fn set_model(ep: Endpoint, model_name: &str, model_repo: &str) {
config.audio_transcriptions_model_name = model_name.to_string();
config.audio_transcriptions_model_repo = model_repo.to_string();
}
Endpoint::Embeddings => todo!(),
Endpoint::Embeddings => {
config.embeddings_model_name = model_name.to_string();
config.embeddings_model_repo = model_repo.to_string();
}
}
write_config(&config).unwrap();

Expand All @@ -292,7 +305,7 @@ pub fn set_model(ep: Endpoint, model_name: &str, model_repo: &str) {
Endpoint::AudioTranscriptions => {
make_url(&[BASE_URL, AUDIO_URL, TRANSCRIPTIONS_URL, STATUS_URL])
}
Endpoint::Embeddings => todo!(),
Endpoint::Embeddings => make_url(&[BASE_URL, EMBEDDINGS_URL, STATUS_URL]),
};
let stat: status::AIStatus = blocking::get(url).unwrap().json().unwrap();
assert_eq!(stat.active_model, model_name);
Expand Down Expand Up @@ -335,13 +348,22 @@ pub fn chat_completions_custom_body(model: &str) -> String {
.expect("cannot convert JSON to String")
}

/// embeddings body with custom model
pub fn embeddings_custom_body(model: &str) -> String {
serde_json::to_string(&json!({
"model": model,
"input": "what is the capital of idaho?",
}))
.expect("cannot convert JSON to String")
}

/// Spawn a thread to send a request to the indicated endpoint.
/// This allows the caller to perform another task in the caller thread.
pub fn spawn_request(ep: Endpoint, body: &str, model: &str) -> thread::JoinHandle<bool> {
match ep {
Endpoint::ChatCompletions => spawn_chat_completions_request(body),
Endpoint::AudioTranscriptions => spawn_audio_transcriptions_request(model),
Endpoint::Embeddings => todo!(),
Endpoint::Embeddings => spawn_embeddings_request(body),
}
}

Expand Down Expand Up @@ -369,6 +391,30 @@ pub fn spawn_chat_completions_request(body: &str) -> thread::JoinHandle<bool> {
})
}

pub fn spawn_embeddings_request(body: &str) -> thread::JoinHandle<bool> {
let body = body.to_string();
thread::spawn(move || {
let ep = make_url(&[BASE_URL, EMBEDDINGS_URL]);
println!("requesting {}", ep);
match blocking::Client::new()
.post(&ep)
.header("Content-Type", "application/json")
.body(body)
.timeout(Duration::from_secs(180))
.send()
{
Err(e) => {
eprintln!("cannot connect: {:?}", e);
false
}
Ok(v) => {
println!("Got {:?}", v);
v.status().is_success()
}
}
})
}

pub fn spawn_audio_transcriptions_request(model: &str) -> thread::JoinHandle<bool> {
let model = model.to_string();
let frost = Path::new("resources").join("frost.wav");
Expand Down
69 changes: 67 additions & 2 deletions crates/edgen_server/tests/settings_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@

chat_completions_status_reachable();
audio_transcriptions_status_reachable();
embeddings_status_reachable();

// ================================
common::test_message("SCENARIO 2");
Expand All @@ -85,14 +86,22 @@
common::SMALL_WHISPER_NAME,
common::SMALL_WHISPER_REPO,
);
common::set_model(
Endpoint::Embeddings,
common::SMALL_EMBEDDINGS_NAME,
common::SMALL_EMBEDDINGS_REPO,
);

// test ai endpoint and download
test_ai_endpoint_with_download(Endpoint::ChatCompletions, "default");
test_ai_endpoint_with_download(Endpoint::AudioTranscriptions, "default");

Check warning on line 97 in crates/edgen_server/tests/settings_tests.rs

View workflow job for this annotation

GitHub Actions / CI

Diff in /home/runner/work/edgen/edgen/crates/edgen_server/tests/settings_tests.rs
test_ai_endpoint_with_download(Endpoint::Embeddings, "default");


// we have downloaded, we should not download again
test_ai_endpoint_no_download(Endpoint::ChatCompletions, "default");
test_ai_endpoint_no_download(Endpoint::AudioTranscriptions, "default");
test_ai_endpoint_no_download(Endpoint::Embeddings, "default");

// ================================
common::test_message("SCENARIO 3");
Expand All @@ -119,20 +128,29 @@
path::MAIN_SEPARATOR,
"audio",
path::MAIN_SEPARATOR,
"transcriptions",

Check warning on line 131 in crates/edgen_server/tests/settings_tests.rs

View workflow job for this annotation

GitHub Actions / CI

Diff in /home/runner/work/edgen/edgen/crates/edgen_server/tests/settings_tests.rs
);

common::set_model_dir(Endpoint::ChatCompletions, &new_chat_completions_dir);
let new_embeddings_dir = my_models_dir.clone()
+ &format!(
"{}{}",
path::MAIN_SEPARATOR,
"embeddings",
);

common::set_model_dir(Endpoint::ChatCompletions, &new_chat_completions_dir);
common::set_model_dir(Endpoint::AudioTranscriptions, &new_audio_transcriptions_dir);
common::set_model_dir(Endpoint::Embeddings, &new_embeddings_dir);

test_ai_endpoint_with_download(Endpoint::ChatCompletions, "default");
test_ai_endpoint_with_download(Endpoint::AudioTranscriptions, "default");
test_ai_endpoint_with_download(Endpoint::Embeddings, "default");

assert!(path::Path::new(&my_models_dir).exists());

test_ai_endpoint_no_download(Endpoint::ChatCompletions, "default");
test_ai_endpoint_no_download(Endpoint::AudioTranscriptions, "default");
test_ai_endpoint_no_download(Endpoint::Embeddings, "default");

// ================================
common::test_message("SCENARIO 4");
Expand All @@ -142,11 +160,13 @@

test_ai_endpoint_with_download(Endpoint::ChatCompletions, "default");
test_ai_endpoint_with_download(Endpoint::AudioTranscriptions, "default");
test_ai_endpoint_with_download(Endpoint::Embeddings, "default");

assert!(path::Path::new(&my_models_dir).exists());

test_ai_endpoint_no_download(Endpoint::ChatCompletions, "default");
test_ai_endpoint_no_download(Endpoint::AudioTranscriptions, "default");
test_ai_endpoint_no_download(Endpoint::Embeddings, "default");

// ================================
common::test_message("SCENARIO 5");
Expand All @@ -163,28 +183,37 @@
common::SMALL_WHISPER_NAME,
common::SMALL_WHISPER_REPO,
);
common::set_model(
Endpoint::Embeddings,
common::SMALL_EMBEDDINGS_NAME,
common::SMALL_EMBEDDINGS_REPO,
);

// make sure we read from the old directories again
remove_dir_all(&my_models_dir).unwrap();
assert!(!path::Path::new(&my_models_dir).exists());

test_ai_endpoint_no_download(Endpoint::ChatCompletions, "default");
test_ai_endpoint_no_download(Endpoint::AudioTranscriptions, "default");
test_ai_endpoint_no_download(Endpoint::Embeddings, "default");

// ================================
common::test_message("SCENARIO 6");
// ================================
let chat_model = "TheBloke/phi-2-GGUF/phi-2.Q2_K.gguf";
let audio_model = "distil-whisper/distil-medium.en/ggml-medium-32-2.en.bin";
let embeddings_model = "TheBloke/phi-2-GGUF/phi-2.Q2_K.gguf";

test_ai_endpoint_with_download(Endpoint::ChatCompletions, chat_model);
test_ai_endpoint_with_download(Endpoint::AudioTranscriptions, audio_model);
test_ai_endpoint_with_download(Endpoint::Embeddings, embeddings_model);

// ================================
common::test_message("SCENARIO 7");
// ================================
test_ai_endpoint_no_download(Endpoint::ChatCompletions, chat_model);
test_ai_endpoint_no_download(Endpoint::AudioTranscriptions, audio_model);
test_ai_endpoint_no_download(Endpoint::Embeddings, embeddings_model);

// ================================
common::test_message("SCENARIO 8");
Expand All @@ -200,6 +229,10 @@
"audio/transcriptions",
);
test_ai_endpoint_no_download(Endpoint::AudioTranscriptions, ".whisper-medium-32-2.en.bin");

let source = "models--TheBloke--phi-2-GGUF/blobs";
common::copy_model(source, ".phi-2.Q2_K.gguf", "embeddings");
test_ai_endpoint_no_download(Endpoint::Embeddings, ".phi-2.Q2_K.gguf");
})
}

Expand Down Expand Up @@ -243,6 +276,25 @@
});
}

fn embeddings_status_reachable() {
common::test_message("embeddings status is reachable");
assert!(match blocking::get(common::make_url(&[
common::BASE_URL,
common::EMBEDDINGS_URL,
common::STATUS_URL,
])) {
Err(e) => {
eprintln!("cannot connect: {:?}", e);
false
}
Ok(v) => {
assert!(v.status().is_success());
println!("have: '{}'", v.text().unwrap());
true
}
});
}

fn test_config_reset() {
common::test_message("test resetting config");
common::reset_config();
Expand Down Expand Up @@ -289,9 +341,22 @@
common::STATUS_URL,
]),
"".to_string(),
)

Check warning on line 344 in crates/edgen_server/tests/settings_tests.rs

View workflow job for this annotation

GitHub Actions / CI

Diff in /home/runner/work/edgen/edgen/crates/edgen_server/tests/settings_tests.rs
}
Endpoint::Embeddings => todo!(),
Endpoint::Embeddings => {
common::test_message(&format!(
"embeddints endpoint with download: {}",
download
));
(
common::make_url(&[
common::BASE_URL,
common::EMBEDDINGS_URL,
common::STATUS_URL,
]),
common::embeddings_custom_body(model),
)
}
};
let handle = common::spawn_request(endpoint, &body, model);
if download {
Expand Down
Loading