Skip to content

Commit

Permalink
Update the documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Apr 24, 2024
1 parent bb61c5f commit 908c224
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 82 deletions.
6 changes: 3 additions & 3 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"

[[deps.Lux]]
deps = ["ADTypes", "Adapt", "ArrayInterface", "ChainRulesCore", "ConcreteStructs", "ConstructionBase", "FastClosures", "Functors", "GPUArraysCore", "LinearAlgebra", "LuxCore", "LuxDeviceUtils", "LuxLib", "MacroTools", "Markdown", "PrecompileTools", "Preferences", "Random", "Reexport", "Setfield", "Statistics", "WeightInitializers"]
git-tree-sha1 = "295c76513705518749fd4e151d9de77c75049d43"
git-tree-sha1 = "ae13ecbe29ee7432dfd477b233db43c462b6a4ff"
repo-rev = "ap/nested_ad"
repo-url = "https://github.com/LuxDL/Lux.jl.git"
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
Expand Down Expand Up @@ -574,9 +574,9 @@ version = "0.1.20"

[[deps.LuxLib]]
deps = ["ArrayInterface", "ChainRulesCore", "FastBroadcast", "FastClosures", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "LuxCore", "Markdown", "NNlib", "PrecompileTools", "Random", "Reexport", "Statistics", "Strided"]
git-tree-sha1 = "7cb3cdf01835d508f2c81e09d2e93f309434b5d6"
git-tree-sha1 = "edbf65f5ceb15ebbfad9d03c6a846d83b9a97baf"
uuid = "82251201-b29d-42c6-8e01-566dec8acb11"
version = "0.3.15"
version = "0.3.16"

[deps.LuxLib.extensions]
LuxLibAMDGPUExt = "AMDGPU"
Expand Down
4 changes: 2 additions & 2 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
[deps]
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
DeepEquilibriumNetworks = "6748aba7-0e9b-415e-a410-ae3cc0ecb334"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
MLDataUtils = "cc2ba9b6-d476-5e6d-8eaf-a92d5412d41d"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand All @@ -21,7 +22,6 @@ DeepEquilibriumNetworks = "2"
Documenter = "1"
DocumenterCitations = "1"
LinearSolve = "2"
LoggingExtras = "1"
Lux = "0.5"
LuxCUDA = "0.3"
MLDataUtils = "0.5"
Expand Down
59 changes: 19 additions & 40 deletions docs/src/tutorials/basic_mnist_deq.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ We will train a simple Deep Equilibrium Model on MNIST. First we load a few pack

```@example basic_mnist_deq
using DeepEquilibriumNetworks, SciMLSensitivity, Lux, NonlinearSolve, OrdinaryDiffEq,
Statistics, Random, Optimisers, LuxCUDA, Zygote, LinearSolve, LoggingExtras
Statistics, Random, Optimisers, LuxCUDA, Zygote, LinearSolve, Dates, Printf
using MLDatasets: MNIST
using MLDataUtils: LabelEnc, convertlabel, stratifiedobs, batchview
Expand All @@ -20,18 +20,6 @@ const cdev = cpu_device()
const gdev = gpu_device()
```

SciMLBase introduced a warning instead of depwarn which pollutes the output. We can suppress
it with the following logger

```@example basic_mnist_deq
function remove_syms_warning(log_args)
return log_args.message !=
"The use of keyword arguments `syms`, `paramsyms` and `indepsym` for `SciMLFunction`s is deprecated. Pass `sys = SymbolCache(syms, paramsyms, indepsym)` instead."
end
filtered_logger = ActiveFilteredLogger(remove_syms_warning, global_logger())
```

We can now construct our dataloader.

