Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: update to latest Lux and LuxCore releases #164

Merged
merged 9 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@ steps:
test_args: "--quickfail"
- JuliaCI/julia-coverage#v1:
codecov: true
dirs:
- src
- ext
agents:
queue: "juliagpu"
cuda: "*"
Expand All @@ -28,9 +25,6 @@ steps:
version: "1"
- JuliaCI/julia-coverage#v1:
codecov: true
dirs:
- src
- ext
command: |
julia --project --code-coverage=user --color=yes --threads=3 -e '
println("--- :julia: Instantiating project")
Expand Down
21 changes: 6 additions & 15 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,38 +1,30 @@
name = "DeepEquilibriumNetworks"
uuid = "6748aba7-0e9b-415e-a410-ae3cc0ecb334"
authors = ["Avik Pal <[email protected]>"]
version = "2.2.0"
version = "2.3.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f"

[weakdeps]
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"

[extensions]
DeepEquilibriumNetworksSciMLSensitivityExt = ["LinearSolve", "SciMLSensitivity"]
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f"

[compat]
ADTypes = "0.2.5, 1"
Aqua = "0.8.7"
ChainRulesCore = "1"
CommonSolve = "0.2.4"
ConcreteStructs = "0.2"
ConstructionBase = "1"
DiffEqBase = "6.119"
Documenter = "1.4"
ExplicitImports = "1.6.0"
Expand All @@ -42,9 +34,8 @@ Functors = "0.4.10"
GPUArraysCore = "0.1.6"
Hwloc = "3"
InteractiveUtils = "<0.0.1, 1"
LinearSolve = "2.21.2"
Lux = "0.5.56"
LuxCore = "0.1.14"
Lux = "1"
LuxCore = "1"
LuxTestUtils = "1"
MLDataDevices = "1"
NLsolve = "4.5.1"
Expand All @@ -57,7 +48,7 @@ ReTestItems = "1.23.1"
SciMLBase = "2"
SciMLSensitivity = "7.43"
StableRNGs = "1.0.2"
Statistics = "1.10"
Static = "1.1.1"
SteadyStateDiffEq = "2.3.2"
Test = "1.10"
Zygote = "0.6.69"
Expand Down
9 changes: 4 additions & 5 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,30 @@ DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
MLDataUtils = "cc2ba9b6-d476-5e6d-8eaf-a92d5412d41d"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
DeepEquilibriumNetworks = "2"
Documenter = "1"
DocumenterCitations = "1"
LinearSolve = "2"
Lux = "0.5"
Lux = "1"
LuxCUDA = "0.3"
MLDataUtils = "0.5"
MLDatasets = "0.7"
NonlinearSolve = "3"
Optimisers = "0.3"
OrdinaryDiffEq = "6"
Random = "1"
SciMLSensitivity = "7"
Statistics = "1"
Zygote = "0.6"
julia = "1.10"
131 changes: 64 additions & 67 deletions docs/src/tutorials/basic_mnist_deq.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@ We will train a simple Deep Equilibrium Model on MNIST. First we load a few pack

```@example basic_mnist_deq
using DeepEquilibriumNetworks, SciMLSensitivity, Lux, NonlinearSolve, OrdinaryDiffEq,
Statistics, Random, Optimisers, LuxCUDA, Zygote, LinearSolve, Dates, Printf
Random, Optimisers, Zygote, LinearSolve, Dates, Printf, Setfield, OneHotArrays
using MLDatasets: MNIST
using MLDataUtils: LabelEnc, convertlabel, stratifiedobs, batchview
using MLUtils: DataLoader, splitobs
using LuxCUDA # For NVIDIA GPU support

CUDA.allowscalar(false)
ENV["DATADEPS_ALWAYS_ACCEPT"] = true
Expand All @@ -24,26 +25,23 @@ We can now construct our dataloader. We are using only limited part of the data
demonstration.

```@example basic_mnist_deq
function onehot(labels_raw)
return convertlabel(LabelEnc.OneOfK, labels_raw, LabelEnc.NativeLabels(collect(0:9)))
end
function loadmnist(batchsize, train_split)
N = 2500
dataset = MNIST(; split=:train)
imgs = dataset.features[:, :, 1:N]
labels_raw = dataset.targets[1:N]

