Skip to content

Commit

Permalink
Merge pull request #78 from iliailmer/verbose-local
Browse files Browse the repository at this point in the history
Create more verbose output for local identifiability
  • Loading branch information
pogudingleb authored Feb 2, 2022
2 parents 925a085 + e9b9d42 commit 5a35bf8
Show file tree
Hide file tree
Showing 7 changed files with 54 additions and 48 deletions.
10 changes: 3 additions & 7 deletions src/ODE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -338,15 +338,11 @@ end

#------------------------------------------------------------------------------
"""
function PreprocessODE(de::ModelingToolkit.ODESystem, inputs)
function PreprocessODE(de::ModelingToolkit.ODESystem, measured_quantities::Array{ModelingToolkit.Equation})
Input:
- `diff_eqs` - array of ModelingToolkit differential equations
- `out_eqs` - array of output equations
- `states` - array of state variables
- `outputs` - array of output function names
- `inputs` - array of input function names
- `parameters` - array of parameter names
- `de` - ModelingToolkit.ODESystem, a system for identifiability query
- `measured_quantities` - array of output functions
Output:
- `ODE` object containing required data for identifiability assessment
Expand Down
21 changes: 10 additions & 11 deletions src/StructuralIdentifiability.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,9 @@ function assess_identifiability(ode::ODE{P}, funcs_to_check::Array{<:RingElem,1}
@debug "Bound: $bound"
end

loc_id = [local_result[each] for each in funcs_to_check]
locally_identifiable = Array{Any,1}()
for (loc, f) in zip(local_result, funcs_to_check)
for (loc, f) in zip(loc_id, funcs_to_check)
if loc
push!(locally_identifiable, f)
end
Expand All @@ -118,22 +119,22 @@ function assess_identifiability(ode::ODE{P}, funcs_to_check::Array{<:RingElem,1}
@info "Global identifiability assessed in $runtime seconds"
_runtime_logger[:glob_time] = runtime

result = Array{Symbol,1}()
result = Dict{Any, Symbol}()
glob_ind = 1
for i in 1:length(funcs_to_check)
if !local_result[i]
push!(result, :nonidentifiable)
if !local_result[funcs_to_check[i]]
result[funcs_to_check[i]] = :nonidentifiable
else
if global_result[glob_ind]
push!(result, :globally)
result[funcs_to_check[i]] = :globally
else
push!(result, :locally)
result[funcs_to_check[i]] = :locally
end
glob_ind += 1
end
end

return result
return Dict(result)
end

"""
Expand Down Expand Up @@ -163,11 +164,9 @@ function assess_identifiability(ode::ModelingToolkit.ODESystem; measured_quantit
ode, syms, gens_ = PreprocessODE(ode, measured_quantities)
out_dict = Dict{Num,Symbol}()
funcs_to_check_ = [eval_at_nemo(each, Dict(syms .=> gens_)) for each in funcs_to_check]
tmp = Dict(param => res for (param, res) in zip(funcs_to_check_, assess_identifiability(ode, funcs_to_check_, p)))
result = assess_identifiability(ode, funcs_to_check_, p)
nemo2mtk = Dict(funcs_to_check_ .=> funcs_to_check)
for (func, res) in pairs(tmp)
out_dict[nemo2mtk[func]] = res
end
out_dict = Dict(nemo2mtk[param] => result[param] for param in funcs_to_check_)
return out_dict
end

Expand Down
29 changes: 20 additions & 9 deletions src/local_identifiability.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ end

# ------------------------------------------------------------------------------
"""
assess_local_identifiability(function assess_local_identifiability(ode::ModelingToolkit.ODESystem; measured_quantities=Array{ModelingToolkit.Equation}[], funcs_to_check=Array{}[], p::Float64=0.99, type=:SE)
function assess_local_identifiability(ode::ModelingToolkit.ODESystem; measured_quantities=Array{ModelingToolkit.Equation}[], funcs_to_check=Array{}[], p::Float64=0.99, type=:SE)
Input:
- `ode` - the ODESystem object from ModelingToolkit
Expand Down Expand Up @@ -170,8 +170,19 @@ function assess_local_identifiability(ode::ModelingToolkit.ODESystem; measured_q
funcs_to_check = ModelingToolkit.parameters(ode)
end
ode, syms, gens_ = PreprocessODE(ode, measured_quantities)
funcs_to_check = [eval_at_nemo(x, Dict(syms .=> gens_)) for x in funcs_to_check]
return assess_local_identifiability(ode, funcs_to_check, p, type)
funcs_to_check_ = [eval_at_nemo(x, Dict(syms .=> gens_)) for x in funcs_to_check]

if isequal(type, :SE)
result = assess_local_identifiability(ode, funcs_to_check_, p, type)
nemo2mtk = Dict(funcs_to_check_ .=> funcs_to_check)
out_dict = Dict(nemo2mtk[param] => result[param] for param in funcs_to_check_)
return out_dict
elseif isequal(type, :ME)
result, bd = assess_local_identifiability(ode, funcs_to_check_, p, type)
nemo2mtk = Dict(funcs_to_check_ .=> funcs_to_check)
out_dict = Dict(nemo2mtk[param] => result[param] for param in funcs_to_check_)
return (out_dict, bd)
end
end
# ------------------------------------------------------------------------------
"""
Expand Down Expand Up @@ -200,10 +211,10 @@ function assess_local_identifiability(ode::ODE{P}, p::Float64=0.99, type=:SE) wh
end
result = assess_local_identifiability(ode, funcs_to_check, p, type)
if type == :SE
return Dict(a => b for (a, b) in zip(funcs_to_check, result))
return Dict(a => result[a] for a in funcs_to_check)
end
return (
Dict(a => b for (a, b) in zip(funcs_to_check, result[1])),
Dict(a => result[1][a] for a in funcs_to_check),
result[2]
)
end
Expand Down Expand Up @@ -314,21 +325,21 @@ function assess_local_identifiability(ode::ODE{P}, funcs_to_check::Array{<: Any,

@debug "Computing the result"
base_rank = LinearAlgebra.rank(Jac)
result = Array{Bool,1}()
result = Dict{Any, Bool}()
for i in 1:length(funcs_to_check)
for (k, p) in enumerate(ode_red.parameters)
Jac[k, 1] = coeff(output_derivatives[str_to_var("loc_aux_$i", ode_red.poly_ring)][p], 0)
end
for (k, x) in enumerate(ode_red.x_vars)
Jac[end - k + 1, 1] = coeff(output_derivatives[str_to_var("loc_aux_$i", ode_red.poly_ring)][x], 0)
end
push!(result, LinearAlgebra.rank(Jac) == base_rank)
result[funcs_to_check[i]] = LinearAlgebra.rank(Jac) == base_rank
end

if type == :SE
return result
return Dict(result)
end
return (result, num_exp)
return (Dict(result), num_exp)
end

# ------------------------------------------------------------------------------
10 changes: 5 additions & 5 deletions test/identifiability.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
push!(test_cases, Dict(
:ode => ode,
:funcs => funcs_to_test,
:correct => correct
:correct => Dict(funcs_to_test .=> correct)
))

#--------------------------------------------------------------------------
Expand All @@ -28,7 +28,7 @@
push!(test_cases, Dict(
:ode => ode,
:funcs => funcs_to_test,
:correct => correct
:correct => Dict(funcs_to_test .=>correct)
))

#--------------------------------------------------------------------------
Expand All @@ -45,7 +45,7 @@
push!(test_cases, Dict(
:ode => ode,
:funcs => funcs_to_test,
:correct => correct
:correct => Dict(funcs_to_test .=>correct)
))

#--------------------------------------------------------------------------
Expand All @@ -62,7 +62,7 @@
push!(test_cases, Dict(
:ode => ode,
:funcs => funcs_to_test,
:correct => correct
:correct => Dict(funcs_to_test .=>correct)
))

#--------------------------------------------------------------------------
Expand All @@ -78,7 +78,7 @@
push!(test_cases, Dict(
:ode => ode,
:funcs => funcs_to_test,
:correct => correct
:correct => Dict(funcs_to_test .=>correct)
))

#--------------------------------------------------------------------------
Expand Down
8 changes: 4 additions & 4 deletions test/local_identifiability.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
push!(test_cases, Dict(
:ode => ode,
:funcs => funcs_to_test,
:correct => correct
:correct => Dict(funcs_to_test .=> correct)
))

#--------------------------------------------------------------------------
Expand All @@ -28,7 +28,7 @@
push!(test_cases, Dict(
:ode => ode,
:funcs => funcs_to_test,
:correct => correct
:correct => Dict(funcs_to_test .=> correct)
))

#--------------------------------------------------------------------------
Expand All @@ -45,7 +45,7 @@
push!(test_cases, Dict(
:ode => ode,
:funcs => funcs_to_test,
:correct => correct
:correct => Dict(funcs_to_test .=> correct)
))

#--------------------------------------------------------------------------
Expand All @@ -58,7 +58,7 @@
y(t) = x1(t)
)
funcs_to_test = [b, c, alpha, beta, delta, gama, beta + delta, beta * delta]
correct = [true, true, false, true, true, false, true, true]
correct = Dict([b=>true, c=>true, alpha=>false, beta=>true, delta=>true, gama=>false, beta+delta=>true, beta*delta=>true])
push!(test_cases, Dict(
:ode => ode,
:funcs => funcs_to_test,
Expand Down
6 changes: 3 additions & 3 deletions test/local_identifiability_me.jl
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ end
push!(test_cases, Dict(
:ode => ode,
:funcs => funcs_to_test,
:correct => (correct, 1)
:correct => (Dict(funcs_to_test .=> correct), 1)
))

#--------------------------------------------------------------------------
Expand All @@ -183,7 +183,7 @@ end
push!(test_cases, Dict(
:ode => ode,
:funcs => funcs_to_test,
:correct => (correct, 2)
:correct => (Dict(funcs_to_test .=> correct), 2)
))

#--------------------------------------------------------------------------
Expand All @@ -199,7 +199,7 @@ end
y3(t) = N
)
funcs_to_test = [b, nu, d, a]
correct = [true, true, true, true]
correct = Dict([b=>true, nu=>true, d=>true, a=>true])
push!(test_cases, Dict(
:ode => ode,
:funcs => funcs_to_test,
Expand Down
18 changes: 9 additions & 9 deletions test/mtk_compat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,16 @@
de = ODESystem(eqs, t, name=:TestSIWR)
measured_quantities = [y ~ k * I]
# check all parameters (default)
@test isequal(true, all(assess_local_identifiability(de; measured_quantities=measured_quantities)))
@test isequal(true, all(values(assess_local_identifiability(de; measured_quantities=measured_quantities))))

# check specific parameters
funcs_to_check = [mu, bi, bw, a, xi, gm, gm + mu, k, S, I, W, R]
correct = [true for _ in funcs_to_check]
correct = Dict(f=>true for f in funcs_to_check)
@test isequal(correct, assess_local_identifiability(de; measured_quantities=measured_quantities, funcs_to_check=funcs_to_check))

# checking ME identifiability
funcs_to_check = [mu, bi, bw, a, xi, gm, gm + mu, k]
correct = [true for _ in funcs_to_check]
correct = Dict(f=>true for f in funcs_to_check)
@test isequal((correct, 1), assess_local_identifiability(de; measured_quantities=measured_quantities, funcs_to_check=funcs_to_check, p=0.99, type=:ME))

# --------------------------------------------------------------------------
Expand All @@ -99,18 +99,18 @@
]
de = ODESystem(eqs, t, name=:TestSIWR)
# check all parameters (default)
@test isequal(true, all(assess_local_identifiability(de)))
@test isequal(true, all(values(assess_local_identifiability(de))))

@test isequal(true, all(assess_local_identifiability(de; measured_quantities=[y~k*I])))
@test isequal(true, all(values(assess_local_identifiability(de; measured_quantities=[y~k*I]))))

# check specific parameters
funcs_to_check = [mu, bi, bw, a, xi, gm, gm + mu, k, S, I, W, R]
correct = [true for _ in funcs_to_check]
correct = Dict(f=>true for f in funcs_to_check)
@test isequal(correct, assess_local_identifiability(de; funcs_to_check=funcs_to_check))

# checking ME identifiability
funcs_to_check = [mu, bi, bw, a, xi, gm, gm + mu, k]
correct = [true for _ in funcs_to_check]
correct = Dict(f=>true for f in funcs_to_check)
@test isequal((correct, 1), assess_local_identifiability(de; funcs_to_check=funcs_to_check, p=0.99, type=:ME))

# --------------------------------------------------------------------------
Expand All @@ -126,12 +126,12 @@
de = ODESystem(eqs, t, name=:TestSIWR)
measured_quantities = [y ~ 1.57 * I * k]
funcs_to_check = [mu, bi, bw, a, xi, gm, mu, gm + mu, k, S, I, W, R]
correct = [true for _ in funcs_to_check]
correct = Dict(f=>true for f in funcs_to_check)
@test isequal(correct, assess_local_identifiability(de; measured_quantities=measured_quantities, funcs_to_check=funcs_to_check))

# checking ME identifiability
funcs_to_check = [bi, bw, a, xi, gm, mu, gm + mu, k]
correct = [true for _ in funcs_to_check]
correct = Dict(f=>true for f in funcs_to_check)
@test isequal((correct, 1), assess_local_identifiability(de; measured_quantities=measured_quantities, funcs_to_check=funcs_to_check, p=0.99, type=:ME))

# ----------
Expand Down

0 comments on commit 5a35bf8

Please sign in to comment.