```@example basic_mnist_deq
Expand Down Expand Up @@ -94,12 +82,12 @@ function construct_model(solver; model_type::Symbol=:deq)
x = randn(rng, Float32, 28, 28, 1, 128)
y = onehot(rand(Random.default_rng(), 0:9, 128)) |> gdev
model_ = Lux.Experimental.StatefulLuxLayer(model, ps, st)
@info "warming up forward pass"
model_ = StatefulLuxLayer(model, ps, st)
@printf "[%s] warming up forward pass\n" string(now())
logitcrossentropy(model_, x, ps, y)
@info "warming up backward pass"
@printf "[%s] warming up backward pass\n" string(now())
Zygote.gradient(logitcrossentropy, model_, x, ps, y)
@info "warmup complete"
@printf "[%s] warmup complete\n" string(now())
return model, ps, st
end
Expand All @@ -121,7 +109,7 @@ classify(x) = argmax.(eachcol(x))
function accuracy(model, data, ps, st)
total_correct, total = 0, 0
st = Lux.testmode(st)
model = Lux.Experimental.StatefulLuxLayer(model, ps, st)
model = StatefulLuxLayer(model, ps, st)
for (x, y) in data
target_class = classify(cdev(y))
predicted_class = classify(cdev(model(x)))
Expand All @@ -134,48 +122,43 @@ end
function train_model(
solver, model_type; data_train=zip(x_train, y_train), data_test=zip(x_test, y_test))
model, ps, st = construct_model(solver; model_type)
model_st = Lux.Experimental.StatefulLuxLayer(model, nothing, st)
model_st = StatefulLuxLayer(model, nothing, st)
@info "Training Model: $(model_type) with Solver: $(nameof(typeof(solver)))"
@printf "[%s] Training Model: %s with Solver: %s\n" string(now()) model_type nameof(typeof(solver))
opt_st = Optimisers.setup(Adam(0.001), ps)
acc = accuracy(model, data_test, ps, st) * 100
@info "Starting Accuracy: $(acc)"
@printf "[%s] Starting Accuracy: %.5f%%\n" string(now()) acc
@info "Pretrain with unrolling to a depth of 5"
@printf "[%s] Pretrain with unrolling to a depth of 5\n" string(now())
st = Lux.update_state(st, :fixed_depth, Val(5))
model_st = Lux.Experimental.StatefulLuxLayer(model, ps, st)
model_st = StatefulLuxLayer(model, ps, st)
for (i, (x, y)) in enumerate(data_train)
res = Zygote.withgradient(logitcrossentropy, model_st, x, ps, y)
Optimisers.update!(opt_st, ps, res.grad[3])
if i % 50 == 1
@info "Pretraining Batch: [$(i)/$(length(data_train))] Loss: $(res.val)"
end
i % 50 == 1 && @printf "[%s] Pretraining Batch: [%4d/%4d] Loss: %.5f\n" string(now()) i length(data_train) res.val
end
acc = accuracy(model, data_test, ps, model_st.st) * 100
@info "Pretraining complete. Accuracy: $(acc)"
@printf "[%s] Pretraining complete. Accuracy: %.5f%%\n" string(now()) acc
st = Lux.update_state(st, :fixed_depth, Val(0))
model_st = Lux.Experimental.StatefulLuxLayer(model, ps, st)
model_st = StatefulLuxLayer(model, ps, st)
for epoch in 1:3
for (i, (x, y)) in enumerate(data_train)
res = Zygote.withgradient(logitcrossentropy, model_st, x, ps, y)
Optimisers.update!(opt_st, ps, res.grad[3])
if i % 50 == 1
@info "Epoch: [$(epoch)/3] Batch: [$(i)/$(length(data_train))] Loss: $(res.val)"
end
i % 50 == 1 && @printf "[%s] Epoch: [%d/%d] Batch: [%4d/%4d] Loss: %.5f\n" string(now()) epoch 3 i length(data_train) res.val
end
acc = accuracy(model, data_test, ps, model_st.st) * 100
@info "Epoch: [$(epoch)/3] Accuracy: $(acc)"
@printf "[%s] Epoch: [%d/%d] Accuracy: %.5f%%\n" string(now()) epoch 3 acc
end
@info "Training complete."
println()
@printf "[%s] Training complete.\n" string(now())
return model, ps, st
end
Expand All @@ -187,19 +170,15 @@ and end up using solvers like `Broyden`, but we can simply slap in any of the fa
from NonlinearSolve.jl. Here we will use Newton-Krylov Method:

```@example basic_mnist_deq
with_logger(filtered_logger) do
train_model(NewtonRaphson(; linsolve=KrylovJL_GMRES()), :regdeq)
end
train_model(NewtonRaphson(; linsolve=KrylovJL_GMRES()), :regdeq);
nothing # hide
```

We can also train a continuous DEQ by passing in an ODE solver. Here we will use `VCAB3()`
which tend to be quite fast for continuous Neural Network problems.

```@example basic_mnist_deq
with_logger(filtered_logger) do
train_model(VCAB3(), :deq)
end
train_model(VCAB3(), :deq);
nothing # hide
```

Expand Down
54 changes: 19 additions & 35 deletions docs/src/tutorials/reduced_dim_deq.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ same MNIST example as before, but this time we will use a reduced state size.

```@example reduced_dim_mnist
using DeepEquilibriumNetworks, SciMLSensitivity, Lux, NonlinearSolve, OrdinaryDiffEq,
Statistics, Random, Optimisers, LuxCUDA, Zygote, LinearSolve, LoggingExtras
Statistics, Random, Optimisers, LuxCUDA, Zygote, LinearSolve, Dates, Printf
using MLDatasets: MNIST
using MLDataUtils: LabelEnc, convertlabel, stratifiedobs, batchview
Expand All @@ -16,13 +16,6 @@ ENV["DATADEPS_ALWAYS_ACCEPT"] = true
const cdev = cpu_device()
const gdev = gpu_device()
function remove_syms_warning(log_args)
return log_args.message !=
"The use of keyword arguments `syms`, `paramsyms` and `indepsym` for `SciMLFunction`s is deprecated. Pass `sys = SymbolCache(syms, paramsyms, indepsym)` instead."
end
filtered_logger = ActiveFilteredLogger(remove_syms_warning, global_logger())
function onehot(labels_raw)
return convertlabel(LabelEnc.OneOfK, labels_raw, LabelEnc.NativeLabels(collect(0:9)))
end
Expand Down Expand Up @@ -83,12 +76,12 @@ function construct_model(solver; model_type::Symbol=:regdeq)
x = randn(rng, Float32, 28, 28, 1, 128)
y = onehot(rand(Random.default_rng(), 0:9, 128)) |> gdev
model_ = Lux.Experimental.StatefulLuxLayer(model, ps, st)
@info "warming up forward pass"
model_ = StatefulLuxLayer(model, ps, st)
@printf "[%s] warming up forward pass\n" string(now())
logitcrossentropy(model_, x, ps, y)
@info "warming up backward pass"
@printf "[%s] warming up backward pass\n" string(now())
Zygote.gradient(logitcrossentropy, model_, x, ps, y)
@info "warmup complete"
@printf "[%s] warmup complete\n" string(now())
return model, ps, st
end
Expand All @@ -110,7 +103,7 @@ classify(x) = argmax.(eachcol(x))
function accuracy(model, data, ps, st)
total_correct, total = 0, 0
st = Lux.testmode(st)
model = Lux.Experimental.StatefulLuxLayer(model, ps, st)
model = StatefulLuxLayer(model, ps, st)
for (x, y) in data
target_class = classify(cdev(y))
predicted_class = classify(cdev(model(x)))
Expand All @@ -123,48 +116,43 @@ end
function train_model(
solver, model_type; data_train=zip(x_train, y_train), data_test=zip(x_test, y_test))
model, ps, st = construct_model(solver; model_type)
model_st = Lux.Experimental.StatefulLuxLayer(model, nothing, st)
model_st = StatefulLuxLayer(model, nothing, st)
@info "Training Model: $(model_type) with Solver: $(nameof(typeof(solver)))"
@printf "[%s] Training Model: %s with Solver: %s\n" string(now()) model_type nameof(typeof(solver))
opt_st = Optimisers.setup(Adam(0.001), ps)
acc = accuracy(model, data_test, ps, st) * 100
@info "Starting Accuracy: $(acc)"
@printf "[%s] Starting Accuracy: %.5f%%\n" string(now()) acc
@info "Pretrain with unrolling to a depth of 5"
@printf "[%s] Pretrain with unrolling to a depth of 5\n" string(now())
st = Lux.update_state(st, :fixed_depth, Val(5))
model_st = Lux.Experimental.StatefulLuxLayer(model, ps, st)
model_st = StatefulLuxLayer(model, ps, st)
for (i, (x, y)) in enumerate(data_train)
res = Zygote.withgradient(logitcrossentropy, model_st, x, ps, y)
Optimisers.update!(opt_st, ps, res.grad[3])
if i % 50 == 1
@info "Pretraining Batch: [$(i)/$(length(data_train))] Loss: $(res.val)"
end
i % 50 == 1 && @printf "[%s] Pretraining Batch: [%4d/%4d] Loss: %.5f\n" string(now()) i length(data_train) res.val
end
acc = accuracy(model, data_test, ps, model_st.st) * 100
@info "Pretraining complete. Accuracy: $(acc)"
@printf "[%s] Pretraining complete. Accuracy: %.5f%%\n" string(now()) acc
st = Lux.update_state(st, :fixed_depth, Val(0))
model_st = Lux.Experimental.StatefulLuxLayer(model, ps, st)
model_st = StatefulLuxLayer(model, ps, st)
for epoch in 1:3
for (i, (x, y)) in enumerate(data_train)
res = Zygote.withgradient(logitcrossentropy, model_st, x, ps, y)
Optimisers.update!(opt_st, ps, res.grad[3])
if i % 50 == 1
@info "Epoch: [$(epoch)/3] Batch: [$(i)/$(length(data_train))] Loss: $(res.val)"
end
i % 50 == 1 && @printf "[%s] Epoch: [%d/%d] Batch: [%4d/%4d] Loss: %.5f\n" string(now()) epoch 3 i length(data_train) res.val
end
acc = accuracy(model, data_test, ps, model_st.st) * 100
@info "Epoch: [$(epoch)/3] Accuracy: $(acc)"
@printf "[%s] Epoch: [%d/%d] Accuracy: %.5f%%\n" string(now()) epoch 3 acc
end
@info "Training complete."
println()
@printf "[%s] Training complete.\n" string(now())
return model, ps, st
end
Expand All @@ -174,15 +162,11 @@ Now we can train our model. We can't use `:regdeq` here currently, but we will s
in the future.

