Skip to content

Commit

Permalink
RFC: Compute working weights and residuals more carefully to reduce u…
Browse files Browse the repository at this point in the history
…nderflow (#312)

* Compute working weights and residuals more carefully to reduce underflow
and get the limits right when overflow is unavoidable. To do this, change
inverselink to return μ, 1-μ, dμdη, instead of μ, dμdη, μ*(1-μ) for Link01
in order to have access to accurate μ or 1-μ.

Introduce an absolute tolerance criterion to avoid convergence issues when
deviance is almost zero and rename tol to rtol.

* Adjust compat section of Project.toml

* Fix deprecation

* Tighten test
  • Loading branch information
andreasnoack authored Jun 20, 2019
1 parent 0926a95 commit ef246bb
Show file tree
Hide file tree
Showing 6 changed files with 268 additions and 165 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ test = ["CategoricalArrays", "CSV", "DataFrames", "RDatasets", "Test"]
[compat]
CategoricalArrays = "0.3, 0.4, 0.5"
CSV = "0.2, 0.3, 0.4"
DataFrames = "0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17"
DataFrames = "0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, 0.18"
Distributions = "0.16, 0.17, 0.18, 0.19"
RDatasets = "0.5, 0.6"
Reexport = "0.1, 0.2"
Expand Down
149 changes: 117 additions & 32 deletions src/glmfit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,16 +106,79 @@ function updateμ!(r::GlmResp{V,D,L}) where {V<:FPVector,D,L}
end
end

function _weights_residuals(yᵢ, ηᵢ, μᵢ, omμᵢ, dμdηᵢ, l::LogitLink)
# LogitLink is the canonical link function for Binomial so only wrkresᵢ can
# possibly fail when dμdη==0 in which case it evaluates to ±1.
if iszero(dμdηᵢ)
wrkresᵢ = ifelse(yᵢ == 1, one(μᵢ), -one(μᵢ))
else
wrkresᵢ = ifelse(yᵢ == 1, omμᵢ, yᵢ - μᵢ) / dμdηᵢ
end
wrkwtᵢ = μᵢ*omμᵢ

return wrkresᵢ, wrkwtᵢ
end

function _weights_residuals(yᵢ, ηᵢ, μᵢ, omμᵢ, dμdηᵢ, l::ProbitLink)
# Since μomμ will underflow before dμdη for Probit, we can just check the
# former to decide when to evaluate with the tail approximation.
μomμᵢ = μᵢ*omμᵢ
if iszero(μomμᵢ)
wrkresᵢ = 1/abs(ηᵢ)
wrkwtᵢ = dμdηᵢ
else
wrkresᵢ = ifelse(yᵢ == 1, omμᵢ, yᵢ - μᵢ) / dμdηᵢ
wrkwtᵢ = abs2(dμdηᵢ)/μomμᵢ
end

return wrkresᵢ, wrkwtᵢ
end

function _weights_residuals(yᵢ, ηᵢ, μᵢ, omμᵢ, dμdηᵢ, l::CloglogLink)
if yᵢ == 1
wrkresᵢ = exp(-ηᵢ)
else
emη = exp(-ηᵢ)
if iszero(emη)
# Diverges to -∞
wrkresᵢ = -typeof(wrkresᵢ)(Inf)
elseif isinf(emη)
# converges to -1
wrkresᵢ = -one(emη)
else
wrkresᵢ = (yᵢ - μᵢ)/omμᵢ*emη
end
end

wrkwtᵢ = exp(2*ηᵢ)/expm1(exp(ηᵢ))
# We know that both limits are zero so we'll convert NaNs
wrkwtᵢ = ifelse(isnan(wrkwtᵢ), zero(wrkwtᵢ), wrkwtᵢ)

return wrkresᵢ, wrkwtᵢ
end

# Fallback for remaining link functions
function _weights_residuals(yᵢ, ηᵢ, μᵢ, omμᵢ, dμdηᵢ, l::Link01)
wrkresᵢ = ifelse(yᵢ == 1, omμᵢ, yᵢ - μᵢ)/dμdηᵢ
wrkwtᵢ = abs2(dμdηᵢ)/(μᵢ*omμᵢ)

return wrkresᵢ, wrkwtᵢ
end

function updateμ!(r::GlmResp{V,D,L}) where {V<:FPVector,D<:Union{Bernoulli,Binomial},L<:Link01}
y, η, μ, wrkres, wrkwt, dres = r.y, r.eta, r.mu, r.wrkresid, r.wrkwt, r.devresid

@inbounds for i in eachindex(y, η, μ, wrkres, wrkwt, dres)
μi, dμdη, μomμ = inverselink(L(), η[i])
μ[i] = μi
yi = y[i]
wrkres[i] = (yi - μi) / dμdη
wrkwt[i] = cancancel(r) ? dμdη : abs2(dμdη) / μomμ
dres[i] = devresid(r.d, yi, μi)
yᵢ, ηᵢ = y[i], η[i]
μᵢ, omμᵢ, dμdηᵢ = inverselink(L(), ηᵢ)
μ[i] = μᵢ
# For large values of ηᵢ the quantities dμdη and μomμ will underflow.
# The ratios defining (yᵢ - μᵢ)/dμdη and dμdη^2/μomμ have fairly stable
# tail behavior so we can switch algorithm to avoid 0/0. The behavior
# is specific to the link function so _weights_residuals dispatches to
# robust versions for LogitLink and ProbitLink
wrkres[i], wrkwt[i] = _weights_residuals(yᵢ, ηᵢ, μᵢ, omμᵢ, dμdηᵢ, L())
dres[i] = devresid(r.d, yᵢ, μᵢ)
end
end

Expand Down Expand Up @@ -200,7 +263,7 @@ end
dof(x::GeneralizedLinearModel) = dispersion_parameter(x.rr.d) ? length(coef(x)) + 1 : length(coef(x))

function _fit!(m::AbstractGLM, verbose::Bool, maxiter::Integer, minstepfac::Real,
tol::Real, start)
atol::Real, rtol::Real, start)

