diff --git a/src/optimise/optimisers.jl b/src/optimise/optimisers.jl index 0149490059..53cef6be9e 100644 --- a/src/optimise/optimisers.jl +++ b/src/optimise/optimisers.jl @@ -555,6 +555,47 @@ function apply!(o::AdaBelief, x, Δ) return Δ end +""" + PAdam(η = 0.01, β::Tuple = (0.9, 0.999), ρ = 0.25, ϵ = $EPS) + +The partially adaptive momentum estimation method (PADAM) [https://arxiv.org/pdf/1806.06763v1.pdf] + +# Parameters +- Learning rate (`η`): Amount by which gradients are discounted before updating + the weights. +- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the + second (β2) momentum estimate. +- Partially adaptive parameter (`p`): Varies between 0 and 0.5. +- Machine epsilon (`ϵ`): Constant to prevent division by zero + (no need to change default) +# Examples +```julia +opt = PAdam() +``` +""" +mutable struct PAdam <: AbstractOptimiser + eta::Float64 + beta::Tuple{Float64, Float64} + rho::Float64 + epsilon::Float64 + state::IdDict{Any, Any} +end + +PAdam(η::Real = 0.01, β = (0.9, 0.999), ρ::Real = 0.25, ϵ::Real = EPS) = PAdam(η, β, ρ, ϵ, IdDict()) +PAdam(η::Real, β::Tuple, ρ::Real, state::IdDict) = PAdam(η, β, ρ, EPS, state) + +function apply!(o::PAdam, x, Δ) + η, β, ρ = o.eta, o.beta, o.rho + + mt, vt, v̂t = get!(o.state, x) do + (fill!(similar(x), o.epsilon), fill!(similar(x), o.epsilon), fill!(similar(x), o.epsilon)) + end :: NTuple{3,typeof(x)} + + @. mt = β[1] * mt + (1 - β[1]) * Δ + @. vt = β[2] * vt + (1 - β[2]) * Δ ^ 2 + @. v̂t = max(v̂t, vt) + @. Δ = η * mt / (v̂t ^ ρ + o.epsilon) +end # Compose optimisers