function loadmnist(batchsize, split)
# Load MNIST
mnist = MNIST(; split)
imgs, labels_raw = mnist.features, mnist.targets
# Process images into (H,W,C,BS) batches
x_train = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3)))[
:, :, 1:1, 1:128] |> gdev
x_train = batchview(x_train, batchsize)
# Onehot and batch the labels
y_train = onehot(labels_raw)[:, 1:128] |> gdev
y_train = batchview(y_train, batchsize)
return x_train, y_train
x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3)))
y_data = onehotbatch(labels_raw, 0:9)
(x_train, y_train), (x_test, y_test) = splitobs((x_data, y_data); at=train_split)

return (
# Use DataLoader to automatically minibatch and shuffle the data
DataLoader(collect.((x_train, y_train)); batchsize, shuffle=true),
# Don't shuffle the test data
DataLoader(collect.((x_test, y_test)); batchsize, shuffle=false))
end

x_train, y_train = loadmnist(16, :train);
x_test, y_test = loadmnist(16, :test);
```

Construct the Lux Neural Network containing a DEQ layer.
Expand All @@ -55,9 +53,13 @@ function construct_model(solver; model_type::Symbol=:deq)

# The input layer of the DEQ
deq_model = Chain(
Parallel(+, Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad()),
Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad())),
Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad()))
Parallel(+,
Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad(),
init_weight=truncated_normal(; std=0.01), use_bias=false),
Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad(),
init_weight=truncated_normal(; std=0.01), use_bias=false)),
Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad(),
init_weight=truncated_normal(; std=0.01), use_bias=false))

if model_type === :skipdeq
init = Conv((3, 3), 64 => 64, gelu; stride=1, pad=SamePad())
Expand All @@ -67,8 +69,8 @@ function construct_model(solver; model_type::Symbol=:deq)
init = missing
end

deq = DeepEquilibriumNetwork(
deq_model, solver; init, verbose=false, linsolve_kwargs=(; maxiters=10))
deq = DeepEquilibriumNetwork(deq_model, solver; init, verbose=false,
linsolve_kwargs=(; maxiters=10), maxiters=10)

classifier = Chain(
GroupNorm(64, 64, relu), GlobalMeanPool(), FlattenLayer(), Dense(64, 10))
Expand All @@ -80,14 +82,13 @@ function construct_model(solver; model_type::Symbol=:deq)
ps, st = Lux.setup(rng, model)

# Warmup the forward and backward passes
x = randn(rng, Float32, 28, 28, 1, 128)
y = onehot(rand(Random.default_rng(), 0:9, 128)) |> gdev
x = randn(rng, Float32, 28, 28, 1, 2)
y = onehotbatch(rand(Random.default_rng(), 0:9, 2), 0:9) |> gdev

model_ = StatefulLuxLayer(model, ps, st)
@printf "[%s] warming up forward pass\n" string(now())
logitcrossentropy(model_, x, ps, y)
loss_function(model, ps, st, (x, y))
@printf "[%s] warming up backward pass\n" string(now())
Zygote.gradient(logitcrossentropy, model_, x, ps, y)
Zygote.gradient(first ∘ loss_function, model, ps, st, (x, y))
@printf "[%s] warmup complete\n" string(now())

return model, ps, st
Expand All @@ -97,73 +98,69 @@ end
Define some helper functions to train the model.

```@example basic_mnist_deq
logitcrossentropy(ŷ, y) = mean(-sum(y .* logsoftmax(ŷ; dims=1); dims=1))
function logitcrossentropy(model, x, ps, y)
l1 = logitcrossentropy(model(x, ps), y)
# Add in some regularization
l2 = mean(abs2, model.st.deq.solution.z_star .- model.st.deq.solution.u0)
return l1 + 10.0 * l2
const logit_cross_entropy = CrossEntropyLoss(; logits=Val(true))
const mse_loss = MSELoss()

function loss_function(model, ps, st, (x, y))
ŷ, st = model(x, ps, st)
l1 = logit_cross_entropy(ŷ, y)
l2 = mse_loss(st.deq.solution.z_star, st.deq.solution.u0) # Add in some regularization
return l1 + eltype(l2)(0.01) * l2, st, (;)
end

