Skip to content

Commit

Permalink
change the export code of Llama 3 to be very GPT-2 friendly, using a …
Browse files Browse the repository at this point in the history
…combination of 3 hacks. this will make it so that we have to change very little code on the C side
  • Loading branch information
karpathy committed Sep 13, 2024
1 parent 01bc4c6 commit b883560
Showing 1 changed file with 38 additions and 2 deletions.
40 changes: 38 additions & 2 deletions train_llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,25 +847,61 @@ def write_bf16(tensor, file):

def write_tensors(model_tensors, L, file, dtype):
# writes LLaMA 3 model's weights to a binary file
# things get a bit more complicated though:
# 1) We want to maintain the ability to finetune just the biases in the C code
# and also GPT-2 supported biases and we want to touch as little code as possible.
# => We will generate biases of all zeros and write them here. It's very little data.
# 2) We want to exactly preserve the GPT-2 code paths, so we can't have SwiGLU using two
# separate nn.Linear layers c_fc and c_fc2. We will merge them into a single c_fc layer.
# Then later in the C code, we do pointer arithmetic to recover them fully internal to
# the SwiGLU layer
# 3) Llama 3 does not use position embeddings table so we have to remove it. AT THE SAME TIME,
# and, very conveniently, Llama 3 does not share the output projection weights with the
# token embeddings table, so we have to add it. Well instead of removing and adding, we
# are going to write the output projection weights into the slot previously used for the
# position embeddings table. Everyone is happy, very little code is changed from GPT-2.
assert dtype in {"float32", "bfloat16"}
write_fun = write_fp32 if dtype == "float32" else write_bf16
write_fun(model_tensors["transformer.wte.weight"], file) # (V, C)
write_fun(model_tensors["lm_head.weight"], file) # (V, C) # <--- hack (3) here!
for i in range(L): # (L, C)
write_fun(model_tensors[f"transformer.h.{i}.ln_1.weight"], file)
for i in range(L): # (L, C)
# see hack (1) above for these
# yes i know this is inefficient and dumb i'm just matching the train_gpt2.py code format
write_fun(torch.zeros_like(model_tensors[f"transformer.h.{i}.ln_1.weight"]), file)
for i in range(L): # (L, 3C, C)
write_fun(model_tensors[f"transformer.h.{i}.attn.c_attn.weight"], file)
for i in range(L): # (L, 3C)
w = model_tensors[f"transformer.h.{i}.attn.c_attn.weight"]
write_fun(torch.zeros(w.size(0), dtype=w.dtype), file)
for i in range(L): # (L, C, C)
write_fun(model_tensors[f"transformer.h.{i}.attn.c_proj.weight"], file)
for i in range(L): # (L, C)
w = model_tensors[f"transformer.h.{i}.attn.c_proj.weight"]
write_fun(torch.zeros(w.size(0), dtype=w.dtype), file)
for i in range(L): # (L, C)
write_fun(model_tensors[f"transformer.h.{i}.ln_2.weight"], file)
for i in range(L): # (L, C)
write_fun(torch.zeros_like(model_tensors[f"transformer.h.{i}.ln_2.weight"]), file)
# now for hack (2) here... inline model surgery to concat c_fc and c_fc2
# -------------------------------------------
for i in range(L): # (L, 4C, C)
# simply write the two weights in sequence
write_fun(model_tensors[f"transformer.h.{i}.mlp.c_fc.weight"], file)
for i in range(L): # (L, 4C, C)
write_fun(model_tensors[f"transformer.h.{i}.mlp.c_fc2.weight"], file)
for i in range(L): # (L, 4C)
w1 = model_tensors[f"transformer.h.{i}.mlp.c_fc.weight"]
w2 = model_tensors[f"transformer.h.{i}.mlp.c_fc2.weight"]
write_fun(torch.zeros(w1.size(0) + w2.size(0), dtype=w1.dtype), file)
# -------------------------------------------
for i in range(L): # (L, C, 4C)
write_fun(model_tensors[f"transformer.h.{i}.mlp.c_proj.weight"], file)
for i in range(L): # (L, C)
w = model_tensors[f"transformer.h.{i}.mlp.c_proj.weight"]
write_fun(torch.zeros(w.size(0), dtype=w.dtype), file)
write_fun(model_tensors["transformer.ln_f.weight"], file) # (C, )
write_fun(model_tensors["lm_head.weight"], file) # (V, C)
write_fun(torch.zeros_like(model_tensors["transformer.ln_f.weight"]), file) # (C, )

def write_model(model, filename, dtype):
# everything we need to instantiate the model
Expand Down

0 comments on commit b883560

Please sign in to comment.