forked from ulzee/dkbehrt
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_gpt.py
100 lines (93 loc) · 3.12 KB
/
train_gpt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# %%
import argparse
import os, sys
#%%
parser = argparse.ArgumentParser()
parser.add_argument('--arch', type=str, default='llama')
parser.add_argument('--layers', type=int, default=2)
parser.add_argument('--heads', type=int, default=8)
parser.add_argument('--batch_size', type=int, default=48)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--gpus', type=str, default='0')
parser.add_argument('--nowandb', action='store_true', default=True)
args = parser.parse_args()
#%%
if not args.nowandb:
os.environ["WANDB_PROJECT"] = "icd"
os.environ["WANDB_LOG_MODEL"] = "end"
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus
import torch
import pickle as pk
import numpy as np
from transformers import MistralConfig, MistralForCausalLM, LlamaConfig, LlamaForCausalLM
from transformers import AutoTokenizer, TrainingArguments, Trainer, DataCollatorForLanguageModeling
from torch.utils.data import Dataset
import utils
#%%
config_class, model_class = dict(
mistral=(MistralConfig, MistralForCausalLM),
llama=(LlamaConfig, LlamaForCausalLM),
)[args.arch]
#%%
with open('saved/diagnoses.pk', 'rb') as fl:
dxs = pk.load(fl)
#%%
tokenizer = AutoTokenizer.from_pretrained('./saved/tokenizers/gpt')
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'left'
#%%
mdlconfig = config_class(
vocab_size=len(tokenizer.vocab),
hidden_size=192,
num_hidden_layers=args.layers,
num_attention_heads=args.heads,
intermediate_size=1024,
)
model = model_class(mdlconfig)
#%%
optimizer = torch.optim.AdamW(
model.parameters(),
lr=args.lr,
)
# %%
phase_ids = { phase: np.genfromtxt(f'artifacts/splits/{phase}_ids.txt') for phase in ['train', 'val', 'test'] }
phase_ids['val'] = phase_ids['val'][::10][:1024]
datasets = { phase: utils.ICDDataset(dxs, tokenizer, ids, separator='<v>') for phase, ids in phase_ids.items() }
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer, mlm=False
)
training_args = TrainingArguments(
output_dir=f'runs/gpt-{args.arch}',
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=16,
learning_rate=args.lr,
num_train_epochs=args.epochs,
report_to='wandb' if not args.nowandb else None,
evaluation_strategy='steps',
run_name=f'gpt-{args.arch}',
eval_steps=500,
save_steps=1000,
)
def compute_metrics(eval_pred, mask_value=-100, topks=(1, 5, 10)):
logits, labels = eval_pred
bsize, seqlen = labels.shape
logits = torch.from_numpy(np.reshape(logits, (bsize*seqlen, -1)))
labels = torch.from_numpy(np.reshape(labels, (bsize*seqlen)))
where_prediction = labels != mask_value
topaccs = utils.topk_accuracy(logits[where_prediction], labels[where_prediction], topk=topks)
return { f'top{n:02d}': acc for n, acc in zip(topks, topaccs) }
trainer = Trainer(
model=model,
data_collator=data_collator,
args=training_args,
train_dataset=datasets['train'],
eval_dataset=datasets['val'],
compute_metrics=compute_metrics,
)
# %%
trainer.evaluate()
trainer.train()
# # %%
torch.save(model.state_dict(), 'saved/gpt_basic.pth')
# # %%