classify(x) = argmax.(eachcol(x))

function accuracy(model, data, ps, st)
function accuracy(model, ps, st, dataloader)
total_correct, total = 0, 0
st = Lux.testmode(st)
model = StatefulLuxLayer(model, ps, st)
for (x, y) in data
target_class = classify(cdev(y))
predicted_class = classify(cdev(model(x)))
for (x, y) in dataloader
target_class = onecold(y)
predicted_class = onecold(first(model(x, ps, st)))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
return total_correct / total
end

function train_model(
solver, model_type; data_train=zip(x_train, y_train), data_test=zip(x_test, y_test))
function train_model(solver, model_type)
model, ps, st = construct_model(solver; model_type)
model_st = StatefulLuxLayer(model, nothing, st)

@printf "[%s] Training Model: %s with Solver: %s\n" string(now()) model_type nameof(typeof(solver))
train_dataloader, test_dataloader = loadmnist(32, 0.8) |> gdev

opt_st = Optimisers.setup(Adam(0.001), ps)
tstate = Training.TrainState(model, ps, st, Adam(0.0005))

acc = accuracy(model, data_test, ps, st) * 100
@printf "[%s] Starting Accuracy: %.5f%%\n" string(now()) acc
@printf "[%s] Training Model: %s with Solver: %s\n" string(now()) model_type nameof(typeof(solver))

@printf "[%s] Pretrain with unrolling to a depth of 5\n" string(now())
st = Lux.update_state(st, :fixed_depth, Val(5))
model_st = StatefulLuxLayer(model, ps, st)

for (i, (x, y)) in enumerate(data_train)
res = Zygote.withgradient(logitcrossentropy, model_st, x, ps, y)
Optimisers.update!(opt_st, ps, res.grad[3])
i % 50 == 1 &&
@printf "[%s] Pretraining Batch: [%4d/%4d] Loss: %.5f\n" string(now()) i length(data_train) res.val
@set! tstate.states = Lux.update_state(tstate.states, :fixed_depth, Val(5))

for _ in 1:2, (i, (x, y)) in enumerate(train_dataloader)
_, loss, _, tstate = Training.single_train_step!(
AutoZygote(), loss_function, (x, y), tstate)
if i % 10 == 1
@printf "[%s] Pretraining Batch: [%4d/%4d] Loss: %.5f\n" string(now()) i length(train_dataloader) loss
end
end

acc = accuracy(model, data_test, ps, model_st.st) * 100
acc = accuracy(model, tstate.parameters, tstate.states, test_dataloader) * 100
@printf "[%s] Pretraining complete. Accuracy: %.5f%%\n" string(now()) acc

st = Lux.update_state(st, :fixed_depth, Val(0))
model_st = StatefulLuxLayer(model, ps, st)
@set! tstate.states = Lux.update_state(tstate.states, :fixed_depth, Val(0))

for epoch in 1:3
for (i, (x, y)) in enumerate(data_train)
res = Zygote.withgradient(logitcrossentropy, model_st, x, ps, y)
Optimisers.update!(opt_st, ps, res.grad[3])
i % 50 == 1 &&
@printf "[%s] Epoch: [%d/%d] Batch: [%4d/%4d] Loss: %.5f\n" string(now()) epoch 3 i length(data_train) res.val
for (i, (x, y)) in enumerate(train_dataloader)
_, loss, _, tstate = Training.single_train_step!(
AutoZygote(), loss_function, (x, y), tstate)
if i % 10 == 1
@printf "[%s] Epoch: [%d/%d] Batch: [%4d/%4d] Loss: %.5f\n" string(now()) epoch 3 i length(train_dataloader) loss
end
end

acc = accuracy(model, data_test, ps, model_st.st) * 100
acc = accuracy(model, tstate.parameters, tstate.states, test_dataloader) * 100
@printf "[%s] Epoch: [%d/%d] Accuracy: %.5f%%\n" string(now()) epoch 3 acc
end

@printf "[%s] Training complete.\n" string(now())

return model, ps, st
return model, ps, tstate.states
end
```

Expand Down
Loading
Loading