Skip to content

Commit

Permalink
Added show option to DiscreteBelief and extended SparseCat to for Dis…
Browse files Browse the repository at this point in the history
…creteBelief inputs (#529)
  • Loading branch information
dylan-asmar authored Dec 4, 2023
1 parent be404e0 commit 4aca91a
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 6 deletions.
26 changes: 21 additions & 5 deletions lib/POMDPTools/src/BeliefUpdaters/discrete.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ Normalization of `b` is assumed in some calculations (e.g. pdf), but it is only
# Constructor
DiscreteBelief(pomdp, b::Vector{Float64}; check::Bool=true)
# Fields
- `pomdp` : the POMDP problem
# Fields
- `pomdp` : the POMDP problem
- `state_list` : a vector of ordered states
- `b` : the probability vector
- `b` : the probability vector
"""
struct DiscreteBelief{P<:POMDP, S}
pomdp::P
Expand All @@ -29,21 +29,37 @@ function DiscreteBelief(pomdp, b::Vector{Float64}; check::Bool=true)
if !isapprox(sum(b), 1.0, atol=0.001)
@warn("""
b in DiscreteBelief(pomdp, b) does not sum to 1.
To suppress this warning use `DiscreteBelief(pomdp, b, check=false)`
""", b)
end
if !all(0.0 <= p <= 1.0 for p in b)
@warn("""
b in DiscreteBelief(pomdp, b) contains entries outside [0,1].
To suppress this warning use `DiscreteBelief(pomdp, b, check=false)`
""", b)
end
end
return DiscreteBelief(pomdp, ordered_states(pomdp), b)
end

import .POMDPDistributions: SparseCat
"""
SparseCat(d::DiscreteBelief; check_zeros=true)
Create a sparse categorical distribution from a DiscreteBelief.
`check_zeros` is a flag to "sparsify" the distribution.
"""
function SparseCat(b::DiscreteBelief; check_zeros=true)
if !check_zeros
return SparseCat(b.state_list, b.b)
end
non_zero_indices = findall(x -> !isapprox(x, 0.0; atol=eps()), b.b)
return SparseCat(b.state_list[non_zero_indices], b.b[non_zero_indices])
end

Base.show(io::IO, m::MIME"text/plain", d::DiscreteBelief) = showdistribution(io, m, d, title="$(typeof(d.pomdp)) DiscreteBelief")

"""
uniform_belief(pomdp)
Expand Down
13 changes: 13 additions & 0 deletions lib/POMDPTools/test/belief_updaters/test_belief.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,16 @@ bnew = update(up, bold, a, o)
b5 = DiscreteBelief(pomdp, [0.4, 0.6])
@test @inferred(mean(b5)) == 0.6
@test @inferred(mode(b5)) == true

# test display of DiscreteBelief
b = DiscreteBelief(MiniHallway(), [0.1, 0.1, 0.123, 0.4, 0, 0, 0, 0, 0.177, 0.05, 0, 0, 0.05])
@test occursin("MiniHallway", sprint(showdistribution, b))
@test occursin("0.123", sprint(showdistribution, b))

# test SparseCat of DiscreteBelief
b_sparse_cat = SparseCat(b)
@test length(b_sparse_cat.vals) == sum(b.b .!= 0.0)
@test isapprox(sum(b_sparse_cat.probs), 1.0; atol=eps())
b_sparse_cat = SparseCat(b; check_zeros=false)
@test length(b_sparse_cat.vals) == length(b.state_list)
@test isapprox(sum(b_sparse_cat.probs), 1.0; atol=eps())
2 changes: 1 addition & 1 deletion lib/POMDPTools/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using POMDPs
using DiscreteValueIteration: ValueIterationSolver
using QuickPOMDPs
using POMDPModels: BabyPOMDP, TigerPOMDP, SimpleGridWorld, LegacyGridWorld, RandomMDP, TMaze, Starve, FeedWhenCrying, RandomPOMDP
using POMDPModels: BabyPOMDP, TigerPOMDP, SimpleGridWorld, LegacyGridWorld, RandomMDP, TMaze, Starve, FeedWhenCrying, RandomPOMDP, MiniHallway
import POMDPLinter

using POMDPTools
Expand Down

0 comments on commit 4aca91a

Please sign in to comment.