Skip to content

Commit

Permalink
Extend EpsGreedyPolicy by internal policy field
Browse files Browse the repository at this point in the history
  • Loading branch information
johannes-fischer committed Jul 12, 2023
1 parent 7dcbd58 commit dcd8e1d
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 15 deletions.
3 changes: 2 additions & 1 deletion lib/POMDPTools/src/Policies/Policies.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ export LinearDecaySchedule,
EpsGreedyPolicy,
SoftmaxPolicy,
ExplorationPolicy,
loginfo
loginfo,
update!

include("exploration_policies.jl")

Expand Down
40 changes: 31 additions & 9 deletions lib/POMDPTools/src/Policies/exploration_policies.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,38 +47,60 @@ The evolution of epsilon can be controlled using a schedule. This feature is use
If a function is passed for `eps`, `eps(k)` is called to compute the value of epsilon when calling `action(exploration_policy, on_policy, k, s)`.
# Fields
# Fields
- `eps::Function`
- `rng::AbstractRNG`
- `m::M` POMDPs or MDPs problem
- `on_policy::P` a policy to use for the greedy part
- `k::Int` the current training step to use for computing eps(k)
"""
struct EpsGreedyPolicy{T<:Function, R<:AbstractRNG, M<:Union{MDP,POMDP}} <: ExplorationPolicy
mutable struct EpsGreedyPolicy{P<:Union{Nothing,Policy},T<:Function,R<:AbstractRNG,M<:Union{MDP,POMDP}} <: ExplorationPolicy
on_policy::P
k::Int
eps::T
rng::R
m::M
end

function EpsGreedyPolicy(problem::Union{MDP,POMDP}, eps::Function;
function EpsGreedyPolicy(problem::Union{MDP,POMDP}, eps::Function;
rng::AbstractRNG=Random.default_rng())
return EpsGreedyPolicy(eps, rng, problem)
return EpsGreedyPolicy(nothing, 1, eps, rng, problem)
end
function EpsGreedyPolicy(problem::Union{MDP,POMDP}, eps::Real;
function EpsGreedyPolicy(problem::Union{MDP,POMDP}, eps::Real;
rng::AbstractRNG=Random.default_rng())
return EpsGreedyPolicy(x->eps, rng, problem)
return EpsGreedyPolicy(problem, x -> eps, rng=rng)
end
function EpsGreedyPolicy(problem::Union{MDP,POMDP}, on_policy::Policy, eps::Function;
k::Int=1, rng::AbstractRNG=Random.default_rng())
return EpsGreedyPolicy(on_policy, k, eps, rng, problem)
end
function EpsGreedyPolicy(problem::Union{MDP,POMDP}, on_policy::Policy, eps::Real;
k::Int=1, rng::AbstractRNG=Random.default_rng())
return EpsGreedyPolicy(problem, on_policy, x -> eps, k=k, rng=rng)
end


function POMDPs.action(p::EpsGreedyPolicy, on_policy::Policy, k, s)
if rand(p.rng) < p.eps(k)
return rand(p.rng, actions(p.m,s))
else
else
return action(on_policy, s)
end
end
POMDPs.action(p::EpsGreedyPolicy{<:Policy}, s) = action(p, p.on_policy, p.k, s)

loginfo(p::EpsGreedyPolicy, k) = (eps=p.eps(k),)
loginfo(p::EpsGreedyPolicy) = loginfo(p, p.k)

function update!(p::EpsGreedyPolicy, k::Int)
p.k = k
return p
end
function update!(p::EpsGreedyPolicy{P}, on_policy::P) where {P<:Policy}
p.on_policy = on_policy
return p
end

# softmax
"""
Expand Down
16 changes: 11 additions & 5 deletions lib/POMDPTools/test/policies/test_exploration_policies.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,24 @@ a = first(actions(problem))
@inferred action(policy, FunctionPolicy(s->a::Symbol), 1, GWPos(1,1))
policy = EpsGreedyPolicy(problem, 0.0)
@test action(policy, FunctionPolicy(s->a), 1, GWPos(1,1)) == a
policy = EpsGreedyPolicy(problem, FunctionPolicy(s->a), 0.0)
@test action(policy, GWPos(1,1)) == a

# softmax
# softmax
policy = SoftmaxPolicy(problem, 0.5)
@test loginfo(policy, 1).temperature == 0.5
on_policy = ValuePolicy(problem)
@inferred action(policy, on_policy, 1, GWPos(1,1))

# test linear schedule
policy = EpsGreedyPolicy(problem, LinearDecaySchedule(start=1.0, stop=0.0, steps=10))
for i=1:11
# test linear schedule
schedule = LinearDecaySchedule(start=1.0, stop=0.0, steps=10)
policy = EpsGreedyPolicy(problem, FunctionPolicy(s->a), schedule)
for i=1:11
action(policy, FunctionPolicy(s->a), i, GWPos(1,1))
@test policy.eps(i) < 1.0
@test policy.eps(i) < 1.0
@test loginfo(policy, i).eps == policy.eps(i)
end
@test policy.eps(11) 0.0
update!(policy, 11)
@test policy.eps(policy.k) 0.0
@test action(policy, FunctionPolicy(s->a), 11, GWPos(1,1)) == action(policy, GWPos(1,1))

0 comments on commit dcd8e1d

Please sign in to comment.