Skip to content

Commit

Permalink
feat: finish the implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 11, 2024
1 parent 86cfc5f commit 844a274
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 59 deletions.
16 changes: 16 additions & 0 deletions examples/NanoGPT/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,19 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[compat]
Comonicon = "1"
DataDeps = "0.7"
Enzyme = "0.13.14"
JLD2 = "0.5"
Lux = "1.2.3"
MLUtils = "0.4"
NNlib = "0.9.24"
OneHotArrays = "0.2.5"
Optimisers = "0.4.1"
Printf = "1.10"
Random = "1.10"
Reactant = "0.2.5"
Statistics = "1.10"
StatsBase = "0.34.3"
58 changes: 58 additions & 0 deletions examples/NanoGPT/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# NanoGPT using Lux & Reactant

## Requirements

* Install [julia](https://julialang.org/)
* In the Julia REPL instantiate the `Project.toml` in the parent directory

## Training

To train a model, run `main.jl` with the necessary parameters.

```bash
julia --startup=no --project=examples/NanoGPT --threads=auto examples/NanoGPT/main.jl
```

## Inference

To run inference on a trained model, run `main.jl` with the necessary parameters.

```bash
julia --startup=no --project=examples/NanoGPT --threads=auto examples/NanoGPT/main.jl \
--inference \
--model-path=<path to model checkpoint>
```

## Usage

```bash
main

Usage

main [options] [flags]

Options

--n-embed <64::Int>
--n-hidden <256::Int>
--n-heads <4::Int>
--qk-dim <16::Int>
--v-dim <16::Int>
--n-layers <6::Int>
--sequence-length <64::Int>
--batchsize <128::Int>
--dropout-rate <0.0::Float32>
--test-split <0.1::Float64>
--lr <0.01::Float64>
--epochs <100::Int>
--model-path <::String>
--seed <::Union{String, Vector{String}}>
--output-length <1024::Int>

Flags

--inference
-h, --help Print this help message.
--version Print version.
```
157 changes: 98 additions & 59 deletions examples/NanoGPT/main.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Taken from https://github.com/FluxML/model-zoo/pull/410
using MLUtils, Lux, Random, Optimisers, Printf, Statistics, NNlib, DataDeps, StatsBase,
OneHotArrays
OneHotArrays, JLD2
using Reactant, Enzyme
using Comonicon: @main

Expand Down Expand Up @@ -51,42 +51,57 @@ function GPT(;
token_embedding=Embedding(n_vocab => n_embed),
position_embedding=Embedding(sequence_length => n_embed),
drop=Dropout(dropout_rate),
blocks=ntuple(n_layers) do i
blocks=Chain(ntuple(n_layers) do i
return gpt_block(; n_embed, n_hidden, qk_dim, v_dim, n_heads, dropout_rate)
end,
end...),
ln=LayerNorm((n_embed, 1)),
output_layer=Dense(n_embed => n_vocab)) do tokens
te = token_embedding(tokens)
pe = position_embedding(1:size(tokens, 1))
x = drop(te .+ pe)
for blk in blocks
x = blk(x)
end
x = ln(x)
x = output_layer(x)
@return x
x = drop(token_embedding(tokens) .+ position_embedding(1:size(tokens, 1)))
x = blocks(x)
@return output_layer(ln(x))
end
end

# Use the model to generate some text.
# function generate(model, seed, outlen)
# seqlen = context_length(model)
# if isempty(seed)
# seed = "_"
# end
# x = map(c -> findfirst(==(c), model.alphabet)::Int64, collect(seed))
# while length(x) < outlen
# tail = x[max(1, end-seqlen+1):end]
# tail = reshape(tail, length(tail), 1)
# y = model(tail |> device) |> cpu
# p = softmax(y[:,end,1])
# j = sample(1:length(model.alphabet), Weights(p))
# #j = argmax(p)
# #x = vcat(x, [j])
# push!(x, j)
# end
# String(map(j -> model.alphabet[j], x))
# end
function generate_text(
model, ps, st, seed; alphabet, output_length, sequence_length
)
dev = get_device((ps, st))
@assert !(dev isa ReactantDevice) "Currently we don't support running inference of \
dynamically sized tensors."

seed = copy(seed)
seed_len = maximum(length, seed)
extra_letters = zeros(Int, length(seed))
for (i, s) in enumerate(seed)
if seed_len != length(s)
extra_letters[i] = seed_len - length(s)
seed[i] = "_"^extra_letters[i] * s
end
end
original_output_length = output_length
output_length += maximum(extra_letters)

st = Lux.testmode(st)

x = zeros(Int, output_length, length(seed))
for (i, s) in enumerate(seed), j in 1:seed_len
x[j, i] = findfirst(==(s[j]), alphabet)
end
for i in (seed_len + 1):output_length
tail = x[max(1, i - sequence_length + 1):(i - 1), :] |> dev
y = model(tail, ps, st)[1] |> cpu_device()
p = softmax(y[:, end, 1])
x[i, :] .= sample(1:length(alphabet), Weights(p))
end

res = [String(map(Base.Fix1(getindex, alphabet), x[:, i])) for i in axes(x, 2)]
for i in eachindex(res)
res[i] = res[i][(extra_letters[i] + 1):end][1:original_output_length]
end

return res
end

# Load data from input file, and partition into training and testing subsets.
function get_nanogpt_data(; sequence_length, test_split)
Expand Down Expand Up @@ -121,32 +136,62 @@ function get_nanogpt_data(; sequence_length, test_split)
return alphabet, Array(trainX), Array(trainY), Array(testX), Array(testY)
end

@main function train_nanogpt(;
@main function main(;
n_embed::Int=64, n_hidden::Int=256, n_heads::Int=4, qk_dim::Int=16,
v_dim::Int=16, n_layers::Int=6, sequence_length::Int=64, batchsize::Int=128,
dropout_rate::Float32=0.0f0, test_split::Float64=0.1, lr::Float64=1e-2,
epochs::Int=20
epochs::Int=100,
# Only inference options
inference::Bool=false, model_path::String="",
seed::Union{String, Vector{String}}=["_", "The", "Julia", "Lux.jl"],
output_length::Int=1024
)
alphabet, trainX, trainY, testX, testY = get_nanogpt_data(; sequence_length, test_split)

@printf "[Info] Alphabet size: %d\n" length(alphabet)
@printf "[Info] Training size: %d sequences.\n" size(trainX, 2)
@printf "[Info] Testing size: %d sequences.\n\n" size(testX, 2)

rng = Random.default_rng()
Random.seed!(rng, 1234)

dev = reactant_device()
cdev = cpu_device()

if inference
@printf "[Info] Inference mode enabled.\n"

@assert !isempty(model_path) "Please provide a path to a model checkpoint."

@printf "[Info] Loading model from %s.\n" model_path
model_config = JLD2.load(model_path, "model_config")
model = GPT(; model_config...)
ps = JLD2.load(model_path, "parameters")
st = JLD2.load(model_path, "states")
alphabet = JLD2.load(model_path, "alphabet")
sequence_length = model_config.sequence_length

texts = generate_text(
model, ps, st, seed; alphabet, output_length, sequence_length
)

for (i, (text, s)) in enumerate(zip(texts, seed))
@printf "[Info] Seed [%d]: %s\n" i s
@printf "[Generated Text] %s\n\n" text
end

return
end

alphabet, trainX, trainY, testX, testY = get_nanogpt_data(; sequence_length, test_split)

@printf "[Info] Alphabet size: %d\n" length(alphabet)
@printf "[Info] Training size: %d sequences.\n" size(trainX, 2)
@printf "[Info] Testing size: %d sequences.\n\n" size(testX, 2)

train_loader = DataLoader(
(trainX, trainY); batchsize, shuffle=true, parallel=true
) |> dev

model = GPT(;
model_config = (;
n_vocab=length(alphabet), n_embed, sequence_length, n_hidden,
n_layers, dropout_rate, n_heads, qk_dim, v_dim
)
model = GPT(; model_config...)
ps, st = Lux.setup(rng, model) |> dev
@printf "[Info] Number of parameters: %d\n" Lux.parameterlength(ps)
@printf "[Info] Number of states: %d\n\n" Lux.statelength(st)
Expand All @@ -156,9 +201,12 @@ end

@printf "[Info] Compiling Inference Model...\n"
testX, testY = (testX, testY) |> dev
start_time = time()
model_compiled = @compile model(testX, ps, Lux.testmode(st))
time_to_compile = time() - start_time
best_test_loss = Inf

@printf "[Info] Time taken to compile inference model: %0.5fs\n" time_to_compile
@printf "[Info] Starting Model Training...\n\n"

loss_fn = CrossEntropyLoss(; logits=Val(true))
Expand All @@ -185,7 +233,15 @@ end
)
@printf "[Test] Epoch %3d\tTest Loss %.8e\n" epoch test_loss

# XXX: Also generate some text here...
# Generate some text here...
texts = generate_text(
model, ps |> cdev, st |> cdev, seed;
alphabet, output_length, sequence_length
)
for (i, (text, s)) in enumerate(zip(texts, seed))
@printf "[Info] Seed [%d]: %s\n" i s
@printf "[Generated Text] %s\n\n" text
end

if test_loss < best_test_loss
best_test_loss = test_loss
Expand All @@ -195,26 +251,9 @@ end
joinpath(@__DIR__, "nanogpt.jld2");
parameters=train_state.parameters |> cdev,
states=train_state.states |> cdev,
alphabet=alphabet
alphabet=alphabet,
model_config=model_config
)
end
end
end

# # Load a model from a checkpoint (see `jldsave` above).
# function load_model(filename)
# args = JLD2.load(filename, "args")
# alphabet = JLD2.load(filename, "alphabet")
# model = GPT(args, alphabet)
# model_state = JLD2.load(filename, "model_state")
# model = Flux.loadmodel!(model, model_state);
# return args, model
# end

# if true
# args, model = train()
# else
# args, model = load_model("model-checkpoint.jld2") |> device
# end

# generate(model, "The", 50)

0 comments on commit 844a274

Please sign in to comment.