forked from JuliaPOMDP/POMDPModels.jl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathMiniHallway.jl
76 lines (62 loc) · 2.41 KB
/
MiniHallway.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# Mini Hallway problem defined in http://cs.brown.edu/research/ai/pomdp/examples/mini-hall2.POMDP.
# Original idea published in Littman, Cassandra and Kaelbling's ML-95 paper.
# Basic parameters are:
# discount: 0.950000
# values: reward
# states: 13
# actions: 3
# observations: 9
# The rest is available at link
struct MiniHallway <: POMDP{Int, Int, Int}
T::Array{Deterministic{Int}, 1}
end
function MiniHallway()
T = Array{Deterministic{Int}, 1}(undef, 13)
# Transitions for action 1 (and all actions in state 13) as I did not find a function for it
T[1] = Deterministic(1); T[2] = Deterministic(2); T[3] = Deterministic(7);
T[4] = Deterministic(4); T[5] = Deterministic(1); T[6] = Deterministic(10);
T[8] = Deterministic(8); T[7] = Deterministic(1); T[9] = Deterministic(9);
T[10] = Deterministic(10); T[11] = Deterministic(13); T[12] = Deterministic(8);
T[13] = Deterministic(13)
return MiniHallway(T)
end
##################
# mdps interface #
##################
POMDPs.states(m::MiniHallway) = 1:13
POMDPs.stateindex(m::MiniHallway, ss::Int)::Int = ss
POMDPs.isterminal(m::MiniHallway, ss::Int)::Bool = ss == 13
function POMDPs.transition(m::MiniHallway, ss::Int, a::Int)
if a == 1 || ss == 13
return m.T[ss]
elseif a == 2
return ss % 4 == 0 ? Deterministic(ss - 3) : Deterministic(ss + 1)
else #a == 3
return (ss - 1) % 4 == 0 ? Deterministic(ss + 3) : Deterministic(ss - 1)
end
end
POMDPs.actions(m::MiniHallway) = 1:3
POMDPs.actionindex(m::MiniHallway, a::Int)::Int = a
POMDPs.reward(m::MiniHallway, ss::Int, a::Int, sp::Int) = float(ss != sp && sp == 13)
POMDPs.reward(m::MiniHallway, ss::Int, a::Int) = mean_reward(m, ss, a)
POMDPs.discount(m::MiniHallway)::Float64 = 0.95
####################
# pomdps interface #
####################
POMDPs.initialstate(m::MiniHallway) = SparseCat(1:13, append!(fill(1/12, 12), 0.))
POMDPs.observations(m::MiniHallway) = 1:9
function POMDPs.observation(m::MiniHallway, a::Int, sp::Int)::Deterministic
if sp <= 8
return Deterministic(sp)
elseif sp <= 10
return Deterministic(sp - 2)
elseif sp <= 12
return Deterministic(sp - 6)
else
return Deterministic(9)
end
return m.O[sp]
end
POMDPs.observation(m::MiniHallway, s::Int, a::Int, sp::Int)::Deterministic = observation(m, a, sp)
POMDPs.obsindex(m::MiniHallway, o::Int)::Int = o
Base.broadcastable(m::MiniHallway) = Ref(m)