Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Commit

Permalink
feat(test): check hyperparameters can roundtrip
Browse files Browse the repository at this point in the history
  • Loading branch information
philpax committed Jul 2, 2023
1 parent c0b197f commit 00624a4
Show file tree
Hide file tree
Showing 10 changed files with 183 additions and 87 deletions.
230 changes: 147 additions & 83 deletions binaries/llm-test/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,9 @@ async fn test_model(
download_dir: &Path,
results_dir: &Path,
) -> anyhow::Result<()> {
// Load the model
let architecture = llm::ModelArchitecture::from_str(&test_config.architecture)?;

let local_path = if test_config.filename.is_file() {
// If this filename points towards a valid file, use it
test_config.filename.clone()
Expand All @@ -173,99 +176,134 @@ async fn test_model(
// Download the model if necessary
download_file(&test_config.url, &local_path).await?;

let start_time = Instant::now();

// Load the model
let architecture = llm::ModelArchitecture::from_str(&test_config.architecture)?;
let model = {
let model = llm::load_dynamic(
Some(architecture),
&local_path,
llm::TokenizerSource::Embedded,
llm::ModelParameters {
prefer_mmap: model_config.mmap,
..Default::default()
},
|progress| {
let print = !matches!(&progress,
llm::LoadProgress::TensorLoaded { current_tensor, tensor_count }
if current_tensor % (tensor_count / 10) != 0
struct TestVisitor<'a> {
model_config: &'a ModelConfig,
test_config: &'a TestConfig,
results_dir: &'a Path,
local_path: &'a Path,
}
impl<'a> llm::ModelArchitectureVisitor<anyhow::Result<()>> for TestVisitor<'a> {
fn visit<M: llm::KnownModel + 'static>(&mut self) -> anyhow::Result<()> {
let Self {
model_config,
test_config,
results_dir,
local_path,
} = *self;

let start_time = Instant::now();

let model = {
let model = llm::load::<M>(
local_path,
llm::TokenizerSource::Embedded,
llm::ModelParameters {
prefer_mmap: model_config.mmap,
..Default::default()
},
|progress| {
let print = !matches!(&progress,
llm::LoadProgress::TensorLoaded { current_tensor, tensor_count }
if current_tensor % (tensor_count / 10) != 0
);

if print {
log::info!("loading: {:?}", progress);
}
},
);

if print {
log::info!("loading: {:?}", progress);
match model {
Ok(m) => m,
Err(err) => {
write_report(
test_config,
results_dir,
&Report::LoadFail {
error: format!("Failed to load model: {}", err),
},
)?;

return Err(err.into());
}
}
},
);

match model {
Ok(m) => m,
Err(err) => {
write_report(
test_config,
results_dir,
&Report::LoadFail {
error: format!("Failed to load model: {}", err),
},
)?;
};

log::info!(
"Model fully loaded! Elapsed: {}ms",
start_time.elapsed().as_millis()
);

//
// Non-model-specific tests
//

// Confirm that the model can be sent to a thread, then sent back
let model = tests::can_send(model)?;

// Confirm that the hyperparameters can be roundtripped
tests::can_roundtrip_hyperparameters(&model)?;

//

//
// Model-specific tests
//

// Run the test cases
let mut test_case_reports = vec![];
for test_case in &test_config.test_cases {
match test_case {
TestCase::Inference {
input,
output,
maximum_token_count,
} => test_case_reports.push(tests::can_infer(
&model,
model_config,
input,
output,
*maximum_token_count,
)?),
}
}
let first_error: Option<String> =
test_case_reports
.iter()
.find_map(|report: &TestCaseReport| match &report.meta {
TestCaseReportMeta::Error { error } => Some(error.clone()),
_ => None,
});

// Save the results
// Serialize the report to a JSON string
write_report(
test_config,
results_dir,
&Report::LoadSuccess {
test_cases: test_case_reports,
},
)?;

return Err(err.into());
// Optionally, panic if there was an error
if let Some(err) = first_error {
panic!("Error: {}", err);
}
}
};

log::info!(
"Model fully loaded! Elapsed: {}ms",
start_time.elapsed().as_millis()
);
log::info!(
"Successfully tested architecture `{}`!",
test_config.architecture
);

// Confirm that the model can be sent to a thread, then sent back
let model = std::thread::spawn(move || model).join().unwrap();

// Run the test cases
let mut test_case_reports = vec![];
for test_case in &test_config.test_cases {
match test_case {
TestCase::Inference {
input,
output,
maximum_token_count,
} => test_case_reports.push(tests::inference(
model.as_ref(),
model_config,
input,
output,
*maximum_token_count,
)?),
Ok(())
}
}
let first_error: Option<String> =
test_case_reports
.iter()
.find_map(|report: &TestCaseReport| match &report.meta {
TestCaseReportMeta::Error { error } => Some(error.clone()),
_ => None,
});

// Save the results
// Serialize the report to a JSON string
write_report(
architecture.visit(&mut TestVisitor {
model_config,
test_config,
results_dir,
&Report::LoadSuccess {
test_cases: test_case_reports,
},
)?;

// Optionally, panic if there was an error
if let Some(err) = first_error {
panic!("Error: {}", err);
}

log::info!(
"Successfully tested architecture `{}`!",
test_config.architecture
);
local_path: &local_path,
})?;

Ok(())
}
Expand All @@ -283,7 +321,33 @@ fn write_report(

mod tests {
use super::*;
pub(super) fn inference(

pub(super) fn can_send<M: llm::KnownModel + 'static>(model: M) -> anyhow::Result<M> {
std::thread::spawn(move || model)
.join()
.map_err(|e| anyhow::anyhow!("Failed to join thread: {e:?}"))
}

pub(super) fn can_roundtrip_hyperparameters<M: llm::KnownModel + 'static>(
model: &M,
) -> anyhow::Result<()> {
fn test_hyperparameters<M: llm::Hyperparameters>(
hyperparameters: &M,
) -> anyhow::Result<()> {
let mut data = vec![];
hyperparameters.write_ggml(&mut data)?;
let new_hyperparameters =
<M as llm::Hyperparameters>::read_ggml(&mut std::io::Cursor::new(data))?;

assert_eq!(hyperparameters, &new_hyperparameters);

Ok(())
}

test_hyperparameters(model.hyperparameters())
}

pub(super) fn can_infer(
model: &dyn llm::Model,
model_config: &ModelConfig,
input: &str,
Expand Down
5 changes: 4 additions & 1 deletion crates/llm-base/src/model/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ pub trait KnownModel: Send + Sync {
output_request: &mut OutputRequest,
);

/// Get the hyperparameters for this model.
fn hyperparameters(&self) -> &Self::Hyperparameters;

/// Get the tokenizer for this model.
fn tokenizer(&self) -> &Tokenizer;

Expand Down Expand Up @@ -150,7 +153,7 @@ impl<H: Hyperparameters, M: KnownModel<Hyperparameters = H>> Model for M {

/// Implemented by model hyperparameters for interacting with hyperparameters
/// without knowing what they are, as well as writing/reading them as required.
pub trait Hyperparameters: Sized + Default + Debug {
pub trait Hyperparameters: Sized + Default + Debug + PartialEq + Eq {
/// Read the parameters in GGML format from a reader.
fn read_ggml(reader: &mut dyn BufRead) -> Result<Self, LoadError>;

Expand Down
4 changes: 2 additions & 2 deletions crates/llm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ use std::{
// This is the "user-facing" API, and GGML may not always be our backend.
pub use llm_base::{
feed_prompt_callback, ggml::format as ggml_format, load, load_progress_callback_stdout,
quantize, samplers, ElementType, FileType, FileTypeFormat, InferenceError, InferenceFeedback,
InferenceParameters, InferenceRequest, InferenceResponse, InferenceSession,
quantize, samplers, ElementType, FileType, FileTypeFormat, Hyperparameters, InferenceError,
InferenceFeedback, InferenceParameters, InferenceRequest, InferenceResponse, InferenceSession,
InferenceSessionConfig, InferenceSnapshot, InferenceSnapshotRef, InferenceStats,
InvalidTokenBias, KnownModel, LoadError, LoadProgress, Loader, Model, ModelKVMemoryType,
ModelParameters, OutputRequest, Prompt, QuantizeError, QuantizeProgress, Sampler,
Expand Down
4 changes: 4 additions & 0 deletions crates/models/bloom/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,10 @@ impl KnownModel for Bloom {
common::extract_embeddings(output_request, &outputs.embedding_result, n_embd, input_len);
}

fn hyperparameters(&self) -> &Self::Hyperparameters {
&self.hyperparameters
}

fn tokenizer(&self) -> &Tokenizer {
&self.tokenizer
}
Expand Down
6 changes: 5 additions & 1 deletion crates/models/falcon/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,10 @@ impl KnownModel for Falcon {
common::extract_embeddings(output_request, &outputs.embedding_result, n_embd, input_len);
}

fn hyperparameters(&self) -> &Self::Hyperparameters {
&self.hyperparameters
}

fn tokenizer(&self) -> &Tokenizer {
&self.tokenizer
}
Expand All @@ -354,7 +358,7 @@ impl KnownModel for Falcon {
}

/// Falcon [hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning))
#[derive(Debug, Default, PartialEq, Clone, Copy)]
#[derive(Debug, Default, PartialEq, Clone, Copy, Eq)]
pub struct Hyperparameters {
/// Size of the model's vocabulary
n_vocab: usize,
Expand Down
4 changes: 4 additions & 0 deletions crates/models/gpt2/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,10 @@ impl KnownModel for Gpt2 {
common::extract_embeddings(output_request, &outputs.embedding_result, n_embd, input_len);
}

fn hyperparameters(&self) -> &Self::Hyperparameters {
&self.hyperparameters
}

fn tokenizer(&self) -> &Tokenizer {
&self.tokenizer
}
Expand Down
4 changes: 4 additions & 0 deletions crates/models/gptj/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,10 @@ impl KnownModel for GptJ {
common::extract_embeddings(output_request, &outputs.embedding_result, n_embd, input_len);
}

fn hyperparameters(&self) -> &Self::Hyperparameters {
&self.hyperparameters
}

fn tokenizer(&self) -> &Tokenizer {
&self.tokenizer
}
Expand Down
4 changes: 4 additions & 0 deletions crates/models/gptneox/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,10 @@ impl KnownModel for GptNeoX {
common::extract_embeddings(output_request, &outputs.embedding_result, n_embd, n);
}

fn hyperparameters(&self) -> &Self::Hyperparameters {
&self.hyperparameters
}

fn tokenizer(&self) -> &Tokenizer {
&self.tokenizer
}
Expand Down
4 changes: 4 additions & 0 deletions crates/models/llama/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,10 @@ impl KnownModel for Llama {
common::extract_embeddings(output_request, &outputs.embedding_result, n_embd, input_len);
}

fn hyperparameters(&self) -> &Self::Hyperparameters {
&self.hyperparameters
}

fn tokenizer(&self) -> &Tokenizer {
&self.tokenizer
}
Expand Down
5 changes: 5 additions & 0 deletions crates/models/mpt/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,10 @@ impl KnownModel for Mpt {
common::extract_embeddings(output_request, &outputs.embedding_result, n_embd, n);
}

fn hyperparameters(&self) -> &Self::Hyperparameters {
&self.hyperparameters
}

fn tokenizer(&self) -> &Tokenizer {
&self.tokenizer
}
Expand Down Expand Up @@ -316,6 +320,7 @@ pub struct Hyperparameters {
/// file_type
file_type: FileType,
}
impl Eq for Hyperparameters {}

impl llm_base::Hyperparameters for Hyperparameters {
fn read_ggml(reader: &mut dyn std::io::BufRead) -> Result<Self, LoadError> {
Expand Down

0 comments on commit 00624a4

Please sign in to comment.