From 77c3abca05511b5a45fab16a53a80e214fa3dfb2 Mon Sep 17 00:00:00 2001 From: Gabriel Gerlero Date: Fri, 29 Dec 2023 20:06:59 -0300 Subject: [PATCH] Update ODE retcode handling --- src/Fronts.jl | 1 - src/ParamEstim.jl | 2 +- src/integration.jl | 22 ++++++++++++---------- src/odes.jl | 8 ++++---- src/shooting.jl | 38 ++++++++------------------------------ test/runtests.jl | 2 +- 6 files changed, 26 insertions(+), 47 deletions(-) diff --git a/src/Fronts.jl b/src/Fronts.jl index 9d5db553..cde98ce0 100644 --- a/src/Fronts.jl +++ b/src/Fronts.jl @@ -18,7 +18,6 @@ using PCHIPInterpolation: Interpolator, integrate import NumericalIntegration using RecipesBase -using OrdinaryDiffEq.SciMLBase: NullParameters using OrdinaryDiffEq: ODEFunction, ODEProblem, ODESolution using OrdinaryDiffEq: init, solve!, reinit! using OrdinaryDiffEq: DiscreteCallback, terminate! diff --git a/src/ParamEstim.jl b/src/ParamEstim.jl index ce2e22e8..73e90bd3 100644 --- a/src/ParamEstim.jl +++ b/src/ParamEstim.jl @@ -3,9 +3,9 @@ module ParamEstim import ..Fronts using ..Fronts: InverseProblem, AbstractSemiinfiniteProblem, Solution, ReturnCode, solve import ..Fronts: sorptivity +import ..Fronts.SciMLBase: successful_retcode, NullParameters using LsqFit: curve_fit -import OrdinaryDiffEq.SciMLBase: successful_retcode, NullParameters """ ScaledSolution diff --git a/src/integration.jl b/src/integration.jl index 4496454e..98af1e26 100644 --- a/src/integration.jl +++ b/src/integration.jl @@ -29,7 +29,7 @@ end Transform `prob` into an ODE problem in terms of the Boltzmann variable `o`. -The ODE problem is set up to terminate automatically (`ReturnCode.Terminated`) when the steady state is reached. +The ODE problem is set up to terminate automatically (with `.retcode == ReturnCode.Success`) when the steady state is reached. See also: [`DifferentialEquations`](https://diffeq.sciml.ai/stable/) """ @@ -45,7 +45,11 @@ function boltzmann(prob::Union{CauchyProblem, SorptivityCauchyProblem}) settled = DiscreteCallback(let direction = monotonicity(prob) (u, t, integrator) -> direction * u[2] ≤ zero(u[2]) end, - terminate!, + function succeed!(integrator) + terminate!(integrator) + integrator.sol = SciMLBase.solution_new_retcode(integrator.sol, + ReturnCode.Success) + end, save_positions = (false, false)) ODEProblem(boltzmann(prob.eq), u0, (ob, typemax(ob)), callback = settled) @@ -112,16 +116,14 @@ function solve(prob::Union{CauchyProblem, SorptivityCauchyProblem}, verbose = true) odesol = solve!(_init(prob, alg, verbose = verbose)) - @assert odesol.retcode != ReturnCode.Success - - if odesol.retcode != ReturnCode.Terminated - return Solution(odesol, prob, alg, _retcode = odesol.retcode, _niter = 1) - end - - return Solution(odesol, prob, alg, _retcode = ReturnCode.Success, _niter = 1) + return Solution(odesol, prob, alg, _niter = 1) end -function Solution(_odesol::ODESolution, _prob, _alg::BoltzmannODE; _retcode, _niter) +function Solution(_odesol::ODESolution, + _prob, + _alg::BoltzmannODE; + _retcode = _odesol.retcode, + _niter) return Solution(o -> _odesol(o, idxs = 1), _prob, _alg, diff --git a/src/odes.jl b/src/odes.jl index ee56578e..0a7eac90 100644 --- a/src/odes.jl +++ b/src/odes.jl @@ -11,14 +11,14 @@ See also: [`DifferentialEquations`](https://diffeq.sciml.ai/stable/), [`StaticAr """ function boltzmann(eq::DiffusionEquation{1}) let K = u -> conductivity(eq, u), C = u -> capacity(eq, u) - function f((u, du_do), ::NullParameters, o) + function f((u, du_do), ::SciMLBase.NullParameters, o) K_, dK_du = value_and_derivative(K, u) d²u_do² = -((C(u) * o / 2 + dK_du * du_do) / K_) * du_do return @SVector [du_do, d²u_do²] end - function jac((u, du_do), ::NullParameters, o) + function jac((u, du_do), ::SciMLBase.NullParameters, o) K_, dK_du, d²K_du² = value_and_derivatives(K, u) C_, dC_du = value_and_derivative(C, u) @@ -36,14 +36,14 @@ end function boltzmann(eq::DiffusionEquation{m}) where {m} @assert m in 2:3 let K = u -> conductivity(eq, u), C = u -> capacity(eq, u), k = m - 1 - function f((u, du_do), ::NullParameters, o) + function f((u, du_do), ::SciMLBase.NullParameters, o) K_, dK_du = value_and_derivative(K, u) d²u_do² = -((C(u) * o / 2 + dK_du * du_do) / K_ + k / o) * du_do return @SVector [du_do, d²u_do²] end - function jac((u, du_do), ::NullParameters, o) + function jac((u, du_do), ::SciMLBase.NullParameters, o) K_, dK_du, d²K_du² = value_and_derivatives(K, u) C_, dC_du = value_and_derivative(C, u) diff --git a/src/shooting.jl b/src/shooting.jl index 70b3f21b..97f8e4ec 100644 --- a/src/shooting.jl +++ b/src/shooting.jl @@ -51,14 +51,11 @@ function solve(prob::DirichletProblem, alg::BoltzmannODE = BoltzmannODE(); CauchyProblem(prob.eq, b = prob.b, d_dob = zero(d_dob_hint), ob = prob.ob)) solve!(integrator) - @assert integrator.sol.retcode != ReturnCode.Success - retcode = integrator.sol.retcode == ReturnCode.Terminated ? ReturnCode.Success : - integrator.sol.retcode - if verbose && !SciMLBase.successful_retcode(integrator.sol) + if verbose && integrator.sol.retcode != ReturnCode.Success @warn "Problem has a trivial solution but failed to obtain it" end - return Solution(integrator.sol, prob, alg, _retcode = retcode, _niter = 0) + return Solution(integrator.sol, prob, alg, _niter = 0) end d_dob_trial = bracket_bisect(zero(d_dob_hint), d_dob_hint, resid) @@ -68,20 +65,14 @@ function solve(prob::DirichletProblem, alg::BoltzmannODE = BoltzmannODE(); CauchyProblem(prob.eq, b = prob.b, d_dob = d_dob_trial(resid), ob = prob.ob)) solve!(integrator) - @assert integrator.sol.retcode != ReturnCode.Success - if integrator.sol.retcode == ReturnCode.Terminated && - direction * integrator.sol.u[end][1] ≤ direction * limit + if integrator.sol.retcode == ReturnCode.Success resid = integrator.sol.u[end][1] - prob.i else resid = direction * typemax(prob.i) end if abs(resid) ≤ abstol - return Solution(integrator.sol, - prob, - alg, - _retcode = ReturnCode.Success, - _niter = niter) + return Solution(integrator.sol, prob, alg, _niter = niter) end end @@ -150,18 +141,11 @@ function solve(prob::Union{FlowrateProblem, SorptivityProblem}, SorptivityCauchyProblem(prob.eq, b = prob.i, S = zero(S), ob = ob)) solve!(integrator) - @assert integrator.sol.retcode != ReturnCode.Success - retcode = integrator.sol.retcode == ReturnCode.Terminated ? ReturnCode.Success : - integrator.sol.retcode - if verbose && !SciMLBase.successful_retcode(integrator.sol) + if verbose && integrator.sol.retcode != ReturnCode.Success @warn "Problem has a trivial solution but failed to obtain it" end - return Solution(integrator.sol, - prob, - alg, - _retcode = retcode, - _niter = 0) + return Solution(integrator.sol, prob, alg, _niter = 0) end b_trial = bracket_bisect(prob.i, b_hint) @@ -171,9 +155,7 @@ function solve(prob::Union{FlowrateProblem, SorptivityProblem}, SorptivityCauchyProblem(prob.eq, b = b_trial(resid), S = S, ob = ob)) solve!(integrator) - @assert integrator.sol.retcode != ReturnCode.Success - if integrator.sol.retcode == ReturnCode.Terminated && - direction * integrator.sol.u[end][1] ≤ direction * limit + if integrator.sol.retcode == ReturnCode.Success resid = integrator.sol.u[end][1] - prob.i elseif integrator.sol.retcode != ReturnCode.Terminated && integrator.t == ob resid = -direction * typemax(prob.i) @@ -182,11 +164,7 @@ function solve(prob::Union{FlowrateProblem, SorptivityProblem}, end if abs(resid) ≤ abstol - return Solution(integrator.sol, - prob, - alg, - _retcode = ReturnCode.Success, - _niter = niter) + return Solution(integrator.sol, prob, alg, _niter = niter) end end diff --git a/test/runtests.jl b/test/runtests.jl index 5a9343ae..4159c05c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,13 +2,13 @@ using Fronts using Fronts._Diff using Fronts.PorousModels using Fronts.ParamEstim +using Fronts.SciMLBase: NullParameters using Test import ForwardDiff import NaNMath using NumericalIntegration using OrdinaryDiffEq: ODEFunction, ODEProblem -using OrdinaryDiffEq.DiffEqBase: NullParameters using StaticArrays: @SVector, SVector using Plots: plot