Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update docs remove extra returns from loss and extra args from callback #1128

Merged
merged 26 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
227e87f
Update docs remove extra returns from loss and extra args from callback
Vaibhavdixit02 Sep 28, 2024
daaeb1b
Update neural_ode_flux.md
Vaibhavdixit02 Sep 28, 2024
296e125
Update docs/src/examples/neural_ode/simplechains.md
Vaibhavdixit02 Sep 28, 2024
c9c8e67
Update second_order_neural.md
Vaibhavdixit02 Sep 28, 2024
ed5cea2
Update docs/src/examples/pde/pde_constrained.md
Vaibhavdixit02 Sep 28, 2024
5110c86
Update docs/src/examples/pde/pde_constrained.md
Vaibhavdixit02 Sep 28, 2024
059f787
Update neural_ode_flux.md
Vaibhavdixit02 Sep 28, 2024
03401e2
Update docs/src/examples/ode/second_order_neural.md
Vaibhavdixit02 Sep 28, 2024
df6355a
Update pde_constrained.md
Vaibhavdixit02 Sep 28, 2024
2900158
Update docs/src/examples/pde/pde_constrained.md
Vaibhavdixit02 Sep 28, 2024
4fb1c43
Add FAQ entry for repeated solve and some more callback updates
Vaibhavdixit02 Oct 28, 2024
8707a1b
use DataLoader from MLUtils
Vaibhavdixit02 Oct 28, 2024
0ca102c
Update docs/Project.toml
Vaibhavdixit02 Oct 29, 2024
fbfa20a
another Flux.Data
Vaibhavdixit02 Oct 29, 2024
84232f0
Set SciMLBase bound
ChrisRackauckas Oct 30, 2024
dbbd40d
bump versions
ChrisRackauckas Nov 3, 2024
f91fcf3
try forward
ChrisRackauckas Nov 3, 2024
ff8d3b6
remove optimal control
ChrisRackauckas Nov 4, 2024
ac7b768
don't make so many plots
ChrisRackauckas Nov 4, 2024
e07eadd
fix link
ChrisRackauckas Nov 4, 2024
a263d14
try adding optimal control back?
ChrisRackauckas Nov 5, 2024
4952ac7
put a random seed on there
ChrisRackauckas Nov 5, 2024
aee2ca7
don't plot in the callback
ChrisRackauckas Nov 5, 2024
c49523b
only ADAM
ChrisRackauckas Nov 5, 2024
e0cd6d0
change optimal control to forward for now
ChrisRackauckas Nov 5, 2024
5d458ad
fix typo
ChrisRackauckas Nov 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ PreallocationTools = "0.4.4"
QuadGK = "2.9.1"
Random = "1.10"
RandomNumbers = "1.5.3"
RecursiveArrayTools = "3.18.1"
RecursiveArrayTools = "3.27.2"
Reexport = "1.0"
ReverseDiff = "1.15.1"
SafeTestsets = "0.1.0"
Expand Down
4 changes: 4 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
Expand All @@ -23,6 +24,7 @@ Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Expand All @@ -45,6 +47,7 @@ Enzyme = "0.12, 0.13"
Flux = "0.14"
ForwardDiff = "0.10"
IterTools = "1"
MLUtils = "0.4"
Lux = "1"
LuxCUDA = "0.3"
Optimization = "3.9, 4"
Expand All @@ -56,6 +59,7 @@ Plots = "1.36"
QuadGK = "2.6"
RecursiveArrayTools = "2.32, 3"
ReverseDiff = "1.14"
SciMLBase = "2.58"
SciMLSensitivity = "7.11"
SimpleChains = "0.4"
StaticArrays = "1"
Expand Down
4 changes: 2 additions & 2 deletions docs/src/examples/dde/delay_diffeq.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ end
loss_dde(p) = sum(abs2, x - 1 for x in predict_dde(p))

