Skip to content

Commit

Permalink
Merge pull request #157 from SciML/ChrisRackauckas-patch-2
Browse files Browse the repository at this point in the history
don't skip the first point
  • Loading branch information
ChrisRackauckas authored Feb 6, 2021
2 parents ea38ca6 + 0f19c11 commit 9c05926
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 6 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LsqFit = "2fda8390-95c7-5789-9bda-21331edee243"
PenaltyFunctions = "06bb1623-fdd5-5ca2-a01c-88eae3ea319e"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"

[compat]
Calculus = "0.5"
Expand All @@ -25,6 +26,7 @@ ForwardDiff = "0.10"
LsqFit = "0.8, 0.9, 0.10, 0.11, 0.12"
PenaltyFunctions = "0.1, 0.2"
RecursiveArrayTools = "1.0, 2.0"
SciMLBase = "1"
julia = "1"

[extras]
Expand Down
4 changes: 3 additions & 1 deletion src/DiffEqParamEstim.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
module DiffEqParamEstim
using DiffEqBase, LsqFit, PenaltyFunctions,
RecursiveArrayTools, ForwardDiff, Calculus, Distributions, LinearAlgebra, DiffEqSensitivity, Dierckx
RecursiveArrayTools, ForwardDiff, Calculus, Distributions,
LinearAlgebra, DiffEqSensitivity, Dierckx,
SciMLBase

STANDARD_PROB_GENERATOR(prob,p) = remake(prob;u0=eltype(p).(prob.u0),p=p)
STANDARD_PROB_GENERATOR(prob::EnsembleProblem,p) = EnsembleProblem(
Expand Down
2 changes: 1 addition & 1 deletion src/build_loss_objective.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ end
(f::DiffEqObjective)(x) = f.cost_function(x)
(f::DiffEqObjective)(x,y) = f.cost_function2(x,y)

function build_loss_objective(prob::DiffEqBase.DEProblem,alg,loss,regularization=nothing,args...;
function build_loss_objective(prob::SciMLBase.SciMLProblem,alg,loss,regularization=nothing,args...;
priors=nothing,mpg_autodiff = false,
verbose_opt = false,verbose_steps = 100,
prob_generator = STANDARD_PROB_GENERATOR,
Expand Down
8 changes: 4 additions & 4 deletions src/cost_functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,11 @@ function (f::L2Loss)(sol::DiffEqBase.DESolution)
sumsq = 0.0

if weight == nothing
@inbounds for i in 2:length(sol)
@inbounds for i in 1:length(sol)
for j in 1:length(sol[i])
sumsq +=(data[j,i] - sol[j,i])^2
end
if diff_weight != nothing
if diff_weight != nothing && i != 1
for j in 1:length(sol[i])
if typeof(diff_weight) <: Real
sumsq += diff_weight*((data[j,i] - data[j,i-1] - sol[j,i] + sol[j,i-1])^2)
Expand All @@ -98,7 +98,7 @@ function (f::L2Loss)(sol::DiffEqBase.DESolution)
end
end
else
@inbounds for i in 2:length(sol)
@inbounds for i in 1:length(sol)
if typeof(weight) <: Real
for j in 1:length(sol[i])
sumsq = sumsq + ((data[j,i] - sol[j,i])^2)*weight
Expand All @@ -108,7 +108,7 @@ function (f::L2Loss)(sol::DiffEqBase.DESolution)
sumsq = sumsq + ((data[j,i] - sol[j,i])^2)*weight[j,i]
end
end
if diff_weight != nothing
if diff_weight != nothing && i != 1
for j in 1:length(sol[i])
if typeof(diff_weight) <: Real
sumsq += diff_weight*((data[j,i] - data[j,i-1] - sol[j,i] + sol[j,i-1])^2)
Expand Down

0 comments on commit 9c05926

Please sign in to comment.