diff --git a/src/pomdps.jl b/src/pomdps.jl index d6712ed..ac0197d 100644 --- a/src/pomdps.jl +++ b/src/pomdps.jl @@ -16,9 +16,7 @@ function predict!(pm, m::POMDP, b, a, rng) all_terminal = true for i in 1:n_particles(b) s = particle(b, i) - if isterminal(m, s) - pm[i] = undef - else + if !isterminal(m, s) all_terminal = false sp = generate_s(m, s, a, rng) pm[i] = sp diff --git a/test/runtests.jl b/test/runtests.jl index d4c0a4c..4f8cb6e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,6 +4,7 @@ using POMDPModels using Test using POMDPPolicies using POMDPSimulators +using POMDPModelTools using Random using Distributions using NBInclude @@ -77,6 +78,17 @@ struct ContinuousPOMDP <: POMDP{Float64, Float64, Float64} end end end +struct TerminalPOMDP <: POMDP{Int, Int, Float64} end +POMDPs.isterminal(::TerminalPOMDP, s) = s == 1 +POMDPs.observation(::TerminalPOMDP, a, sp) = Normal(sp) +POMDPs.transition(::TerminalPOMDP, s, a) = Deterministic(s+a) +@testset "pomdp terminal" begin + pomdp = TerminalPOMDP() + pf = SIRParticleFilter(pomdp, 100) + bp = update(pf, initialize_belief(pf, Categorical([0.5, 0.5])), -1, 1.0) + @test all(particles(bp) .== 1) +end + @testset "alpha" begin # test specific method for alpha vector policies and particle beliefs pomdp = BabyPOMDP()