using Plots
callback = function (state, l...; doplot = false)
callback = function (state, l; doplot = false)
display(loss_dde(state.u))
doplot &&
display(plot(
Expand All @@ -60,7 +60,7 @@ We define a callback to display the solution at the current parameters for each

```@example dde
using Plots
callback = function (state, l...; doplot = false)
callback = function (state, l; doplot = false)
display(loss_dde(state.u))
doplot &&
display(plot(
Expand Down
8 changes: 4 additions & 4 deletions docs/src/examples/neural_ode/minibatch.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

```@example
using SciMLSensitivity
using DifferentialEquations, Flux, Random, Plots
using DifferentialEquations, Flux, Random, Plots, MLUtils
using IterTools: ncycle

rng = Random.default_rng()
Expand Down Expand Up @@ -46,7 +46,7 @@ ode_data = Array(solve(true_prob, Tsit5(), saveat = t))
prob = ODEProblem{false}(dudt_, u0, tspan, θ)

k = 10
train_loader = Flux.Data.DataLoader((ode_data, t), batchsize = k)
train_loader = DataLoader((ode_data, t), batchsize = k)

for (x, y) in train_loader
@show x
Expand Down Expand Up @@ -96,7 +96,7 @@ When training a neural network, we need to find the gradient with respect to our
For this example, we will use a very simple ordinary differential equation, newtons law of cooling. We can represent this in Julia like so.

```@example minibatch
using SciMLSensitivity
using SciMLSensitivity, MLUtils
using DifferentialEquations, Flux, Random, Plots
using IterTools: ncycle

Expand Down Expand Up @@ -152,7 +152,7 @@ ode_data = Array(solve(true_prob, Tsit5(), saveat = t))
prob = ODEProblem{false}(dudt_, u0, tspan, θ)

k = 10
train_loader = Flux.Data.DataLoader((ode_data, t), batchsize = k)
train_loader = DataLoader((ode_data, t), batchsize = k)

for (x, y) in train_loader
@show x
Expand Down
9 changes: 5 additions & 4 deletions docs/src/examples/neural_ode/neural_ode_flux.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,22 +114,23 @@ end
function loss_n_ode(θ)
pred = predict_n_ode(θ)
loss = sum(abs2, ode_data .- pred)
loss, pred
return loss
end

loss_n_ode(θ)

callback = function (θ, l, pred; doplot = false) #callback function to observe training
callback = function (state, l; doplot = false) #callback function to observe training
display(l)
# plot current prediction against data
pred = predict_n_ode(state.u)
pl = scatter(t, ode_data[1, :], label = "data")
scatter!(pl, t, pred[1, :], label = "prediction")
display(plot(pl))
return false
end

# Display the ODE with the initial parameter values.
callback(θ, loss_n_ode(θ)...)
callback((; u = θ), loss_n_ode(θ)...)

# use Optimization.jl to solve the problem
adtype = Optimization.AutoZygote()
Expand All @@ -143,7 +144,7 @@ result_neuralode = Optimization.solve(optprob,
maxiters = 300)
```

Notice that the advantage of this format is that we can use Optim's optimizers, like
Notice that the advantage of this format is that we can use other optimizers, like
`LBFGS` with a full `Chain` object, for all of Flux's neural networks, like
convolutional neural networks.

Expand Down
7 changes: 4 additions & 3 deletions docs/src/examples/neural_ode/simplechains.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,19 +54,20 @@ end
function loss_neuralode(p)
pred = predict_neuralode(p)
loss = sum(abs2, data .- pred)
return loss, pred
return loss
Vaibhavdixit02 marked this conversation as resolved.
Show resolved Hide resolved
end
```

## Training

The next step is to minimize the loss, so that the NeuralODE gets trained. But in order to be able to do that, we have to be able to backpropagate through the NeuralODE model. Here the backpropagation through the neural network is the easy part, and we get that out of the box with any deep learning package(although not as fast as SimpleChains for the small nn case here). But we have to find a way to first propagate the sensitivities of the loss back, first through the ODE solver and then to the neural network.

The adjoint of a neural ODE can be calculated through the various AD algorithms available in SciMLSensitivity.jl. But working with [StaticArrays](https://docs.sciml.ai/StaticArrays/stable/) in SimpleChains.jl requires a special adjoint method as StaticArrays do not allow any mutation. All the adjoint methods make heavy use of in-place mutation to be performant with the heap allocated normal arrays. For our statically sized, stack allocated StaticArrays, in order to be able to compute the ODE adjoint we need to do everything out of place. Hence, we have specifically used `QuadratureAdjoint(autojacvec=ZygoteVJP())` adjoint algorithm in the solve call inside `predict_neuralode(p)` which computes everything out-of-place when u0 is a StaticArray. Hence, we can move forward with the training of the NeuralODE
The adjoint of a neural ODE can be calculated through the various AD algorithms available in SciMLSensitivity.jl. But working with [StaticArrays](https://juliaarrays.github.io/StaticArrays.jl/stable/) in SimpleChains.jl requires a special adjoint method as StaticArrays do not allow any mutation. All the adjoint methods make heavy use of in-place mutation to be performant with the heap allocated normal arrays. For our statically sized, stack allocated StaticArrays, in order to be able to compute the ODE adjoint we need to do everything out of place. Hence, we have specifically used `QuadratureAdjoint(autojacvec=ZygoteVJP())` adjoint algorithm in the solve call inside `predict_neuralode(p)` which computes everything out-of-place when u0 is a StaticArray. Hence, we can move forward with the training of the NeuralODE

```@example sc_neuralode
callback = function (state, l, pred; doplot = true)
callback = function (state, l; doplot = true)
display(l)
pred = predict_neuralode(state.u)
plt = scatter(tsteps, data[1, :], label = "data")
scatter!(plt, tsteps, pred[1, :], label = "prediction")
if doplot
Expand Down
5 changes: 3 additions & 2 deletions docs/src/examples/ode/second_order_adjoints.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,13 @@ end
function loss_neuralode(p)
pred = predict_neuralode(p)
loss = sum(abs2, ode_data .- pred)
return loss, pred
return loss
end

# Callback function to observe training
list_plots = []
iter = 0
callback = function (state, l, pred; doplot = false)
callback = function (state, l; doplot = false)
global list_plots, iter

if iter == 0
Expand All @@ -66,6 +66,7 @@ callback = function (state, l, pred; doplot = false)
display(l)

# plot current prediction against data
pred = predict_neuralode(state.u)
plt = scatter(tsteps, ode_data[1, :], label = "data")
scatter!(plt, tsteps, pred[1, :], label = "prediction")
push!(list_plots, plt)
Expand Down
6 changes: 3 additions & 3 deletions docs/src/examples/ode/second_order_neural.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ t = range(tspan[1], tspan[2], length = 20)
model = Chain(Dense(2, 50, tanh), Dense(50, 2))
ps, st = Lux.setup(Random.default_rng(), model)
ps = ComponentArray(ps)
model = StatefulLuxLayer{true}(model, ps, st)
model = Lux.StatefulLuxLayer{true}(model, ps, st)

ff(du, u, p, t) = model(u, p)
prob = SecondOrderODEProblem{false}(ff, du0, u0, tspan, ps)
Expand All @@ -46,12 +46,12 @@ correct_pos = Float32.(transpose(hcat(collect(0:0.05:1)[2:end], collect(2:-0.05:

function loss_n_ode(p)
pred = predict(p)
sum(abs2, correct_pos .- pred[1:2, :]), pred
sum(abs2, correct_pos .- pred[1:2, :])
end

l1 = loss_n_ode(ps)

callback = function (state, l, pred)
callback = function (state, l)
println(l)
l < 0.01
end
Expand Down
14 changes: 10 additions & 4 deletions docs/src/examples/optimal_control/feedback_control.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ l = loss_univ(θ)
```@example udeneuralcontrol
list_plots = []
iter = 0
cb = function (state, l)
cb = function (state, l; makeplot = false)
global list_plots, iter

if iter == 0
Expand All @@ -71,9 +71,11 @@ cb = function (state, l)

println(l)

plt = plot(predict_univ(state.u)', ylim = (0, 6))
push!(list_plots, plt)
display(plt)
if makeplot
plt = plot(predict_univ(state.u)', ylim = (0, 6))
push!(list_plots, plt)
display(plt)
end
return false
end
```
Expand All @@ -84,3 +86,7 @@ optf = Optimization.OptimizationFunction((x, p) -> loss_univ(x), adtype)
optprob = Optimization.OptimizationProblem(optf, θ)
result_univ = Optimization.solve(optprob, PolyOpt(), callback = cb)
```

```@example udeneuralcontrol
cb(result_univ, result_univ.minimum; makeplot=true)
```
5 changes: 3 additions & 2 deletions docs/src/examples/optimal_control/optimal_control.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ of a local minimum. This looks like:

```@example neuraloptimalcontrol
using Lux, ComponentArrays, OrdinaryDiffEq, Optimization, OptimizationOptimJL,
OptimizationOptimisers, SciMLSensitivity, Zygote, Plots, Statistics, Random
OptimizationOptimisers, SciMLSensitivity, Zygote, Plots, Statistics, Random,
ForwardDiff

rng = Random.default_rng()
tspan = (0.0f0, 8.0f0)
Expand Down Expand Up @@ -89,7 +90,7 @@ end
# Setup and run the optimization

loss1 = loss_adjoint(θ)
adtype = Optimization.AutoZygote()
adtype = Optimization.AutoForwardDiff()
optf = Optimization.OptimizationFunction((x, p) -> loss_adjoint(x), adtype)

optprob = Optimization.OptimizationProblem(optf, θ)
Expand Down
26 changes: 14 additions & 12 deletions docs/src/examples/pde/pde_constrained.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,25 +67,26 @@ end
## Defining Loss function
function loss(θ)
pred = predict(θ)
return sum(abs2.(predict(θ) .- arr_sol)), pred # Mean squared error
return sum(abs2.(predict(θ) .- arr_sol)) # Mean squared error
end

l, pred = loss(ps)
size(pred), size(sol), size(t) # Checking sizes
l = loss(ps)
size(sol), size(t) # Checking sizes

LOSS = [] # Loss accumulator
PRED = [] # prediction accumulator
PARS = [] # parameters accumulator

cb = function (θ, l, pred) #callback function to observe training
cb = function (st, l) #callback function to observe training
display(l)
pred = predict(st.u)
append!(PRED, [pred])
append!(LOSS, l)
append!(PARS, [θ])
append!(PARS, [st.u])
false
end

cb(ps, loss(ps)...) # Testing callback function
cb((; u = ps), loss(ps)) # Testing callback function

# Let see prediction vs. Truth
scatter(sol[:, end], label = "Truth", size = (800, 500))
Expand Down Expand Up @@ -228,11 +229,11 @@ use the **mean squared error**.
## Defining Loss function
function loss(θ)
pred = predict(θ)
return sum(abs2.(predict(θ) .- arr_sol)), pred # Mean squared error
return sum(abs2.(predict(θ) .- arr_sol)) # Mean squared error
end

l, pred = loss(ps)
size(pred), size(sol), size(t) # Checking sizes
l = loss(ps)
size(sol), size(t) # Checking sizes
```

#### Optimizer
Expand All @@ -251,15 +252,16 @@ LOSS = [] # Loss accumulator
PRED = [] # prediction accumulator
PARS = [] # parameters accumulator

cb = function (θ, l, pred) #callback function to observe training
cb = function (st, l) #callback function to observe training
display(l)
pred = predict(st.u)
append!(PRED, [pred])
append!(LOSS, l)
append!(PARS, [θ])
append!(PARS, [st.u])
false
end

cb(ps, loss(ps)...) # Testing callback function
cb((; u = ps), loss(ps)) # Testing callback function
```

### Plotting Prediction vs Ground Truth
Expand Down
11 changes: 5 additions & 6 deletions docs/src/examples/sde/SDE_control.md
Original file line number Diff line number Diff line change
Expand Up @@ -189,19 +189,18 @@ function loss(p_nn; alg = EM(), sensealg = BacksolveAdjoint(autojacvec = Reverse
W = sqrt(myparameters.dt) * randn(typeof(myparameters.dt), size(myparameters.ts)) #for 1 trajectory
W1 = cumsum([zero(myparameters.dt); W[1:(end - 1)]], dims = 1)
NG = CreateGrid(myparameters.ts, W1)

remake(prob,
p = pars,
u0 = u0tmp,
callback = callback,
noise = NG)
end
_prob = remake(prob, p = pars)

ensembleprob = EnsembleProblem(prob,
ensembleprob = EnsembleProblem(_prob,
prob_func = prob_func,
safetycopy = true)

_sol = solve(ensembleprob, alg, EnsembleThreads(),
_sol = solve(ensembleprob, alg, EnsembleSerial(),
sensealg = sensealg,
saveat = myparameters.tinterval,
dt = myparameters.dt,
Expand Down Expand Up @@ -293,7 +292,7 @@ visualization_callback((; u = p_nn), l; doplot = true)

# optimize the parameters for a few epochs with Adam on time span
# Setup and run the optimization
adtype = Optimization.AutoZygote()
adtype = Optimization.AutoForwardDiff()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)

optprob = Optimization.OptimizationProblem(optf, p_nn)
Expand Down Expand Up @@ -655,7 +654,7 @@ is computed under the hood in the SciMLSensitivity package.
```@example sdecontrol
# optimize the parameters for a few epochs with Adam on time span
# Setup and run the optimization
adtype = Optimization.AutoZygote()
adtype = Optimization.AutoForwardDiff()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)

optprob = Optimization.OptimizationProblem(optf, p_nn)
Expand Down
Loading
Loading