diff --git a/lib/POMDPTools/src/Policies/Policies.jl b/lib/POMDPTools/src/Policies/Policies.jl index 409637cc..33d02ccb 100644 --- a/lib/POMDPTools/src/Policies/Policies.jl +++ b/lib/POMDPTools/src/Policies/Policies.jl @@ -68,7 +68,8 @@ export LinearDecaySchedule, EpsGreedyPolicy, SoftmaxPolicy, ExplorationPolicy, - loginfo + loginfo, + update! include("exploration_policies.jl") diff --git a/lib/POMDPTools/src/Policies/exploration_policies.jl b/lib/POMDPTools/src/Policies/exploration_policies.jl index 438d22a3..f64bb6f9 100644 --- a/lib/POMDPTools/src/Policies/exploration_policies.jl +++ b/lib/POMDPTools/src/Policies/exploration_policies.jl @@ -47,38 +47,60 @@ The evolution of epsilon can be controlled using a schedule. This feature is use If a function is passed for `eps`, `eps(k)` is called to compute the value of epsilon when calling `action(exploration_policy, on_policy, k, s)`. - -# Fields + +# Fields - `eps::Function` - `rng::AbstractRNG` - `m::M` POMDPs or MDPs problem +- `on_policy::P` a policy to use for the greedy part +- `k::Int` the current training step to use for computing eps(k) """ -struct EpsGreedyPolicy{T<:Function, R<:AbstractRNG, M<:Union{MDP,POMDP}} <: ExplorationPolicy +mutable struct EpsGreedyPolicy{P<:Union{Nothing,Policy},T<:Function,R<:AbstractRNG,M<:Union{MDP,POMDP}} <: ExplorationPolicy + on_policy::P + k::Int eps::T rng::R m::M end -function EpsGreedyPolicy(problem::Union{MDP,POMDP}, eps::Function; +function EpsGreedyPolicy(problem::Union{MDP,POMDP}, eps::Function; rng::AbstractRNG=Random.default_rng()) - return EpsGreedyPolicy(eps, rng, problem) + return EpsGreedyPolicy(nothing, 1, eps, rng, problem) end -function EpsGreedyPolicy(problem::Union{MDP,POMDP}, eps::Real; +function EpsGreedyPolicy(problem::Union{MDP,POMDP}, eps::Real; rng::AbstractRNG=Random.default_rng()) - return EpsGreedyPolicy(x->eps, rng, problem) + return EpsGreedyPolicy(problem, x -> eps, rng=rng) +end +function EpsGreedyPolicy(problem::Union{MDP,POMDP}, on_policy::Policy, eps::Function; + k::Int=1, rng::AbstractRNG=Random.default_rng()) + return EpsGreedyPolicy(on_policy, k, eps, rng, problem) +end +function EpsGreedyPolicy(problem::Union{MDP,POMDP}, on_policy::Policy, eps::Real; + k::Int=1, rng::AbstractRNG=Random.default_rng()) + return EpsGreedyPolicy(problem, on_policy, x -> eps, k=k, rng=rng) end - function POMDPs.action(p::EpsGreedyPolicy, on_policy::Policy, k, s) if rand(p.rng) < p.eps(k) return rand(p.rng, actions(p.m,s)) - else + else return action(on_policy, s) end end +POMDPs.action(p::EpsGreedyPolicy{<:Policy}, s) = action(p, p.on_policy, p.k, s) loginfo(p::EpsGreedyPolicy, k) = (eps=p.eps(k),) +loginfo(p::EpsGreedyPolicy) = loginfo(p, p.k) + +function update!(p::EpsGreedyPolicy, k::Int) + p.k = k + return p +end +function update!(p::EpsGreedyPolicy{P}, on_policy::P) where {P<:Policy} + p.on_policy = on_policy + return p +end # softmax """ diff --git a/lib/POMDPTools/test/policies/test_exploration_policies.jl b/lib/POMDPTools/test/policies/test_exploration_policies.jl index dae36a22..120b51f0 100644 --- a/lib/POMDPTools/test/policies/test_exploration_policies.jl +++ b/lib/POMDPTools/test/policies/test_exploration_policies.jl @@ -7,18 +7,24 @@ a = first(actions(problem)) @inferred action(policy, FunctionPolicy(s->a::Symbol), 1, GWPos(1,1)) policy = EpsGreedyPolicy(problem, 0.0) @test action(policy, FunctionPolicy(s->a), 1, GWPos(1,1)) == a +policy = EpsGreedyPolicy(problem, FunctionPolicy(s->a), 0.0) +@test action(policy, GWPos(1,1)) == a -# softmax +# softmax policy = SoftmaxPolicy(problem, 0.5) @test loginfo(policy, 1).temperature == 0.5 on_policy = ValuePolicy(problem) @inferred action(policy, on_policy, 1, GWPos(1,1)) -# test linear schedule -policy = EpsGreedyPolicy(problem, LinearDecaySchedule(start=1.0, stop=0.0, steps=10)) -for i=1:11 +# test linear schedule +schedule = LinearDecaySchedule(start=1.0, stop=0.0, steps=10) +policy = EpsGreedyPolicy(problem, FunctionPolicy(s->a), schedule) +for i=1:11 action(policy, FunctionPolicy(s->a), i, GWPos(1,1)) - @test policy.eps(i) < 1.0 + @test policy.eps(i) < 1.0 @test loginfo(policy, i).eps == policy.eps(i) end @test policy.eps(11) ≈ 0.0 +update!(policy, 11) +@test policy.eps(policy.k) ≈ 0.0 +@test action(policy, FunctionPolicy(s->a), 11, GWPos(1,1)) == action(policy, GWPos(1,1))