```@example reduced_dim_mnist
with_logger(filtered_logger) do
train_model(NewtonRaphson(; linsolve=KrylovJL_GMRES()), :skipdeq)
end
train_model(NewtonRaphson(; linsolve=KrylovJL_GMRES()), :skipdeq)
nothing # hide
```

```@example reduced_dim_mnist
with_logger(filtered_logger) do
train_model(NewtonRaphson(; linsolve=KrylovJL_GMRES()), :deq)
end
train_model(NewtonRaphson(; linsolve=KrylovJL_GMRES()), :deq)
nothing # hide
```
3 changes: 1 addition & 2 deletions src/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -316,8 +316,7 @@ julia> model(x, ps, st);
"""
function MultiScaleDeepEquilibriumNetwork(
main_layers::Tuple, mapping_layers::Matrix, post_fuse_layer::Union{Nothing, Tuple},
solver, scales; jacobian_regularization=nothing, kwargs...)
@assert jacobian_regularization===nothing "Jacobian Regularization is not supported yet for MultiScale Models."
solver, scales; kwargs...)
l1 = Parallel(nothing, main_layers...)
l2 = BranchLayer(Parallel.(+, map(x -> tuple(x...), eachrow(mapping_layers))...)...)

Expand Down

0 comments on commit 908c224

Please sign in to comment.