Skip to content

Commit

Permalink
Added BetaZero policy images
Browse files Browse the repository at this point in the history
  • Loading branch information
mossr committed Nov 18, 2023
1 parent e44428b commit 400d07b
Show file tree
Hide file tree
Showing 10 changed files with 2,776 additions and 107 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,15 @@ policy = RandomPolicy(pomdp)
h = simulate(HistoryRecorder(), pomdp, policy, up)
```

<p align="center">
<img src="./img/trajectories.svg">
</p>

| | | |
| :---: | :---: | :---: |
| <kbd> <img src="./img/policy.svg"> </kbd> | <kbd> <img src="./img/value.svg"> </kbd> | <kbd> <img src="./img/pfail.svg"> </kbd> |


## Unscented Kalman filter 🧼
_(Derivative free! How clean!)_

Expand Down
179 changes: 92 additions & 87 deletions img/cas.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1,072 changes: 1,072 additions & 0 deletions img/pfail.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
217 changes: 217 additions & 0 deletions img/policy.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
182 changes: 182 additions & 0 deletions img/trajectories.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1,037 changes: 1,037 additions & 0 deletions img/value.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
97 changes: 97 additions & 0 deletions scripts/cas_plots.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
using ProgressMeter
using ImageFiltering
default(fontfamily="Computer Modern", framestyle=:box)

blur(img, σ) = imfilter(Float64.(img), ImageFiltering.Kernel.gaussian(σ))

coarse = true
h_rel_range = coarse ? (-100:10:100) : (-100:0.5:100)
dh_rel = 0
a_prev = 0
τ_range = coarse ? (0:40) : (0:0.25:40)

ds0 = initialstate(pomdp)
b0 = initialize_belief(up, ds0)
A = actions(pomdp)

@enum PlotTyle ValuePlot PolicyPlot PfailPlot

# plot_type = ValuePlot
plot_type = PolicyPlot
# plot_type = PfailPlot

replan = true
verbose = false
n_runs = 3
use_mean = true
discrete_action_colors = false

policy = online_mode!(solver, policy)

s = rand(ds0)
policy_map = Matrix(undef, length(h_rel_range), length(τ_range))

@showprogress for (i,x) in enumerate(τ_range)
for (j,y) in enumerate(h_rel_range)
# s[1] = y
# s[2] = dh_rel
# s[3] = a_prev
# s[4] = x
# o = rand(observation(pomdp, s))
b = deepcopy(b0)
b.ukf.μ[1] = y
b.ukf.μ[2] = dh_rel # randn()
b.ukf.μ[3] = a_prev # rand(actions(pomdp))
b.ukf.μ[4] = x
# b = update(up, b, a_prev, o)
if plot_type == ValuePlot
z = value_lookup(policy.surrogate, b)
elseif plot_type == PolicyPlot
if replan
as = []
for n in 1:n_runs
a = action(policy, b)
push!(as, a)
end
if use_mean
z = mean(as)
else
a = as[argmax(map(a->sum(a .== as), as))]
z = A[actionindex(pomdp, a)]
end
else
z = A[argmax(policy_lookup(policy.surrogate, b))]
end
elseif plot_type == PfailPlot
z = pfail_lookup(policy.surrogate, b)
end
verbose && @info b.ukf.μ z
policy_map[j,i] = z
end
end

if plot_type == ValuePlot
c = cgrad(:viridis, rev=true)
kwargs = (c=c,)
title_str = "value estimate"
elseif plot_type == PolicyPlot
# c = palette(:pigeon, 3)
# action_colors = ["#8c1515", :white, "#007662"]
action_colors = ["#008000", :white, "#0000FF"]
c = cgrad(action_colors)
kwargs = (c=c,)
if discrete_action_colors
c = palette(c, length(A))
kwargs = (c=c, level=length(A))
end
title_str = "online policy"
elseif plot_type == PfailPlot
c = cgrad(["#007662", :white, "#8c1515"])
kwargs = (c=c, ) # clims=(0,1))
title_str = "failure probability estimate"
end

# Plots.heatmap(τ_range, h_rel_range, blur(policy_map, 1.5); xflip=true, label=false, size=(400,250), title=title_str, clims=(minimum(A), maximum(5)), kwargs...)
Plots.heatmap(τ_range, h_rel_range, policy_map; xflip=true, label=false, size=(400,250), title=title_str, kwargs...)
Plots.xlabel!(raw"time to closest approach ($\tau$)")
Plots.ylabel!(raw"relative altitude ($h_\mathrm{rel}$)")
4 changes: 3 additions & 1 deletion src/CollisionAvoidancePOMDPs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ export
get_obs_h_rel,
get_belief_mean_h_rel,
get_belief_std_h_rel,
get_rewards
get_rewards,
generate_histories,
plot_histories

include("pomdp.jl")
include("ukf.jl")
Expand Down
82 changes: 63 additions & 19 deletions src/plotting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,39 +11,59 @@ get_rewards(h) = [step.r for step in h]
rectangle(w, h, x, y) = Shape(x .+ [0,w,w,0], y .+ [0,0,h,h])

function plot_history(pomdp::CollisionAvoidancePOMDP, h::SimHistory, t=length(h);
show_actions=true, action_colors=["#8c1515", :white, "#007662"],
show_collision_area=true, show_aircraft=true,
ymin=missing, ymax=missing)
show_actions=true, show_belief=true, show_obs=false,
show_zero_actions=true, show_collision_area=true, show_aircraft=true,
action_colors=["#8c1515", :white, "#007662"],
hold=false, ymin=missing, ymax=missing,
alpha=1, fillalpha=0.5, belief_lw=2, action_ms=4)
X = get_taus(h)[1:t]

plot(size=(450, 300), xlims=(0, pomdp.τ_max),
xlabel=raw"time to closest approach ($\tau$)", ylabel=raw"relative altitude ($h_\mathrm{rel}$)",
fontfamily="Computer Modern", framestyle=:box)
plotf = hold ? plot! : plot

hline!([0], label=false, c=:black, lw=0.5)
plotf(size=(450, 300), xlims=(0, pomdp.τ_max),
xlabel=raw"time to closest approach ($\tau$)", ylabel=raw"relative altitude ($h_\mathrm{rel}$)",
fontfamily="Computer Modern", framestyle=:box)

# belief
plot!(X, get_belief_mean_h_rel(h)[1:t], c=:gray, lw=2, ls=:dash,
ribbon=get_belief_std_h_rel(h)[1:t], label=false)
!hold && hline!([0], label=false, c=:black, lw=0.5, ls=:dot)

if show_belief
plot!(X, get_belief_mean_h_rel(h)[1:t], c=:gray, lw=belief_lw, ls=:dash,
ribbon=get_belief_std_h_rel(h)[1:t], label=false, fillalpha=fillalpha)
end

# true state
plot!(X, get_h_rel(h)[1:t], label=false, xflip=true, c=:black, lw=1)
plot!(X, get_h_rel(h)[1:t], label=false, xflip=true, c=:black, lw=1, alpha=alpha)

if show_obs
mark = :circle
color = :white
msc = :black
ms = 2
obs_x = X
obs_y = get_obs_h_rel(h)[1:t]
scatter!(obs_x, obs_y; ms, label=false, mark, color, msc)
end

if show_actions
act_x = copy(X)
act_y = get_h_rel(h)[1:t]
AI = map(a->actionindex(pomdp, a), get_actions(h)[1:t])
markers = [:dtriangle, :square, :utriangle]
stroke_colors = [action_colors[1], :gray, action_colors[3]]
stroke_colors = [:white, :gray, :white]
mark = [markers[ai] for ai in AI]
color = [action_colors[ai] for ai in AI]
msc = [stroke_colors[ai] for ai in AI]
ms = 3
else
mark = :circle
color = :white
msc = :black
ms = 2
ms = action_ms
if !show_zero_actions
idx = AI .== actionindex(pomdp, 0)
deleteat!(mark, idx)
deleteat!(color, idx)
deleteat!(msc, idx)
deleteat!(act_x, idx)
deleteat!(act_y, idx)
end
scatter!(act_x, act_y; ms, label=false, mark, color, msc)
end
scatter!(X, get_obs_h_rel(h)[1:t]; ms, label=false, mark, color, msc)

if show_collision_area
plot!(rectangle(1, 2pomdp.collision_threshold, 0, -pomdp.collision_threshold), opacity=0.25, color=:crimson, label=false)
Expand Down Expand Up @@ -78,3 +98,27 @@ function overlay_aircraft!()
Y = [-height/2, height/2]
return plot!(X, Y, reverse(img, dims=1), yflip=false, ratio=:none)
end

function generate_histories(pomdp::CollisionAvoidancePOMDP, policy::Policy, up::Updater, n::Int)
return [simulate(HistoryRecorder(), pomdp, policy, up) for _ in 1:n]
end

function plot_histories(pomdp, H::Vector{<:SimHistory})
for (i,h) in enumerate(H)
isfirst = i==1
plot_history(pomdp, h;
ymin=-350,
ymax=350,
show_actions=true,
show_obs=false,
show_collision_area=isfirst,
show_aircraft=isfirst,
show_belief=true,
fillalpha=0.1,
belief_lw=0,
show_zero_actions=false,
alpha=0.25,
hold=!isfirst)
end
return plot!()
end
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,12 @@ end
policy = RandomPolicy(pomdp)
h = simulate(HistoryRecorder(), pomdp, policy, up)
plot_history(pomdp, h)
plot_history(pomdp, h; show_obs=true)
plot_history(pomdp, h; ymin=-350, ymax=350)
plot_history(pomdp, h; show_actions=false)
plot_history(pomdp, h; show_actions=true, show_zero_actions=false)
H = generate_histories(pomdp, policy, up, 2)
plot_histories(pomdp, H)
get_actions(h)
get_h_rel(h)
get_dh_rel(h)
Expand Down

0 comments on commit 400d07b

Please sign in to comment.