Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

feat(tracing): add tracing to llm and llm-base crates #367

Merged
merged 1 commit into from
Jul 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 130 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ serde_json = { version = "1.0" }
spinoff = { version = "0.7.0", default-features = false, features = ["dots2"] }
clap = { version = "4.1.8", features = ["derive"] }
memmap2 = "0.5.10"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
tracing = { version = "0.1", features = ["log"] }

# Config for 'cargo dist'
[workspace.metadata.dist]
Expand All @@ -50,4 +52,4 @@ inherits = "release"
lto = "thin"

[workspace.metadata.release]
tag-prefix = ""
tag-prefix = ""
3 changes: 3 additions & 0 deletions binaries/llm-cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ num_cpus = "1.15.0"

color-eyre = { version = "0.6.2", default-features = false }
zstd = { version = "0.12", default-features = false }
tracing-subscriber = {workspace = true }
tracing = { workspace = true}
tracing-appender = "0.2.2"

[dev-dependencies]
rusty-hook = "^0.11.2"
Expand Down
94 changes: 53 additions & 41 deletions binaries/llm-cli/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::{
convert::Infallible,
fs::File,
io::{BufReader, BufWriter},
io::{BufReader, BufWriter, IsTerminal},
};

use clap::Parser;
Expand All @@ -14,10 +14,12 @@ mod snapshot;
mod util;

fn main() -> eyre::Result<()> {
env_logger::builder()
.filter_level(log::LevelFilter::Info)
.parse_default_env()
tracing_subscriber::fmt()
.with_writer(std::io::stderr)
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.with_ansi(std::io::stderr().is_terminal())
.init();

color_eyre::install()?;

let args = Args::parse();
Expand All @@ -32,6 +34,7 @@ fn main() -> eyre::Result<()> {
}
}

#[tracing::instrument(skip_all)]
fn infer(args: &cli_args::Infer) -> eyre::Result<()> {
let prompt = load_prompt_file_with_prompt(&args.prompt_file, args.prompt.as_deref())?;
let inference_session_config = args.generate.inference_session_config();
Expand All @@ -46,46 +49,55 @@ fn infer(args: &cli_args::Infer) -> eyre::Result<()> {
let parameters = args.generate.inference_parameters(model.eot_token_id());

let mut rng = args.generate.rng();
let res = session.infer::<Infallible>(
model.as_ref(),
&mut rng,
&llm::InferenceRequest {
prompt: prompt.as_str().into(),
parameters: &parameters,
play_back_previous_tokens: session_loaded,
maximum_token_count: args.generate.num_predict,
},
// OutputRequest
&mut Default::default(),
|r| {
match r {
llm::InferenceResponse::PromptToken(t) if !args.hide_prompt => util::print_token(t),
llm::InferenceResponse::InferredToken(t) => util::print_token(t),
_ => {}

let span = tracing::trace_span!("infer");

span.in_scope(|| {
// do work inside the span...
let res = session.infer::<Infallible>(
model.as_ref(),
&mut rng,
&llm::InferenceRequest {
prompt: prompt.as_str().into(),
parameters: &parameters,
play_back_previous_tokens: session_loaded,
maximum_token_count: args.generate.num_predict,
},
// OutputRequest
&mut Default::default(),
|r| {
match r {
llm::InferenceResponse::PromptToken(t) if !args.hide_prompt => {
util::print_token(t)
}
llm::InferenceResponse::InferredToken(t) => util::print_token(t),
_ => {}
}
Ok(llm::InferenceFeedback::Continue)
},
);

println!();

match res {
Ok(stats) => {
if args.stats {
println!();
println!("{}", stats);
println!();
}
}
Ok(llm::InferenceFeedback::Continue)
},
);
println!();

match res {
Ok(stats) => {
if args.stats {
println!();
println!("{}", stats);
println!();
Err(llm::InferenceError::ContextFull) => {
log::warn!("Context window full, stopping inference.")
}
Err(llm::InferenceError::TokenizationFailed(err)) => {
log::error!("A tokenization-related failure occurred: {}", err);
}
Err(llm::InferenceError::UserCallback(_)) | Err(llm::InferenceError::EndOfText) => {
unreachable!("cannot fail")
}
}
Err(llm::InferenceError::ContextFull) => {
log::warn!("Context window full, stopping inference.")
}
Err(llm::InferenceError::TokenizationFailed(err)) => {
log::error!("A tokenization-related failure occurred: {}", err);
}
Err(llm::InferenceError::UserCallback(_)) | Err(llm::InferenceError::EndOfText) => {
unreachable!("cannot fail")
}
}
});

if let Some(session_path) = args.save_session.as_ref().or(args.persist_session.as_ref()) {
// Write the memory to the cache file
Expand Down
1 change: 1 addition & 0 deletions crates/llm-base/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ memmap2 = { workspace = true }
half = "2.2.1"
tokenizers = {version="0.13.3", default-features=false, features=["onig"]}
regex = "1.8"
tracing = { workspace = true }

[features]
tokenizers-remote = ["tokenizers/http"]
Expand Down
Loading