diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 14462213..b088891c 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -7,9 +7,6 @@ steps: test_args: "--quickfail" - JuliaCI/julia-coverage#v1: codecov: true - dirs: - - src - - ext agents: queue: "juliagpu" cuda: "*" @@ -28,9 +25,6 @@ steps: version: "1" - JuliaCI/julia-coverage#v1: codecov: true - dirs: - - src - - ext command: | julia --project --code-coverage=user --color=yes --threads=3 -e ' println("--- :julia: Instantiating project") diff --git a/Project.toml b/Project.toml index 6639452f..af197535 100644 --- a/Project.toml +++ b/Project.toml @@ -1,14 +1,13 @@ name = "DeepEquilibriumNetworks" uuid = "6748aba7-0e9b-415e-a410-ae3cc0ecb334" authors = ["Avik Pal "] -version = "2.2.0" +version = "2.3.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2" ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" -ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" @@ -16,15 +15,9 @@ LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" -Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f" - -[weakdeps] -LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" - -[extensions] -DeepEquilibriumNetworksSciMLSensitivityExt = ["LinearSolve", "SciMLSensitivity"] +Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" +SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f" [compat] ADTypes = "0.2.5, 1" @@ -32,7 +25,6 @@ Aqua = "0.8.7" ChainRulesCore = "1" CommonSolve = "0.2.4" ConcreteStructs = "0.2" -ConstructionBase = "1" DiffEqBase = "6.119" Documenter = "1.4" ExplicitImports = "1.6.0" @@ -42,9 +34,8 @@ Functors = "0.4.10" GPUArraysCore = "0.1.6" Hwloc = "3" InteractiveUtils = "<0.0.1, 1" -LinearSolve = "2.21.2" -Lux = "0.5.56" -LuxCore = "0.1.14" +Lux = "1" +LuxCore = "1" LuxTestUtils = "1" MLDataDevices = "1" NLsolve = "4.5.1" @@ -57,7 +48,7 @@ ReTestItems = "1.23.1" SciMLBase = "2" SciMLSensitivity = "7.43" StableRNGs = "1.0.2" -Statistics = "1.10" +Static = "1.1.1" SteadyStateDiffEq = "2.3.2" Test = "1.10" Zygote = "0.6.69" diff --git a/docs/Project.toml b/docs/Project.toml index 0ad07734..78d3d014 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -6,15 +6,16 @@ DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244" LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" -MLDataUtils = "cc2ba9b6-d476-5e6d-8eaf-a92d5412d41d" MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" +OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" -Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] @@ -22,15 +23,13 @@ DeepEquilibriumNetworks = "2" Documenter = "1" DocumenterCitations = "1" LinearSolve = "2" -Lux = "0.5" +Lux = "1" LuxCUDA = "0.3" -MLDataUtils = "0.5" MLDatasets = "0.7" NonlinearSolve = "3" Optimisers = "0.3" OrdinaryDiffEq = "6" Random = "1" SciMLSensitivity = "7" -Statistics = "1" Zygote = "0.6" julia = "1.10" diff --git a/docs/src/tutorials/basic_mnist_deq.md b/docs/src/tutorials/basic_mnist_deq.md index 815e89d0..4027c42d 100644 --- a/docs/src/tutorials/basic_mnist_deq.md +++ b/docs/src/tutorials/basic_mnist_deq.md @@ -4,9 +4,10 @@ We will train a simple Deep Equilibrium Model on MNIST. First we load a few pack ```@example basic_mnist_deq using DeepEquilibriumNetworks, SciMLSensitivity, Lux, NonlinearSolve, OrdinaryDiffEq, - Statistics, Random, Optimisers, LuxCUDA, Zygote, LinearSolve, Dates, Printf + Random, Optimisers, Zygote, LinearSolve, Dates, Printf, Setfield, OneHotArrays using MLDatasets: MNIST -using MLDataUtils: LabelEnc, convertlabel, stratifiedobs, batchview +using MLUtils: DataLoader, splitobs +using LuxCUDA # For NVIDIA GPU support CUDA.allowscalar(false) ENV["DATADEPS_ALWAYS_ACCEPT"] = true @@ -24,26 +25,23 @@ We can now construct our dataloader. We are using only limited part of the data demonstration. ```@example basic_mnist_deq -function onehot(labels_raw) - return convertlabel(LabelEnc.OneOfK, labels_raw, LabelEnc.NativeLabels(collect(0:9))) -end +function loadmnist(batchsize, train_split) + N = 2500 + dataset = MNIST(; split=:train) + imgs = dataset.features[:, :, 1:N] + labels_raw = dataset.targets[1:N] -function loadmnist(batchsize, split) - # Load MNIST - mnist = MNIST(; split) - imgs, labels_raw = mnist.features, mnist.targets # Process images into (H,W,C,BS) batches - x_train = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3)))[ - :, :, 1:1, 1:128] |> gdev - x_train = batchview(x_train, batchsize) - # Onehot and batch the labels - y_train = onehot(labels_raw)[:, 1:128] |> gdev - y_train = batchview(y_train, batchsize) - return x_train, y_train + x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3))) + y_data = onehotbatch(labels_raw, 0:9) + (x_train, y_train), (x_test, y_test) = splitobs((x_data, y_data); at=train_split) + + return ( + # Use DataLoader to automatically minibatch and shuffle the data + DataLoader(collect.((x_train, y_train)); batchsize, shuffle=true), + # Don't shuffle the test data + DataLoader(collect.((x_test, y_test)); batchsize, shuffle=false)) end - -x_train, y_train = loadmnist(16, :train); -x_test, y_test = loadmnist(16, :test); ``` Construct the Lux Neural Network containing a DEQ layer. @@ -55,9 +53,13 @@ function construct_model(solver; model_type::Symbol=:deq) # The input layer of the DEQ deq_model = Chain( - Parallel(+, Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad()), - Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad())), - Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad())) + Parallel(+, + Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad(), + init_weight=truncated_normal(; std=0.01), use_bias=false), + Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad(), + init_weight=truncated_normal(; std=0.01), use_bias=false)), + Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad(), + init_weight=truncated_normal(; std=0.01), use_bias=false)) if model_type === :skipdeq init = Conv((3, 3), 64 => 64, gelu; stride=1, pad=SamePad()) @@ -67,8 +69,8 @@ function construct_model(solver; model_type::Symbol=:deq) init = missing end - deq = DeepEquilibriumNetwork( - deq_model, solver; init, verbose=false, linsolve_kwargs=(; maxiters=10)) + deq = DeepEquilibriumNetwork(deq_model, solver; init, verbose=false, + linsolve_kwargs=(; maxiters=10), maxiters=10) classifier = Chain( GroupNorm(64, 64, relu), GlobalMeanPool(), FlattenLayer(), Dense(64, 10)) @@ -80,14 +82,13 @@ function construct_model(solver; model_type::Symbol=:deq) ps, st = Lux.setup(rng, model) # Warmup the forward and backward passes - x = randn(rng, Float32, 28, 28, 1, 128) - y = onehot(rand(Random.default_rng(), 0:9, 128)) |> gdev + x = randn(rng, Float32, 28, 28, 1, 2) + y = onehotbatch(rand(Random.default_rng(), 0:9, 2), 0:9) |> gdev - model_ = StatefulLuxLayer(model, ps, st) @printf "[%s] warming up forward pass\n" string(now()) - logitcrossentropy(model_, x, ps, y) + loss_function(model, ps, st, (x, y)) @printf "[%s] warming up backward pass\n" string(now()) - Zygote.gradient(logitcrossentropy, model_, x, ps, y) + Zygote.gradient(first ∘ loss_function, model, ps, st, (x, y)) @printf "[%s] warmup complete\n" string(now()) return model, ps, st @@ -97,73 +98,69 @@ end Define some helper functions to train the model. ```@example basic_mnist_deq -logitcrossentropy(ŷ, y) = mean(-sum(y .* logsoftmax(ŷ; dims=1); dims=1)) -function logitcrossentropy(model, x, ps, y) - l1 = logitcrossentropy(model(x, ps), y) - # Add in some regularization - l2 = mean(abs2, model.st.deq.solution.z_star .- model.st.deq.solution.u0) - return l1 + 10.0 * l2 +const logit_cross_entropy = CrossEntropyLoss(; logits=Val(true)) +const mse_loss = MSELoss() + +function loss_function(model, ps, st, (x, y)) + ŷ, st = model(x, ps, st) + l1 = logit_cross_entropy(ŷ, y) + l2 = mse_loss(st.deq.solution.z_star, st.deq.solution.u0) # Add in some regularization + return l1 + eltype(l2)(0.01) * l2, st, (;) end -classify(x) = argmax.(eachcol(x)) - -function accuracy(model, data, ps, st) +function accuracy(model, ps, st, dataloader) total_correct, total = 0, 0 st = Lux.testmode(st) - model = StatefulLuxLayer(model, ps, st) - for (x, y) in data - target_class = classify(cdev(y)) - predicted_class = classify(cdev(model(x))) + for (x, y) in dataloader + target_class = onecold(y) + predicted_class = onecold(first(model(x, ps, st))) total_correct += sum(target_class .== predicted_class) total += length(target_class) end return total_correct / total end -function train_model( - solver, model_type; data_train=zip(x_train, y_train), data_test=zip(x_test, y_test)) +function train_model(solver, model_type) model, ps, st = construct_model(solver; model_type) - model_st = StatefulLuxLayer(model, nothing, st) - @printf "[%s] Training Model: %s with Solver: %s\n" string(now()) model_type nameof(typeof(solver)) + train_dataloader, test_dataloader = loadmnist(32, 0.8) |> gdev - opt_st = Optimisers.setup(Adam(0.001), ps) + tstate = Training.TrainState(model, ps, st, Adam(0.0005)) - acc = accuracy(model, data_test, ps, st) * 100 - @printf "[%s] Starting Accuracy: %.5f%%\n" string(now()) acc + @printf "[%s] Training Model: %s with Solver: %s\n" string(now()) model_type nameof(typeof(solver)) @printf "[%s] Pretrain with unrolling to a depth of 5\n" string(now()) - st = Lux.update_state(st, :fixed_depth, Val(5)) - model_st = StatefulLuxLayer(model, ps, st) - - for (i, (x, y)) in enumerate(data_train) - res = Zygote.withgradient(logitcrossentropy, model_st, x, ps, y) - Optimisers.update!(opt_st, ps, res.grad[3]) - i % 50 == 1 && - @printf "[%s] Pretraining Batch: [%4d/%4d] Loss: %.5f\n" string(now()) i length(data_train) res.val + @set! tstate.states = Lux.update_state(tstate.states, :fixed_depth, Val(5)) + + for _ in 1:2, (i, (x, y)) in enumerate(train_dataloader) + _, loss, _, tstate = Training.single_train_step!( + AutoZygote(), loss_function, (x, y), tstate) + if i % 10 == 1 + @printf "[%s] Pretraining Batch: [%4d/%4d] Loss: %.5f\n" string(now()) i length(train_dataloader) loss + end end - acc = accuracy(model, data_test, ps, model_st.st) * 100 + acc = accuracy(model, tstate.parameters, tstate.states, test_dataloader) * 100 @printf "[%s] Pretraining complete. Accuracy: %.5f%%\n" string(now()) acc - st = Lux.update_state(st, :fixed_depth, Val(0)) - model_st = StatefulLuxLayer(model, ps, st) + @set! tstate.states = Lux.update_state(tstate.states, :fixed_depth, Val(0)) for epoch in 1:3 - for (i, (x, y)) in enumerate(data_train) - res = Zygote.withgradient(logitcrossentropy, model_st, x, ps, y) - Optimisers.update!(opt_st, ps, res.grad[3]) - i % 50 == 1 && - @printf "[%s] Epoch: [%d/%d] Batch: [%4d/%4d] Loss: %.5f\n" string(now()) epoch 3 i length(data_train) res.val + for (i, (x, y)) in enumerate(train_dataloader) + _, loss, _, tstate = Training.single_train_step!( + AutoZygote(), loss_function, (x, y), tstate) + if i % 10 == 1 + @printf "[%s] Epoch: [%d/%d] Batch: [%4d/%4d] Loss: %.5f\n" string(now()) epoch 3 i length(train_dataloader) loss + end end - acc = accuracy(model, data_test, ps, model_st.st) * 100 + acc = accuracy(model, tstate.parameters, tstate.states, test_dataloader) * 100 @printf "[%s] Epoch: [%d/%d] Accuracy: %.5f%%\n" string(now()) epoch 3 acc end @printf "[%s] Training complete.\n" string(now()) - return model, ps, st + return model, ps, tstate.states end ``` diff --git a/docs/src/tutorials/reduced_dim_deq.md b/docs/src/tutorials/reduced_dim_deq.md index 1e01b3a5..df344e08 100644 --- a/docs/src/tutorials/reduced_dim_deq.md +++ b/docs/src/tutorials/reduced_dim_deq.md @@ -6,9 +6,10 @@ same MNIST example as before, but this time we will use a reduced state size. ```@example reduced_dim_mnist using DeepEquilibriumNetworks, SciMLSensitivity, Lux, NonlinearSolve, OrdinaryDiffEq, - Statistics, Random, Optimisers, LuxCUDA, Zygote, LinearSolve, Dates, Printf + Random, Optimisers, Zygote, LinearSolve, Dates, Printf, Setfield, OneHotArrays using MLDatasets: MNIST -using MLDataUtils: LabelEnc, convertlabel, stratifiedobs, batchview +using MLUtils: DataLoader, splitobs +using LuxCUDA # For NVIDIA GPU support CUDA.allowscalar(false) ENV["DATADEPS_ALWAYS_ACCEPT"] = true @@ -16,26 +17,23 @@ ENV["DATADEPS_ALWAYS_ACCEPT"] = true const cdev = cpu_device() const gdev = gpu_device() -function onehot(labels_raw) - return convertlabel(LabelEnc.OneOfK, labels_raw, LabelEnc.NativeLabels(collect(0:9))) -end +function loadmnist(batchsize, train_split) + N = 2500 + dataset = MNIST(; split=:train) + imgs = dataset.features[:, :, 1:N] + labels_raw = dataset.targets[1:N] -function loadmnist(batchsize, split) - # Load MNIST - mnist = MNIST(; split) - imgs, labels_raw = mnist.features, mnist.targets # Process images into (H,W,C,BS) batches - x_train = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3))) |> - gdev - x_train = batchview(x_train, batchsize) - # Onehot and batch the labels - y_train = onehot(labels_raw) |> gdev - y_train = batchview(y_train, batchsize) - return x_train, y_train + x_data = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3))) + y_data = onehotbatch(labels_raw, 0:9) + (x_train, y_train), (x_test, y_test) = splitobs((x_data, y_data); at=train_split) + + return ( + # Use DataLoader to automatically minibatch and shuffle the data + DataLoader(collect.((x_train, y_train)); batchsize, shuffle=true), + # Don't shuffle the test data + DataLoader(collect.((x_test, y_test)); batchsize, shuffle=false)) end - -x_train, y_train = loadmnist(128, :train); -x_test, y_test = loadmnist(128, :test); ``` Now we will define the construct model function. Here we will use Dense Layers and @@ -46,23 +44,29 @@ function construct_model(solver; model_type::Symbol=:regdeq) down = Chain(FlattenLayer(), Dense(784 => 512, gelu)) # The input layer of the DEQ - deq_model = Chain(Parallel(+, Dense(128 => 64, tanh), # Reduced dim of `128` - Dense(512 => 64, tanh)), # Original dim of `512` - Dense(64 => 64, tanh), Dense(64 => 128)) # Return the reduced dim of `128` + deq_model = Chain( + Parallel(+, + Dense( + 128 => 64, tanh; use_bias=false, init_weight=truncated_normal(; std=0.01)), # Reduced dim of `128` + Dense( + 512 => 64, tanh; use_bias=false, init_weight=truncated_normal(; std=0.01))), # Original dim of `512` + Dense(64 => 64, tanh; use_bias=false, init_weight=truncated_normal(; std=0.01)), + Dense(64 => 128; use_bias=false, init_weight=truncated_normal(; std=0.01))) # Return the reduced dim of `128` if model_type === :skipdeq - init = Dense(512 => 128, tanh) + init = Dense( + 512 => 128, tanh; use_bias=false, init_weight=truncated_normal(; std=0.01)) elseif model_type === :regdeq error(":regdeq is not supported for reduced dim models") else # This should preferably done via `ChainRulesCore.@ignore_derivatives`. But here # we are only using Zygote so this is fine. - init = WrappedFunction{:direct_call}(x -> Zygote.@ignore(fill!( + init = WrappedFunction(x -> Zygote.@ignore(fill!( similar(x, 128, size(x, 2)), false))) end - deq = DeepEquilibriumNetwork( - deq_model, solver; init, verbose=false, linsolve_kwargs=(; maxiters=10)) + deq = DeepEquilibriumNetwork(deq_model, solver; init, verbose=false, + linsolve_kwargs=(; maxiters=10), maxiters=10) classifier = Chain(Dense(128 => 128, gelu), Dense(128, 10)) @@ -73,14 +77,13 @@ function construct_model(solver; model_type::Symbol=:regdeq) ps, st = Lux.setup(rng, model) # Warmup the forward and backward passes - x = randn(rng, Float32, 28, 28, 1, 128) - y = onehot(rand(Random.default_rng(), 0:9, 128)) |> gdev + x = randn(rng, Float32, 28, 28, 1, 2) + y = onehotbatch(rand(Random.default_rng(), 0:9, 2), 0:9) |> gdev - model_ = StatefulLuxLayer(model, ps, st) @printf "[%s] warming up forward pass\n" string(now()) - logitcrossentropy(model_, x, ps, y) + loss_function(model, ps, st, (x, y)) @printf "[%s] warming up backward pass\n" string(now()) - Zygote.gradient(logitcrossentropy, model_, x, ps, y) + Zygote.gradient(first ∘ loss_function, model, ps, st, (x, y)) @printf "[%s] warmup complete\n" string(now()) return model, ps, st @@ -90,73 +93,69 @@ end Define some helper functions to train the model. ```@example reduced_dim_mnist -logitcrossentropy(ŷ, y) = mean(-sum(y .* logsoftmax(ŷ; dims=1); dims=1)) -function logitcrossentropy(model, x, ps, y) - l1 = logitcrossentropy(model(x, ps), y) - # Add in some regularization - l2 = mean(abs2, model.st.deq.solution.z_star .- model.st.deq.solution.u0) - return l1 + 0.1f0 * l2 +const logit_cross_entropy = CrossEntropyLoss(; logits=Val(true)) +const mse_loss = MSELoss() + +function loss_function(model, ps, st, (x, y)) + ŷ, st = model(x, ps, st) + l1 = logit_cross_entropy(ŷ, y) + l2 = mse_loss(st.deq.solution.z_star, st.deq.solution.u0) # Add in some regularization + return l1 + eltype(l2)(0.01) * l2, st, (;) end -classify(x) = argmax.(eachcol(x)) - -function accuracy(model, data, ps, st) +function accuracy(model, ps, st, dataloader) total_correct, total = 0, 0 st = Lux.testmode(st) - model = StatefulLuxLayer(model, ps, st) - for (x, y) in data - target_class = classify(cdev(y)) - predicted_class = classify(cdev(model(x))) + for (x, y) in dataloader + target_class = onecold(y) + predicted_class = onecold(first(model(x, ps, st))) total_correct += sum(target_class .== predicted_class) total += length(target_class) end return total_correct / total end -function train_model( - solver, model_type; data_train=zip(x_train, y_train), data_test=zip(x_test, y_test)) +function train_model(solver, model_type) model, ps, st = construct_model(solver; model_type) - model_st = StatefulLuxLayer(model, nothing, st) - @printf "[%s] Training Model: %s with Solver: %s\n" string(now()) model_type nameof(typeof(solver)) + train_dataloader, test_dataloader = loadmnist(32, 0.8) |> gdev - opt_st = Optimisers.setup(Adam(0.001), ps) + tstate = Training.TrainState(model, ps, st, Adam(0.0005)) - acc = accuracy(model, data_test, ps, st) * 100 - @printf "[%s] Starting Accuracy: %.5f%%\n" string(now()) acc + @printf "[%s] Training Model: %s with Solver: %s\n" string(now()) model_type nameof(typeof(solver)) @printf "[%s] Pretrain with unrolling to a depth of 5\n" string(now()) - st = Lux.update_state(st, :fixed_depth, Val(5)) - model_st = StatefulLuxLayer(model, ps, st) - - for (i, (x, y)) in enumerate(data_train) - res = Zygote.withgradient(logitcrossentropy, model_st, x, ps, y) - Optimisers.update!(opt_st, ps, res.grad[3]) - i % 50 == 1 && - @printf "[%s] Pretraining Batch: [%4d/%4d] Loss: %.5f\n" string(now()) i length(data_train) res.val + @set! tstate.states = Lux.update_state(tstate.states, :fixed_depth, Val(5)) + + for _ in 1:2, (i, (x, y)) in enumerate(train_dataloader) + _, loss, _, tstate = Training.single_train_step!( + AutoZygote(), loss_function, (x, y), tstate) + if i % 10 == 1 + @printf "[%s] Pretraining Batch: [%4d/%4d] Loss: %.5f\n" string(now()) i length(train_dataloader) loss + end end - acc = accuracy(model, data_test, ps, model_st.st) * 100 + acc = accuracy(model, tstate.parameters, tstate.states, test_dataloader) * 100 @printf "[%s] Pretraining complete. Accuracy: %.5f%%\n" string(now()) acc - st = Lux.update_state(st, :fixed_depth, Val(0)) - model_st = StatefulLuxLayer(model, ps, st) + @set! tstate.states = Lux.update_state(tstate.states, :fixed_depth, Val(0)) for epoch in 1:3 - for (i, (x, y)) in enumerate(data_train) - res = Zygote.withgradient(logitcrossentropy, model_st, x, ps, y) - Optimisers.update!(opt_st, ps, res.grad[3]) - i % 50 == 1 && - @printf "[%s] Epoch: [%d/%d] Batch: [%4d/%4d] Loss: %.5f\n" string(now()) epoch 3 i length(data_train) res.val + for (i, (x, y)) in enumerate(train_dataloader) + _, loss, _, tstate = Training.single_train_step!( + AutoZygote(), loss_function, (x, y), tstate) + if i % 10 == 1 + @printf "[%s] Epoch: [%d/%d] Batch: [%4d/%4d] Loss: %.5f\n" string(now()) epoch 3 i length(train_dataloader) loss + end end - acc = accuracy(model, data_test, ps, model_st.st) * 100 + acc = accuracy(model, tstate.parameters, tstate.states, test_dataloader) * 100 @printf "[%s] Epoch: [%d/%d] Accuracy: %.5f%%\n" string(now()) epoch 3 acc end @printf "[%s] Training complete.\n" string(now()) - return model, ps, st + return model, ps, tstate.states end ``` diff --git a/ext/DeepEquilibriumNetworksSciMLSensitivityExt.jl b/ext/DeepEquilibriumNetworksSciMLSensitivityExt.jl deleted file mode 100644 index b76f5749..00000000 --- a/ext/DeepEquilibriumNetworksSciMLSensitivityExt.jl +++ /dev/null @@ -1,20 +0,0 @@ -module DeepEquilibriumNetworksSciMLSensitivityExt - -# Linear Solve is a dependency of SciMLSensitivity, so we only need to load SciMLSensitivity -# to load this extension -using LinearSolve: SimpleGMRES -using SciMLBase: SteadyStateProblem, ODEProblem -using SciMLSensitivity: SteadyStateAdjoint, GaussAdjoint, ZygoteVJP -using DeepEquilibriumNetworks: DEQs - -@inline function DEQs.__default_sensealg(prob::SteadyStateProblem) - # We want to avoid the cost for cache construction for linsolve = nothing - # For small problems we should use concrete jacobian but we assume users want to solve - # large problems with this package so we default to GMRES and avoid runtime dispatches - linsolve = SimpleGMRES{true}(; blocksize=prod(size(prob.u0)[1:(end - 1)])) - linsolve_kwargs = (; maxiters=10, abstol=1e-3, reltol=1e-3) - return SteadyStateAdjoint(; linsolve, linsolve_kwargs, autojacvec=ZygoteVJP()) -end -@inline DEQs.__default_sensealg(::ODEProblem) = GaussAdjoint(; autojacvec=ZygoteVJP()) - -end diff --git a/src/DeepEquilibriumNetworks.jl b/src/DeepEquilibriumNetworks.jl index 98efeea2..382c7439 100644 --- a/src/DeepEquilibriumNetworks.jl +++ b/src/DeepEquilibriumNetworks.jl @@ -4,17 +4,19 @@ using ADTypes: AutoFiniteDiff, AutoForwardDiff, AutoZygote using ChainRulesCore: ChainRulesCore using CommonSolve: solve using ConcreteStructs: @concrete -using ConstructionBase: ConstructionBase using DiffEqBase: DiffEqBase, AbsNormTerminationMode using FastClosures: @closure -using Lux: Lux, BranchLayer, Chain, NoOpLayer, Parallel, RepeatedLayer, StatefulLuxLayer, - WrappedFunction -using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer -using NNlib: ⊠ using Random: Random, AbstractRNG, randn! using SciMLBase: SciMLBase, AbstractNonlinearAlgorithm, AbstractODEAlgorithm, NonlinearSolution, ODESolution, ODEFunction, ODEProblem, SteadyStateProblem, _unwrap_val +using SciMLSensitivity: SteadyStateAdjoint, GaussAdjoint, ZygoteVJP +using Static: StaticSymbol, StaticInt, known, static + +using Lux: Lux, LuxOps, BranchLayer, Chain, NoOpLayer, Parallel, RepeatedLayer, + StatefulLuxLayer, WrappedFunction +using LuxCore: LuxCore, AbstractLuxLayer, AbstractLuxContainerLayer, AbstractLuxWrapperLayer +using NNlib: ⊠ using SteadyStateDiffEq: DynamicSS, SSRootfind # Useful Constants diff --git a/src/layers.jl b/src/layers.jl index b4e700c0..4d12f807 100644 --- a/src/layers.jl +++ b/src/layers.jl @@ -51,19 +51,17 @@ function Base.show(io::IO, sol::DeepEquilibriumSolution) end # Core Model -@concrete struct DeepEquilibriumNetwork{pType} <: - AbstractExplicitContainerLayer{(:model, :init)} +@concrete struct DeepEquilibriumNetwork <: AbstractLuxContainerLayer{(:model, :init)} init model solver jacobian_regularization kwargs + kind <: StaticSymbol end const DEQ = DeepEquilibriumNetwork -ConstructionBase.constructorof(::Type{<:DEQ{pType}}) where {pType} = DEQ{pType} - function LuxCore.initialstates(rng::AbstractRNG, deq::DEQ) rng = LuxCore.replicate(rng) randn(rng, 1) @@ -71,11 +69,11 @@ function LuxCore.initialstates(rng::AbstractRNG, deq::DEQ) init=LuxCore.initialstates(rng, deq.init), solution=DeepEquilibriumSolution(), rng) end -(deq::DEQ)(x, ps, st::NamedTuple) = deq(x, ps, st, __check_unrolled_mode(st)) +(deq::DEQ)(x, ps, st::NamedTuple) = deq(x, ps, st, check_unrolled_mode(st)) ## Pretraining function (deq::DEQ)(x, ps, st::NamedTuple, ::Val{true}) - z, st = __get_initial_condition(deq, x, ps, st) + z, st = get_initial_condition(deq, x, ps, st) repeated_model = RepeatedLayer(deq.model; repeats=st.fixed_depth) z_star, st_ = repeated_model((z, x), ps.model, st.model) @@ -83,19 +81,19 @@ function (deq::DEQ)(x, ps, st::NamedTuple, ::Val{true}) resid = CRC.ignore_derivatives(z_star .- model((z_star, x))) rng = LuxCore.replicate(st.rng) - jac_loss = __estimate_jacobian_trace( - __getproperty(deq, Val(:jacobian_regularization)), model, z_star, x, rng) + jac_loss = estimate_jacobian_trace( + LuxOps.getproperty(deq, Val(:jacobian_regularization)), model, z_star, x, rng) solution = DeepEquilibriumSolution( z_star, z, resid, zero(eltype(x)), _unwrap_val(st.fixed_depth), jac_loss) - res = __split_and_reshape(z_star, __getproperty(deq.model, Val(:split_idxs)), - __getproperty(deq.model, Val(:scales))) + res = split_and_reshape(z_star, LuxOps.getproperty(deq.model, Val(:split_idxs)), + LuxOps.getproperty(deq.model, Val(:scales))) return res, (; st..., model=model.st, solution, rng) end -function (deq::DEQ{pType})(x, ps, st::NamedTuple, ::Val{false}) where {pType} - z, st = __get_initial_condition(deq, x, ps, st) +function (deq::DEQ)(x, ps, st::NamedTuple, ::Val{false}) + z, st = get_initial_condition(deq, x, ps, st) model = StatefulLuxLayer{true}(deq.model, ps.model, st.model) @@ -106,21 +104,21 @@ function (deq::DEQ{pType})(x, ps, st::NamedTuple, ::Val{false}) where {pType} return y .- u end - prob = __construct_prob(pType, ODEFunction{false}(dudt), z, (; ps=ps.model, x)) - alg = __normalize_alg(deq) + prob = construct_prob(deq.kind, ODEFunction{false}(dudt), z, (; ps=ps.model, x)) + alg = normalize_alg(deq) termination_condition = AbsNormTerminationMode(Base.Fix1(maximum, abs)) - sol = solve(prob, alg; sensealg=__default_sensealg(prob), abstol=1e-3, + sol = solve(prob, alg; sensealg=default_sensealg(prob), abstol=1e-3, reltol=1e-3, termination_condition, maxiters=32, deq.kwargs...) - z_star = __get_steady_state(sol) + z_star = get_steady_state(sol) rng = LuxCore.replicate(st.rng) - jac_loss = __estimate_jacobian_trace( - __getproperty(deq, Val(:jacobian_regularization)), model, z_star, x, rng) + jac_loss = estimate_jacobian_trace( + LuxOps.getproperty(deq, Val(:jacobian_regularization)), model, z_star, x, rng) solution = DeepEquilibriumSolution( - z_star, z, __getproperty(sol, Val(:resid)), jac_loss, __get_nfe(sol), sol) - res = __split_and_reshape(z_star, __getproperty(deq.model, Val(:split_idxs)), - __getproperty(deq.model, Val(:scales))) + z_star, z, LuxOps.getproperty(sol, Val(:resid)), jac_loss, get_nfe(sol), sol) + res = split_and_reshape(z_star, LuxOps.getproperty(deq.model, Val(:split_idxs)), + LuxOps.getproperty(deq.model, Val(:scales))) return res, (; st..., model=model.st, solution, rng) end @@ -128,7 +126,7 @@ end ## Constructors """ DeepEquilibriumNetwork(model, solver; init = missing, jacobian_regularization=nothing, - problem_type::Type{pType}=SteadyStateProblem{false}, kwargs...) + problem_type::Type=SteadyStateProblem{false}, kwargs...) Deep Equilibrium Network as proposed in [baideep2019](@cite) and [pal2022mixing](@cite). @@ -142,7 +140,7 @@ Deep Equilibrium Network as proposed in [baideep2019](@cite) and [pal2022mixing] - `init`: Initial Condition for the rootfinding problem. If `nothing`, the initial condition is set to `zero(x)`. If `missing`, the initial condition is set to - `WrappedFunction{:direct_call}(zero)`. In other cases the initial condition is set to + `WrappedFunction(zero)`. In other cases the initial condition is set to `init(x, ps, st)`. - `jacobian_regularization`: Must be one of `nothing`, `AutoForwardDiff`, `AutoFiniteDiff` or `AutoZygote`. @@ -171,19 +169,17 @@ See also: [`SkipDeepEquilibriumNetwork`](@ref), [`MultiScaleDeepEquilibriumNetwo """ function DeepEquilibriumNetwork( model, solver; init=missing, jacobian_regularization=nothing, - problem_type::Type{pType}=SteadyStateProblem{false}, kwargs...) where {pType} - model isa AbstractExplicitLayer || (model = Lux.transform(model)) - + problem_type::Type=SteadyStateProblem{false}, kwargs...) if init === missing # Regular DEQ - init = WrappedFunction{:direct_call}(Base.Fix1( - __zeros_init, __getproperty(model, Val(:scales)))) + init = WrappedFunction(Base.Fix1( + zeros_init, LuxOps.getproperty(model, Val(:scales)))) elseif init === nothing # SkipRegDEQ init = NoOpLayer() - elseif !(init isa AbstractExplicitLayer) - init = Lux.transform(init) + elseif !(init isa AbstractLuxLayer) + error("init::$(typeof(init)) is not a valid input for DeepEquilibriumNetwork.") end - return DeepEquilibriumNetwork{pType}( - init, model, solver, jacobian_regularization, kwargs) + return DeepEquilibriumNetwork(init, model, solver, jacobian_regularization, + kwargs, problem_type_to_symbol(problem_type)) end """ @@ -277,7 +273,7 @@ If `init` is not passed, it creates a MultiScale Regularized Deep Equilibrium Ne """ function MultiScaleSkipDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix, post_fuse_layer::Union{Nothing, Tuple}, init::Tuple, solver, scales; kwargs...) - init = Chain(Parallel(nothing, init...), __flatten_vcat) + init = Chain(Parallel(nothing, init...), flatten_vcat) return MultiScaleDeepEquilibriumNetwork( main_layers, mapping_layers, post_fuse_layer, solver, scales; init, kwargs...) end @@ -300,43 +296,35 @@ function MultiScaleNeuralODE(args...; kwargs...) end ## Generate Initial Condition -@inline function __get_initial_condition( - deq::DEQ{pType, NoOpLayer}, x, ps, st) where {pType} - zₓ = __zeros_init(__getproperty(deq.model, Val(:scales)), x) +function get_initial_condition(deq::DEQ{NoOpLayer}, x, ps, st) + zₓ = zeros_init(LuxOps.getproperty(deq.model, Val(:scales)), x) z, st_ = deq.model((zₓ, x), ps.model, st.model) return z, (; st..., model=st_) end -@inline function __get_initial_condition(deq::DEQ, x, ps, st) +function get_initial_condition(deq::DEQ, x, ps, st) z, st_ = deq.init(x, ps.init, st.init) return z, (; st..., init=st_) end # Other Layers -@concrete struct MultiScaleInputLayer{N, M <: AbstractExplicitLayer} <: - AbstractExplicitContainerLayer{(:model,)} - model::M +@concrete struct MultiScaleInputLayer <: AbstractLuxWrapperLayer{:model} + n <: StaticInt + model <: AbstractLuxLayer split_idxs scales end -function ConstructionBase.constructorof(::Type{<:MultiScaleInputLayer{N}}) where {N} - return MultiScaleInputLayer{N} -end -function LuxCore.display_name(::MultiScaleInputLayer{N}) where {N} - return "MultiScaleInputLayer{scales = $N}" -end - function MultiScaleInputLayer(model, split_idxs, scales::Val{S}) where {S} - return MultiScaleInputLayer{length(S)}(model, split_idxs, scales) + return MultiScaleInputLayer(static(length(S)), model, split_idxs, scales) end @generated function (m::MultiScaleInputLayer{N})(z, ps, st) where {N} - inputs = (:((u_[1], x)), (:(u_[$i]) for i in 2:N)...) + inputs = (:((u_[1], x)), (:(u_[$i]) for i in 2:known(N))...) return quote u, x = z - u_ = __split_and_reshape(u, m.split_idxs, m.scales) + u_ = split_and_reshape(u, m.split_idxs, m.scales) u_res, st = LuxCore.apply(m.model, ($(inputs...),), ps, st) - return __flatten_vcat(u_res), st + return flatten_vcat(u_res), st end end diff --git a/src/utils.jl b/src/utils.jl index 5c7e23b3..e4277a3e 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,4 +1,4 @@ -@generated function __split_and_reshape( +@generated function split_and_reshape( x::AbstractMatrix, ::Val{idxs}, ::Val{shapes}) where {idxs, shapes} dims = [reshape((idxs[i] + 1):idxs[i + 1], shapes[i]...) for i in 1:(length(idxs) - 1)] varnames = map(_ -> gensym("x_view"), dims) @@ -8,10 +8,10 @@ return tuple($(varnames...)) end end -__split_and_reshape(x::AbstractMatrix, ::Nothing, ::Nothing) = x -__split_and_reshape(x::AbstractArray, ::Nothing, ::Nothing) = x +split_and_reshape(x::AbstractMatrix, ::Nothing, ::Nothing) = x +split_and_reshape(x::AbstractArray, ::Nothing, ::Nothing) = x -function __split_and_reshape(y::AbstractMatrix, x) +function split_and_reshape(y::AbstractMatrix, x) szs = [prod(size(xᵢ)[1:(end - 1)]) for xᵢ in x] counters = vcat(0, cumsum(szs)[1:(end - 1)]) # Make the data contiguous @@ -19,92 +19,95 @@ function __split_and_reshape(y::AbstractMatrix, x) szs, counters, x) end -@inline __flatten(x::AbstractVector) = reshape(x, length(x), 1) -@inline __flatten(x::AbstractMatrix) = x -@inline __flatten(x::AbstractArray) = reshape(x, :, size(x, ndims(x))) +flatten(x::AbstractVector) = reshape(x, length(x), 1) +flatten(x::AbstractMatrix) = x +flatten(x::AbstractArray) = reshape(x, :, size(x, ndims(x))) -@inline __flatten_vcat(x) = mapreduce(__flatten, vcat, x) +flatten_vcat(x) = mapreduce(flatten, vcat, x) -function CRC.rrule(::typeof(__flatten_vcat), x) - y = __flatten_vcat(x) +function CRC.rrule(::typeof(flatten_vcat), x) + y = flatten_vcat(x) project_x = CRC.ProjectTo(x) - ∇__flatten_vcat = @closure ∂y -> begin + ∇flatten_vcat = @closure ∂y -> begin ∂y isa CRC.NoTangent && return (CRC.NoTangent(), CRC.NoTangent()) - return CRC.NoTangent(), project_x(__split_and_reshape(∂y, x)) + return CRC.NoTangent(), project_x(split_and_reshape(∂y, x)) end - return y, ∇__flatten_vcat + return y, ∇flatten_vcat end -@inline __check_unrolled_mode(::Val{d}) where {d} = Val(d ≥ 1) -@inline __check_unrolled_mode(st::NamedTuple) = __check_unrolled_mode(st.fixed_depth) +check_unrolled_mode(::Val{d}) where {d} = Val(d ≥ 1) +check_unrolled_mode(st::NamedTuple) = check_unrolled_mode(st.fixed_depth) -@inline __get_unrolled_depth(::Val{d}) where {d} = d -@inline __get_unrolled_depth(st::NamedTuple) = __get_unrolled_depth(st.fixed_depth) +get_unrolled_depth(::Val{d}) where {d} = d +get_unrolled_depth(st::NamedTuple) = get_unrolled_depth(st.fixed_depth) -CRC.@non_differentiable __check_unrolled_mode(::Any) -CRC.@non_differentiable __get_unrolled_depth(::Any) +CRC.@non_differentiable check_unrolled_mode(::Any) +CRC.@non_differentiable get_unrolled_depth(::Any) -@inline @generated function __getproperty(obj, ::Val{field}) where {field} - hasfield(obj, field) && return :(obj.$field) - return :(nothing) -end - -@inline __get_nfe(sol::ODESolution) = __get_nfe(sol.stats) -@inline function __get_nfe(sol::NonlinearSolution) +get_nfe(sol::ODESolution) = get_nfe(sol.stats) +function get_nfe(sol::NonlinearSolution) return ifelse(sol.stats === nothing, - ifelse(sol.original === nothing, -1, __get_nfe(sol.original)), - __get_nfe(sol.stats)) + ifelse(sol.original === nothing, -1, get_nfe(sol.original)), get_nfe(sol.stats)) end -@inline __get_nfe(stats) = -1 -@inline __get_nfe(stats::Union{SciMLBase.NLStats, SciMLBase.DEStats}) = stats.nf +get_nfe(stats) = -1 +get_nfe(stats::Union{SciMLBase.NLStats, SciMLBase.DEStats}) = stats.nf + +problem_type_to_symbol(::Type{<:SteadyStateProblem{false}}) = static(:SteadyState) +problem_type_to_symbol(::Type{<:ODEProblem{false}}) = static(:ODE) -@inline __normalize_alg(deq::DEQ{pType}) where {pType} = __normalize_alg(pType, deq.solver) -@inline __normalize_alg(::Type{<:SteadyStateProblem}, alg) = alg -@inline __normalize_alg(::Type{<:SteadyStateProblem}, alg::AbstractODEAlgorithm) = DynamicSS(alg) -@inline __normalize_alg(::Type{<:SteadyStateProblem}, alg::AbstractNonlinearAlgorithm) = SSRootfind(alg) -@inline __normalize_alg(::Type{<:ODEProblem}, alg::AbstractODEAlgorithm) = alg +normalize_alg(deq::DEQ) = normalize_alg(deq.kind, deq.solver) +normalize_alg(_, alg) = alg +normalize_alg(::StaticSymbol{:SteadyState}, alg::AbstractODEAlgorithm) = DynamicSS(alg) +function normalize_alg(::StaticSymbol{:SteadyState}, alg::AbstractNonlinearAlgorithm) + return SSRootfind(alg) +end +normalize_alg(::StaticSymbol{:ODE}, alg::AbstractODEAlgorithm) = alg -@inline __get_steady_state(sol::ODESolution) = last(sol.u) -@inline __get_steady_state(sol::NonlinearSolution) = sol.u -@inline __get_steady_state(sol::AbstractArray) = sol +get_steady_state(sol::ODESolution) = last(sol.u) +get_steady_state(sol::NonlinearSolution) = sol.u +get_steady_state(sol::AbstractArray) = sol -@inline function __construct_prob(::Type{<:SteadyStateProblem{false}}, f, u₀, p) +function construct_prob(::StaticSymbol{:SteadyState}, f, u₀, p) return SteadyStateProblem{false}(f, u₀, p) end -@inline function __construct_prob(::Type{<:ODEProblem{false}}, f, u₀, p) - return ODEProblem{false}(f, u₀, (0.0f0, 1.0f0), p) -end +construct_prob(::StaticSymbol{:ODE}, f, u₀, p) = ODEProblem{false}(f, u₀, (0.0f0, 1.0f0), p) -@inline function __zeros_init(::Val{scales}, x::AbstractArray) where {scales} +function zeros_init(::Val{scales}, x::AbstractArray) where {scales} u₀ = similar(x, sum(prod, scales), size(x, ndims(x))) fill!(u₀, false) return u₀ end -@inline __zeros_init(::Nothing, x::AbstractArray) = zero(x) +zeros_init(::Nothing, x::AbstractArray) = zero(x) -CRC.@non_differentiable __zeros_init(::Any, ::Any) +CRC.@non_differentiable zeros_init(::Any, ::Any) ## Don't rely on SciMLSensitivity's choice -@inline __default_sensealg(prob) = nothing +function default_sensealg(::SteadyStateProblem) + # Ideally we should use GMRES here, but it is not very robust + return SteadyStateAdjoint(; + linsolve=nothing, linsolve_kwargs=(; maxiters=10, abstol=1e-3, reltol=1e-3), + autojacvec=ZygoteVJP()) +end +default_sensealg(::ODEProblem) = GaussAdjoint(; autojacvec=ZygoteVJP()) -@inline function __gaussian_like(rng::AbstractRNG, x::AbstractArray) +function randn_like(rng::AbstractRNG, x::AbstractArray) y = similar(x)::typeof(x) randn!(rng, y) return y end -CRC.@non_differentiable __gaussian_like(::Any...) +CRC.@non_differentiable randn_like(::Any...) -@inline __tupleify(x) = @closure(u->(u, x)) +tupleify(x) = @closure(u->(u, x)) # Jacobian Stabilization -function __estimate_jacobian_trace(::AutoFiniteDiff, model::StatefulLuxLayer, z, x, rng) +function estimate_jacobian_trace(::AutoFiniteDiff, model::StatefulLuxLayer, z, x, rng) __f = @closure u -> model((u, x)) res = zero(eltype(x)) ϵ = cbrt(eps(typeof(res))) ϵ⁻¹ = inv(ϵ) f₀ = __f(z) - v = __gaussian_like(rng, x) + v = randn_like(rng, x) for idx in eachindex(z) _z = z[idx] @@ -120,20 +123,20 @@ function __estimate_jacobian_trace(::AutoFiniteDiff, model::StatefulLuxLayer, z, return res end -function __estimate_jacobian_trace(ad::AutoZygote, model::StatefulLuxLayer, z, x, rng) - v = __gaussian_like(rng, x) - smodel = model ∘ __tupleify(x) +function estimate_jacobian_trace(ad::AutoZygote, model::StatefulLuxLayer, z, x, rng) + v = randn_like(rng, x) + smodel = model ∘ tupleify(x) vjp = Lux.vector_jacobian_product(smodel, ad, z, v) return sum(reshape(vjp, 1, :, size(vjp, ndims(vjp))) ⊠ reshape(v, :, 1, size(v, ndims(v)))) end -function __estimate_jacobian_trace(ad::AutoForwardDiff, model::StatefulLuxLayer, z, x, rng) - v = __gaussian_like(rng, x) - smodel = model ∘ __tupleify(x) +function estimate_jacobian_trace(ad::AutoForwardDiff, model::StatefulLuxLayer, z, x, rng) + v = randn_like(rng, x) + smodel = model ∘ tupleify(x) jvp = Lux.jacobian_vector_product(smodel, ad, z, v) return sum(reshape(v, 1, :, size(v, ndims(v))) ⊠ reshape(jvp, :, 1, size(jvp, ndims(jvp)))) end -__estimate_jacobian_trace(::Nothing, model, z, x, rng) = zero(eltype(x)) +estimate_jacobian_trace(::Nothing, model, z, x, rng) = zero(eltype(x)) diff --git a/test/layers_tests.jl b/test/layers_tests.jl index b2d81b94..127a84f0 100644 --- a/test/layers_tests.jl +++ b/test/layers_tests.jl @@ -138,7 +138,7 @@ end x = randn(rng, Float32, x_size...) |> dev z, st = model(x, ps, st) - z_ = DEQs.__flatten_vcat(z) + z_ = DEQs.flatten_vcat(z) opt_broken = mtype !== :node @jet model(x, ps, st) opt_broken=opt_broken @@ -160,7 +160,7 @@ end @test st.solution == DeepEquilibriumSolution() z, st = model(x, ps, st) - z_ = DEQs.__flatten_vcat(z) + z_ = DEQs.flatten_vcat(z) opt_broken = jacobian_regularization isa AutoZygote @jet model(x, ps, st) opt_broken=opt_broken diff --git a/test/utils_tests.jl b/test/utils_tests.jl index 2d114a79..be3a6b6c 100644 --- a/test/utils_tests.jl +++ b/test/utils_tests.jl @@ -7,28 +7,28 @@ x = vcat(x1, x2, x3) split_idxs = Val(cumsum((0, size(x1, 1), size(x2, 1), size(x3, 1)))) shapes = Val((size(x1, 1), size(x2, 1), size(x3, 1))) - x_split = DEQs.__split_and_reshape(x, split_idxs, shapes) + x_split = DEQs.split_and_reshape(x, split_idxs, shapes) @test x1 == x_split[1] @test x2 == x_split[2] @test x3 == x_split[3] - @jet DEQs.__split_and_reshape(x, split_idxs, shapes) + @jet DEQs.split_and_reshape(x, split_idxs, shapes) end end @testitem "unrolled_mode check" setup=[SharedTestSetup] begin using SciMLBase - @test SciMLBase._unwrap_val(DEQs.__check_unrolled_mode(Val(10))) - @test !SciMLBase._unwrap_val(DEQs.__check_unrolled_mode(Val(0))) - @test SciMLBase._unwrap_val(DEQs.__check_unrolled_mode((; fixed_depth=Val(10)))) - @test !SciMLBase._unwrap_val(DEQs.__check_unrolled_mode((; fixed_depth=Val(0)))) + @test SciMLBase._unwrap_val(DEQs.check_unrolled_mode(Val(10))) + @test !SciMLBase._unwrap_val(DEQs.check_unrolled_mode(Val(0))) + @test SciMLBase._unwrap_val(DEQs.check_unrolled_mode((; fixed_depth=Val(10)))) + @test !SciMLBase._unwrap_val(DEQs.check_unrolled_mode((; fixed_depth=Val(0)))) end @testitem "get unrolled_mode" setup=[SharedTestSetup] begin - @test DEQs.__get_unrolled_depth(Val(10)) == 10 - @test DEQs.__get_unrolled_depth((; fixed_depth=Val(10))) == 10 + @test DEQs.get_unrolled_depth(Val(10)) == 10 + @test DEQs.get_unrolled_depth((; fixed_depth=Val(10))) == 10 end @testitem "deep equilibrium solution" setup=[SharedTestSetup] begin