diff --git a/lib/POMDPTools/src/BeliefUpdaters/discrete.jl b/lib/POMDPTools/src/BeliefUpdaters/discrete.jl index fade2143..8d7f3901 100644 --- a/lib/POMDPTools/src/BeliefUpdaters/discrete.jl +++ b/lib/POMDPTools/src/BeliefUpdaters/discrete.jl @@ -13,10 +13,10 @@ Normalization of `b` is assumed in some calculations (e.g. pdf), but it is only # Constructor DiscreteBelief(pomdp, b::Vector{Float64}; check::Bool=true) -# Fields -- `pomdp` : the POMDP problem +# Fields +- `pomdp` : the POMDP problem - `state_list` : a vector of ordered states -- `b` : the probability vector +- `b` : the probability vector """ struct DiscreteBelief{P<:POMDP, S} pomdp::P @@ -29,14 +29,14 @@ function DiscreteBelief(pomdp, b::Vector{Float64}; check::Bool=true) if !isapprox(sum(b), 1.0, atol=0.001) @warn(""" b in DiscreteBelief(pomdp, b) does not sum to 1. - + To suppress this warning use `DiscreteBelief(pomdp, b, check=false)` """, b) end if !all(0.0 <= p <= 1.0 for p in b) @warn(""" b in DiscreteBelief(pomdp, b) contains entries outside [0,1]. - + To suppress this warning use `DiscreteBelief(pomdp, b, check=false)` """, b) end @@ -44,6 +44,22 @@ function DiscreteBelief(pomdp, b::Vector{Float64}; check::Bool=true) return DiscreteBelief(pomdp, ordered_states(pomdp), b) end +import .POMDPDistributions: SparseCat +""" + SparseCat(d::DiscreteBelief; check_zeros=true) + +Create a sparse categorical distribution from a DiscreteBelief. +`check_zeros` is a flag to "sparsify" the distribution. +""" +function SparseCat(b::DiscreteBelief; check_zeros=true) + if !check_zeros + return SparseCat(b.state_list, b.b) + end + non_zero_indices = findall(x -> !isapprox(x, 0.0; atol=eps()), b.b) + return SparseCat(b.state_list[non_zero_indices], b.b[non_zero_indices]) +end + +Base.show(io::IO, m::MIME"text/plain", d::DiscreteBelief) = showdistribution(io, m, d, title="$(typeof(d.pomdp)) DiscreteBelief") """ uniform_belief(pomdp) diff --git a/lib/POMDPTools/test/belief_updaters/test_belief.jl b/lib/POMDPTools/test/belief_updaters/test_belief.jl index 2876c5dc..42d88dd8 100644 --- a/lib/POMDPTools/test/belief_updaters/test_belief.jl +++ b/lib/POMDPTools/test/belief_updaters/test_belief.jl @@ -76,3 +76,16 @@ bnew = update(up, bold, a, o) b5 = DiscreteBelief(pomdp, [0.4, 0.6]) @test @inferred(mean(b5)) == 0.6 @test @inferred(mode(b5)) == true + +# test display of DiscreteBelief +b = DiscreteBelief(MiniHallway(), [0.1, 0.1, 0.123, 0.4, 0, 0, 0, 0, 0.177, 0.05, 0, 0, 0.05]) +@test occursin("MiniHallway", sprint(showdistribution, b)) +@test occursin("0.123", sprint(showdistribution, b)) + +# test SparseCat of DiscreteBelief +b_sparse_cat = SparseCat(b) +@test length(b_sparse_cat.vals) == sum(b.b .!= 0.0) +@test isapprox(sum(b_sparse_cat.probs), 1.0; atol=eps()) +b_sparse_cat = SparseCat(b; check_zeros=false) +@test length(b_sparse_cat.vals) == length(b.state_list) +@test isapprox(sum(b_sparse_cat.probs), 1.0; atol=eps()) diff --git a/lib/POMDPTools/test/runtests.jl b/lib/POMDPTools/test/runtests.jl index 9cbcf2a7..321fb6a9 100644 --- a/lib/POMDPTools/test/runtests.jl +++ b/lib/POMDPTools/test/runtests.jl @@ -1,7 +1,7 @@ using POMDPs using DiscreteValueIteration: ValueIterationSolver using QuickPOMDPs -using POMDPModels: BabyPOMDP, TigerPOMDP, SimpleGridWorld, LegacyGridWorld, RandomMDP, TMaze, Starve, FeedWhenCrying, RandomPOMDP +using POMDPModels: BabyPOMDP, TigerPOMDP, SimpleGridWorld, LegacyGridWorld, RandomMDP, TMaze, Starve, FeedWhenCrying, RandomPOMDP, MiniHallway import POMDPLinter using POMDPTools