-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.jl
92 lines (82 loc) · 2.19 KB
/
train.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
include("data.jl")
include("rwkv.jl")
include("utils.jl")
using Flux.Losses
using OneHotArrays
using Wandb, Dates, Logging
using CUDA
using BSON: @save
using Optimisers
function save_model(model, path)
model = model |> cpu
@save path model
end
cfg = (
learning_rate = 3e-4,
epochs = 5,
dataset = "astroph_combined.jsonl",
use_cuda = true,
ctx_len = 256,
batch_size = 8,
log_per_nbatch = 5,
save_per_nbatch = 1000,
n_vocab = 50277,
n_layer = 12,
n_head = 12
)
lg = WandbLogger(
project = "nano-llm.jl",
name = "rwkv-169m-$(now())",
config = Dict(
"learning_rate" => cfg.learning_rate,
"dataset" => "astroph_combined.jsonl",
"n_layer" => cfg.n_layer,
"n_head" => cfg.n_head,
),
)
global_logger(lg)
data_file = cfg.dataset
tokenizer = get_tokenizer()
ts = TextSplitter(cfg.ctx_len+1, 16, tokenizer)
function loss_func(y_pred, y; n_vocab=cfg.n_vocab)
return logitcrossentropy(y_pred, onehotbatch(y, 1:n_vocab))
end
if cfg.use_cuda && CUDA.functional()
CUDA.allowscalar(false)
@info "Using CUDA"
device = gpu
else
@info "Using CPU"
device = identity
end
Flux.trainable(m::RWKV) = (blocks=m.blocks,)
Flux.trainable(m::Block) = (token_mixing=m.token_mixing,)
# device = identity # debug
model = rwkv_from_pth("RWKV-4-Pile-169M-20220807-8023.pth"; cfg.n_layer) |> f32
model = model |> device
# setup optimizer
opt = Optimisers.setup(Optimisers.Adam(cfg.learning_rate), model)
i_batch = 1
for epoch in 1:cfg.epochs
global i_batch
batches = jsonl_reader(data_file) |> ch->batch_sampler(ch, ts; batch_size=cfg.batch_size)
for (x, y) in batches
x = x |> device
y = y |> device
val, grads = Flux.withgradient(model) do m
y_pred = m(x)
loss_func(y_pred, y)
end
Flux.update!(opt, model, grads[1])
if i_batch % cfg.log_per_nbatch == 0
@info "metrics" loss=val
println("Ep $(epoch) Batch $(i_batch) Loss $(val)")
end
if i_batch % cfg.save_per_nbatch == 0
@info "saving model"
save_model(model, "rwkv-169m-$(now()).bson")
end
i_batch += 1
end
end
close(lg)