diff --git a/src/cost_functions.jl b/src/cost_functions.jl index 56b482bf..6f249b94 100644 --- a/src/cost_functions.jl +++ b/src/cost_functions.jl @@ -40,9 +40,9 @@ function (f::L2Loss)(sol::DiffEqBase.AbstractNoTimeSolution) dudt = f.dudt if sol isa DiffEqBase.AbstractEnsembleSolution - failure = any(s.retcode !== :Success && s.retcode !== :Terminated for s in sol) + failure = any(!SciMLBase.successful_retcode(s.retcode) for s in sol) else - failure = sol.retcode !== :Success && sol.retcode !== :Terminated + failure = !SciMLBase.successful_retcode(sol.retcode) end failure && return Inf @@ -72,9 +72,9 @@ function (f::L2Loss)(sol::SciMLBase.AbstractSciMLSolution) dudt = f.dudt if sol isa DiffEqBase.AbstractEnsembleSolution - failure = any(s.retcode !== :Success && s.retcode !== :Terminated for s in sol) + failure = any(!SciMLBase.successful_retcode(s.retcode) for s in sol) else - failure = sol.retcode !== :Success && sol.retcode !== :Terminated + failure = !SciMLBase.successful_retcode(sol.retcode) end failure && return Inf @@ -171,9 +171,9 @@ end function (f::LogLikeLoss)(sol::SciMLBase.AbstractSciMLSolution) distributions = f.data_distributions if sol isa DiffEqBase.AbstractEnsembleSolution - failure = any(s.retcode !== :Success && s.retcode !== :Terminated for s in sol) + failure = any(!SciMLBase.successful_retcode(s.retcode) for s in sol) else - failure = sol.retcode !== :Success && sol.retcode !== :Terminated + failure = !SciMLBase.successful_retcode(sol.retcode) end failure && return Inf ll = 0.0 @@ -220,9 +220,9 @@ end function (f::LogLikeLoss)(sol::DiffEqBase.AbstractEnsembleSolution) distributions = f.data_distributions if sol_tmp isa DiffEqBase.AbstractEnsembleSolution - failure = any(s.retcode !== :Success && s.retcode !== :Terminated for s in sol_tmp) + failure = any(!SciMLBase.successful_retcode(s.retcode) for s in sol_tmp) else - failure = sol_tmp.retcode !== :Success && sol_tmp.retcode !== :Terminated + failure = !SciMLBase.successful_retcode(sol_tmp.retcode) end failure && return Inf ll = 0.0 diff --git a/src/multiple_shooting_objective.jl b/src/multiple_shooting_objective.jl index bc9918a4..9b4d56ba 100644 --- a/src/multiple_shooting_objective.jl +++ b/src/multiple_shooting_objective.jl @@ -54,7 +54,7 @@ function multiple_shooting_objective(prob::DiffEqBase.DEProblem, alg, loss, push!(sol, solve(tmp_prob, alg; kwargs...)) end end - if any((s.retcode != :Success for s in sol)) + if any((!SciMLBase.successful_retcode(s.retcode) for s in sol)) return Inf end u = [uc for k in 1:K for uc in sol[k].u[1:(end - 1)]] @@ -63,7 +63,7 @@ function multiple_shooting_objective(prob::DiffEqBase.DEProblem, alg, loss, push!(t, sol[K].t[end]) sol_loss = Merged_Solution(u, t, sol) sol_new = DiffEqBase.build_solution(prob, alg, sol_loss.t, sol_loss.u, - retcode = :Success) + retcode = ReturnCode.Success) loss_val = loss(sol_new) if priors !== nothing loss_val += prior_loss(priors, p[(end - length(priors)):end])