Skip to content

Commit

Permalink
worked on docs for GBMDP
Browse files Browse the repository at this point in the history
  • Loading branch information
zsunberg committed Jul 12, 2024
1 parent ac2ee88 commit 2d7f0cf
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 8 deletions.
27 changes: 21 additions & 6 deletions lib/POMDPTools/src/ModelTools/generative_belief_mdp.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
"""
GenerativeBeliefMDP(pomdp, updater)
GenerativeBeliefMDP(pomdp, updater, terminal_behavior)
GenerativeBeliefMDP(pomdp, updater; terminal_behavior=TerminalStateTerminalBehavior())
Create a generative model of the belief MDP corresponding to POMDP `pomdp` with belief updates performed by `updater`.
Create a generative model of the belief MDP corresponding to POMDP `pomdp` with belief updates performed by `updater`. Each step is performed by sampling a state from the current belief, generating an observation from that state and action, and then using `updater` to update the belief.
A belief is considered terminal when _all_ POMDP states in the support with nonzero probability are terminal.
The default behavior when a terminal POMDP state is sampled from the belief is to transition to [`terminalstate`](@ref). This can be controlled by the `terminal_behavior` keyword argument. Using `terminal_behavior=ContinueTerminalBehavior(pomdp, updater)` will cause the MDP to keep attempting a belief update even when the sampled state is terminal. This can be further customized by providing `terminal_behavior` with a `Function` or callable object that takes arguments b, s, a, rng and returns a new belief (see the implementation of `ContinueTerminalBehavior` for an example).
"""
struct GenerativeBeliefMDP{P<:POMDP, U<:Updater, T, B, A} <: MDP{B, A}
pomdp::P
updater::U
terminal_behavior::T
end

function GenerativeBeliefMDP(pomdp, updater; terminal_behavior=DefaultGBMDPTerminalBehavior(pomdp, updater))
function GenerativeBeliefMDP(pomdp, updater; terminal_behavior=BackwardCompatibleTerminalBehavior(pomdp, updater))
B = determine_gbmdp_state_type(pomdp, updater, terminal_behavior)
GenerativeBeliefMDP{typeof(pomdp),
typeof(updater),
Expand Down Expand Up @@ -43,19 +47,30 @@ isterminal(bmdp::GenerativeBeliefMDP, ts::TerminalState) = true

discount(bmdp::GenerativeBeliefMDP) = discount(bmdp.pomdp)

"""
determine_gbmdp_state_type(pomdp, updater, [terminal_behavior])
This function is called to determine the state type for a GenerativeBeliefMDP. By default, it will return typeof(initialize_belief(updater, initialstate(pomdp))).
If a belief updater may use a belief type different from the output of initialize_belief, for example if the belief type can change after an update, override `determine_gbmdp_state_type(pomdp, updater)`.
If the terminal behavior adds a new possible state type, override `determine_gbmdp_state_type(pomdp, updater, terminal_behavior)` to return the `Union` of the new state type and the output of `determine_gbmdp_state_type(pomdp, updater)`
"""
function determine_gbmdp_state_type end # for documentation

function determine_gbmdp_state_type(pomdp, updater)
b0 = initialize_belief(updater, initialstate(pomdp))
return typeof(b0)
end

determine_gbmdp_state_type(pomdp, updater, terminal_behavior) = determine_gbmdp_state_type(pomdp, updater)

struct DefaultGBMDPTerminalBehavior{M, U}
struct BackwardCompatibleTerminalBehavior{M, U}
pomdp::M
updater::U
end

function (tb::DefaultGBMDPTerminalBehavior)(b, s, a, rng)
function (tb::BackwardCompatibleTerminalBehavior)(b, s, a, rng)

# This code block is only to handle backwards compatibility for the deprecated gbmdp_handle_terminal function
bp = gbmdp_handle_terminal(tb.pomdp, tb.updater, b, s, a, rng)
Expand All @@ -67,7 +82,7 @@ function (tb::DefaultGBMDPTerminalBehavior)(b, s, a, rng)
return TerminalStateTerminalBehavior()(b, s, a, rng)
end

determine_gbmdp_state_type(pomdp, updater, tb::DefaultGBMDPTerminalBehavior) = determine_gbmdp_state_type(pomdp, updater, TerminalStateTerminalBehavior())
determine_gbmdp_state_type(pomdp, updater, tb::BackwardCompatibleTerminalBehavior) = determine_gbmdp_state_type(pomdp, updater, TerminalStateTerminalBehavior())

struct ContinueTerminalBehavior{M, U}
pomdp::M
Expand Down
4 changes: 2 additions & 2 deletions lib/POMDPTools/src/ModelTools/terminal_state.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ struct TerminalState end
"""
terminalstate
The singleton instance of type `TerminalState` representing a terminal state.
The singleton instance of type [`TerminalState`](@ref) representing a terminal state.
"""
const terminalstate = TerminalState()

isterminal(m::Union{MDP,POMDP}, ts::TerminalState) = true
isterminal(m::Union{MDP,POMDP}, ts::TerminalState) = true
Base.promote_rule(::Type{TerminalState}, T::Type) = Union{TerminalState, T}

0 comments on commit 2d7f0cf

Please sign in to comment.