A single-file implementation of LLaMA 3, with support for jitting, KV caching and prompting.
The original implementation can be found at https://github.com/meta-llama/llama3.
First of all, install Python 3.8 or later. Open a terminal and run:
pip install git+https://github.com/lucadellalib/llama3@main#egg=llama3[all]
First of all, install Python 3.8 or later.
Clone or download and extract the repository, navigate to <path-to-repository>
, open a terminal and run:
# Install the package locally in editable mode
pip install -e .[all]
import torch
from llama3 import LlamaDecoder
B, H, K = 3, 512, 30
model = LlamaDecoder(K)
print(model)
# Process 50 timesteps
input = torch.randn(B, 50, H)
output, state = model(input)
print(output.shape)
# Process 2 additional timesteps
input = torch.randn(B, 2, H)
output, state = model(input, state=state)
print(output.shape)
# JIT the model
model_jit = model.jit()
output_jit, state_jit = model_jit(input)
print(output_jit.shape)
First of all, download the model weights and tokenizer (pretrained variant, e.g. Llama3.2-1B). Check the official website for instructions on how to download the models.
Navigate to <path-to-repository>
, open a terminal and run:
python main.py --checkpoint_path <path-to-checkpoint>
It is recommended to run this script on a machine with at least 1 GPU.