Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Figure out sciml dispatching for _concrete_solve_adjoint #56

Open
jlperla opened this issue Mar 13, 2022 · 2 comments
Open

Figure out sciml dispatching for _concrete_solve_adjoint #56

jlperla opened this issue Mar 13, 2022 · 2 comments

Comments

@jlperla
Copy link
Collaborator

jlperla commented Mar 13, 2022

Couldn't figure out how to do the correcf dispatching with sciml sensitivity stuff. We don't need to enable different sensitivity algorithms yet, just get the dispatching working correctly.

To get things working, it adds an rrule directly to the DiffEqBase.solve for the problem type (i.e. the signature is

function ChainRulesCore.rrule(::typeof(DiffEqBase.solve), prob::LinearStateSpaceProblem,
                              alg::DirectIteration, args...; kwargs...)

But ideally I believe this should be

function DiffEqBase._concrete_solve_adjoint(prob::LinearStateSpaceProblem, alg::DirectIteration,
                                           sensealg, u0, p, args...; kwargs...)

And then the dispatches for the DiffEq.solve should take things through to that implementation which uses the ChainRulesCore.rrule(::typeof(DiffEqBase.solve), prob::DiffEqBase.DEProblem default implementaiton.

The problem isn't the concrete_solve_adjoint itself, but rather zygote problems with the solve, solve_up, or promotion on the model. Not exactly sure, but my gut says if those can be fixed it will dispatch to the _concrete_solve_adjoint specialization correctly.
See

# Ideally hook into existing sensitity dispatching
# Trouble with Zygote. The problem isn't the _concrete_solve_adjoint but rather something in the
# adjoint of the basic solve and `solve_up`. Probably promotion on the prob
# function DiffEqBase._concrete_solve_adjoint(prob::LinearStateSpaceProblem, alg::DirectIteration,
# sensealg, u0, p, args...; kwargs...)
function ChainRulesCore.rrule(::typeof(DiffEqBase.solve), prob::LinearStateSpaceProblem,
alg::DirectIteration, args...; kwargs...)
for example.

The error if the built-in rrule is changed to the _concrete_solve_adjoint is

ERROR: LoadError: ArgumentError: tuple must be non-empty
Stacktrace:
  [1] first(#unused#::Tuple{})
    @ Base .\tuple.jl:140
  [2] _unapply(t::Nothing, xs::Tuple{})
    @ Zygote C:\Users\jesse\.julia\packages\Zygote\3I4nT\src\lib\lib.jl:172
  [3] _unapply(t::Tuple{Nothing, Nothing}, xs::Tuple{})
    @ Zygote C:\Users\jesse\.julia\packages\Zygote\3I4nT\src\lib\lib.jl:176
  [4] _unapply(t::Tuple{Nothing, Nothing, Nothing}, xs::Tuple{Nothing}) (repeats 2 times)
    @ Zygote C:\Users\jesse\.julia\packages\Zygote\3I4nT\src\lib\lib.jl:177
  [5] _unapply(t::Tuple{NTuple{4, Nothing}, Tuple{Nothing}}, xs::Tuple{NamedTuple{(:f, :A, :B, :C, :observables_noise, :observables, :u0, :u0_prior, :tspan, :p, :noise, :kwargs, :seed, :syms), Tuple{Nothing, Matrix{Float64}, Matrix{Float64}, Matrix{Float64}, Nothing, Nothing, Vector{Float64}, Nothing, Nothing, Nothing, Matrix{Float64}, Nothing, Nothing, Nothing}}, Nothing})
    @ Zygote C:\Users\jesse\.julia\packages\Zygote\3I4nT\src\lib\lib.jl:176
  [6] unapply(t::Tuple{NTuple{4, Nothing}, Tuple{Nothing}}, xs::Tuple{NamedTuple{(:f, :A, :B, :C, :observables_noise, :observables, :u0, :u0_prior, :tspan, :p, :noise, :kwargs, :seed, :syms), Tuple{Nothing, Matrix{Float64}, Matrix{Float64}, Matrix{Float64}, Nothing, Nothing, Vector{Float64}, Nothing, Nothing, Nothing, Matrix{Float64}, Nothing, Nothing, Nothing}}, Nothing})
    @ Zygote C:\Users\jesse\.julia\packages\Zygote\3I4nT\src\lib\lib.jl:186
  [7] (::Zygote.var"#213#214"{Tuple{NTuple{4, Nothing}, Tuple{Nothing}}, Zygote.ZBack{DifferenceEquations.var"#solve_pb#31"{LinearStateSpaceProblem{Vector{Float64}, Nothing, Tuple{Int64, Int64}, SciMLBase.NullParameters, Matrix{Float64}, Matrix{Float64}, Matrix{Float64}, Matrix{Float64}, ZeroMeanDiagNormal{Tuple{Base.OneTo{Int64}}}, Matrix{Float64}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Nothing}, Tuple{}, Vector{Vector{Float64}}, Vector{Vector{Float64}}, Matrix{Float64}, Matrix{Float64}, Matrix{Float64}, Int64}}})(Δ::NamedTuple{(:u, :u_analytic, :errors, :t, :W, :prob, :alg, :interp, :dense, :tslocation, :destats, :retcode, :seed, :P, :logpdf), Tuple{Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Float64}})
    @ Zygote C:\Users\jesse\.julia\packages\Zygote\3I4nT\src\lib\lib.jl:204
  [8] (::Zygote.var"#1754#back#215"{Zygote.var"#213#214"{Tuple{NTuple{4, Nothing}, Tuple{Nothing}}, Zygote.ZBack{DifferenceEquations.var"#solve_pb#31"{LinearStateSpaceProblem{Vector{Float64}, Nothing, Tuple{Int64, Int64}, SciMLBase.NullParameters, Matrix{Float64}, Matrix{Float64}, Matrix{Float64}, Matrix{Float64}, ZeroMeanDiagNormal{Tuple{Base.OneTo{Int64}}}, Matrix{Float64}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Nothing}, Tuple{}, Vector{Vector{Float64}}, Vector{Vector{Float64}}, Matrix{Float64}, Matrix{Float64}, Matrix{Float64}, Int64}}}})(Δ::NamedTuple{(:u, :u_analytic, :errors, :t, :W, :prob, :alg, :interp, :dense, :tslocation, :destats, :retcode, :seed, :P, :logpdf), Tuple{Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Float64}})
    @ Zygote C:\Users\jesse\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:67
  [9] Pullback
    @ C:\Users\jesse\.julia\packages\DiffEqBase\bMXa3\src\solve.jl:73 [inlined]
 [10] (::typeof((#solve#38)))(Δ::NamedTuple{(:u, :u_analytic, :errors, :t, :W, :prob, :alg, :interp, :dense, :tslocation, :destats, :retcode, :seed, :P, :logpdf), Tuple{Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Float64}})
    @ Zygote C:\Users\jesse\.julia\packages\Zygote\3I4nT\src\compiler\interface2.jl:0
 [11] (::Zygote.var"#213#214"{Tuple{NTuple{6, Nothing}, Tuple{Nothing}}, typeof((#solve#38))})(Δ::NamedTuple{(:u, :u_analytic, :errors, :t, :W, :prob, :alg, :interp, :dense, :tslocation, :destats, :retcode, :seed, :P, :logpdf), Tuple{Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Float64}})
    @ Zygote C:\Users\jesse\.julia\packages\Zygote\3I4nT\src\lib\lib.jl:203
 [12] #1754#back
    @ C:\Users\jesse\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:67 [inlined]
 [13] Pullback
    @ C:\Users\jesse\.julia\packages\DiffEqBase\bMXa3\src\solve.jl:68 [inlined]
 [14] (::typeof((solve)))(Δ::NamedTuple{(:u, :u_analytic, :errors, :t, :W, :prob, :alg, :interp, :dense, :tslocation, :destats, :retcode, :seed, :P, :logpdf), Tuple{Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Float64}})

But the trouble seems to be at the solve_up call in https://github.com/SciML/DiffEqBase.jl/blob/master/src/solve.jl#L73 which has a comment about a different Zygote bug? Could be wrong.

@ChrisRackauckas
Copy link
Member

What's the MWE?

@jlperla
Copy link
Collaborator Author

jlperla commented Mar 14, 2022

Sorry, was written more for me to discuss with people later, but basically toggle off the custom rule in

# Ideally hook into existing sensitity dispatching
# Trouble with Zygote. The problem isn't the _concrete_solve_adjoint but rather something in the
# adjoint of the basic solve and `solve_up`. Probably promotion on the prob
# function DiffEqBase._concrete_solve_adjoint(prob::LinearStateSpaceProblem, alg::DirectIteration,
# sensealg, u0, p, args...; kwargs...)
function ChainRulesCore.rrule(::typeof(DiffEqBase.solve), prob::LinearStateSpaceProblem,
alg::DirectIteration, args...; kwargs...)
and run line
gradient((args...) -> joint_likelihood_1(args..., observables_rbc, D_rbc), A_rbc, B_rbc, C_rbc,

And it triggers.

But hold off a day or two, adding in a few final SciML features first which might interact with dispatching and make_concrete

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants