Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Commit

Permalink
feat(test): improve output quality
Browse files Browse the repository at this point in the history
  • Loading branch information
philpax committed Jul 2, 2023
1 parent b130f5c commit 6e2362b
Showing 1 changed file with 18 additions and 5 deletions.
23 changes: 18 additions & 5 deletions binaries/llm-test/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -323,9 +323,13 @@ mod tests {
use super::*;

pub(super) fn can_send<M: llm::KnownModel + 'static>(model: M) -> anyhow::Result<M> {
std::thread::spawn(move || model)
let model = std::thread::spawn(move || model)
.join()
.map_err(|e| anyhow::anyhow!("Failed to join thread: {e:?}"))
.map_err(|e| anyhow::anyhow!("Failed to join thread: {e:?}"));

log::info!("`can_send` test passed!");

model
}

pub(super) fn can_roundtrip_hyperparameters<M: llm::KnownModel + 'static>(
Expand All @@ -341,6 +345,8 @@ mod tests {

assert_eq!(hyperparameters, &new_hyperparameters);

log::info!("`can_roundtrip_hyperparameters` test passed!");

Ok(())
}

Expand All @@ -354,13 +360,21 @@ mod tests {
expected_output: &str,
maximum_token_count: usize,
) -> anyhow::Result<TestCaseReport> {
let (actual_output, res) = run_inference(model, model_config, input, maximum_token_count);
let mut session = model.start_session(Default::default());
let (actual_output, res) = run_inference(
model,
model_config,
&mut session,
input,
maximum_token_count,
);

// Process the results
Ok(TestCaseReport {
meta: match res {
Ok(inference_stats) => {
if expected_output == actual_output {
log::info!("`can_infer` test passed!");
TestCaseReportMeta::Success { inference_stats }
} else {
TestCaseReportMeta::Error {
Expand All @@ -384,11 +398,10 @@ mod tests {
fn run_inference(
model: &dyn llm::Model,
model_config: &ModelConfig,
session: &mut llm::InferenceSession,
input: &str,
maximum_token_count: usize,
) -> (String, Result<InferenceStats, llm::InferenceError>) {
let mut session = model.start_session(Default::default());

let mut actual_output: String = String::new();
let res = session.infer::<Infallible>(
model,
Expand Down

0 comments on commit 6e2362b

Please sign in to comment.