Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added belief space binning #18

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ LinearAlgebra = "1"
POMDPTools = "0.1, 1"
POMDPs = "0.9, 1"
Printf = "1"
SparseArrays = "1"
julia = "1.7"
9 changes: 7 additions & 2 deletions src/sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,13 @@ function sample_points(sol::SARSOPSolver, tree::SARSOPTree, b_idx::Int, L, U, t,
V̲, V̄ = tree.V_lower[b_idx], tree.V_upper[b_idx]
γ = discount(tree)

V̂ = V̄ #TODO: BAD, binning method
if V̂ ≤ V̲ + sol.kappa*ϵ*γ^(-t) || (V̂ ≤ L && V̄ ≤ max(U, V̲ + ϵ*γ^(-t)))
if sol.use_binning
V̂ = get_bin_value(tree, b_idx)
dylan-asmar marked this conversation as resolved.
Show resolved Hide resolved
else
V̂ = V̄
end

if V̄ ≤ V̲ + sol.kappa*ϵ*γ^(-t) || (V̂ ≤ L && V̄ ≤ max(U, V̲ + ϵ*γ^(-t)))
return
else
Q̲, Q̄, a′ = max_r_and_q(tree, b_idx)
Expand Down
3 changes: 2 additions & 1 deletion src/solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Base.@kwdef struct SARSOPSolver{LOW,UP} <: Solver
init_lower::LOW = BlindLowerBound(bel_res = 1e-2)
init_upper::UP = FastInformedBound(bel_res=1e-2)
prunethresh::Float64= 0.10
use_binning::Bool = true
end

function POMDPTools.solve_info(solver::SARSOPSolver, pomdp::POMDP)
Expand All @@ -20,7 +21,7 @@ function POMDPTools.solve_info(solver::SARSOPSolver, pomdp::POMDP)

t0 = time()
iter = 0
while time()-t0 < solver.max_time && root_diff(tree) > solver.precision
while iter <= solver.max_steps && time()-t0 < solver.max_time && root_diff(tree) > solver.precision
sample!(solver, tree)
backup!(tree)
prune!(solver, tree)
Expand Down
190 changes: 188 additions & 2 deletions src/tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,60 @@ mutable struct PruneData
prune_threshold::Float64
end

struct BinManager
lowest_ub::Float64
num_levels::Int
num_of_bins_per_level::Vector{Int}
bin_levels_intervals::Vector{Dict{Symbol, Float64}}
bin_levels_nodes::Vector{Dict{Int, Dict{Symbol, Union{Tuple{Int, Int}, Float64}}}}
bin_levels::Vector{Dict{Symbol, Dict{Tuple{Int, Int}, Union{Float64, Int}}}}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These vectors of dictionaries (nested with further dictionaries) are really expensive in terms of performance. Also, the union typing results in some type instabilities.

Provided that a lot of the keys for these dictionaries are always the same (e.g. bin_value, bin_count, bin_error), can you find any way to flatten everything out into vectors, use fewer dictionaries, or just use structs/named tuples?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. I wasn't a big fan of the dictionary implementation when I was doing it. Reworking as vectors wouldn't be too bad since the number of bins stays the same. I can mark this as WIP until I make these changes if desired.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the implementation with 3762018. In the process, I also found a bug with the dictionaries. With the binning, we are back down to the normal level of allocations vs the original implementation.

I updated the performance comparisons and added results for both delta=0.1 and delta=0.0001.

previous_lowerbound::Dict{Int, Float64}
end

function BinManager(Vs_upper::Vector{Float64}, num_bins_per_level=[5, 10])
num_levels = length(num_bins_per_level)
lowest_ub = minimum(Vs_upper)
highest_ub = maximum(Vs_upper)

# [level][:ub|:entropy] => value
bin_levels_intervals = [Dict{Symbol, Float64}() for _ in 1:num_levels]

# [level][b_idx][:key|:prev_error] => (ub_interval_idx, entropy_interval_idx)|previous_error
bin_levels_nodes = Vector{Dict{Int, Dict{Symbol, Union{Tuple{Int, Int}, Float64}}}}(undef, num_levels)

# [level][:bin_value|:bin_count|:bin_error][(ub_interval_idx, entropy_interval_idx)] => Float64|Int|Float64
bin_levels = Vector{Dict{Symbol, Dict{Tuple{Int, Int}, Union{Float64, Int}}}}(undef, num_levels)

num_states = length(Vs_upper)
max_e = max_entropy(num_states)
for level_i in 1:num_levels
num_bins = num_bins_per_level[level_i]

bin_levels_intervals[level_i][:ub] = (highest_ub - lowest_ub) / num_bins
bin_levels_intervals[level_i][:entropy] = max_e / num_bins

bin_levels_nodes[level_i] = Dict{Int, Dict{Symbol, Union{Tuple{Int, Int}, Float64}}}()

level = Dict{Symbol, Dict{Tuple{Int, Int}, Union{Float64, Int}}}()
level[:bin_value] = Dict{Tuple{Int, Int}, Float64}()
level[:bin_count] = Dict{Tuple{Int, Int}, Int}()
level[:bin_error] = Dict{Tuple{Int, Int}, Float64}()
bin_levels[level_i] = level
end

previous_lowerbound = Dict{Int, Float64}() # b_idx => lowerbound

return BinManager(
lowest_ub,
num_levels,
num_bins_per_level,
bin_levels_intervals,
bin_levels_nodes,
bin_levels,
previous_lowerbound
)
end

struct SARSOPTree
pomdp::ModifiedSparseTabular

Expand Down Expand Up @@ -32,16 +86,21 @@ struct SARSOPTree
prune_data::PruneData

Γ::Vector{AlphaVec{Int}}

use_binning::Bool
bm::BinManager
end


function SARSOPTree(solver, pomdp::POMDP)
function SARSOPTree(solver, pomdp::POMDP; num_bins_per_level=[5, 10])
sparse_pomdp = ModifiedSparseTabular(pomdp)
cache = TreeCache(sparse_pomdp)

upper_policy = solve(solver.init_upper, sparse_pomdp)
corner_values = map(maximum, zip(upper_policy.alphas...))

bin_manager = BinManager(corner_values, num_bins_per_level)

tree = SARSOPTree(
sparse_pomdp,

Expand All @@ -64,7 +123,9 @@ function SARSOPTree(solver, pomdp::POMDP)
BitVector(),
cache,
PruneData(0,0,solver.prunethresh),
AlphaVec{Int}[]
AlphaVec{Int}[],
solver.use_binning,
bin_manager
)
return insert_root!(solver, tree, _initialize_belief(pomdp, initialstate(pomdp)))
end
Expand Down Expand Up @@ -163,6 +224,9 @@ function fill_belief!(tree::SARSOPTree, b_idx::Int)
else
fill_populated!(tree, b_idx)
end
if tree.use_binning
update_bin_node!(tree, b_idx)
end
end

"""
Expand Down Expand Up @@ -249,3 +313,125 @@ function fill_unpopulated!(tree::SARSOPTree, b_idx::Int)
tree.V_lower[b_idx] = lower_value(tree, tree.b[b_idx])
tree.V_upper[b_idx] = maximum(tree.Qa_upper[b_idx])
end

function initialize_bin_node!(tree::SARSOPTree, b_idx::Int)
lb_val = tree.V_lower[b_idx]
ub_val = tree.V_upper[b_idx]
node_entropy = entropy(tree.b[b_idx])

for level_i in 1:tree.bm.num_levels
ub_interval_idx = get_interval_idx(
ub_val, tree.bm.lowest_ub, tree.bm.bin_levels_intervals[level_i][:ub],
tree.bm.num_of_bins_per_level[level_i]
)

entropy_interval_idx = get_interval_idx(
node_entropy, 0.0, tree.bm.bin_levels_intervals[level_i][:entropy],
tree.bm.num_of_bins_per_level[level_i]
)

key = (ub_interval_idx, entropy_interval_idx)

if !haskey(tree.bm.bin_levels_nodes[level_i], b_idx)
tree.bm.bin_levels_nodes[level_i] = Dict(b_idx => Dict(:key => key))
end

bin_count = get(tree.bm.bin_levels[level_i][:bin_count], key, 0)
if bin_count > 0
err = tree.bm.bin_levels[level_i][:bin_value][key] - lb_val
tree.bm.bin_levels_nodes[level_i][b_idx][:prev_error] = err * err
tree.bm.bin_levels[level_i][:bin_value][key] += err * err

value = (tree.bm.bin_levels[level_i][:bin_value][key] * bin_count + lb_val) / (bin_count + 1)
tree.bm.bin_levels[level_i][:bin_count][key] += 1
else
err = ub_val - lb_val
tree.bm.bin_levels_nodes[level_i][b_idx][:prev_error] = err * err
tree.bm.bin_levels[level_i][:bin_error][key] = err * err

value = lb_val
tree.bm.bin_levels[level_i][:bin_count] = Dict(key => bin_count + 1)
end
tree.bm.bin_levels[level_i][:bin_value][key] = value
end
tree.bm.previous_lowerbound[b_idx] = lb_val
end

function update_bin_node!(tree::SARSOPTree, b_idx::Int)
lb_val = tree.V_lower[b_idx]
up_val = tree.V_upper[b_idx]

if !haskey(tree.bm.bin_levels_nodes[1], b_idx)
return initialize_bin_node!(tree, b_idx)
end

for level_i in 1:tree.bm.num_levels
key = tree.bm.bin_levels_nodes[level_i][b_idx][:key]

bin_count = get(tree.bm.bin_levels[level_i][:bin_count], key, 0)
if bin_count == 1
err = up_val - lb_val
tree.bm.bin_levels_nodes[level_i][b_idx][:prev_error] = err * err
tree.bm.bin_levels[level_i][:bin_error][key] = err * err
else
err = tree.bm.bin_levels[level_i][:bin_value][key] - lb_val
tree.bm.bin_levels[level_i][:bin_error][key] -= tree.bm.bin_levels_nodes[level_i][b_idx][:prev_error]
tree.bm.bin_levels_nodes[level_i][b_idx][:prev_error] = err * err
tree.bm.bin_levels[level_i][:bin_error][key] += err * err
end

tree.bm.bin_levels[level_i][:bin_value][key] = (tree.bm.bin_levels[level_i][:bin_value][key] * bin_count + lb_val - tree.bm.previous_lowerbound[b_idx]) / bin_count
end
tree.bm.previous_lowerbound[b_idx] = lb_val
end

function get_bin_value(tree::SARSOPTree, b_idx::Int)

lb_val = tree.V_lower[b_idx]
ub_val = tree.V_upper[b_idx]

key = tree.bm.bin_levels_nodes[1][b_idx][:key]
if tree.bm.bin_levels[1][:bin_count][key] == 1
return ub_val
else
smallest_error = Inf
best_level = 0
best_key = key
for level_i in 1:tree.bm.num_levels
key = tree.bm.bin_levels_nodes[level_i][b_idx][:key]
if tree.bm.bin_levels[level_i][:bin_error][key] + 1e-10 < smallest_error
best_level = level_i
smallest_error = tree.bm.bin_levels[level_i][:bin_error][key]
best_key = key
end
end

if tree.bm.bin_levels[best_level][:bin_value][best_key] > ub_val + 1e-10
return ub_val
elseif tree.bm.bin_levels[best_level][:bin_value][best_key] + 1e-10 < lb_val
return lb_val
else
return tree.bm.bin_levels[best_level][:bin_value][best_key]
end
end
end

function max_entropy(n::Int)
return -1 * ((1.0 / n) * log(1.0 / n)) * n
end

function entropy(b::AbstractVector)
ent = 0.0
for b_i in b
b_i > 0 && (ent -= b_i * log(b_i))
end
return ent
end
dylan-asmar marked this conversation as resolved.
Show resolved Hide resolved

function get_interval_idx(value::Float64, lower::Float64, interval::Float64, num_intervals::Int)
if interval == 0.0
return 1
end
idx = Int(floor((value - lower) / interval) + 1)
return clamp(idx, 1, num_intervals)
end
91 changes: 85 additions & 6 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ include("updater.jl")

include("tree.jl")

@testset "Tiger POMDP" begin
@testset "Tiger POMDP (no binning)" begin
pomdp = TigerPOMDP()
solver = SARSOPSolver(epsilon=0.5, precision=1e-3, verbose=false)
solver = SARSOPSolver(epsilon=0.5, precision=1e-3, verbose=false, use_binning=false)
tree = SARSOPTree(pomdp)
Γ = solve(solver, pomdp)
iterations = 0
Expand All @@ -50,9 +50,9 @@ include("tree.jl")
@test abs(value(policyCPP, initialstate(pomdp)) - value(Γ, initialstate(pomdp))) < 0.01
end

@testset "Baby POMDP" begin
@testset "Baby POMDP (no binning)" begin
pomdp = BabyPOMDP()
solver = SARSOPSolver(epsilon=0.1, delta=0.1, precision=1e-3, verbose=false)
solver = SARSOPSolver(epsilon=0.1, delta=0.1, precision=1e-3, verbose=false, use_binning=false)
tree = SARSOPTree(pomdp)
Γ = solve(solver, pomdp)
iterations = 0
Expand All @@ -71,9 +71,88 @@ end
@test abs(value(policyCPP, initialstate(pomdp)) - value(Γ, initialstate(pomdp))) < 0.01
end

@testset "RockSample POMDP" begin
@testset "RockSample POMDP (no binning)" begin
pomdp = RockSamplePOMDP()
solver = SARSOPSolver(epsilon=0.1, delta=0.1, precision=1e-2, verbose=false)
solver = SARSOPSolver(epsilon=0.1, delta=0.1, precision=1e-2, verbose=false, use_binning=false)
tree = SARSOPTree(pomdp)
Γ = solve(solver, pomdp)
iterations = 0
while JSOP.root_diff(tree) > solver.precision
iterations += 1
JSOP.sample!(solver, tree)
JSOP.backup!(tree)
JSOP.prune!(solver, tree)
end
# @test isapprox(tree.V_lower[1], -16.3; atol=1e-2)
@test JSOP.root_diff(tree) < solver.precision

solverCPP = SARSOP.SARSOPSolver(trial_improvement_factor=0.5, precision=1e-2, verbose=false)
policyCPP = solve(solverCPP, pomdp)
@test abs(value(policyCPP, initialstate(pomdp)) - tree.V_lower[1]) < 0.1
@test abs(value(policyCPP, initialstate(pomdp)) - value(Γ, initialstate(pomdp))) < 0.1
end

@testset "Binning" begin
pomdp = BabyPOMDP()

solver = SARSOPSolver(epsilon=0.1, delta=0.1, precision=1e-8, max_time=3.0, verbose=false, use_binning=false)
Γ1 = solve(solver, pomdp)

solver = SARSOPSolver(epsilon=0.1, delta=0.1, precision=1e-8, max_time=3.0, verbose=false, use_binning=true)
Γ2 = solve(solver, pomdp)

@test abs(value(Γ1, initialstate(pomdp)) - value(Γ2, initialstate(pomdp))) < 1e-7

Γ, info = solve_info(solver, pomdp)
@test !isempty(info.tree.bm.bin_levels[1][:bin_count])
@test length(info.tree.bm.bin_levels) == 2
end

@testset "Tiger POMDP (with binning)" begin
pomdp = TigerPOMDP()
solver = SARSOPSolver(epsilon=0.5, precision=1e-3, verbose=false, use_binning=true)
tree = SARSOPTree(pomdp)
Γ = solve(solver, pomdp)
iterations = 0
while JSOP.root_diff(tree) > solver.precision
iterations += 1
JSOP.sample!(solver, tree)
JSOP.backup!(tree)
JSOP.prune!(solver, tree)
end
@test isapprox(tree.V_lower[1], 19.37; atol=1e-1)
@test JSOP.root_diff(tree) < solver.precision

solverCPP = SARSOP.SARSOPSolver(trial_improvement_factor=0.5, precision=1e-3, verbose=false)
policyCPP = solve(solverCPP, pomdp)
@test abs(value(policyCPP, initialstate(pomdp)) - tree.V_lower[1]) < 0.01
@test abs(value(policyCPP, initialstate(pomdp)) - value(Γ, initialstate(pomdp))) < 0.01
end

@testset "Baby POMDP (with binning)" begin
pomdp = BabyPOMDP()
solver = SARSOPSolver(epsilon=0.1, delta=0.1, precision=1e-3, verbose=false, use_binning=true)
tree = SARSOPTree(pomdp)
Γ = solve(solver, pomdp)
iterations = 0
while JSOP.root_diff(tree) > solver.precision
iterations += 1
JSOP.sample!(solver, tree)
JSOP.backup!(tree)
JSOP.prune!(solver, tree)
end
@test isapprox(tree.V_lower[1], -16.3; atol=1e-2)
@test JSOP.root_diff(tree) < solver.precision

solverCPP = SARSOP.SARSOPSolver(trial_improvement_factor=0.5, precision=1e-3, verbose=false)
policyCPP = solve(solverCPP, pomdp)
@test abs(value(policyCPP, initialstate(pomdp)) - tree.V_lower[1]) < 0.01
@test abs(value(policyCPP, initialstate(pomdp)) - value(Γ, initialstate(pomdp))) < 0.01
end

@testset "RockSample POMDP (with binning)" begin
pomdp = RockSamplePOMDP()
solver = SARSOPSolver(epsilon=0.1, delta=0.1, precision=1e-2, verbose=false, use_binning=true)
tree = SARSOPTree(pomdp)
Γ = solve(solver, pomdp)
iterations = 0
Expand Down
Loading