From c344592eb734b5376f9bc246a6b69bd72f7b56cf Mon Sep 17 00:00:00 2001 From: Radu Matei Date: Fri, 14 Jul 2023 12:08:23 +0200 Subject: [PATCH] feat(tracing): add tracing and update CLI to use tracing Signed-off-by: Radu Matei --- Cargo.lock | 130 +++++++++++++++++++++++ Cargo.toml | 4 +- binaries/llm-cli/Cargo.toml | 3 + binaries/llm-cli/src/main.rs | 94 +++++++++------- crates/llm-base/Cargo.toml | 1 + crates/llm-base/src/inference_session.rs | 9 ++ crates/llm-base/src/loader.rs | 11 ++ crates/llm/Cargo.toml | 1 + crates/models/llama/Cargo.toml | 4 +- crates/models/llama/src/lib.rs | 1 + 10 files changed, 215 insertions(+), 43 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d449a216..82a142fa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1276,6 +1276,7 @@ dependencies = [ "serde", "serde_json", "spinoff", + "tracing", ] [[package]] @@ -1293,6 +1294,7 @@ dependencies = [ "serde_bytes", "thiserror", "tokenizers", + "tracing", ] [[package]] @@ -1318,6 +1320,9 @@ dependencies = [ "rusty-hook", "rustyline", "spinoff", + "tracing", + "tracing-appender", + "tracing-subscriber", "zstd 0.12.3+zstd.1.5.2", ] @@ -1355,6 +1360,7 @@ name = "llm-llama" version = "0.2.0-dev" dependencies = [ "llm-base", + "tracing", ] [[package]] @@ -1414,6 +1420,15 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "58093314a45e00c77d5c508f76e77c3396afbbc0d01506e7fae47b018bac2b1d" +[[package]] +name = "matchers" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558" +dependencies = [ + "regex-automata", +] + [[package]] name = "memchr" version = "2.5.0" @@ -1546,6 +1561,16 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "nu-ansi-term" +version = "0.46.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" +dependencies = [ + "overload", + "winapi", +] + [[package]] name = "num_cpus" version = "1.16.0" @@ -1643,6 +1668,12 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "overload" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" + [[package]] name = "owo-colors" version = "3.5.0" @@ -1888,6 +1919,15 @@ dependencies = [ "regex-syntax 0.7.2", ] +[[package]] +name = "regex-automata" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132" +dependencies = [ + "regex-syntax 0.6.29", +] + [[package]] name = "regex-syntax" version = "0.6.29" @@ -2141,6 +2181,15 @@ dependencies = [ "digest", ] +[[package]] +name = "sharded-slab" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "900fba806f70c630b0a382d0d825e17a0f19fcd059a2ade1ff237bcddf446b31" +dependencies = [ + "lazy_static", +] + [[package]] name = "shlex" version = "1.1.0" @@ -2304,14 +2353,26 @@ dependencies = [ "syn 2.0.22", ] +[[package]] +name = "thread_local" +version = "1.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fdd6f064ccff2d6567adcb3873ca630700f00b5ad3f060c25b5dcfd9a4ce152" +dependencies = [ + "cfg-if", + "once_cell", +] + [[package]] name = "time" version = "0.3.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ea9e1b3cf1243ae005d9e74085d4d542f3125458f3a81af210d901dcd7411efd" dependencies = [ + "itoa", "serde", "time-core", + "time-macros", ] [[package]] @@ -2320,6 +2381,15 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7300fbefb4dadc1af235a9cef3737cea692a9d97e1b9cbcd4ebdae6f8868e6fb" +[[package]] +name = "time-macros" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "372950940a5f07bf38dbe211d7283c9e6d7327df53794992d293e534c733d09b" +dependencies = [ + "time-core", +] + [[package]] name = "tinyvec" version = "1.6.0" @@ -2446,10 +2516,34 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8" dependencies = [ "cfg-if", + "log", "pin-project-lite", + "tracing-attributes", "tracing-core", ] +[[package]] +name = "tracing-appender" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09d48f71a791638519505cefafe162606f706c25592e4bde4d97600c0195312e" +dependencies = [ + "crossbeam-channel", + "time", + "tracing-subscriber", +] + +[[package]] +name = "tracing-attributes" +version = "0.1.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f4f31f56159e98206da9efd823404b79b6ef3143b4a7ab76e67b1751b25a4ab" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.22", +] + [[package]] name = "tracing-core" version = "0.1.31" @@ -2457,6 +2551,36 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0955b8137a1df6f1a2e9a37d8a6656291ff0297c1a97c24e0d8425fe2312f79a" dependencies = [ "once_cell", + "valuable", +] + +[[package]] +name = "tracing-log" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78ddad33d2d10b1ed7eb9d1f518a5674713876e97e5bb9b7345a7984fbb4f922" +dependencies = [ + "lazy_static", + "log", + "tracing-core", +] + +[[package]] +name = "tracing-subscriber" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30a651bc37f915e81f087d86e62a18eec5f79550c7faff886f7090b4ea757c77" +dependencies = [ + "matchers", + "nu-ansi-term", + "once_cell", + "regex", + "sharded-slab", + "smallvec", + "thread_local", + "tracing", + "tracing-core", + "tracing-log", ] [[package]] @@ -2536,6 +2660,12 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" +[[package]] +name = "valuable" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" + [[package]] name = "vcpkg" version = "0.2.15" diff --git a/Cargo.toml b/Cargo.toml index 793b27a4..59ad9021 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,6 +30,8 @@ serde_json = { version = "1.0" } spinoff = { version = "0.7.0", default-features = false, features = ["dots2"] } clap = { version = "4.1.8", features = ["derive"] } memmap2 = "0.5.10" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } +tracing = { version = "0.1", features = ["log"] } # Config for 'cargo dist' [workspace.metadata.dist] @@ -50,4 +52,4 @@ inherits = "release" lto = "thin" [workspace.metadata.release] -tag-prefix = "" \ No newline at end of file +tag-prefix = "" diff --git a/binaries/llm-cli/Cargo.toml b/binaries/llm-cli/Cargo.toml index 13e67291..bb5ab08b 100644 --- a/binaries/llm-cli/Cargo.toml +++ b/binaries/llm-cli/Cargo.toml @@ -27,6 +27,9 @@ num_cpus = "1.15.0" color-eyre = { version = "0.6.2", default-features = false } zstd = { version = "0.12", default-features = false } +tracing-subscriber = {workspace = true } +tracing = { workspace = true} +tracing-appender = "0.2.2" [dev-dependencies] rusty-hook = "^0.11.2" diff --git a/binaries/llm-cli/src/main.rs b/binaries/llm-cli/src/main.rs index 1c55824f..57735e10 100644 --- a/binaries/llm-cli/src/main.rs +++ b/binaries/llm-cli/src/main.rs @@ -1,7 +1,7 @@ use std::{ convert::Infallible, fs::File, - io::{BufReader, BufWriter}, + io::{BufReader, BufWriter, IsTerminal}, }; use clap::Parser; @@ -14,10 +14,12 @@ mod snapshot; mod util; fn main() -> eyre::Result<()> { - env_logger::builder() - .filter_level(log::LevelFilter::Info) - .parse_default_env() + tracing_subscriber::fmt() + .with_writer(std::io::stderr) + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .with_ansi(std::io::stderr().is_terminal()) .init(); + color_eyre::install()?; let args = Args::parse(); @@ -32,6 +34,7 @@ fn main() -> eyre::Result<()> { } } +#[tracing::instrument(skip_all)] fn infer(args: &cli_args::Infer) -> eyre::Result<()> { let prompt = load_prompt_file_with_prompt(&args.prompt_file, args.prompt.as_deref())?; let inference_session_config = args.generate.inference_session_config(); @@ -46,46 +49,55 @@ fn infer(args: &cli_args::Infer) -> eyre::Result<()> { let parameters = args.generate.inference_parameters(model.eot_token_id()); let mut rng = args.generate.rng(); - let res = session.infer::( - model.as_ref(), - &mut rng, - &llm::InferenceRequest { - prompt: prompt.as_str().into(), - parameters: ¶meters, - play_back_previous_tokens: session_loaded, - maximum_token_count: args.generate.num_predict, - }, - // OutputRequest - &mut Default::default(), - |r| { - match r { - llm::InferenceResponse::PromptToken(t) if !args.hide_prompt => util::print_token(t), - llm::InferenceResponse::InferredToken(t) => util::print_token(t), - _ => {} + + let span = tracing::trace_span!("infer"); + + span.in_scope(|| { + // do work inside the span... + let res = session.infer::( + model.as_ref(), + &mut rng, + &llm::InferenceRequest { + prompt: prompt.as_str().into(), + parameters: ¶meters, + play_back_previous_tokens: session_loaded, + maximum_token_count: args.generate.num_predict, + }, + // OutputRequest + &mut Default::default(), + |r| { + match r { + llm::InferenceResponse::PromptToken(t) if !args.hide_prompt => { + util::print_token(t) + } + llm::InferenceResponse::InferredToken(t) => util::print_token(t), + _ => {} + } + Ok(llm::InferenceFeedback::Continue) + }, + ); + + println!(); + + match res { + Ok(stats) => { + if args.stats { + println!(); + println!("{}", stats); + println!(); + } } - Ok(llm::InferenceFeedback::Continue) - }, - ); - println!(); - - match res { - Ok(stats) => { - if args.stats { - println!(); - println!("{}", stats); - println!(); + Err(llm::InferenceError::ContextFull) => { + log::warn!("Context window full, stopping inference.") + } + Err(llm::InferenceError::TokenizationFailed(err)) => { + log::error!("A tokenization-related failure occurred: {}", err); + } + Err(llm::InferenceError::UserCallback(_)) | Err(llm::InferenceError::EndOfText) => { + unreachable!("cannot fail") } } - Err(llm::InferenceError::ContextFull) => { - log::warn!("Context window full, stopping inference.") - } - Err(llm::InferenceError::TokenizationFailed(err)) => { - log::error!("A tokenization-related failure occurred: {}", err); - } - Err(llm::InferenceError::UserCallback(_)) | Err(llm::InferenceError::EndOfText) => { - unreachable!("cannot fail") - } - } + }); if let Some(session_path) = args.save_session.as_ref().or(args.persist_session.as_ref()) { // Write the memory to the cache file diff --git a/crates/llm-base/Cargo.toml b/crates/llm-base/Cargo.toml index 66ac16d1..51a0949e 100644 --- a/crates/llm-base/Cargo.toml +++ b/crates/llm-base/Cargo.toml @@ -24,6 +24,7 @@ memmap2 = { workspace = true } half = "2.2.1" tokenizers = {version="0.13.3", default-features=false, features=["onig"]} regex = "1.8" +tracing = { workspace = true } [features] tokenizers-remote = ["tokenizers/http"] diff --git a/crates/llm-base/src/inference_session.rs b/crates/llm-base/src/inference_session.rs index caff5e67..dfd4a474 100644 --- a/crates/llm-base/src/inference_session.rs +++ b/crates/llm-base/src/inference_session.rs @@ -2,6 +2,7 @@ use ggml::{Buffer, ComputationGraph, Context, Tensor}; use serde::Serialize; use std::{fmt::Display, sync::Arc}; use thiserror::Error; +use tracing::{instrument, log}; #[cfg(feature = "metal")] use ggml::metal::MetalContext; @@ -280,6 +281,7 @@ impl InferenceSession { } /// Feed a prompt to the model for this session. + #[instrument(skip_all)] pub fn feed_prompt<'a, E: std::error::Error + Send + Sync + 'static, P: Into>>( &mut self, model: &dyn Model, @@ -329,6 +331,7 @@ impl InferenceSession { self.decoded_tokens.append(&mut token); } } + log::trace!("Finished feed prompt"); Ok(()) } @@ -361,6 +364,7 @@ impl InferenceSession { } /// Infer the next token for this session. + #[instrument(level = "trace", skip_all)] pub fn infer_next_token( &mut self, model: &dyn Model, @@ -407,6 +411,7 @@ impl InferenceSession { /// generated (specified by [InferenceRequest::maximum_token_count]). /// /// This is a wrapper around [Self::feed_prompt] and [Self::infer_next_token]. + #[instrument(skip_all)] pub fn infer( &mut self, model: &dyn Model, @@ -431,6 +436,10 @@ impl InferenceSession { } } } + log::trace!( + "Starting inference request with max_token_count: {}", + maximum_token_count + ); let mut stats = InferenceStats::default(); let start_at = std::time::SystemTime::now(); diff --git a/crates/llm-base/src/loader.rs b/crates/llm-base/src/loader.rs index 9d01e0a1..e9154d19 100644 --- a/crates/llm-base/src/loader.rs +++ b/crates/llm-base/src/loader.rs @@ -18,6 +18,7 @@ use ggml::{ }; use memmap2::Mmap; use thiserror::Error; +use tracing::log; #[derive(Debug, PartialEq, Clone, Copy, Eq, Default)] /// Information about the file. @@ -436,12 +437,14 @@ pub fn load( path: path.to_owned(), })?; let mut reader = BufReader::new(&file); + log::trace!("Read model file from {:?}", path); let tokenizer = tokenizer_source.retrieve(path)?; let mut loader = Loader::new(tokenizer, load_progress_callback); ggml::format::load(&mut reader, &mut loader) .map_err(|err| LoadError::from_format_error(err, path.to_owned()))?; + log::trace!("Loaded GGML model from reader"); let Loader { hyperparameters, @@ -469,6 +472,10 @@ pub fn load( } else { quantization_version }; + log::trace!( + "Determined quantization version of model as {:?}", + quantization_version + ); // TODO: this is temporary while we figure out how to handle this if tensors.values().any(|t| t.element_type.is_quantized()) { @@ -482,6 +489,7 @@ pub fn load( .values() .map(|ti| ti.calc_absolute_size(use_mmap)) .sum::(); + log::trace!("Context size: {:?}", ctx_size); let mut lora_adapters: Option> = None; if let Some(lora_paths) = ¶ms.lora_adapters { @@ -508,6 +516,7 @@ pub fn load( .filter_map(|k| Some(k.rsplit_once('.')?.0.to_owned())) .collect(); + log::trace!("Loaded LoRA weights"); // Return the LoRA patches Ok::<_, LoadError>(LoraAdapter { scaling: lora_loader.hyperparameters.calculate_scaling(), @@ -551,6 +560,8 @@ pub fn load( tensor_count: tensors_len, }); + log::trace!("Loaded model"); + Ok(model) } diff --git a/crates/llm/Cargo.toml b/crates/llm/Cargo.toml index 1d7f688f..0f395f5a 100644 --- a/crates/llm/Cargo.toml +++ b/crates/llm/Cargo.toml @@ -18,6 +18,7 @@ llm-mpt = { path = "../models/mpt", optional = true, version = "0.2.0-dev" } llm-falcon = { path = "../models/falcon", optional = true, version = "0.2.0-dev" } serde = { workspace = true } +tracing = { workspace = true } [dev-dependencies] bytesize = { workspace = true } diff --git a/crates/models/llama/Cargo.toml b/crates/models/llama/Cargo.toml index b7c3bdbf..aa599ab0 100644 --- a/crates/models/llama/Cargo.toml +++ b/crates/models/llama/Cargo.toml @@ -8,4 +8,6 @@ edition = "2021" readme = "../../../README.md" [dependencies] -llm-base = { path = "../../llm-base", version = "0.2.0-dev" } \ No newline at end of file +llm-base = { path = "../../llm-base", version = "0.2.0-dev" } +tracing = { version = "0.1", features = ["log"] } + diff --git a/crates/models/llama/src/lib.rs b/crates/models/llama/src/lib.rs index 94585218..fab91ec0 100644 --- a/crates/models/llama/src/lib.rs +++ b/crates/models/llama/src/lib.rs @@ -99,6 +99,7 @@ impl KnownModel for Llama { ) } + #[tracing::instrument(level = "trace", skip_all)] fn evaluate( &self, session: &mut InferenceSession,