From c55138fdceea3a2ef3284acd7ac2658dfe161a03 Mon Sep 17 00:00:00 2001 From: Zachary Sunberg Date: Sun, 22 Oct 2023 11:18:06 -0700 Subject: [PATCH] Fix pomdpmodels 97 (#522) * verified problems started adding tests * fixed --- .../src/CommonRLIntegration/to_env.jl | 6 ---- src/pomdp.jl | 23 +++++++++---- test/runtests.jl | 34 +++++++++++++++---- 3 files changed, 44 insertions(+), 19 deletions(-) diff --git a/lib/POMDPTools/src/CommonRLIntegration/to_env.jl b/lib/POMDPTools/src/CommonRLIntegration/to_env.jl index efa60c2e..74d32fb2 100644 --- a/lib/POMDPTools/src/CommonRLIntegration/to_env.jl +++ b/lib/POMDPTools/src/CommonRLIntegration/to_env.jl @@ -102,9 +102,3 @@ POMDPsCommonRLEnv(m::MDP, s) = MDPCommonRLEnv(m, s) Base.convert(::Type{MDP}, env::MDPCommonRLEnv) = env.m Base.convert(::Type{POMDP}, env::POMDPCommonRLEnv) = env.m - -POMDPs.convert_s(::Type{Any}, s::S, problem::Union{MDP{S},POMDP{S}}) where {S} = s -POMDPs.convert_s(::Type{S}, s, problem::Union{MDP{S},POMDP{S}}) where {S} = convert(S, s) - -POMDPs.convert_o(::Type{Any}, o::O, problem::POMDP{<:Any,<:Any,O}) where {O} = o -POMDPs.convert_o(::Type{O}, o, problem::POMDP{<:Any,<:Any,O}) where {O} = convert(O, o) diff --git a/src/pomdp.jl b/src/pomdp.jl index 4d60b8bb..fd9589ae 100644 --- a/src/pomdp.jl +++ b/src/pomdp.jl @@ -150,10 +150,13 @@ Convert a state to vectorized form or vice versa. """ function convert_s end +convert_s(::Type{Any}, s::S, problem::Union{MDP{S,<:Any},POMDP{S,<:Any,<:Any}}) where {S} = s +convert_s(::Type{S}, s, problem::Union{MDP{S,<:Any},POMDP{S,<:Any,<:Any}}) where {S} = convert(S, s) + convert_s(T::Type{A1}, s::A2, problem::Union{MDP, POMDP}) where {A1<:AbstractArray, A2<:AbstractArray} = convert(T, s) convert_s(::Type{A}, s::Number, problem::Union{MDP,POMDP}) where A<:AbstractArray = convert(A, [s]) -convert_s(::Type{N}, v::AbstractArray{F}, problem::Union{MDP, POMDP}) where {N<:Number, F<:Number} = convert(N, first(v)) +convert_s(::Type{N}, v::AbstractArray{F}, problem::Union{MDP, POMDP}) where {N<:Number, F<:Number} = convert(N, only(v)) """ @@ -164,10 +167,13 @@ Convert an action to vectorized form or vice versa. """ function convert_a end -convert_a(T::Type{A1}, s::A2, problem::Union{MDP, POMDP}) where {A1<:AbstractArray, A2<:AbstractArray} = convert(T, s) +convert_a(::Type{Any}, a::A, problem::Union{MDP{<:Any,A},POMDP{<:Any,A,<:Any}}) where {A} = a +convert_a(::Type{A}, a, problem::Union{MDP{<:Any,A},POMDP{<:Any,A,<:Any}}) where {A} = convert(A, a) + +convert_a(T::Type{A1}, a::A2, problem::Union{MDP, POMDP}) where {A1<:AbstractArray, A2<:AbstractArray} = convert(T, a) -convert_a(::Type{A}, s::Number, problem::Union{MDP,POMDP}) where A<:AbstractArray = convert(A,[s]) -convert_a(::Type{N}, v::AbstractArray{F}, problem::Union{MDP, POMDP}) where {N<:Number, F<:Number} = convert(N, first(v)) +convert_a(::Type{A}, a::Number, problem::Union{MDP,POMDP}) where A<:AbstractArray = convert(A,[a]) +convert_a(::Type{N}, v::AbstractArray{F}, problem::Union{MDP, POMDP}) where {N<:Number, F<:Number} = convert(N, only(v)) """ @@ -178,7 +184,10 @@ Convert an observation to vectorized form or vice versa. """ function convert_o end -convert_o(T::Type{A1}, s::A2, problem::Union{MDP, POMDP}) where {A1<:AbstractArray, A2<:AbstractArray} = convert(T, s) +convert_o(::Type{Any}, o::O, problem::POMDP{<:Any,<:Any,O}) where {O} = o +convert_o(::Type{O}, o, problem::POMDP{<:Any,<:Any,O}) where {O} = convert(O, o) + +convert_o(T::Type{A1}, s::A2, problem::POMDP) where {A1<:AbstractArray, A2<:AbstractArray} = convert(T, s) -convert_o(::Type{A}, s::Number, problem::Union{MDP,POMDP}) where A<:AbstractArray = convert(A, [s]) -convert_o(::Type{N}, v::AbstractArray{F}, problem::Union{MDP, POMDP}) where {N<:Number, F<:Number} = convert(N, first(v)) +convert_o(::Type{A}, s::Number, problem::POMDP) where A<:AbstractArray = convert(A, [s]) +convert_o(::Type{N}, v::AbstractArray{F}, problem::POMDP) where {N<:Number, F<:Number} = convert(N, only(v)) diff --git a/test/runtests.jl b/test/runtests.jl index 0cc3cc80..cd2aa1f0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -31,19 +31,32 @@ struct CV <: POMDP{Vector{Float64},Vector{Float64},Vector{Float64}} end @testset "convert" begin @test convert_s(Vector{Float32}, 1, CI()) == Float32[1.0] - @test convert_s(statetype(CI), Float32[1.0], CI()) == 1 - @test convert_s(statetype(CV), Float32[2.0,3.0], CV()) == [2.0, 3.0] + @test convert_s(statetype(CI()), Float32[1.0], CI()) == 1 + @test convert_s(statetype(CV()), Float32[2.0,3.0], CV()) == [2.0, 3.0] @test convert_s(Vector{Float32}, [2.0, 3.0], CV()) == Float32[2.0, 3.0] + @test convert_s(Any, 1, CI()) == 1 + @test convert_s(Any, [1.], CV()) == [1.] + @test convert_s(statetype(CI()), 1.0, CI()) == 1 + @test convert_s(statetype(CV()), Float32[1.0], CV()) == [1.] @test convert_a(Vector{Float32}, 1, CI()) == Float32[1.0] - @test convert_a(statetype(CI), Float32[1.0], CI()) == 1 - @test convert_a(statetype(CV), Float32[2.0,3.0], CV()) == [2.0, 3.0] + @test convert_a(actiontype(CI()), Float32[1.0], CI()) == 1 + @test convert_a(actiontype(CV()), Float32[2.0,3.0], CV()) == [2.0, 3.0] @test convert_a(Vector{Float32}, [2.0, 3.0], CV()) == Float32[2.0, 3.0] + @test convert_a(Any, 1, CI()) == 1 + @test convert_a(Any, [1.], CV()) == [1.] + @test convert_a(actiontype(CI()), 1.0, CI()) == 1 + @test convert_a(actiontype(CV()), Float32[1.0], CV()) == [1.] + @test convert_o(Vector{Float32}, 1, CI()) == Float32[1.0] - @test convert_o(statetype(CI), Float32[1.0], CI()) == 1 - @test convert_o(statetype(CV), Float32[2.0,3.0], CV()) == [2.0, 3.0] + @test convert_o(obstype(CI()), Float32[1.0], CI()) == 1 + @test convert_o(obstype(CV()), Float32[2.0,3.0], CV()) == [2.0, 3.0] @test convert_o(Vector{Float32}, [2.0, 3.0], CV()) == Float32[2.0, 3.0] + @test convert_o(Any, 1, CI()) == 1 + @test convert_o(Any, [1.], CV()) == [1.] + @test convert_o(obstype(CI()), 1.0, CI()) == 1 + @test convert_o(obstype(CV()), Float32[1.0], CV()) == [1.] end struct EA <: POMDP{Int, Int, Int} end @@ -54,3 +67,12 @@ struct EB <: POMDP{Int, Int, Int} end @test history(4)[end][:o] == 4 @test currentobs(4) == 4 end + +@testset "Issues" begin + @testset "POMDPModels Issue #97" begin + struct ModelsIssue97POMDP <: POMDP{Bool, Bool, Bool} end + m = ModelsIssue97POMDP() + @test convert_o(Bool, [1.], m) == true + @test convert_s(Bool, [1.], m) == true + end +end