forked from anthonix/llm.c
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_llama3.py
1255 lines (1095 loc) · 56.2 KB
/
train_llama3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
Reference code for LLaMA-3.1 training and inference.
Will save the model weights into files, to be read from C as initialization.
This code differs from GPT-2 very slightly, there are three main differences:
1) RoPE: LLaMA uses a different positional encoding scheme called Relative Positional Encoding (RoPE).
2) GQA: Grouped Query Attention (GQA) is used to reduce the number of attention heads.
3) SwiGLU: Swish-Gated Linear Unit (SwiGLU) is used as the activation function in the MLP.
References:
# 1) https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/api/tokenizer.py
# 2) https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/api/model.py
# 3) https://github.com/meta-llama/llama3/blob/11817d47e1ba7a4959b025eb1ca308572e0e3963/llama/generation.py
Example launches to only benchmark the speed of bfloat16 compiled GPU training:
TODO: add the actual commands
"""
import argparse
import os
import math
import glob
import inspect
from contextlib import nullcontext
from dataclasses import dataclass
from pathlib import Path
import time
from typing import (
AbstractSet,
Collection,
Dict,
Iterator,
List,
Literal,
Optional,
Sequence,
Tuple,
Union,
cast,
)
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
import torch._inductor.config as config
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
from torch.distributed.optim import ZeroRedundancyOptimizer
import torch.distributed as dist
import tiktoken
from tiktoken.load import load_tiktoken_bpe
# -----------------------------------------------------------------------------
# PyTorch nn.Module definitions for the LLaMA 3.x model
# Used in Grouped Query Attention (GQA), broadcasts the key and value tensors
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
bs, slen, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x
return (
x[:, :, :, None, :]
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
)
# -----------------------------------------------------------------------------
# RoPE related
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def apply_scaling(freqs: torch.Tensor):
# Values obtained from grid search
scale_factor = 8
low_freq_factor = 1
high_freq_factor = 4
old_context_len = 8192 # original llama3 length
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
new_freqs = []
for freq in freqs:
wavelen = 2 * math.pi / freq
if wavelen < high_freq_wavelen:
new_freqs.append(freq)
elif wavelen > low_freq_wavelen:
new_freqs.append(freq / scale_factor)
else:
assert low_freq_wavelen != high_freq_wavelen
smooth = (old_context_len / wavelen - low_freq_factor) / (
high_freq_factor - low_freq_factor
)
new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq)
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
def precompute_freqs_cis(
dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False
):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
if use_scaled:
freqs = apply_scaling(freqs)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis
# -----------------------------------------------------------------------------
# LLaMA building blocks
# LLaMA reference code explicitly implemented RMSNorm so we copy pasted it
# (https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/api/model.py)
# we could also use nn.RMSNorm, it has slightly different numeric properties, but equivalent
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
class CausalSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
assert config.n_embd % config.n_head == 0
self.n_head = config.n_head
self.n_kv_head = config.n_kv_head
self.n_rep = self.n_head // self.n_kv_head
self.hd = config.n_embd // config.n_head
self.use_kv = config.use_kv
self.flash = config.flash
self.c_attn = nn.Linear(config.n_embd, (config.n_head + 2 * config.n_kv_head) * self.hd, bias=False) # key, query, value projections
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) # output projection
# static KV cache - we could alternatively allocate it outside of the model and just pass it in when needed
if self.use_kv:
self.cache_k = torch.zeros((config.max_gen_batch_size, config.block_size, config.n_kv_head, self.hd))
self.cache_v = torch.zeros((config.max_gen_batch_size, config.block_size, config.n_kv_head, self.hd))
def forward(self, x, freqs_cis=None, start_pos=None, mask=None):
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
qkv = self.c_attn(x)
q, k, v = qkv.split([self.n_head * self.hd, self.n_kv_head * self.hd, self.n_kv_head * self.hd], dim=-1)
q, k, v = map(lambda t: t.view(B, T, -1, self.hd), (q, k, v)) # (B, T, NH, HD)
q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis) # rotate QK (rope) <-- 1. difference compared to GPT-2
if self.use_kv and not self.training and start_pos >= 0: # use kv-caching during inference
self.cache_k[:B, start_pos : start_pos + T] = k
self.cache_v[:B, start_pos : start_pos + T] = v
k = self.cache_k[:B, : start_pos + T]
v = self.cache_v[:B, : start_pos + T]
k = repeat_kv(k, self.n_rep) # GQA <-- 2. difference compared to GPT-2
v = repeat_kv(v, self.n_rep)
q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v)) # (B, NH, T, HD)
if self.flash:
# flashattention
# if T == 1 no need to mask, otherwise the function complains
# scaled_dot_product_attention expects a mask where value of True indicates that the element should take part in attention
# our mask is the opposite, so we need to invert it
y = F.scaled_dot_product_attention(q, k, v, mask == 0 if T > 1 else None)
else:
# manual implementation of attention
# this materializes the large (T,T) matrix for all the queries and keys
scores = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.hd))
if mask is not None:
scores.masked_fill_(mask, torch.finfo(scores.dtype).min)
att = F.softmax(scores.float(), dim=-1).type_as(q)
y = att @ v # (B, NH, T, T) x (B, NH, T, HD) -> (B, NH, T, HD)
y = y.transpose(1, 2).contiguous().view(B, T, C)
y = self.c_proj(y)
return y
class MLP(nn.Module):
def __init__(self, config):
super().__init__()
hidden_dim = 4 * config.n_embd
hidden_dim = int(2 * hidden_dim / 3)
# custom dim factor multiplier
if config.ffn_dim_multiplier is not None:
hidden_dim = int(config.ffn_dim_multiplier * hidden_dim)
hidden_dim = config.multiple_of * ((hidden_dim + config.multiple_of - 1) // config.multiple_of)
self.c_fc = nn.Linear(config.n_embd, hidden_dim, bias=False)
self.c_fc2 = nn.Linear(config.n_embd, hidden_dim, bias=False)
self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=False)
def forward(self, x):
# SwiGLU self.c_proj(F.silu(self.c_fc2(x)) * self.c_fc(x)) <-- 3. difference compared to GPT-2
x1 = self.c_fc(x)
x2 = self.c_fc2(x)
x2 = F.silu(x2)
x = x1 * x2
x = self.c_proj(x)
return x
class Block(nn.Module):
def __init__(self, config):
super().__init__()
self.ln_1 = RMSNorm(config.n_embd, config.norm_eps)
self.attn = CausalSelfAttention(config)
self.ln_2 = RMSNorm(config.n_embd, config.norm_eps)
self.mlp = MLP(config)
def forward(self, x, freqs_cis=None, start_pos=None, mask=None):
x = x + self.attn(self.ln_1(x), freqs_cis, start_pos, mask)
x = x + self.mlp(self.ln_2(x))
return x
# -----------------------------------------------------------------------------
# The main LLaMA 3.1 model
@dataclass
class LlamaConfig:
version: str = "3.1"
block_size: int = 8192
vocab_size: int = 128256
n_layer: int = 32
n_head: int = 32
n_kv_head: int = 8
n_embd: int = 4096
ffn_dim_multiplier: float = 1.3
multiple_of: int = 1024
norm_eps: float = 1e-5
rope_theta: float = 500000.0
use_scaled_rope: bool = True
max_gen_batch_size: int = 4
use_kv: bool = True
flash: bool = False # use flashattention?
def __init__(self, **kwargs):
for k, v in kwargs.items():
if hasattr(self, k):
setattr(self, k, v)
assert self.n_kv_head <= self.n_head
assert self.n_head % self.n_kv_head == 0
assert self.n_embd % self.n_head == 0
class LLaMA(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.transformer = nn.ModuleDict(dict(
wte = nn.Embedding(config.vocab_size, config.n_embd),
h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
ln_f = RMSNorm(config.n_embd, config.norm_eps),
))
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# init all weights, use a torch rng object to be very careful
self.init_rng = torch.Generator()
self.init_rng.manual_seed(42)
self.freqs_cis = precompute_freqs_cis(
config.n_embd // config.n_head,
config.block_size * 2,
config.rope_theta,
config.use_scaled_rope,
)
def forward(self, idx, targets=None, return_logits=True, start_pos=0):
_, t = idx.size()
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
# forward the LLaMA model itself
x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
freqs_cis = self.freqs_cis[start_pos:start_pos+t]
mask = torch.triu(torch.ones((t, t), device=next(self.parameters()).device, dtype=torch.bool), diagonal=1)
for i, block in enumerate(self.transformer.h):
x = block(x, freqs_cis, start_pos, mask)
x = self.transformer.ln_f(x)
if targets is not None:
# if we are given some desired targets also calculate the loss
logits = self.lm_head(x).float()
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
else:
# inference-time mini-optimization: only forward the lm_head on the very last position
logits = self.lm_head(x[:, [-1], :]).float() # note: using list [-1] to preserve the time dim
loss = None
# there are performance reasons why not returning logits is prudent, if not needed
if not return_logits:
logits = None
return logits, loss
@staticmethod
def adapt_llama_state_dict_keys(checkpoint, config: LlamaConfig):
# Modify key names from Meta's LLaMA to our LLaMA
# our key names are derived from GPT-2's key names
checkpoint['transformer.wte.weight'] = checkpoint.pop('tok_embeddings.weight')
for i in range(config.n_layer):
for name in ['attention_norm', 'ffn_norm']:
old_key = f'layers.{i}.{name}.weight' # e.g. layers.x.attention_norm.weight -> transformer.h.x.ln_1.weight
new_key = f'transformer.h.{i}.ln_{1 if name == "attention_norm" else 2}.weight'
checkpoint[new_key] = checkpoint.pop(old_key)
for i in range(config.n_layer):
for name in ['attention.wq', 'attention.wk', 'attention.wv']:
old_key = f'layers.{i}.{name}.weight'
new_key = f'transformer.h.{i}.attn.c_attn.weight'
if name == 'attention.wq':
checkpoint[new_key] = checkpoint.pop(old_key)
else: # merge 3 weights into transformer.h.x.attn.c_attn.weight
checkpoint[new_key] = torch.cat((checkpoint[new_key], checkpoint.pop(old_key)), dim=0)
old_key = f'layers.{i}.attention.wo.weight'
new_key = f'transformer.h.{i}.attn.c_proj.weight'
checkpoint[new_key] = checkpoint.pop(old_key)
ffn_map = {'w1': 'c_fc2', 'w2': 'c_proj', 'w3': 'c_fc'}
for i in range(config.n_layer):
for name in ['feed_forward.w1', 'feed_forward.w2', 'feed_forward.w3']:
old_key = f'layers.{i}.{name}.weight'
new_key = f'transformer.h.{i}.mlp.{ffn_map[name.split(".")[-1]]}.weight'
checkpoint[new_key] = checkpoint.pop(old_key)
checkpoint['transformer.ln_f.weight'] = checkpoint.pop('norm.weight')
checkpoint['lm_head.weight'] = checkpoint.pop('output.weight')
return checkpoint
@staticmethod
def adapt_llama_state_dict_keys_hf(checkpoint, config: LlamaConfig):
# Modify key names from HuggingFace's LLaMA to our LLaMA
# our key names are derived from GPT-2's key names
checkpoint['transformer.wte.weight'] = checkpoint.pop('model.embed_tokens.weight')
# We need to unpermute K and V because HF script permuted the original Meta-LLaMA weights
# see: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py
def unpermute(w, n_heads, dim1, dim2):
return w.view(n_heads, 2, dim1 // n_heads // 2, dim2).transpose(1, 2).reshape(dim1, dim2)
for i in range(config.n_layer):
for name in ['input_layernorm', 'post_attention_layernorm']:
old_key = f'model.layers.{i}.{name}.weight' # e.g. layers.x.attention_norm.weight -> transformer.h.x.ln_1.weight
new_key = f'transformer.h.{i}.ln_{1 if name == "input_layernorm" else 2}.weight'
checkpoint[new_key] = checkpoint.pop(old_key)
for i in range(config.n_layer):
for name in ['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj']:
old_key = f'model.layers.{i}.{name}.weight'
new_key = f'transformer.h.{i}.attn.c_attn.weight'
if name == 'self_attn.q_proj':
checkpoint[new_key] = unpermute(checkpoint.pop(old_key), config.n_head, config.n_embd, config.n_embd)
else: # merge 3 weights into transformer.h.x.attn.c_attn.weight
tensor = checkpoint.pop(old_key)
if name == 'self_attn.k_proj':
tensor = unpermute(tensor, config.n_kv_head, config.n_kv_head * (config.n_embd // config.n_head), config.n_embd)
checkpoint[new_key] = torch.cat((checkpoint[new_key], tensor), dim=0)
old_key = f'model.layers.{i}.self_attn.o_proj.weight'
new_key = f'transformer.h.{i}.attn.c_proj.weight'
checkpoint[new_key] = checkpoint.pop(old_key)
ffn_map = {'gate_proj': 'c_fc2', 'down_proj': 'c_proj', 'up_proj': 'c_fc'}
for i in range(config.n_layer):
for name in ['gate_proj', 'down_proj', 'up_proj']:
old_key = f'model.layers.{i}.mlp.{name}.weight'
new_key = f'transformer.h.{i}.mlp.{ffn_map[name]}.weight'
checkpoint[new_key] = checkpoint.pop(old_key)
checkpoint['transformer.ln_f.weight'] = checkpoint.pop('model.norm.weight')
return checkpoint
@classmethod
def from_pretrained_llama3_hf(cls, model_id):
"""Loads pretrained LLaMA model weights from HuggingFace"""
from transformers import AutoModelForCausalLM, AutoTokenizer
assert model_id == "meta-llama/Meta-Llama-3.1-8B", "Only the 8B-base model is supported for now"
model_args = LlamaConfig()
model = AutoModelForCausalLM.from_pretrained(model_id)
checkpoint = LLaMA.adapt_llama_state_dict_keys_hf(model.state_dict(), model_args)
original_default_type = torch.get_default_dtype() # save the default type
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) # much faster loading
model = LLaMA(model_args)
model.load_state_dict(checkpoint, strict=False)
torch.set_default_tensor_type(torch.tensor([], dtype=original_default_type, device="cpu").type()) # restore default type
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_id = 128004 # this is the pad token id for LLaMA 3.1 base, we need to set this explicitly as our generate func expects it
tokenizer.stop_tokens = [tokenizer.eos_token_id]
model.tokenizer = tokenizer
return model
@classmethod
def from_pretrained_llama3_meta(cls, ckpt_dir, tokenizer_path):
"""Loads pretrained LLaMA model weights from a checkpoint directory"""
model_args = LlamaConfig()
ckpt_path = sorted(Path(ckpt_dir).glob("*.pth"))[0]
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
checkpoint = LLaMA.adapt_llama_state_dict_keys(checkpoint, model_args)
original_default_type = torch.get_default_dtype() # save the default type
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) # much faster loading
model = LLaMA(model_args)
model.load_state_dict(checkpoint, strict=False)
torch.set_default_tensor_type(torch.tensor([], dtype=original_default_type, device="cpu").type()) # restore default type
tokenizer = Tokenizer(model_path=tokenizer_path)
model.tokenizer = tokenizer
return model
def configure_optimizers(self, weight_decay, learning_rate, betas, device_type, zero_stage):
# start with all of the candidate parameters
param_dict = {pn: p for pn, p in self.named_parameters()}
# filter out those that do not require grad
param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
# create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
# i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
optim_groups = [
{'params': decay_params, 'weight_decay': weight_decay},
{'params': nodecay_params, 'weight_decay': 0.0}
]
num_decay_params = sum(p.numel() for p in decay_params)
num_nodecay_params = sum(p.numel() for p in nodecay_params)
print0(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
print0(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
# Create AdamW optimizer and use the fused version if it is available
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
use_fused = fused_available and device_type == 'cuda'
print0(f"using fused AdamW: {use_fused}")
if zero_stage == 1:
print0("using ZeroRedundancyOptimizer")
optimizer = ZeroRedundancyOptimizer(**optim_groups[0], optimizer_class=torch.optim.AdamW,
lr=learning_rate, betas=betas, fused=use_fused)
optimizer.add_param_group(optim_groups[1])
else:
print0("using regular AdamW")
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, fused=use_fused)
return optimizer
@torch.inference_mode()
def generate(
self,
prompt_tokens: List[List[int]],
max_gen_len: int,
temperature: float = 0.6,
top_p: float = 0.9,
echo: bool = False,
) -> Tuple[List[List[int]], Optional[List[List[float]]]]:
"""
Generate text sequences based on provided prompts using the language generation model.
Args:
prompt_tokens (List[List[int]]): List of tokenized prompts, where each prompt is represented as a list of integers.
max_gen_len (int): Maximum length of the generated text sequence.
temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.
Returns:
Tuple[List[List[int]], Optional[List[List[float]]]]: A tuple containing generated token sequences.
Note:
This method uses the provided prompts as a basis for generating text. It employs nucleus sampling to produce text with controlled randomness.
"""
bsz = len(prompt_tokens)
assert bsz <= self.config.max_gen_batch_size, f"Batch size {bsz} exceeds the maximum generation batch size {self.config.max_gen_batch_size}"
device = next(self.parameters()).device
min_prompt_len = min(len(t) for t in prompt_tokens)
max_prompt_len = max(len(t) for t in prompt_tokens)
assert max_prompt_len <= self.config.block_size, f"Prompt length {max_prompt_len} exceeds the maximum block size {self.config.block_size}"
total_len = min(self.config.block_size, max_gen_len + max_prompt_len)
pad_id = self.tokenizer.pad_id
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device=device)
for idx, t in enumerate(prompt_tokens):
tokens[idx, : len(t)] = torch.tensor(t, dtype=torch.long, device=device)
prev_pos = 0
eos_reached = torch.tensor([False] * bsz, device=device)
input_text_mask = tokens != pad_id
if min_prompt_len == total_len:
logits, _ = self.forward(tokens, start_pos=prev_pos)
stop_tokens = torch.tensor(list(self.tokenizer.stop_tokens)).to(device)
for cur_pos in range(min_prompt_len, total_len):
logits, _ = self.forward(tokens[:, prev_pos:cur_pos], start_pos=prev_pos)
if temperature > 0:
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
next_token = sample_top_p(probs, top_p)
else:
next_token = torch.argmax(logits[:, -1], dim=-1)
next_token = next_token.reshape(-1)
# only replace token if prompt has already been generated
next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token)
tokens[:, cur_pos] = next_token
eos_reached |= ~input_text_mask[:, cur_pos] & torch.isin(next_token, stop_tokens)
prev_pos = cur_pos
if all(eos_reached):
break
out_tokens = []
for i, toks in enumerate(tokens.tolist()):
# cut to max gen len
start = 0 if echo else len(prompt_tokens[i])
toks = toks[start : len(prompt_tokens[i]) + max_gen_len]
# cut to after eos tok if any
for stop_token in self.tokenizer.stop_tokens:
try:
eos_idx = toks.index(stop_token)
toks = toks[:eos_idx]
except ValueError:
pass
out_tokens.append(toks)
return out_tokens
# -----------------------------------------------------------------------------
# sampling utils
def sample_top_p(probs, p):
"""
Perform top-p (nucleus) sampling on a probability distribution.
Args:
probs (torch.Tensor): Probability distribution tensor.
p (float): Probability threshold for top-p sampling.
Returns:
torch.Tensor: Sampled token indices.
Note:
Top-p sampling selects the smallest set of tokens whose cumulative probability mass
exceeds the threshold p. The distribution is renormalized based on the selected tokens.
"""
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > p
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token
# -----------------------------------------------------------------------------
# Llama 3.1 Tokenizer
# The tiktoken tokenizer can handle <=400k chars without
# pyo3_runtime.PanicException.
TIKTOKEN_MAX_ENCODE_CHARS = 400_000
# https://github.com/openai/tiktoken/issues/195
# Here we iterate over subsequences and split if we exceed the limit
# of max consecutive non-whitespace or whitespace characters.
MAX_NO_WHITESPACES_CHARS = 25_000
class Tokenizer:
"""
Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
"""
special_tokens: Dict[str, int]
num_reserved_special_tokens = 256
pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501
def __init__(self, model_path: str):
"""
Initializes the Tokenizer with a Tiktoken model.
Args:
model_path (str): The path to the Tiktoken model file.
"""
assert os.path.isfile(model_path), model_path
mergeable_ranks = load_tiktoken_bpe(model_path)
num_base_tokens = len(mergeable_ranks)
special_tokens = [
"<|begin_of_text|>",
"<|end_of_text|>",
"<|reserved_special_token_0|>",
"<|reserved_special_token_1|>",
"<|finetune_right_pad_id|>",
"<|step_id|>",
"<|start_header_id|>",
"<|end_header_id|>",
"<|eom_id|>", # end of message
"<|eot_id|>", # end of turn
"<|python_tag|>",
]
reserved_tokens = [
f"<|reserved_special_token_{2 + i}|>"
for i in range(self.num_reserved_special_tokens - len(special_tokens))
]
special_tokens = special_tokens + reserved_tokens
self.special_tokens = {
token: num_base_tokens + i for i, token in enumerate(special_tokens)
}
self.model = tiktoken.Encoding(
name=Path(model_path).name,
pat_str=self.pat_str,
mergeable_ranks=mergeable_ranks,
special_tokens=self.special_tokens,
)
self.n_words: int = num_base_tokens + len(special_tokens)
# BOS / EOS token IDs
self.bos_id: int = self.special_tokens["<|begin_of_text|>"]
self.eos_id: int = self.special_tokens["<|end_of_text|>"]
self.eot_id: int = self.special_tokens["<|eot_id|>"]
self.eom_id: int = self.special_tokens["<|eom_id|>"]
self.python_tag_id = self.special_tokens["<|python_tag|>"]
self.pad_id: int = self.special_tokens["<|finetune_right_pad_id|>"]
# hardcoded stop tokens for the base model
self.stop_tokens = [
self.special_tokens["<|begin_of_text|>"],
self.special_tokens["<|end_of_text|>"],
]
def encode(
self,
s: str,
*,
bos: bool,
eos: bool,
allowed_special: Optional[Union[Literal["all"], AbstractSet[str]]] = None,
disallowed_special: Union[Literal["all"], Collection[str]] = (),
) -> List[int]:
"""
Encodes a string into a list of token IDs.
Args:
s (str): The input string to be encoded.
bos (bool): Whether to prepend the beginning-of-sequence token.
eos (bool): Whether to append the end-of-sequence token.
allowed_tokens ("all"|set[str]): allowed special tokens in string
disallowed_tokens ("all"|set[str]): special tokens that raise an error when in string
Returns:
list[int]: A list of token IDs.
By default, setting disallowed_special=() encodes a string by ignoring
special tokens. Specifically:
- Setting `disallowed_special` to () will cause all text corresponding
to special tokens to be encoded as natural text (insteading of raising
an error).
- Setting `allowed_special` to "all" will treat all text corresponding
to special tokens to be encoded as special tokens.
"""
if allowed_special is None:
allowed_special = set()
assert type(s) is str
substrs = (
substr
for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS)
for substr in self._split_whitespaces_or_nonwhitespaces(
s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
)
)
t: List[int] = []
for substr in substrs:
t.extend(
self.model.encode(
substr,
allowed_special=allowed_special,
disallowed_special=disallowed_special,
)
)
if bos:
t.insert(0, self.bos_id)
if eos:
t.append(self.eos_id)
return t
def decode(self, t: Sequence[int]) -> str:
# Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
return self.model.decode(cast(List[int], t))
@staticmethod
def _split_whitespaces_or_nonwhitespaces(
s: str, max_consecutive_slice_len: int
) -> Iterator[str]:
"""
Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`
consecutive whitespaces or consecutive non-whitespaces.
"""
current_slice_len = 0
current_slice_is_space = s[0].isspace() if len(s) > 0 else False
slice_start = 0
for i in range(len(s)):
is_now_space = s[i].isspace()
if current_slice_is_space ^ is_now_space:
current_slice_len = 1
current_slice_is_space = is_now_space
else:
current_slice_len += 1
if current_slice_len > max_consecutive_slice_len:
yield s[slice_start:i]
slice_start = i
current_slice_len = 1
yield s[slice_start:]
# -----------------------------------------------------------------------------
# Our own simple Distributed Data Loader
def _peek_data_shard(filename):
# only reads the header, returns header data
with open(filename, "rb") as f:
# first read the header, which is 256 int32 integers (4 bytes each)
header = np.frombuffer(f.read(256*4), dtype=np.int32)
if header[0] != 20240801:
print("ERROR: magic number mismatch in the data .bin file!")
exit(1)
assert header[1] == 7, "unsupported version"
ntok = header[2] # number of tokens (claimed)
return ntok # for now just return the number of tokens
def _load_data_shard(filename):
with open(filename, "rb") as f:
# first read the header, which is 256 int32 integers (4 bytes each)
header = np.frombuffer(f.read(256*4), dtype=np.int32)
assert header[0] == 20240801, "magic number mismatch in the data .bin file"
assert header[1] == 7, "unsupported version"
ntok = header[2] # number of tokens (claimed)
# the rest of it are tokens, stored as uint16
tokens = np.frombuffer(f.read(), dtype=np.uint32)
assert len(tokens) == ntok, "number of tokens read does not match header?"
return tokens
class DistributedShardedDataLoader:
"""
This DataLoader is both:
- distributed (works correctly in case of multiple processes in DDP)
- sharded (supports datasets that are broken up into multiple data shards)
It is not *permuted*, meaning that it itearates over the data in the order
of the dataset on disk, so the user should make sure to shuffle their examples
during the creation of their data shards for best performance.
"""
def __init__(self, filename_pattern, B, T, process_rank, num_processes):
self.process_rank = process_rank
self.num_processes = num_processes
self.B = B
self.T = T
# glob files that match the pattern
self.files = sorted(glob.glob(filename_pattern))
assert len(self.files) > 0, f"did not find any files that match the pattern {filename_pattern}"
# load and validate all data shards, count number of tokens in total
ntok_total = 0
for fname in self.files:
shard_ntok = _peek_data_shard(fname)
assert shard_ntok >= num_processes * B * T + 1
ntok_total += shard_ntok
self.ntok_total = ntok_total
print0(f"DataLoader: total number of tokens: {ntok_total:,} across {len(self.files)} files")
# kick things off
self.current_shard = None
self.reset()
def reset(self):
# we're being a bit clever here: if we already had shard 0 loaded,
# then don't do the work to reload it, just reset the pointer
if self.current_shard != 0:
self.current_shard = 0
self.tokens = _load_data_shard(self.files[self.current_shard])
self.current_position = self.process_rank * self.B * self.T
def advance(self): # advance to next data shard
self.current_shard = (self.current_shard + 1) % len(self.files)
self.current_position = self.process_rank * self.B * self.T
self.tokens = _load_data_shard(self.files[self.current_shard])
def next_batch(self):
B = self.B
T = self.T
buf = self.tokens[self.current_position : self.current_position+B*T+1]
buf = torch.tensor(buf, dtype=torch.long)
x = (buf[:-1]).view(B, T) # inputs
y = (buf[1:]).view(B, T) # targets
# advance the start pointer in current shard
self.current_position += B * T * self.num_processes
# if loading the next batch would be out of bounds advance the shard
if self.current_position + (B * T * self.num_processes + 1) > len(self.tokens):
self.advance()
return x, y
# -----------------------------------------------------------------------------
# Python -> C bridge utilities for saving params/grads/activations to .bin files
def write_fp32(tensor, file):
t = tensor.detach().cpu().to(torch.float32)
b = t.numpy().tobytes()
file.write(b)
def write_bf16(tensor, file):
t = tensor.detach().cpu().to(torch.bfloat16)
# numpy doesn't have bf16 datatype so we have to trick it
t = t.view(torch.int16) # trick: reinterpret as int16
b = t.numpy().tobytes()
file.write(b)
def write_tensors(model_tensors, L, file, dtype):
# writes LLaMA 3 model's weights to a binary file
assert dtype in {"float32", "bfloat16"}
write_fun = write_fp32 if dtype == "float32" else write_bf16
write_fun(model_tensors["transformer.wte.weight"], file) # (V, C)
for i in range(L): # (L, C)
write_fun(model_tensors[f"transformer.h.{i}.ln_1.weight"], file)
for i in range(L): # (L, 3C, C)
write_fun(model_tensors[f"transformer.h.{i}.attn.c_attn.weight"], file)
for i in range(L): # (L, C, C)
write_fun(model_tensors[f"transformer.h.{i}.attn.c_proj.weight"], file)
for i in range(L): # (L, C)
write_fun(model_tensors[f"transformer.h.{i}.ln_2.weight"], file)
for i in range(L): # (L, 4C, C)
write_fun(model_tensors[f"transformer.h.{i}.mlp.c_fc.weight"], file)
for i in range(L): # (L, 4C, C)
write_fun(model_tensors[f"transformer.h.{i}.mlp.c_fc2.weight"], file)
for i in range(L): # (L, C, 4C)
write_fun(model_tensors[f"transformer.h.{i}.mlp.c_proj.weight"], file)
write_fun(model_tensors["transformer.ln_f.weight"], file) # (C, )
write_fun(model_tensors["lm_head.weight"], file) # (V, C)
def write_model(model, filename, dtype):
# everything we need to instantiate the model
# 1) header is: version int, LLaMAConfig ints, padding to 1024 bytes
assert dtype in {"float32", "bfloat16"}
version = {
"float32": 3, # 3: all tensors are fp32
"bfloat16": 5, # 5: all tensors are bf16
}[dtype]
header = torch.zeros(256, dtype=torch.int32)
header[0] = 20240803 # magic
header[1] = version # checkpoint version
header[2] = model.config.block_size
header[3] = model.config.vocab_size
header[4] = model.config.n_layer
header[5] = model.config.n_head
header[6] = model.config.n_kv_head
header[7] = model.config.n_embd
header[8] = model.config.ffn_dim_multiplier
header[9] = model.config.multiple_of
header[10] = model.config.norm_eps
header[11] = model.config.rope_theta
header[12] = model.config.use_scaled_rope
header[13] = model.config.max_gen_batch_size
header[14] = int(model.config.version.split('.')[0]) # major version
header[15] = int(model.config.version.split('.')[1]) # minor version
# 2) the parameters follow the header
params = {name: param.cpu() for name, param in model.named_parameters()}
# now write to file
with open(filename, "wb") as file:
file.write(header.numpy().tobytes()) # header
write_tensors(params, model.config.n_layer, file, dtype) # params
print(f"wrote {filename}")
def write_state(model, x, y, logits, loss, filename):
# the state is used for debugging.
# it contains information about the input, logits, loss, and the parameter gradients
# this can be used for checking the computation correctness in C
header = torch.zeros(256, dtype=torch.int32)
header[0] = 20240803 # magic
header[1] = x.size(0) # batch size of the batch, B
header[2] = x.size(1) # temporal extent of the batch, T
grads = {name: param.grad.cpu() for name, param in model.named_parameters()}
with open(filename, "wb") as file:
# header
file.write(header.numpy().tobytes())
# input x
file.write(x.cpu().numpy().astype("int32").tobytes()) # (B, T)
# targets y
file.write(y.cpu().numpy().astype("int32").tobytes()) # (B, T)
# logits (result of the model forward pass)
write_fp32(logits.cpu(), file)
# loss (single float, result of the cross entropy loss)
write_fp32(loss.cpu(), file)
# gradients
write_tensors(grads, model.config.n_layer, file, "float32")
print(f"wrote {filename}")
# -----------------------------------------------------------------------------
# int main
def print0(*args, **kwargs):
# modified print that only prints from the master process
# if this is not a distributed run, it's just a print
if int(os.environ.get("RANK", 0)) == 0:
print(*args, **kwargs)
if __name__ == "__main__":
print0(f"Running pytorch {torch.version.__version__}")
# default settings will overfit a tiny batch of data
# and save model weights and debug state to disk on the first iteration
parser = argparse.ArgumentParser()
parser.add_argument("--use_hf", type=int, default=1, help="use HuggingFace (default) or use Meta's model")
parser.add_argument("--ckpt_dir", type=str, default=None, help="path to llama3 model checkpoint (needed if use_hf=0)")
parser.add_argument("--tokenizer_path", type=str, default=None, help="path to llama3 tokenizer (needed if use_hf=0)")
# file system input / output
parser.add_argument("--input_bin", type=str, default="dev/data/tinyshakespeare/tiny_shakespeare_val.bin", help="input .bin to train on")
parser.add_argument("--input_val_bin", type=str, default="", help="input .bin to eval validation loss on")
parser.add_argument("--output_dir", type=str, default="", help="output directory to which to write logs and checkpoints")
parser.add_argument("--model", type=str, default="meta-llama/Meta-Llama-3.1-8B", help="chose the llama model")
# token layout for each step of the optimization
parser.add_argument("--batch_size", type=int, default=4, help="batch size, in units of #batch dimensions")
parser.add_argument("--sequence_length", type=int, default=64, help="sequence length")
parser.add_argument("--total_batch_size", type=int, default=256, help="total desired batch size, in units of #tokens")
# workload (number of steps)
parser.add_argument("--num_iterations", type=int, default=10, help="number of iterations to run")
parser.add_argument("--inference_only", type=int, default=0, help="only run inference")
# optimization
parser.add_argument("--learning_rate", type=float, default=1e-5, help="learning rate warmup iterations")
parser.add_argument("--warmup_iters", type=int, default=0, help="learning rate warmup iterations")
parser.add_argument("--learning_rate_decay_frac", type=float, default=1.0, help="learning rate warmup iterations")
parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay")
parser.add_argument("--grad_clip", type=float, default=1.0, help="maximum gradient magnitude")
# evaluation
parser.add_argument("--val_loss_every", type=int, default=0, help="every how mant steps to evaluate val loss?")
parser.add_argument("--val_max_steps", type=int, default=20, help="how many batches of val to average?")
parser.add_argument("--sample_every", type=int, default=0, help="how often to sample from the model?")
# debugging
parser.add_argument("--overfit_single_batch", type=int, default=1, help="overfit just one batch of data")
# numerics
parser.add_argument("--tensorcores", type=int, default=0, help="use tensorcores")
# memory management
parser.add_argument("--device", type=str, default="", help="by default we autodetect, or set it here")
parser.add_argument("--compile", type=int, default=0, help="torch.compile the model")
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|float16|bfloat16")
parser.add_argument("--zero_stage", type=int, default=0, help="zero redundancy optimizer stage (0/1/2/3)")
# python -> C bridge
parser.add_argument("--write_tensors", type=int, default=0, help="write tensors to disk")
args = parser.parse_args()
# args error checking and convenience variables
B, T = args.batch_size, args.sequence_length
assert 1 <= T <= 8192, "sequence length must be between 1 and 8192"
assert args.dtype in {"float32", "float16", "bfloat16"}
assert args.model in {"meta-llama/Meta-Llama-3.1-8B"} # only 8B base model supported for now
# create the logging directory if it does not exist
logfile = None
if args.output_dir:
os.makedirs(args.output_dir, exist_ok=True)
logfile = os.path.join(args.output_dir, "main.log")
# create the log file "main.log" inside it, and wipe it clean
with open(logfile, "w") as f:
pass
# set up DDP (distributed data parallel). torchrun sets this env variable
ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?
if ddp:
# use of DDP atm demands CUDA, we set the device appropriately according to rank
assert torch.cuda.is_available(), "for now i think we need CUDA for DDP"
init_process_group(backend='nccl')