-
Notifications
You must be signed in to change notification settings - Fork 0
/
example_tiny_shakspear.py
93 lines (68 loc) · 2.69 KB
/
example_tiny_shakspear.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from typing import Tuple
import torch
import math
from tokenizers import Tokenizer
from trainer.SFTTrainer import train
from model.args import MOEModelArgs
from model.KANaMoEv1 import KANamav5
from model.handler import from_pretrained
def lr_lambda(current_step: int, max_steps: int=50000, warmup_steps: int=40, lr_scheduler_type: str="cosine"):
if current_step < warmup_steps:
return current_step / warmup_steps
annealing_steps = max_steps - warmup_steps
if annealing_steps <= 0:
annealing_steps = 1
progress = (current_step - warmup_steps) / annealing_steps
if lr_scheduler_type == "cosine":
new_learning_rate = 0.5 * (1.0 + math.cos(math.pi * progress))
elif lr_scheduler_type == "sinus":
new_learning_rate = 0.5 * (1.0 + math.sin(math.pi * progress))
else:
new_learning_rate = 1.0
return new_learning_rate
print("[LOADING TOKENIZER]")
tokenizer = Tokenizer.from_file("custom_tokenizer.json")
MOEModelArgs.vocab_size = tokenizer.get_vocab_size()
MOEModelArgs.pad_id = tokenizer.token_to_id("[PAD]") if "[PAD]" in tokenizer.get_vocab() else None
MOEModelArgs.max_batch_size = 4
MOEModelArgs.max_seq_len = 20
MOEModelArgs.n_layers = 12
MOEModelArgs.dim = 64
# MOEModelArgs.use_kan = False
# MOEModelArgs.use_softmax_temp_proj = False
print("[LOADING DATASET]")
with open("datasets/tiny-shakespear.txt", "r") as file:
dataset = file.read()
dataset = tokenizer.encode(dataset)
data = torch.LongTensor(dataset.ids).unsqueeze(0)
n = int(0.9 * len(data[0]))
train_data = data[:, :n]
val_data = data[:, n:]
print("[LOADING MODEL]")
model = KANamav5(MOEModelArgs)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
new_model = train(
model=model,
optimizer=optimizer,
train_data=train_data,
val_data=val_data,
scheduler=scheduler,
save_model_name="shakespear",
max_steps=100000,
loss_interval=1,
eval_interval=2000,
device="cpu"
)
line_19826 = """ROMEO:\nI pay thy poverty, """
first_tokens = tokenizer.encode(line_19826)
input_tokens = torch.LongTensor(first_tokens.ids).unsqueeze(0)
def inference(model: torch.nn.Module, tokens, max_new_tokens: int):
for _ in range(max_new_tokens):
tokens_conditioned = tokens[:, -MOEModelArgs.max_seq_len:]
logits, _ = model(tokens_conditioned)
probabilities = torch.softmax(logits[:, -1], dim=-1)
next_token = torch.multinomial(probabilities, num_samples=1)
tokens = torch.cat((tokens, next_token), dim=1)
print(tokenizer.decode(next_token.squeeze(dim=1).tolist(), skip_special_tokens=True), end="", flush=False)
inference(new_model, input_tokens, 100)