diff --git a/README.md b/README.md index 70155f8..5924a43 100644 --- a/README.md +++ b/README.md @@ -45,6 +45,15 @@ policy = RandomPolicy(pomdp) h = simulate(HistoryRecorder(), pomdp, policy, up) ``` +

+ +

+ +| | | | +| :---: | :---: | :---: | +| | | | + + ## Unscented Kalman filter 🧼 _(Derivative free! How clean!)_ diff --git a/img/cas.svg b/img/cas.svg index d5b8a9d..5288314 100644 --- a/img/cas.svg +++ b/img/cas.svg @@ -1,98 +1,103 @@ - - - + + - + - - + + - - - + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +" transform="translate(1640, 485)"> - + \ No newline at end of file diff --git a/img/pfail.svg b/img/pfail.svg new file mode 100644 index 0000000..5138fae --- /dev/null +++ b/img/pfail.svg @@ -0,0 +1,1072 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/img/policy.svg b/img/policy.svg new file mode 100644 index 0000000..0abe138 --- /dev/null +++ b/img/policy.svg @@ -0,0 +1,217 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/img/trajectories.svg b/img/trajectories.svg new file mode 100644 index 0000000..6a51917 --- /dev/null +++ b/img/trajectories.svg @@ -0,0 +1,182 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/img/value.svg b/img/value.svg new file mode 100644 index 0000000..5ceee8c --- /dev/null +++ b/img/value.svg @@ -0,0 +1,1037 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/scripts/cas_plots.jl b/scripts/cas_plots.jl new file mode 100644 index 0000000..91f679e --- /dev/null +++ b/scripts/cas_plots.jl @@ -0,0 +1,97 @@ +using ProgressMeter +using ImageFiltering +default(fontfamily="Computer Modern", framestyle=:box) + +blur(img, σ) = imfilter(Float64.(img), ImageFiltering.Kernel.gaussian(σ)) + +coarse = true +h_rel_range = coarse ? (-100:10:100) : (-100:0.5:100) +dh_rel = 0 +a_prev = 0 +τ_range = coarse ? (0:40) : (0:0.25:40) + +ds0 = initialstate(pomdp) +b0 = initialize_belief(up, ds0) +A = actions(pomdp) + +@enum PlotTyle ValuePlot PolicyPlot PfailPlot + +# plot_type = ValuePlot +plot_type = PolicyPlot +# plot_type = PfailPlot + +replan = true +verbose = false +n_runs = 3 +use_mean = true +discrete_action_colors = false + +policy = online_mode!(solver, policy) + +s = rand(ds0) +policy_map = Matrix(undef, length(h_rel_range), length(τ_range)) + +@showprogress for (i,x) in enumerate(τ_range) + for (j,y) in enumerate(h_rel_range) + # s[1] = y + # s[2] = dh_rel + # s[3] = a_prev + # s[4] = x + # o = rand(observation(pomdp, s)) + b = deepcopy(b0) + b.ukf.μ[1] = y + b.ukf.μ[2] = dh_rel # randn() + b.ukf.μ[3] = a_prev # rand(actions(pomdp)) + b.ukf.μ[4] = x + # b = update(up, b, a_prev, o) + if plot_type == ValuePlot + z = value_lookup(policy.surrogate, b) + elseif plot_type == PolicyPlot + if replan + as = [] + for n in 1:n_runs + a = action(policy, b) + push!(as, a) + end + if use_mean + z = mean(as) + else + a = as[argmax(map(a->sum(a .== as), as))] + z = A[actionindex(pomdp, a)] + end + else + z = A[argmax(policy_lookup(policy.surrogate, b))] + end + elseif plot_type == PfailPlot + z = pfail_lookup(policy.surrogate, b) + end + verbose && @info b.ukf.μ z + policy_map[j,i] = z + end +end + +if plot_type == ValuePlot + c = cgrad(:viridis, rev=true) + kwargs = (c=c,) + title_str = "value estimate" +elseif plot_type == PolicyPlot + # c = palette(:pigeon, 3) + # action_colors = ["#8c1515", :white, "#007662"] + action_colors = ["#008000", :white, "#0000FF"] + c = cgrad(action_colors) + kwargs = (c=c,) + if discrete_action_colors + c = palette(c, length(A)) + kwargs = (c=c, level=length(A)) + end + title_str = "online policy" +elseif plot_type == PfailPlot + c = cgrad(["#007662", :white, "#8c1515"]) + kwargs = (c=c, ) # clims=(0,1)) + title_str = "failure probability estimate" +end + +# Plots.heatmap(τ_range, h_rel_range, blur(policy_map, 1.5); xflip=true, label=false, size=(400,250), title=title_str, clims=(minimum(A), maximum(5)), kwargs...) +Plots.heatmap(τ_range, h_rel_range, policy_map; xflip=true, label=false, size=(400,250), title=title_str, kwargs...) +Plots.xlabel!(raw"time to closest approach ($\tau$)") +Plots.ylabel!(raw"relative altitude ($h_\mathrm{rel}$)") diff --git a/src/CollisionAvoidancePOMDPs.jl b/src/CollisionAvoidancePOMDPs.jl index 2d5096e..5dfa0e2 100644 --- a/src/CollisionAvoidancePOMDPs.jl +++ b/src/CollisionAvoidancePOMDPs.jl @@ -29,7 +29,9 @@ export get_obs_h_rel, get_belief_mean_h_rel, get_belief_std_h_rel, - get_rewards + get_rewards, + generate_histories, + plot_histories include("pomdp.jl") include("ukf.jl") diff --git a/src/plotting.jl b/src/plotting.jl index 6eb9f9a..6b7c74c 100644 --- a/src/plotting.jl +++ b/src/plotting.jl @@ -11,39 +11,59 @@ get_rewards(h) = [step.r for step in h] rectangle(w, h, x, y) = Shape(x .+ [0,w,w,0], y .+ [0,0,h,h]) function plot_history(pomdp::CollisionAvoidancePOMDP, h::SimHistory, t=length(h); - show_actions=true, action_colors=["#8c1515", :white, "#007662"], - show_collision_area=true, show_aircraft=true, - ymin=missing, ymax=missing) + show_actions=true, show_belief=true, show_obs=false, + show_zero_actions=true, show_collision_area=true, show_aircraft=true, + action_colors=["#8c1515", :white, "#007662"], + hold=false, ymin=missing, ymax=missing, + alpha=1, fillalpha=0.5, belief_lw=2, action_ms=4) X = get_taus(h)[1:t] - plot(size=(450, 300), xlims=(0, pomdp.τ_max), - xlabel=raw"time to closest approach ($\tau$)", ylabel=raw"relative altitude ($h_\mathrm{rel}$)", - fontfamily="Computer Modern", framestyle=:box) + plotf = hold ? plot! : plot - hline!([0], label=false, c=:black, lw=0.5) + plotf(size=(450, 300), xlims=(0, pomdp.τ_max), + xlabel=raw"time to closest approach ($\tau$)", ylabel=raw"relative altitude ($h_\mathrm{rel}$)", + fontfamily="Computer Modern", framestyle=:box) - # belief - plot!(X, get_belief_mean_h_rel(h)[1:t], c=:gray, lw=2, ls=:dash, - ribbon=get_belief_std_h_rel(h)[1:t], label=false) + !hold && hline!([0], label=false, c=:black, lw=0.5, ls=:dot) + + if show_belief + plot!(X, get_belief_mean_h_rel(h)[1:t], c=:gray, lw=belief_lw, ls=:dash, + ribbon=get_belief_std_h_rel(h)[1:t], label=false, fillalpha=fillalpha) + end # true state - plot!(X, get_h_rel(h)[1:t], label=false, xflip=true, c=:black, lw=1) + plot!(X, get_h_rel(h)[1:t], label=false, xflip=true, c=:black, lw=1, alpha=alpha) + + if show_obs + mark = :circle + color = :white + msc = :black + ms = 2 + obs_x = X + obs_y = get_obs_h_rel(h)[1:t] + scatter!(obs_x, obs_y; ms, label=false, mark, color, msc) + end if show_actions + act_x = copy(X) + act_y = get_h_rel(h)[1:t] AI = map(a->actionindex(pomdp, a), get_actions(h)[1:t]) markers = [:dtriangle, :square, :utriangle] - stroke_colors = [action_colors[1], :gray, action_colors[3]] + stroke_colors = [:white, :gray, :white] mark = [markers[ai] for ai in AI] color = [action_colors[ai] for ai in AI] msc = [stroke_colors[ai] for ai in AI] - ms = 3 - else - mark = :circle - color = :white - msc = :black - ms = 2 + ms = action_ms + if !show_zero_actions + idx = AI .== actionindex(pomdp, 0) + deleteat!(mark, idx) + deleteat!(color, idx) + deleteat!(msc, idx) + deleteat!(act_x, idx) + deleteat!(act_y, idx) + end + scatter!(act_x, act_y; ms, label=false, mark, color, msc) end - scatter!(X, get_obs_h_rel(h)[1:t]; ms, label=false, mark, color, msc) if show_collision_area plot!(rectangle(1, 2pomdp.collision_threshold, 0, -pomdp.collision_threshold), opacity=0.25, color=:crimson, label=false) @@ -78,3 +98,27 @@ function overlay_aircraft!() Y = [-height/2, height/2] return plot!(X, Y, reverse(img, dims=1), yflip=false, ratio=:none) end + +function generate_histories(pomdp::CollisionAvoidancePOMDP, policy::Policy, up::Updater, n::Int) + return [simulate(HistoryRecorder(), pomdp, policy, up) for _ in 1:n] +end + +function plot_histories(pomdp, H::Vector{<:SimHistory}) + for (i,h) in enumerate(H) + isfirst = i==1 + plot_history(pomdp, h; + ymin=-350, + ymax=350, + show_actions=true, + show_obs=false, + show_collision_area=isfirst, + show_aircraft=isfirst, + show_belief=true, + fillalpha=0.1, + belief_lw=0, + show_zero_actions=false, + alpha=0.25, + hold=!isfirst) + end + return plot!() +end diff --git a/test/runtests.jl b/test/runtests.jl index 47cea3a..d37d60f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -103,8 +103,12 @@ end policy = RandomPolicy(pomdp) h = simulate(HistoryRecorder(), pomdp, policy, up) plot_history(pomdp, h) + plot_history(pomdp, h; show_obs=true) plot_history(pomdp, h; ymin=-350, ymax=350) plot_history(pomdp, h; show_actions=false) + plot_history(pomdp, h; show_actions=true, show_zero_actions=false) + H = generate_histories(pomdp, policy, up, 2) + plot_histories(pomdp, H) get_actions(h) get_h_rel(h) get_dh_rel(h)