Skip to content

Commit

Permalink
Merge pull request #184 from SciML/retcode
Browse files Browse the repository at this point in the history
update retcode handling
  • Loading branch information
ChrisRackauckas authored Nov 6, 2022
2 parents a4e0056 + 8d8fe79 commit bc0b417
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
16 changes: 8 additions & 8 deletions src/cost_functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ function (f::L2Loss)(sol::DiffEqBase.AbstractNoTimeSolution)
dudt = f.dudt

if sol isa DiffEqBase.AbstractEnsembleSolution
failure = any(s.retcode !== :Success && s.retcode !== :Terminated for s in sol)
failure = any(!SciMLBase.successful_retcode(s.retcode) for s in sol)
else
failure = sol.retcode !== :Success && sol.retcode !== :Terminated
failure = !SciMLBase.successful_retcode(sol.retcode)
end
failure && return Inf

Expand Down Expand Up @@ -72,9 +72,9 @@ function (f::L2Loss)(sol::SciMLBase.AbstractSciMLSolution)
dudt = f.dudt

if sol isa DiffEqBase.AbstractEnsembleSolution
failure = any(s.retcode !== :Success && s.retcode !== :Terminated for s in sol)
failure = any(!SciMLBase.successful_retcode(s.retcode) for s in sol)
else
failure = sol.retcode !== :Success && sol.retcode !== :Terminated
failure = !SciMLBase.successful_retcode(sol.retcode)
end
failure && return Inf

Expand Down Expand Up @@ -171,9 +171,9 @@ end
function (f::LogLikeLoss)(sol::SciMLBase.AbstractSciMLSolution)
distributions = f.data_distributions
if sol isa DiffEqBase.AbstractEnsembleSolution
failure = any(s.retcode !== :Success && s.retcode !== :Terminated for s in sol)
failure = any(!SciMLBase.successful_retcode(s.retcode) for s in sol)
else
failure = sol.retcode !== :Success && sol.retcode !== :Terminated
failure = !SciMLBase.successful_retcode(sol.retcode)
end
failure && return Inf
ll = 0.0
Expand Down Expand Up @@ -220,9 +220,9 @@ end
function (f::LogLikeLoss)(sol::DiffEqBase.AbstractEnsembleSolution)
distributions = f.data_distributions
if sol_tmp isa DiffEqBase.AbstractEnsembleSolution
failure = any(s.retcode !== :Success && s.retcode !== :Terminated for s in sol_tmp)
failure = any(!SciMLBase.successful_retcode(s.retcode) for s in sol_tmp)
else
failure = sol_tmp.retcode !== :Success && sol_tmp.retcode !== :Terminated
failure = !SciMLBase.successful_retcode(sol_tmp.retcode)
end
failure && return Inf
ll = 0.0
Expand Down
4 changes: 2 additions & 2 deletions src/multiple_shooting_objective.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ function multiple_shooting_objective(prob::DiffEqBase.DEProblem, alg, loss,
push!(sol, solve(tmp_prob, alg; kwargs...))
end
end
if any((s.retcode != :Success for s in sol))
if any((!SciMLBase.successful_retcode(s.retcode) for s in sol))
return Inf
end
u = [uc for k in 1:K for uc in sol[k].u[1:(end - 1)]]
Expand All @@ -63,7 +63,7 @@ function multiple_shooting_objective(prob::DiffEqBase.DEProblem, alg, loss,
push!(t, sol[K].t[end])
sol_loss = Merged_Solution(u, t, sol)
sol_new = DiffEqBase.build_solution(prob, alg, sol_loss.t, sol_loss.u,
retcode = :Success)
retcode = ReturnCode.Success)
loss_val = loss(sol_new)
if priors !== nothing
loss_val += prior_loss(priors, p[(end - length(priors)):end])
Expand Down

0 comments on commit bc0b417

Please sign in to comment.