Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama 3.1 #38

Merged
merged 15 commits into from
Sep 3, 2024
11 changes: 7 additions & 4 deletions llama-burn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@ tiny = ["dep:tokenizers"]
# Example feature flags (backend selection)
tch-cpu = ["burn/tch"]
tch-gpu = ["burn/tch"]
cuda = ["burn/cuda-jit"]
wgpu = ["burn/wgpu"]

# To import pytorch weights
import = ["burn-import"]

[dependencies]
# Note: default-features = false is needed to disable std
burn = { git = "https://github.com/tracel-ai/burn", rev = "a53f459f205889a22ecea3713bbae12d3de7eb0c", default-features = false }
burn-import = { git = "https://github.com/tracel-ai/burn", rev = "a53f459f205889a22ecea3713bbae12d3de7eb0c" }
burn = { version = "0.14.0", default-features = false, features = ["std"] }
burn-import = { version = "0.14.0", optional = true }
itertools = { version = "0.12.1", default-features = false, features = [
"use_alloc",
] }
Expand All @@ -46,5 +49,5 @@ rand = { version = "0.8.5", default-features = false, features = [
] } # std_rng is for no_std

[dev-dependencies]
burn = { git = "https://github.com/tracel-ai/burn", rev = "a53f459f205889a22ecea3713bbae12d3de7eb0c" }
burn = { version = "0.14.0", default-features = false }
clap = { version = "4.5.4", features = ["derive"] }
2 changes: 2 additions & 0 deletions llama-burn/NOTICES.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ derived from. The use of the following resources complies with the licenses prov
The model implementation was adapted from the original
[Llama 3 implementation](https://github.com/meta-llama/llama3), which is distributed under the
[Meta Llama 3 Community License Agreement](https://github.com/meta-llama/llama3/blob/main/LICENSE).
The Llama 3.1 model is distributed under the
[Llama 3.1 Community License Agreement](https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/LICENSE).

The TinyLlama implementation is derived from the same code, but its weights and tokenizers were
adapted from the [original implementation](https://github.com/jzhang38/TinyLlama) distributed under
Expand Down
32 changes: 25 additions & 7 deletions llama-burn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

The popular Llama LLM is here!

This repository contains the [Llama 3](https://github.com/meta-llama/llama3) and
This repository contains the [Llama 3.1](https://github.com/meta-llama/llama-models/),
[Llama 3](https://github.com/meta-llama/llama3) and
[TinyLlama](https://github.com/jzhang38/TinyLlama) implementations with their corresponding
tokenizers. You can find the [Burn](https://github.com/tracel-ai/burn) implementation for the Llama
variants in [src/llama.rs](src/llama.rs).
Expand All @@ -23,9 +24,7 @@ llama-burn = { git = "https://github.com/tracel-ai/models", package = "llama-bur
If you want to use Llama 3 or TinyLlama (including pre-trained weights if default features are
active), enable the corresponding feature flag.

> **Important:** these features require `std`. Note that the weights have been saved in the binary
> format, which is more compact and faster to save & load, but might not be compatible in future
> versions if the Burn data schema were to evolve.
> **Important:** these features require `std`.

#### Llama 3

Expand All @@ -47,7 +46,7 @@ The [chat completion example](examples/chat.rs) initializes a Llama model from t
file and generates a sequence of text based on the input prompt. The instruction-tuned model is
loaded for dialogue applications, so the prompt is automatically formatted for chat completion.

The example can be executed on the `tch` backend (CUDA or CPU) or `wgpu`.
The example can be executed on the `tch` backend (CUDA or CPU), `cuda` or `wgpu`.

| Argument | Description |
| :-------------- | :------------------------------------------------------------------------------------------------------------- |
Expand Down Expand Up @@ -83,9 +82,16 @@ Using the `wgpu` backend:
cargo run --release --features llama3,wgpu --example chat
```

Using the `cuda` backend:

```sh
cargo run --release --features llama3,cuda --example chat
```

**Built with Meta Llama 3.** This example uses the
[Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)
instruction-tuned model. Note that the [base pre-trained Llama-3 model](./src/pretrained.rs#L77) is
[Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct) (default)
and [Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)
instruction-tuned models. Note that the [base pre-trained Llama-3 model](./src/pretrained.rs#L77) is
also available if you wish to use it in your application.

#### TinyLlama
Expand All @@ -109,6 +115,18 @@ Using the `wgpu` backend:
cargo run --release --features tiny,wgpu --example chat
```

Using the `cuda` backend:

```sh
cargo run --release --features tiny,cuda --example chat
```

This example uses the
[TinyLlama-1.1B-Chat-v1.0](https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0)
instruction-tuned model based on the Llama2 architecture and tokenizer.

## Known Issues

Based on your hardware and the model selected, the `wgpu` backend might not be able to successfully
run the model due to the current memory management strategy. With `cuda` selected, the precision is
set to `f32` due to compilation errors with `f16`.
50 changes: 46 additions & 4 deletions llama-burn/examples/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ use llama_burn::{
tokenizer::Tokenizer,
};

#[cfg(feature = "llama3")]
use clap::ValueEnum;

const DEFAULT_PROMPT: &str = "How many helicopters can a human eat in one sitting?";

#[derive(Parser, Debug)]
Expand All @@ -26,7 +29,7 @@ pub struct Config {
max_seq_len: usize,

/// The number of new tokens to generate (i.e., the number of generation steps to take).
#[arg(long, short = 'n', default_value_t = 50)]
#[arg(long, short = 'n', default_value_t = 65)]
sample_len: usize,

/// The seed to use when generating random samples.
Expand All @@ -36,6 +39,23 @@ pub struct Config {
/// The input prompt.
#[arg(short, long, default_value_t = String::from(DEFAULT_PROMPT))]
prompt: String,

/// The Llama 3 model version.
#[cfg(feature = "llama3")]
#[arg(long, default_value = "llama-3.1-8b-instruct")]
version: Llama3,
}

#[cfg(feature = "llama3")]
#[derive(Clone, Debug, ValueEnum)]
/// Llama-3 model variants to load.
enum Llama3 {
/// Llama-3-8B-Instruct.
#[value(name = "llama-3-8b-instruct")]
V3Instruct,
/// Llama-3.1-8B-Instruct.
#[value(name = "llama-3.1-8b-instruct")]
V31Instruct,
}

pub fn generate<B: Backend, T: Tokenizer>(
Expand Down Expand Up @@ -76,7 +96,7 @@ pub fn chat<B: Backend>(args: Config, device: Device<B>) {
#[cfg(feature = "tiny")]
{
// TinyLlama-1.1B Chat v1.0
let mut llama = LlamaConfig::tiny_llama_pretrained::<B>(&device).unwrap();
let mut llama = LlamaConfig::tiny_llama_pretrained::<B>(args.max_seq_len, &device).unwrap();
println!("Processing prompt: {}", prompt);

// Prompt formatting for chat model
Expand All @@ -95,8 +115,15 @@ pub fn chat<B: Backend>(args: Config, device: Device<B>) {

#[cfg(feature = "llama3")]
{
// Llama-3-8B-Instruct
let mut llama = LlamaConfig::llama3_8b_pretrained::<B>(true, &device).unwrap();
// Llama-3-8B-Instruct or Llama-3.1-8B-Instruct
let mut llama = match args.version {
Llama3::V3Instruct => {
LlamaConfig::llama3_8b_pretrained::<B>(args.max_seq_len, &device).unwrap()
}
Llama3::V31Instruct => {
LlamaConfig::llama3_1_8b_pretrained::<B>(args.max_seq_len, &device).unwrap()
}
};
println!("Processing prompt: {}", prompt);

// Prompt formatting for chat model
Expand Down Expand Up @@ -156,6 +183,19 @@ mod wgpu {
}
}

#[cfg(feature = "cuda")]
mod cuda {
use super::*;
use burn::backend::{cuda_jit::CudaDevice, CudaJit};

pub fn run(args: Config) {
let device = CudaDevice::default();

// NOTE: compilation errors in f16
chat::<CudaJit<f32, i32>>(args, device);
}
}

pub fn main() {
// Parse arguments
let args = Config::parse();
Expand All @@ -166,4 +206,6 @@ pub fn main() {
tch_cpu::run(args);
#[cfg(feature = "wgpu")]
wgpu::run(args);
#[cfg(feature = "cuda")]
cuda::run(args);
}
11 changes: 0 additions & 11 deletions llama-burn/src/cache.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
use burn::tensor::{backend::Backend, Tensor};

/// All Llama-3 models support sequence length up to 8192 tokens.
pub(crate) const MAX_SEQ_LEN: usize = 8192;

// /// All Llama-2 models support sequence length up to 4096 tokens.
// pub(crate) const MAX_SEQ_LEN_V2: usize = 4096;

// Adapted from `burn::nn::cache`
enum CacheState<T> {
Value(T),
Expand Down Expand Up @@ -39,11 +33,6 @@ pub(crate) struct AutoregressiveCache<B: Backend> {
impl<B: Backend> AutoregressiveCache<B> {
/// Creates a new empty cache.
pub fn new(max_seq_len: usize) -> Self {
assert!(
max_seq_len <= MAX_SEQ_LEN,
"Maximum sequence length must not exceed {MAX_SEQ_LEN}"
);

Self {
cache: TensorCache::empty(),
max_seq_len,
Expand Down
Loading