-
Notifications
You must be signed in to change notification settings - Fork 10
/
a2c.jl
37 lines (28 loc) · 1.35 KB
/
a2c.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
function a2c_loss(π, 𝒫, 𝒟; info = Dict())
new_probs = logpdf(π, 𝒟[:s], 𝒟[:a])
p_loss = -mean(new_probs .* 𝒟[:advantage])
e_loss = -mean(entropy(π, 𝒟[:s]))
# Log useful information
ignore_derivatives() do
info[:entropy] = -e_loss
info[:kl] = mean(𝒟[:logprob] .- new_probs)
end
𝒫[:λp]*p_loss + 𝒫[:λe]*e_loss
end
function A2C(;π::ActorCritic,
a_opt::NamedTuple=(;),
c_opt::NamedTuple=(;),
log::NamedTuple=(;),
λp::Float32=1f0,
λe::Float32=0.1f0,
required_columns=[],
kwargs...)
OnPolicySolver(;agent=PolicyParams(π),
𝒫=(λp=λp, λe=λe),
log=LoggerParams(;dir = "log/a2c", log...),
a_opt=TrainingParams(;loss=a2c_loss, early_stopping = (infos) -> (infos[end][:kl] > 0.015), name = "actor_", a_opt...),
c_opt=TrainingParams(;loss=(π, 𝒫, D; kwargs...) -> Flux.mse(value(π, D[:s]), D[:return]), name = "critic_", c_opt...),
post_sample_callback=(𝒟; kwargs...) -> (𝒟[:advantage] .= whiten(𝒟[:advantage])),
required_columns = unique([required_columns..., :return, :logprob, :advantage]),
kwargs...)
end