# Return early if model has the fit flag set
m.fit && return m
Expand Down Expand Up @@ -246,9 +309,9 @@ function _fit!(m::AbstractGLM, verbose::Bool, maxiter::Integer, minstepfac::Real

# Line search
## If the deviance isn't declining then half the step size
## The tol*dev term is to avoid failure when deviance
## The rtol*dev term is to avoid failure when deviance
## is unchanged except for rouding errors.
while dev > devold + tol*dev
while dev > devold + rtol*dev
f /= 2
f > minstepfac || error("step-halving failed at beta0 = $(p.beta0)")
try
Expand All @@ -261,22 +324,26 @@ function _fit!(m::AbstractGLM, verbose::Bool, maxiter::Integer, minstepfac::Real
installbeta!(p, f)

# Test for convergence
crit = (devold - dev)/dev
verbose && println("$i: $dev, $crit")
if crit < tol || dev == 0
verbose && println("Iteration: $i, deviance: $dev, diff.dev.:$(devold - dev)")
if devold - dev < max(rtol*devold, atol)
cvg = true
break
end
@assert isfinite(crit)
@assert isfinite(dev)
devold = dev
end
cvg || throw(ConvergenceException(maxiter))
m.fit = true
m
end

function StatsBase.fit!(m::AbstractGLM; verbose::Bool=false, maxiter::Integer=30,
minstepfac::Real=0.001, tol::Real=1e-6, start=nothing,
function StatsBase.fit!(m::AbstractGLM;
verbose::Bool=false,
maxiter::Integer=30,
minstepfac::Real=0.001,
atol::Real=1e-6,
rtol::Real=1e-6,
start=nothing,
kwargs...)
if haskey(kwargs, :maxIter)
Base.depwarn("'maxIter' argument is deprecated, use 'maxiter' instead", :fit!)
Expand All @@ -287,19 +354,32 @@ function StatsBase.fit!(m::AbstractGLM; verbose::Bool=false, maxiter::Integer=30
minstepfac = kwargs[:minStepFac]
end
if haskey(kwargs, :convTol)
Base.depwarn("'convTol' argument is deprecated, use 'tol' instead", :fit!)
tol = kwargs[:convTol]
Base.depwarn("'convTol' argument is deprecated, use `atol` and `rtol` instead", :fit!)
rtol = kwargs[:convTol]
end
if !issubset(keys(kwargs), (:maxIter, :minStepFac, :convTol))
throw(ArgumentError("unsupported keyword argument"))
end
if haskey(kwargs, :tol)
Base.depwarn("`tol` argument is deprecated, use `atol` and `rtol` instead", :fit!)
rtol = kwargs[:tol]
end

_fit!(m, verbose, maxiter, minstepfac, tol, start)
_fit!(m, verbose, maxiter, minstepfac, atol, rtol, start)
end

function StatsBase.fit!(m::AbstractGLM, y; wts=nothing, offset=nothing, dofit::Bool=true,
verbose::Bool=false, maxiter::Integer=30, minstepfac::Real=0.001,
tol::Real=1e-6, start=nothing, kwargs...)
function StatsBase.fit!(m::AbstractGLM,
y;
wts=nothing,
offset=nothing,
dofit::Bool=true,
verbose::Bool=false,
maxiter::Integer=30,
minstepfac::Real=0.001,
atol::Real=1e-6,
rtol::Real=1e-6,
start=nothing,
kwargs...)
if haskey(kwargs, :maxIter)
Base.depwarn("'maxIter' argument is deprecated, use 'maxiter' instead", :fit!)
maxiter = kwargs[:maxIter]
Expand All @@ -309,12 +389,16 @@ function StatsBase.fit!(m::AbstractGLM, y; wts=nothing, offset=nothing, dofit::B
minstepfac = kwargs[:minStepFac]
end
if haskey(kwargs, :convTol)
Base.depwarn("'convTol' argument is deprecated, use 'tol' instead", :fit!)
tol = kwargs[:convTol]
Base.depwarn("'convTol' argument is deprecated, use `atol` and `rtol` instead", :fit!)
rtol = kwargs[:convTol]
end
if !issubset(keys(kwargs), (:maxIter, :minStepFac, :convTol))
throw(ArgumentError("unsupported keyword argument"))
end
if haskey(kwargs, :tol)
Base.depwarn("`tol` argument is deprecated, use `atol` and `rtol` instead", :fit!)
rtol = kwargs[:tol]
end

r = m.rr
V = typeof(r.y)
Expand All @@ -326,7 +410,7 @@ function StatsBase.fit!(m::AbstractGLM, y; wts=nothing, offset=nothing, dofit::B
fill!(m.pp.beta0, 0)
m.fit = false
if dofit
_fit!(m, verbose, maxiter, minstepfac, tol, start)
_fit!(m, verbose, maxiter, minstepfac, atol, rtol, start)
else
m
end
Expand All @@ -346,8 +430,10 @@ vector, respectively, or a formula and a data frame. `d` must be a
length 0
- `verbose::Bool=false`: Display convergence information for each iteration
- `maxiter::Integer=30`: Maximum number of iterations allowed to achieve convergence
- `tol::Real=1e-6`: Convergence is achieved when the relative change in
deviance is less than this
- `atol::Real=1e-6`: Convergence is achieved when the relative change in
deviance is less than `max(rtol*dev, atol)`.
- `rtol::Real=1e-6`: Convergence is achieved when the relative change in
deviance is less than `max(rtol*dev, atol)`.
- `minstepfac::Real=0.001`: Minimum line step fraction. Must be between 0 and 1.
- `start::AbstractVector=nothing`: Starting values for beta. Should have the
same length as the number of columns in the model matrix.
Expand All @@ -373,11 +459,11 @@ function fit(::Type{M},
end

fit(::Type{M},
X::Union{Matrix,SparseMatrixCSC},
y::AbstractVector,
d::UnivariateDistribution,
l::Link=canonicallink(d); kwargs...) where {M<:AbstractGLM} =
fit(M, float(X), float(y), d, l; kwargs...)
X::Union{Matrix,SparseMatrixCSC},
y::AbstractVector,
d::UnivariateDistribution,
l::Link=canonicallink(d); kwargs...) where {M<:AbstractGLM} =
fit(M, float(X), float(y), d, l; kwargs...)

"""
glm(F, D, args...; kwargs...)
Expand Down Expand Up @@ -485,4 +571,3 @@ function checky(y, d::Binomial)
end
return nothing
end

Loading

0 comments on commit ef246bb

Please sign in to comment.