Skip to content

Commit

Permalink
Merge pull request #35 from tracel-ai/llama
Browse files Browse the repository at this point in the history
Llama
  • Loading branch information
nathanielsimard authored Aug 12, 2024
2 parents 7bf239c + c3ec71e commit 7ebd9e3
Show file tree
Hide file tree
Showing 17 changed files with 1,676 additions and 8 deletions.
15 changes: 8 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@ examples constructed using the [Burn](https://github.com/burn-rs/burn) deep lear

## Collection of Official Models

| Model | Description | Repository Link |
|-------------------------------------------------|----------------------------------------------------------|------------------------------------------------|
| [MobileNetV2](https://arxiv.org/abs/1801.04381) | A CNN model targeted at mobile devices. | [mobilenetv2-burn](mobilenetv2-burn/README.md) |
| [SqueezeNet](https://arxiv.org/abs/1602.07360) | A small CNN-based model for image classification. | [squeezenet-burn](squeezenet-burn/README.md) |
| [ResNet](https://arxiv.org/abs/1512.03385) | A CNN based on residual blocks with skip connections. | [resnet-burn](resnet-burn/README.md) |
| [RoBERTa](https://arxiv.org/abs/1907.11692) | A robustly optimized BERT pretraining approach. | [bert-burn](bert-burn/README.md) |
| [YOLOX](https://arxiv.org/abs/2107.08430) | A single-stage object detector based on the YOLO series. | [yolox-burn](yolox-burn/README.md) |
| Model | Description | Repository Link |
|-------------------------------------------------|----------------------------------------------------------|---------------------------------------|
| [Llama](https://github.com/meta-llama/llama3) | Llama 3 and TinyLlama large language models. | [llama-burn](llama-burn/) |
| [MobileNetV2](https://arxiv.org/abs/1801.04381) | A CNN model targeted at mobile devices. | [mobilenetv2-burn](mobilenetv2-burn/) |
| [SqueezeNet](https://arxiv.org/abs/1602.07360) | A small CNN-based model for image classification. | [squeezenet-burn](squeezenet-burn/) |
| [ResNet](https://arxiv.org/abs/1512.03385) | A CNN based on residual blocks with skip connections. | [resnet-burn](resnet-burn/) |
| [RoBERTa](https://arxiv.org/abs/1907.11692) | A robustly optimized BERT pretraining approach. | [bert-burn](bert-burn/) |
| [YOLOX](https://arxiv.org/abs/2107.08430) | A single-stage object detector based on the YOLO series. | [yolox-burn](yolox-burn/) |

## Community Contributions

Expand Down
50 changes: 50 additions & 0 deletions llama-burn/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
[package]
authors = ["guillaumelagrange <[email protected]>"]
license = "MIT OR Apache-2.0"
name = "llama-burn"
version = "0.1.0"
edition = "2021"
description = "Llama 3 large language model with Burn"

[features]
default = ["pretrained"]
pretrained = ["burn/network", "dep:dirs"]

llama3 = ["dep:tiktoken-rs", "dep:rustc-hash", "dep:base64"]
tiny = ["dep:tokenizers"]

# Example feature flags (backend selection)
tch-cpu = ["burn/tch"]
tch-gpu = ["burn/tch"]
wgpu = ["burn/wgpu"]

[dependencies]
# Note: default-features = false is needed to disable std
burn = { git = "https://github.com/tracel-ai/burn", rev = "a53f459f205889a22ecea3713bbae12d3de7eb0c", default-features = false }
burn-import = { git = "https://github.com/tracel-ai/burn", rev = "a53f459f205889a22ecea3713bbae12d3de7eb0c" }
itertools = { version = "0.12.1", default-features = false, features = [
"use_alloc",
] }
dirs = { version = "5.0.1", optional = true }
serde = { version = "1.0.192", default-features = false, features = [
"derive",
"alloc",
] } # alloc is for no_std, derive is needed

# Tiktoken tokenizer (llama 3)
tiktoken-rs = { version = "0.5.8", optional = true }
base64 = { version = "0.22.1", optional = true }
rustc-hash = { version = "1.1.0", optional = true }

# SentencePiece tokenizer (tiny llama / llama 2)
tokenizers = { version = "0.19.1", default-features = false, features = [
"onig",
], optional = true }

rand = { version = "0.8.5", default-features = false, features = [
"std_rng",
] } # std_rng is for no_std

[dev-dependencies]
burn = { git = "https://github.com/tracel-ai/burn", rev = "a53f459f205889a22ecea3713bbae12d3de7eb0c" }
clap = { version = "4.5.4", features = ["derive"] }
14 changes: 14 additions & 0 deletions llama-burn/NOTICES.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# NOTICES AND INFORMATION

This file contains notices and information required by libraries that this repository copied or
derived from. The use of the following resources complies with the licenses provided.

## Implementation

The model implementation was adapted from the original
[Llama 3 implementation](https://github.com/meta-llama/llama3), which is distributed under the
[Meta Llama 3 Community License Agreement](https://github.com/meta-llama/llama3/blob/main/LICENSE).

The TinyLlama implementation is derived from the same code, but its weights and tokenizers were
adapted from the [original implementation](https://github.com/jzhang38/TinyLlama) distributed under
the [Apache 2.0](https://github.com/jzhang38/TinyLlama/blob/main/LICENSE) open source license.
114 changes: 114 additions & 0 deletions llama-burn/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Llama Burn

<img src="./assets/llama-burn.jpeg" alt="An image of a llama surrounded by fiery colors and a gust of fire" width="500px"/>

The popular Llama LLM is here!

This repository contains the [Llama 3](https://github.com/meta-llama/llama3) and
[TinyLlama](https://github.com/jzhang38/TinyLlama) implementations with their corresponding
tokenizers. You can find the [Burn](https://github.com/tracel-ai/burn) implementation for the Llama
variants in [src/llama.rs](src/llama.rs).

## Usage

### `Cargo.toml`

Add this to your `Cargo.toml`:

```toml
[dependencies]
llama-burn = { git = "https://github.com/tracel-ai/models", package = "llama-burn", default-features = false }
```

If you want to use Llama 3 or TinyLlama (including pre-trained weights if default features are
active), enable the corresponding feature flag.

> **Important:** these features require `std`. Note that the weights have been saved in the binary
> format, which is more compact and faster to save & load, but might not be compatible in future
> versions if the Burn data schema were to evolve.
#### Llama 3

```toml
[dependencies]
llama-burn = { git = "https://github.com/tracel-ai/models", package = "llama-burn", features = ["llama3"] }
```

#### TinyLlama

```toml
[dependencies]
llama-burn = { git = "https://github.com/tracel-ai/models", package = "llama-burn", features = ["tiny"] }
```

### Example Usage

The [chat completion example](examples/chat.rs) initializes a Llama model from the provided weights
file and generates a sequence of text based on the input prompt. The instruction-tuned model is
loaded for dialogue applications, so the prompt is automatically formatted for chat completion.

The example can be executed on the `tch` backend (CUDA or CPU) or `wgpu`.

| Argument | Description |
| :-------------- | :------------------------------------------------------------------------------------------------------------- |
| `-p` | The prompt or question to pass to the LLM (default: `"How many helicopters can a human eat in one sitting?"`). |
| `-n` | The number of new tokens to generate (default: `50`). |
| `--top-p` | Top-p probability threshold (default: `0.9`). |
| `--temperature` | Temperature value for controlling randomness in sampling. (default: `0.6`). |
| `--max-seq-len` | Maximum sequence length for input text. (default: `128`). |
| `--seed` | The seed to use when generating random samples.. (default: `42`). |

Any of the commands below can be used by appending any of the listed arguments by appending
`[-- <arguments>]`. For example, you can provided your own prompt/question
`-- -p "How many llamas does it take to change a lightbulb?"`.

#### Llama 3

Using the `tch` backend with CUDA:

```sh
export TORCH_CUDA_VERSION=cu121
cargo run --release --features llama3,tch-gpu --example chat
```

Using the `tch` backend with CPU:

```sh
cargo run --release --features llama3,tch-cpu --example chat
```

Using the `wgpu` backend:

```sh
cargo run --release --features llama3,wgpu --example chat
```

**Built with Meta Llama 3.** This example uses the
[Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)
instruction-tuned model. Note that the [base pre-trained Llama-3 model](./src/pretrained.rs#L77) is
also available if you wish to use it in your application.

#### TinyLlama

Using the `tch` backend with CUDA:

```sh
export TORCH_CUDA_VERSION=cu121
cargo run --release --features tiny,tch-gpu --example chat
```

Using the `tch` backend with CPU:

```sh
cargo run --release --features tiny,tch-cpu --example chat
```

Using the `wgpu` backend:

```sh
cargo run --release --features tiny,wgpu --example chat
```

This example uses the
[TinyLlama-1.1B-Chat-v1.0](https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0)
instruction-tuned model based on the Llama2 architecture and tokenizer.
Binary file added llama-burn/assets/llama-burn.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
169 changes: 169 additions & 0 deletions llama-burn/examples/chat.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
use std::time::Instant;

use burn::tensor::{backend::Backend, Device};
use clap::Parser;
use llama_burn::{
llama::{Llama, LlamaConfig},
sampling::{Sampler, TopP},
tokenizer::Tokenizer,
};

const DEFAULT_PROMPT: &str = "How many helicopters can a human eat in one sitting?";

#[derive(Parser, Debug)]
#[command(version, about, long_about = None)]
pub struct Config {
/// Top-p probability threshold.
#[arg(long, default_value_t = 0.9)]
top_p: f64,

/// Temperature value for controlling randomness in sampling.
#[arg(long, default_value_t = 0.6)]
temperature: f64,

/// Maximum sequence length for input text.
#[arg(long, default_value_t = 128)]
max_seq_len: usize,

/// The number of new tokens to generate (i.e., the number of generation steps to take).
#[arg(long, short = 'n', default_value_t = 50)]
sample_len: usize,

/// The seed to use when generating random samples.
#[arg(long, default_value_t = 42)]
seed: u64,

/// The input prompt.
#[arg(short, long, default_value_t = String::from(DEFAULT_PROMPT))]
prompt: String,
}

pub fn generate<B: Backend, T: Tokenizer>(
llama: &mut Llama<B, T>,
prompt: &str,
sample_len: usize,
temperature: f64,
sampler: &mut Sampler,
) {
let now = Instant::now();
let generated = llama.generate(prompt, sample_len, temperature, sampler);
let elapsed = now.elapsed().as_secs();

println!("> {}\n", generated.text);
println!(
"{} tokens generated ({:.4} tokens/s)\n",
generated.tokens,
generated.tokens as f64 / generated.time
);

println!(
"Generation completed in {}m{}s",
(elapsed / 60),
elapsed % 60
);
}

pub fn chat<B: Backend>(args: Config, device: Device<B>) {
let mut prompt = args.prompt;

// Sampling strategy
let mut sampler = if args.temperature > 0.0 {
Sampler::TopP(TopP::new(args.top_p, args.seed))
} else {
Sampler::Argmax
};

#[cfg(feature = "tiny")]
{
// TinyLlama-1.1B Chat v1.0
let mut llama = LlamaConfig::tiny_llama_pretrained::<B>(&device).unwrap();
println!("Processing prompt: {}", prompt);

// Prompt formatting for chat model
prompt = format!(
"<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate</s>\n<|user|>\n{prompt}</s>\n<|assistant|>\n"
);

generate(
&mut llama,
&prompt,
args.sample_len,
args.temperature,
&mut sampler,
);
}

#[cfg(feature = "llama3")]
{
// Llama-3-8B-Instruct
let mut llama = LlamaConfig::llama3_8b_pretrained::<B>(true, &device).unwrap();
println!("Processing prompt: {}", prompt);

// Prompt formatting for chat model
prompt = format!(
"<|start_header_id|>system<|end_header_id|>\n\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
);

generate(
&mut llama,
&prompt,
args.sample_len,
args.temperature,
&mut sampler,
);
}
}

#[cfg(feature = "tch-gpu")]
mod tch_gpu {
use super::*;
use burn::{
backend::{libtorch::LibTorchDevice, LibTorch},
tensor::f16,
};

pub fn run(args: Config) {
#[cfg(not(target_os = "macos"))]
let device = LibTorchDevice::Cuda(0);
#[cfg(target_os = "macos")]
let device = LibTorchDevice::Mps;

chat::<LibTorch<f16>>(args, device);
}
}

#[cfg(feature = "tch-cpu")]
mod tch_cpu {
use super::*;
use burn::backend::{libtorch::LibTorchDevice, LibTorch};

pub fn run(args: Config) {
let device = LibTorchDevice::Cpu;

chat::<LibTorch>(args, device);
}
}

#[cfg(feature = "wgpu")]
mod wgpu {
use super::*;
use burn::backend::wgpu::{Wgpu, WgpuDevice};

pub fn run(args: Config) {
let device = WgpuDevice::default();

chat::<Wgpu>(args, device);
}
}

pub fn main() {
// Parse arguments
let args = Config::parse();

#[cfg(feature = "tch-gpu")]
tch_gpu::run(args);
#[cfg(feature = "tch-cpu")]
tch_cpu::run(args);
#[cfg(feature = "wgpu")]
wgpu::run(args);
}
Loading

0 comments on commit 7ebd9e3

Please sign in to comment.