Skip to content

Commit

Permalink
Fix pomdpmodels 97 (#522)
Browse files Browse the repository at this point in the history
* verified problems started adding tests

* fixed
  • Loading branch information
zsunberg authored Oct 22, 2023
1 parent aa84846 commit c55138f
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 19 deletions.
6 changes: 0 additions & 6 deletions lib/POMDPTools/src/CommonRLIntegration/to_env.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,3 @@ POMDPsCommonRLEnv(m::MDP, s) = MDPCommonRLEnv(m, s)

Base.convert(::Type{MDP}, env::MDPCommonRLEnv) = env.m
Base.convert(::Type{POMDP}, env::POMDPCommonRLEnv) = env.m

POMDPs.convert_s(::Type{Any}, s::S, problem::Union{MDP{S},POMDP{S}}) where {S} = s
POMDPs.convert_s(::Type{S}, s, problem::Union{MDP{S},POMDP{S}}) where {S} = convert(S, s)

POMDPs.convert_o(::Type{Any}, o::O, problem::POMDP{<:Any,<:Any,O}) where {O} = o
POMDPs.convert_o(::Type{O}, o, problem::POMDP{<:Any,<:Any,O}) where {O} = convert(O, o)
23 changes: 16 additions & 7 deletions src/pomdp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,13 @@ Convert a state to vectorized form or vice versa.
"""
function convert_s end

convert_s(::Type{Any}, s::S, problem::Union{MDP{S,<:Any},POMDP{S,<:Any,<:Any}}) where {S} = s
convert_s(::Type{S}, s, problem::Union{MDP{S,<:Any},POMDP{S,<:Any,<:Any}}) where {S} = convert(S, s)

convert_s(T::Type{A1}, s::A2, problem::Union{MDP, POMDP}) where {A1<:AbstractArray, A2<:AbstractArray} = convert(T, s)

convert_s(::Type{A}, s::Number, problem::Union{MDP,POMDP}) where A<:AbstractArray = convert(A, [s])
convert_s(::Type{N}, v::AbstractArray{F}, problem::Union{MDP, POMDP}) where {N<:Number, F<:Number} = convert(N, first(v))
convert_s(::Type{N}, v::AbstractArray{F}, problem::Union{MDP, POMDP}) where {N<:Number, F<:Number} = convert(N, only(v))


"""
Expand All @@ -164,10 +167,13 @@ Convert an action to vectorized form or vice versa.
"""
function convert_a end

convert_a(T::Type{A1}, s::A2, problem::Union{MDP, POMDP}) where {A1<:AbstractArray, A2<:AbstractArray} = convert(T, s)
convert_a(::Type{Any}, a::A, problem::Union{MDP{<:Any,A},POMDP{<:Any,A,<:Any}}) where {A} = a
convert_a(::Type{A}, a, problem::Union{MDP{<:Any,A},POMDP{<:Any,A,<:Any}}) where {A} = convert(A, a)

convert_a(T::Type{A1}, a::A2, problem::Union{MDP, POMDP}) where {A1<:AbstractArray, A2<:AbstractArray} = convert(T, a)

convert_a(::Type{A}, s::Number, problem::Union{MDP,POMDP}) where A<:AbstractArray = convert(A,[s])
convert_a(::Type{N}, v::AbstractArray{F}, problem::Union{MDP, POMDP}) where {N<:Number, F<:Number} = convert(N, first(v))
convert_a(::Type{A}, a::Number, problem::Union{MDP,POMDP}) where A<:AbstractArray = convert(A,[a])
convert_a(::Type{N}, v::AbstractArray{F}, problem::Union{MDP, POMDP}) where {N<:Number, F<:Number} = convert(N, only(v))


"""
Expand All @@ -178,7 +184,10 @@ Convert an observation to vectorized form or vice versa.
"""
function convert_o end

convert_o(T::Type{A1}, s::A2, problem::Union{MDP, POMDP}) where {A1<:AbstractArray, A2<:AbstractArray} = convert(T, s)
convert_o(::Type{Any}, o::O, problem::POMDP{<:Any,<:Any,O}) where {O} = o
convert_o(::Type{O}, o, problem::POMDP{<:Any,<:Any,O}) where {O} = convert(O, o)

convert_o(T::Type{A1}, s::A2, problem::POMDP) where {A1<:AbstractArray, A2<:AbstractArray} = convert(T, s)

convert_o(::Type{A}, s::Number, problem::Union{MDP,POMDP}) where A<:AbstractArray = convert(A, [s])
convert_o(::Type{N}, v::AbstractArray{F}, problem::Union{MDP, POMDP}) where {N<:Number, F<:Number} = convert(N, first(v))
convert_o(::Type{A}, s::Number, problem::POMDP) where A<:AbstractArray = convert(A, [s])
convert_o(::Type{N}, v::AbstractArray{F}, problem::POMDP) where {N<:Number, F<:Number} = convert(N, only(v))
34 changes: 28 additions & 6 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,32 @@ struct CV <: POMDP{Vector{Float64},Vector{Float64},Vector{Float64}} end

@testset "convert" begin
@test convert_s(Vector{Float32}, 1, CI()) == Float32[1.0]
@test convert_s(statetype(CI), Float32[1.0], CI()) == 1
@test convert_s(statetype(CV), Float32[2.0,3.0], CV()) == [2.0, 3.0]
@test convert_s(statetype(CI()), Float32[1.0], CI()) == 1
@test convert_s(statetype(CV()), Float32[2.0,3.0], CV()) == [2.0, 3.0]
@test convert_s(Vector{Float32}, [2.0, 3.0], CV()) == Float32[2.0, 3.0]
@test convert_s(Any, 1, CI()) == 1
@test convert_s(Any, [1.], CV()) == [1.]
@test convert_s(statetype(CI()), 1.0, CI()) == 1
@test convert_s(statetype(CV()), Float32[1.0], CV()) == [1.]

@test convert_a(Vector{Float32}, 1, CI()) == Float32[1.0]
@test convert_a(statetype(CI), Float32[1.0], CI()) == 1
@test convert_a(statetype(CV), Float32[2.0,3.0], CV()) == [2.0, 3.0]
@test convert_a(actiontype(CI()), Float32[1.0], CI()) == 1
@test convert_a(actiontype(CV()), Float32[2.0,3.0], CV()) == [2.0, 3.0]
@test convert_a(Vector{Float32}, [2.0, 3.0], CV()) == Float32[2.0, 3.0]
@test convert_a(Any, 1, CI()) == 1
@test convert_a(Any, [1.], CV()) == [1.]
@test convert_a(actiontype(CI()), 1.0, CI()) == 1
@test convert_a(actiontype(CV()), Float32[1.0], CV()) == [1.]


@test convert_o(Vector{Float32}, 1, CI()) == Float32[1.0]
@test convert_o(statetype(CI), Float32[1.0], CI()) == 1
@test convert_o(statetype(CV), Float32[2.0,3.0], CV()) == [2.0, 3.0]
@test convert_o(obstype(CI()), Float32[1.0], CI()) == 1
@test convert_o(obstype(CV()), Float32[2.0,3.0], CV()) == [2.0, 3.0]
@test convert_o(Vector{Float32}, [2.0, 3.0], CV()) == Float32[2.0, 3.0]
@test convert_o(Any, 1, CI()) == 1
@test convert_o(Any, [1.], CV()) == [1.]
@test convert_o(obstype(CI()), 1.0, CI()) == 1
@test convert_o(obstype(CV()), Float32[1.0], CV()) == [1.]
end

struct EA <: POMDP{Int, Int, Int} end
Expand All @@ -54,3 +67,12 @@ struct EB <: POMDP{Int, Int, Int} end
@test history(4)[end][:o] == 4
@test currentobs(4) == 4
end

@testset "Issues" begin
@testset "POMDPModels Issue #97" begin
struct ModelsIssue97POMDP <: POMDP{Bool, Bool, Bool} end
m = ModelsIssue97POMDP()
@test convert_o(Bool, [1.], m) == true
@test convert_s(Bool, [1.], m) == true
end
end

0 comments on commit c55138f

Please sign in to comment.