-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
23 changed files
with
926 additions
and
1,169 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1451,4 +1451,7 @@ openhathi | |
sarvam | ||
subtask | ||
acc | ||
OCRVQA | ||
OCRVQADataCollator | ||
ocrvqa | ||
langchain |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,47 +1,29 @@ | ||
# Llama Recipes: Examples to get started using the Llama models from Meta | ||
<!-- markdown-link-check-disable --> | ||
The 'llama-recipes' repository is a companion to the [Meta Llama](https://github.com/meta-llama/llama-models) models. We support the latest version, [Llama 3.1](https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/MODEL_CARD.md), in this repository. The goal is to provide a scalable library for fine-tuning Meta Llama models, along with some example scripts and notebooks to quickly get started with using the models in a variety of use-cases, including fine-tuning for domain adaptation and building LLM-based applications with Llama and other tools in the LLM ecosystem. The examples here showcase how to run Llama locally, in the cloud, and on-prem. | ||
The 'llama-recipes' repository is a companion to the [Meta Llama](https://github.com/meta-llama/llama-models) models. We support the latest version, [Llama 3.2 Vision](https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/MODEL_CARD_VISION.md) and [Llama 3.2 Text](https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/MODEL_CARD.md), in this repository. This repository contains example scripts and notebooks to get started with the models in a variety of use-cases, including fine-tuning for domain adaptation and building LLM-based applications with Llama and other tools in the LLM ecosystem. The examples here use Llama locally, in the cloud, and on-prem. | ||
|
||
<!-- markdown-link-check-enable --> | ||
> [!IMPORTANT] | ||
> Meta Llama 3.1 has a new prompt template and special tokens. | ||
> Llama 3.2 follows the same prompt template as Llama 3.1, with a new special token `<|image|>` representing the input image for the multimodal models. | ||
> | ||
> | Token | Description | | ||
> |---|---| | ||
> `<\|begin_of_text\|>` | Specifies the start of the prompt. | | ||
> `<\|image\|>` | Represents the image tokens passed as an input to Llama. | | ||
> `<\|eot_id\|>` | This token signifies the end of a turn i.e. the end of the model's interaction either with the user or tool executor. | | ||
> `<\|eom_id\|>` | End of Message. A message represents a possible stopping point where the model can inform the execution environment that a tool call needs to be made. | | ||
> `<\|python_tag\|>` | A special tag used in the model’s response to signify a tool call. | | ||
> `<\|finetune_right_pad_id\|>` | Used for padding text sequences in a batch to the same length. | | ||
> `<\|start_header_id\|>{role}<\|end_header_id\|>` | These tokens enclose the role for a particular message. The possible roles can be: system, user, assistant and ipython. | | ||
> `<\|end_of_text\|>` | This is equivalent to the EOS token. For multiturn-conversations it's usually unused, this token is expected to be generated only by the base models. | | ||
> | ||
> A multiturn-conversation with Meta Llama 3.1 that includes tool-calling follows this structure: | ||
> ``` | ||
> <|begin_of_text|><|start_header_id|>system<|end_header_id|> | ||
> | ||
> {{ system_prompt }}<|eot_id|><|start_header_id|>user<|end_header_id|> | ||
> | ||
> {{ user_message_1 }}<|eot_id|><|start_header_id|>assistant<|end_header_id|> | ||
> | ||
> <|python_tag|>{{ model_tool_call_1 }}<|eom_id|><|start_header_id|>ipython<|end_header_id|> | ||
> | ||
> {{ tool_response }}<|eot_id|><|start_header_id|>assistant<|end_header_id|> | ||
> | ||
> {{model_response_based_on_tool_response}}<|eot_id|> | ||
> ``` | ||
> Each message gets trailed by an `<|eot_id|>` token before a new header is started, signaling a role change. | ||
> | ||
> More details on the new tokenizer and prompt template can be found [here](https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1). | ||
> More details on the prompt templates for image reasoning, tool-calling and code interpreter can be found [on the documentation website](https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_2). | ||
|
||
> | ||
> [!NOTE] | ||
> The llama-recipes repository was recently refactored to promote a better developer experience of using the examples. Some files have been moved to new locations. The `src/` folder has NOT been modified, so the functionality of this repo and package is not impacted. | ||
> | ||
> Make sure you update your local clone by running `git pull origin main` | ||
|
||
## Table of Contents | ||
|
||
- [Llama Recipes: Examples to get started using the Meta Llama models from Meta](#llama-recipes-examples-to-get-started-using-the-llama-models-from-meta) | ||
- [Llama Recipes: Examples to get started using the Llama models from Meta](#llama-recipes-examples-to-get-started-using-the-llama-models-from-meta) | ||
- [Table of Contents](#table-of-contents) | ||
- [Getting Started](#getting-started) | ||
- [Prerequisites](#prerequisites) | ||
|
@@ -117,23 +99,21 @@ pip install -e .[tests,auditnlg,vllm] | |
``` | ||
|
||
|
||
### Getting the Meta Llama models | ||
You can find Meta Llama models on Hugging Face hub [here](https://huggingface.co/meta-llama), **where models with `hf` in the name are already converted to Hugging Face checkpoints so no further conversion is needed**. The conversion step below is only for original model weights from Meta that are hosted on Hugging Face model hub as well. | ||
### Getting the Llama models | ||
You can find Llama models on Hugging Face hub [here](https://huggingface.co/meta-llama), **where models with `hf` in the name are already converted to Hugging Face checkpoints so no further conversion is needed**. The conversion step below is only for original model weights from Meta that are hosted on Hugging Face model hub as well. | ||
|
||
#### Model conversion to Hugging Face | ||
The recipes and notebooks in this folder are using the Meta Llama model definition provided by Hugging Face's transformers library. | ||
Given that the original checkpoint resides under models/7B you can install all requirements and convert the checkpoint with: | ||
If you have the model checkpoints downloaded from the Meta website, you can convert it to the Hugging Face format with: | ||
|
||
```bash | ||
## Install Hugging Face Transformers from source | ||
pip freeze | grep transformers ## verify it is version 4.31.0 or higher | ||
pip freeze | grep transformers ## verify it is version 4.45.0 or higher | ||
|
||
git clone [email protected]:huggingface/transformers.git | ||
cd transformers | ||
pip install protobuf | ||
python src/transformers/models/llama/convert_llama_weights_to_hf.py \ | ||
--input_dir /path/to/downloaded/llama/weights --model_size 7B --output_dir /output/path | ||
--input_dir /path/to/downloaded/llama/weights --model_size 3B --output_dir /output/path | ||
``` | ||
|
||
|
||
|
@@ -196,6 +176,8 @@ Please read [CONTRIBUTING.md](CONTRIBUTING.md) for details on our code of conduc | |
## License | ||
<!-- markdown-link-check-disable --> | ||
|
||
See the License file for Meta Llama 3.2 [here](https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/LICENSE) and Acceptable Use Policy [here](https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/USE_POLICY.md) | ||
|
||
See the License file for Meta Llama 3.1 [here](https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/LICENSE) and Acceptable Use Policy [here](https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/USE_POLICY.md) | ||
|
||
See the License file for Meta Llama 3 [here](https://github.com/meta-llama/llama-models/blob/main/models/llama3/LICENSE) and Acceptable Use Policy [here](https://github.com/meta-llama/llama-models/blob/main/models/llama3/USE_POLICY.md) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# This software may be used and distributed according to the terms of the Llama 3 Community License Agreement. | ||
|
||
|
||
import copy | ||
from datasets import load_dataset | ||
import itertools | ||
import torch | ||
|
||
# check system prompt token seq or user prompt token seq is in the current token list | ||
def check_header(targets,seq): | ||
for i in range(len(seq)-3): | ||
if seq[i:i+3] in targets: | ||
return True | ||
return False | ||
def replace_target(target,seq): | ||
for i in range(len(seq)-3): | ||
if seq[i:i+3] == target: | ||
seq[i],seq[i+1],seq[i+2] = -100,-100,-100 | ||
return seq | ||
def tokenize_dialogs(dialogs, images, processor): | ||
text_prompt = processor.apply_chat_template(dialogs) | ||
batch = processor(images=images, text=text_prompt,padding = True, return_tensors="pt") | ||
label_list = [] | ||
for i in range(len(batch["input_ids"])): | ||
dialog_tokens = batch["input_ids"][i].tolist() | ||
labels = copy.copy(dialog_tokens) | ||
eot_indices = [i for i,n in enumerate(labels) if n == 128009] | ||
last_idx = 0 | ||
# system prompt header "<|start_header_id|>system<|end_header_id|>" has been tokenized to [128006, 9125, 128007] | ||
# user prompt header "<|start_header_id|>user<|end_header_id|>" has been tokenized to [128006, 882, 128007] | ||
prompt_header_seqs = [[128006, 9125, 128007],[128006, 882, 128007]] | ||
for n, idx in enumerate(eot_indices): | ||
current_seq = labels[last_idx:idx+1] | ||
if check_header(prompt_header_seqs,current_seq): | ||
# found prompt header, indicating that this seq should be masked | ||
labels[last_idx:idx+1] = [-100] * (idx-last_idx+1) | ||
else: | ||
last_idx = idx+1 | ||
# Mask all the assistant header prompt <|start_header_id|>assistant<|end_header_id|>, which has been tokenized to [128006, 78191, 128007] | ||
assistant_header_seq = [128006, 78191, 128007] | ||
labels = replace_target(assistant_header_seq,labels) | ||
# Mask the padding token and image token 128256 | ||
for i in range(len(labels)): | ||
if labels[i] == processor.tokenizer.pad_token_id or labels[i] == 128256: # 128256 is image token index | ||
labels[i] = -100 | ||
label_list.append(labels) | ||
batch["labels"] = torch.tensor(label_list) | ||
return batch | ||
|
||
|
||
def get_custom_dataset(dataset_config, processor, split, split_ratio=0.9): | ||
# load_dataset will return DatasetDict that contains all the data in the train set | ||
dataset_dict = load_dataset("HuggingFaceM4/the_cauldron", name="ocrvqa") | ||
dataset = dataset_dict['train'] | ||
# Comment out the following line to use the full dataset, for quick testing only use 2000 samples | ||
dataset = dataset.select(range(2000)) | ||
dataset = dataset.train_test_split(test_size=1-split_ratio, shuffle=True, seed=42)[split] | ||
return dataset | ||
|
||
class OCRVQADataCollator: | ||
def __init__(self, processor): | ||
self.processor = processor | ||
self.processor.tokenizer.padding_side = "right" # during training, one always uses padding on the right | ||
def __call__(self, samples): | ||
dialogs,images = [],[] | ||
for sample in samples: | ||
image_list,sample_list = sample["images"],sample["texts"] | ||
if len(image_list) > 1: | ||
raise ValueError("Only support one image per sample") | ||
image = image_list[0].convert("RGB") # only use the first image | ||
dialog = [] | ||
for sample_dict in sample_list: | ||
if not dialog: | ||
# only append image to the first sentence | ||
dialog += [ | ||
{"role":"user","content":[{"type": "image"},{"type": "text", "text": sample_dict["user"].strip()}]}, | ||
{"role":"assistant","content":[{"type": "text", "text": sample_dict["assistant"].strip()}]} | ||
] | ||
|
||
else: | ||
dialog += [ | ||
{"role":"user","content":[{"type": "text", "text": sample_dict["user"].strip()}]}, | ||
{"role":"assistant","content":[{"type": "text", "text": sample_dict["assistant"].strip()}]} | ||
] | ||
dialogs.append(dialog) | ||
images.append([image]) | ||
return tokenize_dialogs(dialogs,images, self.processor) | ||
def get_data_collator(processor): | ||
return OCRVQADataCollator(processor) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
## Fine-Tuning Meta Llama Multi Modal Models recipe | ||
This recipe steps you through how to finetune a Llama 3.2 vision model on the OCR VQA task using the [OCRVQA](https://huggingface.co/datasets/HuggingFaceM4/the_cauldron/viewer/ocrvqa?row=0) dataset. | ||
|
||
**Disclaimer**: As our vision models already have a very good OCR ability, here we just use the OCRVQA dataset only for demonstration purposes of the required steps for fine-tuning our vision models with llama-recipes. | ||
|
||
### Fine-tuning steps | ||
|
||
We created an example script [ocrvqa_dataset.py](./datasets/ocrvqa_dataset.py) that can load the OCRVQA dataset with `get_custom_dataset` function, then provide OCRVQADataCollator class to process the image dataset. | ||
|
||
For **full finetuning with FSDP**, we can run the following code: | ||
|
||
```bash | ||
torchrun --nnodes 1 --nproc_per_node 4 recipes/quickstart/finetuning/finetuning.py --enable_fsdp --lr 1e-5 --num_epochs 3 --batch_size_training 2 --model_name meta-llama/Llama-3.2-11B-Vision-Instruct --dist_checkpoint_root_folder ./finetuned_model --dist_checkpoint_folder fine-tuned --use_fast_kernels --dataset "custom_dataset" --custom_dataset.test_split "test" --custom_dataset.file "recipes/quickstart/finetuning/datasets/ocrvqa_dataset.py" --run_validation True --batching_strategy padding | ||
``` | ||
|
||
For **LoRA finetuning with FSDP**, we can run the following code: | ||
|
||
```bash | ||
torchrun --nnodes 1 --nproc_per_node 4 recipes/quickstart/finetuning/finetuning.py --enable_fsdp --lr 1e-5 --num_epochs 3 --batch_size_training 2 --model_name meta-llama/Llama-3.2-11B-Vision-Instruct --dist_checkpoint_root_folder ./finetuned_model --dist_checkpoint_folder fine-tuned --use_fast_kernels --dataset "custom_dataset" --custom_dataset.test_split "test" --custom_dataset.file "recipes/quickstart/finetuning/datasets/ocrvqa_dataset.py" --run_validation True --batching_strategy padding --use_peft --peft_method lora | ||
``` | ||
**Note**: `--batching_strategy padding` is needed as the vision model will not work with `packing` method. | ||
|
||
For more details about the finetuning configurations, please read the [finetuning readme](./README.md). | ||
|
||
### How to use a custom dataset to fine-tune vision model | ||
|
||
In order to use a custom dataset, please follow the steps below: | ||
|
||
1. Create a new dataset python file under `recipes/quickstart/finetuning/dataset` folder. | ||
2. In this python file, you need to define a `get_custom_dataset(dataset_config, processor, split, split_ratio=0.9)` function that handles the data loading. | ||
3. In this python file, you need to define a `get_data_collator(processor)` that returns a custom data collator that can be used by the Pytorch Data Loader. | ||
4. This custom data collator class must have a `__call__(self, samples)` function that converts the image and text samples into the actual inputs that vision model expects. | ||
5. Run the `torchrun` commend from above section, please change the `--custom_dataset.file` to the new dataset python file, adjust the learning rate accordingly. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
66 changes: 66 additions & 0 deletions
66
recipes/quickstart/inference/local_inference/multi_modal_infer.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import os | ||
import sys | ||
import argparse | ||
from PIL import Image as PIL_Image | ||
import torch | ||
from transformers import MllamaForConditionalGeneration, MllamaProcessor | ||
|
||
|
||
# Constants | ||
DEFAULT_MODEL = "meta-llama/Llama-3.2-11B-Vision-Instruct" | ||
|
||
|
||
def load_model_and_processor(model_name: str, hf_token: str): | ||
""" | ||
Load the model and processor based on the 11B or 90B model. | ||
""" | ||
model = MllamaForConditionalGeneration.from_pretrained(model_name, device_map="auto", torch_dtype=torch.bfloat16, token=hf_token) | ||
processor = MllamaProcessor.from_pretrained(model_name, token=hf_token) | ||
return model, processor | ||
|
||
|
||
def process_image(image_path: str) -> PIL_Image.Image: | ||
""" | ||
Open and convert an image from the specified path. | ||
""" | ||
if not os.path.exists(image_path): | ||
print(f"The image file '{image_path}' does not exist.") | ||
sys.exit(1) | ||
with open(image_path, "rb") as f: | ||
return PIL_Image.open(f).convert("RGB") | ||
|
||
|
||
def generate_text_from_image(model, processor, image, prompt_text: str, temperature: float, top_p: float): | ||
""" | ||
Generate text from an image using the model and processor. | ||
""" | ||
conversation = [ | ||
{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt_text}]} | ||
] | ||
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False) | ||
inputs = processor(prompt, image, return_tensors="pt").to(model.device) | ||
output = model.generate(**inputs, temperature=temperature, top_p=top_p, max_new_tokens=512) | ||
return processor.decode(output[0])[len(prompt):] | ||
|
||
|
||
def main(image_path: str, prompt_text: str, temperature: float, top_p: float, model_name: str, hf_token: str): | ||
""" | ||
Call all the functions. | ||
""" | ||
model, processor = load_model_and_processor(model_name, hf_token) | ||
image = process_image(image_path) | ||
result = generate_text_from_image(model, processor, image, prompt_text, temperature, top_p) | ||
print("Generated Text: " + result) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description="Generate text from an image and prompt using the 3.2 MM Llama model.") | ||
parser.add_argument("--image_path", type=str, help="Path to the image file") | ||
parser.add_argument("--prompt_text", type=str, help="Prompt text to describe the image") | ||
parser.add_argument("--temperature", type=float, default=0.7, help="Temperature for generation (default: 0.7)") | ||
parser.add_argument("--top_p", type=float, default=0.9, help="Top p for generation (default: 0.9)") | ||
parser.add_argument("--model_name", type=str, default=DEFAULT_MODEL, help=f"Model name (default: '{DEFAULT_MODEL}')") | ||
parser.add_argument("--hf_token", type=str, required=True, help="Hugging Face token for authentication") | ||
|
||
args = parser.parse_args() | ||
main(args.image_path, args.prompt_text, args.temperature, args.top_p, args.model_name, args.hf_token) |
Oops, something went wrong.