From e85503548a8fda03a69d75591f518efd6bba02da Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 19 Sep 2024 15:57:26 +0530 Subject: [PATCH 1/3] fix: handle `nothing` passed as `u0` to `ODEProblem` fix: handle non-vector non-symbolic array u0 for ODEProblem --- src/systems/diffeqs/abstractodesystem.jl | 5 +++-- test/odesystem.jl | 7 +++++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index d8085d7e33..f7f94504e4 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -817,11 +817,12 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap; ModelingToolkit.get_tearing_state(sys) !== nothing) || !isempty(initialization_equations(sys))) && t !== nothing if eltype(u0map) <: Number - u0map = unknowns(sys) .=> u0map + u0map = unknowns(sys) .=> vec(u0map) end - if isempty(u0map) + if u0map === nothing || isempty(u0map) u0map = Dict() end + initializeprob = ModelingToolkit.InitializationProblem( sys, t, u0map, parammap; guesses, warn_initialize_determined, initialization_eqs, eval_expression, eval_module, fully_determined, check_units) diff --git a/test/odesystem.jl b/test/odesystem.jl index e39649c1c2..c229eebe28 100644 --- a/test/odesystem.jl +++ b/test/odesystem.jl @@ -1387,3 +1387,10 @@ end @test obsfn(ones(2), 2ones(2), 3ones(4), 4.0) == 6ones(2) end + +@testset "Passing `nothing` to `u0`" begin + @variables x(t) = 1 + @mtkbuild sys = ODEProblem(D(x) ~ t, t) + prob = @test_nowarn ODEProblem(sys, nothing, (0.0, 1.0)) + @test_nowarn solve(prob) +end From bbc8bf788d2d069faadb752a0e3e1ebda680fff0 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 19 Sep 2024 17:41:59 +0530 Subject: [PATCH 2/3] fix: handle non-symbolic and `nothing` u0 for `DiscreteProblem` --- src/systems/discrete_system/discrete_system.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/systems/discrete_system/discrete_system.jl b/src/systems/discrete_system/discrete_system.jl index 4a2e9bca97..7103cfca80 100644 --- a/src/systems/discrete_system/discrete_system.jl +++ b/src/systems/discrete_system/discrete_system.jl @@ -247,6 +247,13 @@ function process_DiscreteProblem(constructor, sys::DiscreteSystem, u0map, paramm dvs = unknowns(sys) ps = parameters(sys) + if eltype(u0map) <: Number + u0map = unknowns(sys) .=> vec(u0map) + end + if u0map === nothing || isempty(u0map) + u0map = Dict() + end + trueu0map = Dict() for (k, v) in u0map k = unwrap(k) From f84e5715a366f05c7a096c99da3eb9009292ad17 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 19 Sep 2024 17:42:12 +0530 Subject: [PATCH 3/3] test: test passing `nothing` to u0 for all system types --- test/discrete_system.jl | 8 ++++++++ test/nonlinearsystem.jl | 7 +++++++ test/odesystem.jl | 2 +- test/optimizationsystem.jl | 7 +++++++ test/sdesystem.jl | 8 ++++++++ 5 files changed, 31 insertions(+), 1 deletion(-) diff --git a/test/discrete_system.jl b/test/discrete_system.jl index f8ed0a911f..7d8c0d9678 100644 --- a/test/discrete_system.jl +++ b/test/discrete_system.jl @@ -271,3 +271,11 @@ end k = ShiftIndex(t) @named sys = DiscreteSystem([x ~ x^2 + y^2, y ~ x(k - 1) + y(k - 1)], t) @test_throws ["algebraic equations", "not yet supported"] structural_simplify(sys) + +@testset "Passing `nothing` to `u0`" begin + @variables x(t) = 1 + k = ShiftIndex() + @mtkbuild sys = DiscreteSystem([x(k) ~ x(k - 1) + 1], t) + prob = @test_nowarn DiscreteProblem(sys, nothing, (0.0, 1.0)) + @test_nowarn solve(prob, FunctionMap()) +end diff --git a/test/nonlinearsystem.jl b/test/nonlinearsystem.jl index 841a026d0f..a71d34a880 100644 --- a/test/nonlinearsystem.jl +++ b/test/nonlinearsystem.jl @@ -318,3 +318,10 @@ sys = structural_simplify(ns; conservative = true) sol = solve(prob, NewtonRaphson()) @test sol[x] ≈ sol[y] ≈ sol[z] ≈ -3 end + +@testset "Passing `nothing` to `u0`" begin + @variables x = 1 + @mtkbuild sys = NonlinearSystem([0 ~ x^2 - x^3 + 3]) + prob = @test_nowarn NonlinearProblem(sys, nothing) + @test_nowarn solve(prob) +end diff --git a/test/odesystem.jl b/test/odesystem.jl index c229eebe28..0837573baa 100644 --- a/test/odesystem.jl +++ b/test/odesystem.jl @@ -1390,7 +1390,7 @@ end @testset "Passing `nothing` to `u0`" begin @variables x(t) = 1 - @mtkbuild sys = ODEProblem(D(x) ~ t, t) + @mtkbuild sys = ODESystem(D(x) ~ t, t) prob = @test_nowarn ODEProblem(sys, nothing, (0.0, 1.0)) @test_nowarn solve(prob) end diff --git a/test/optimizationsystem.jl b/test/optimizationsystem.jl index f182e3ef87..426d6d5de0 100644 --- a/test/optimizationsystem.jl +++ b/test/optimizationsystem.jl @@ -340,3 +340,10 @@ end prob.f.cons_h(H3, [1.0, 1.0], [1.0, 100.0]) @test prob.f.cons_h([1.0, 1.0], [1.0, 100.0]) == H3 end + +@testset "Passing `nothing` to `u0`" begin + @variables x = 1.0 + @mtkbuild sys = OptimizationSystem((x - 3)^2, [x], []) + prob = @test_nowarn OptimizationProblem(sys, nothing) + @test_nowarn solve(prob, NelderMead()) +end diff --git a/test/sdesystem.jl b/test/sdesystem.jl index 5628772041..cae9ec9376 100644 --- a/test/sdesystem.jl +++ b/test/sdesystem.jl @@ -776,3 +776,11 @@ end prob = SDEProblem(de, u0map, (0.0, 100.0), parammap) @test solve(prob, SOSRI()).retcode == ReturnCode.Success end + +@testset "Passing `nothing` to `u0`" begin + @variables x(t) = 1 + @brownian b + @mtkbuild sys = System([D(x) ~ x + b], t) + prob = @test_nowarn SDEProblem(sys, nothing, (0.0, 1.0)) + @test_nowarn solve(prob, ImplicitEM()) +end