Skip to content

Commit

Permalink
Merge pull request #39 from TuringLang/minor-improvements
Browse files Browse the repository at this point in the history
Minor improvements
  • Loading branch information
ParadaCarleton authored Sep 17, 2021
2 parents cce392b + 75d5e9c commit 425404f
Show file tree
Hide file tree
Showing 11 changed files with 53 additions and 56 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ jobs:
fail-fast: false
matrix:
version:
- '1.6' # Replace this with the minimum Julia version that your package supports. E.g. if your package requires Julia 1.5 or higher, change this to '1.5'.
- '1' # Leave this line unchanged. '1' will automatically expand to the latest stable 1.x release of Julia.
#- 'nightly'
- '1' # latest stable 1.x release of Julia
- '1.6' # oldest supported version
- 'nightly'
os:
- ubuntu-latest
arch:
Expand Down
2 changes: 1 addition & 1 deletion CITATION.bib
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ @misc{ParetoSmooth.jl
author = {Carlos Parada <paradac@carleton.edu>},
title = {ParetoSmooth.jl},
url = {https://github.com/TuringLang/ParetoSmooth.jl},
version = {v0.3.0},
version = {v0.6.0},
year = {2021},
month = {6}
}
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"

Expand All @@ -33,6 +34,7 @@ NamedDims = "0.2.35"
Polyester = "0.3.4, 0.4, 0.5"
PrettyTables = "1.1.0"
Requires = "1.1.3"
StatsBase = "0.33.10"
StatsFuns = "0.9.9"
Tullio = "0.3.0"
julia = "1.6"
Expand Down
6 changes: 3 additions & 3 deletions src/ESS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,15 @@ function psis_ess(
weights::AbstractVector{T}, r_eff::AbstractVector{T}
) where {T <: Union{Real, Missing}}
@tullio sum_of_squares := weights[x]^2
return r_eff ./ sum_of_squares
return @turbo r_eff ./ sum_of_squares
end


function psis_ess(
weights::AbstractMatrix{T}, r_eff::AbstractVector{T}
) where {T <: Union{Real, Missing}}
@tullio sum_of_squares[x] := weights[x, y]^2
return @tturbo r_eff ./ sum_of_squares
return @turbo r_eff ./ sum_of_squares
end


Expand Down Expand Up @@ -84,5 +84,5 @@ L-∞ norm.
function sup_ess(
weights::AbstractMatrix{T}, r_eff::V
) where {T<:Union{Real, Missing}, V<:AbstractVector{T}}
return @tturbo inv.(dropdims(maximum(weights; dims=2); dims=2)) .* r_eff
return @turbo inv.(dropdims(maximum(weights; dims=2); dims=2)) .* r_eff
end
19 changes: 6 additions & 13 deletions src/ImportanceSampling.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using LoopVectorization

using StatsBase
using Tullio

const LIKELY_ERROR_CAUSES = """
Expand Down Expand Up @@ -80,8 +80,7 @@ end
psis(
log_ratios::AbstractArray{T<:Real},
r_eff::AbstractVector;
source::String="mcmc",
log_weights::Bool=false
source::String="mcmc"
) -> Psis
Implements Pareto-smoothed importance sampling (PSIS).
Expand All @@ -99,15 +98,13 @@ Implements Pareto-smoothed importance sampling (PSIS).
- `source::String="mcmc"`: A string or symbol describing the source of the sample being
used. If `"mcmc"`, adjusts ESS for autocorrelation. Otherwise, samples are assumed to be
independent. Currently permitted values are $SAMPLE_SOURCES.
- `log_weights::Bool=false`: Return the log weights, rather than the PSIS weights.
See also: [`relative_eff`]@ref, [`psis_loo`]@ref, [`psis_ess`]@ref.
"""
function psis(
log_ratios::AbstractArray{<:Real, 3};
r_eff::AbstractVector{<:Real}=similar(log_ratios, 0),
source::Union{AbstractString, Symbol}="mcmc",
log_weights::Bool=false,
source::Union{AbstractString, Symbol}="mcmc"
)

source = lowercase(String(source))
Expand All @@ -118,17 +115,17 @@ function psis(

# Reshape to matrix (easier to deal with)
log_ratios = reshape(log_ratios, data_size, post_sample_size)
r_eff = _generate_r_eff(log_ratios, dims, r_eff, source)
weights = similar(log_ratios)
# Shift ratios by maximum to prevent overflow
@tturbo @. weights = exp(log_ratios - $maximum(log_ratios; dims=2))

r_eff = _generate_r_eff(weights, dims, r_eff, source)

_check_input_validity_psis(reshape(log_ratios, dims), r_eff)

tail_length = Vector{Int}(undef, data_size)
ξ = similar(r_eff)
@inbounds Threads.@threads for i in eachindex(tail_length)
tail_length[i] = @views _def_tail_length(post_sample_size, r_eff[i])
tail_length[i] = _def_tail_length(post_sample_size, r_eff[i])
ξ[i] = @views ParetoSmooth._do_psis_i!(weights[i, :], tail_length[i])
end

Expand All @@ -139,10 +136,6 @@ function psis(

weights = reshape(weights, dims)

if log_weights
@tturbo @. weights = log(weights)
end

return Psis(
weights,
ξ,
Expand Down
5 changes: 5 additions & 0 deletions src/InternalHelpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ of that point. This function must take the form `f(θ[1], ..., θ[n], data)`, wh
parameter vector. See also the `splat` keyword argument.
"""

const LIKELIHOOD_ARRAY_ARG = """
`log_likelihood::Array`: A matrix or 3d array of log-likelihood values indexed as
`[data, step, chain]`. See the `chain_index` argument if leaving the `chain` index off.
"""

const R_EFF_DOC = """
`r_eff::AbstractVector`: An (optional) vector of relative effective sample sizes used in ESS
calculations. If left empty, calculated automatically using the FFTESS method from
Expand Down
28 changes: 12 additions & 16 deletions src/LeaveOneOut.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,7 @@ score.
# Arguments
- `log_likelihood::Array`: A matrix or 3d array of log-likelihood values indexed as
`[data, step, chain]`. The chain argument can be left off if `chain_index` is provided
or if all posterior samples were drawn from a single chain.
- $LIKELIHOOD_ARRAY_ARG
- $ARGS [`psis`](@ref).
- $CHAIN_INDEX_DOC
- $KWARGS [`psis`](@ref).
Expand Down Expand Up @@ -135,11 +133,9 @@ Use a precalculated `Psis` object to estimate the leave-one-out cross validation
# Arguments
- `log_likelihood::Array`: A matrix or 3d array of log-likelihood values indexed as
`[data, step, chain]`. The chain argument can be left off if `chain_index` is provided
or if all posterior samples were drawn from a single chain.
- `psis_object`: A precomputed `Psis` object used to estimate the LOO-CV score.
- $CHAIN_INDEX_DOC
- $LIKELIHOOD_ARRAY_ARG
- `psis_object`: A precomputed `Psis` object used to estimate the LOO-CV score.
- $CHAIN_INDEX_DOC
See also: [`psis`](@ref), [`loo`](@ref), [`PsisLoo`](@ref).
Expand All @@ -160,6 +156,7 @@ function loo_from_psis(log_likelihood::AbstractArray{<:Real, 3}, psis_object::Ps
ξ = psis_object.pareto_k
r_eff = psis_object.r_eff


@tullio pointwise_loo[i] := weights[i, j, k] * exp(log_likelihood[i, j, k]) |> log
@tullio pointwise_naive[i] := exp(log_likelihood[i, j, k] - log_count) |> log
pointwise_p_eff = pointwise_naive - pointwise_loo
Expand All @@ -174,8 +171,7 @@ function loo_from_psis(log_likelihood::AbstractArray{<:Real, 3}, psis_object::Ps
table = _generate_loo_table(pointwise)

gmpd = exp.(table(column=:mean, statistic=:cv_elpd))

@tullio mcse := pointwise_mcse[i]^2
@tullio mcse := pointwise_mcse[i]^2
mcse = sqrt(mcse)

return PsisLoo(table, pointwise, psis_object, gmpd, mcse)
Expand Down Expand Up @@ -203,7 +199,7 @@ function _generate_loo_table(pointwise::AbstractMatrix{<:Real})

# calculate the sample expectation for the total score
to_sum = pointwise([:cv_elpd, :naive_lpd, :p_eff])
@tullio avgs[statistic] := to_sum[data, statistic] / data_size
@tullio avgs[statistic] := to_sum[data, statistic] |> _ / data_size
avgs = reshape(avgs, 3)
table(:, :mean) .= avgs

Expand All @@ -228,12 +224,12 @@ end


function _calc_mcse(weights, log_likelihood, pointwise_loo, r_eff)
@turbo E_epd = exp.(pointwise_loo)
pointwise_gmpd = exp.(pointwise_loo)
@tullio pointwise_var[i] :=
(weights[i, j, k] * (exp(log_likelihood[i, j, k]) - E_epd[i]))^2
# If MCMC draws follow a log-normal distribution, then their log has this std. error:
@turbo @. pointwise_var = log1p(pointwise_var / E_epd^2)
# (google "log-normal method of moments" for a proof)
(weights[i, j, k] * (exp(log_likelihood[i, j, k]) - pointwise_gmpd[i]))^2
# If MCMC draws follow a log-normal distribution, we can use method of moments to est
# the standard deviation of their log:
@turbo @. pointwise_var = log1p(pointwise_var / pointwise_gmpd^2)
# apply MCMC correlation correction:
return @turbo @. sqrt(pointwise_var / r_eff)
end
11 changes: 4 additions & 7 deletions src/ModelComparison.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,6 @@ import Base.show

export loo_compare, ModelComparison

const LOO_COMPARE_KWARGS = """
- `model_names`: A vector or tuple of strings or symbols used to identify models. If
none, models are numbered using the order of the arguments.
- `sort_models`: Sort models by total score.
- `high_to_low`: Sort models from best to worst score. If `false`, reverse the order.
"""

"""
ModelComparison
Expand Down Expand Up @@ -65,7 +59,10 @@ Construct a model comparison table from several [`PsisLoo`](@ref) objects.
- `cv_results`: One or more [`PsisLoo`](@ref) objects to be compared. Alternatively,
a tuple or named tuple of `PsisLoo` objects can be passed. If a named tuple is passed,
these names will be used to label each model.
- $LOO_COMPARE_KWARGS
- `model_names`: A vector or tuple of strings or symbols used to identify models. If
none, models are numbered using the order of the arguments.
- `sort_models`: Sort models by total score.
- `high_to_low`: Sort models from best to worst score. If `false`, reverse the order.
See also: [`ModelComparison`](@ref), [`PsisLoo`](@ref), [`psis_loo`](@ref)
"""
Expand Down
10 changes: 8 additions & 2 deletions src/NaiveLPD.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@ using Tullio


"""
$(TYPEDSIGNATURES)
naive_lpd(log_likelihood::AbstractArray{<:Real}, chain_index::Vector}<:Int)
Calculate the naive (in-sample) estimate of the expected log probability density, otherwise
known as the in-sample Bayes score. Not recommended for most uses.
# Arguments
- $LIKELIHOOD_ARRAY_ARG
- $CHAIN_INDEX_DOC
"""
function naive_lpd(log_likelihood::AbstractArray{<:Real, 3})
@info "We advise against using `naive_lpd`, as it gives inconsistent and strongly " *
Expand All @@ -18,6 +22,7 @@ function naive_lpd(
log_likelihood::AbstractMatrix{<:Real},
chain_index::AbstractVector{<:Integer} = _assume_one_chain(log_likelihood)
)
@nospecialize(chain_index)
log_likelihood = _convert_to_array(log_likelihood, chain_index)
return _naive_lpd(log_likelihood)
end
Expand All @@ -29,6 +34,7 @@ function _naive_lpd(log_likelihood::AbstractArray{<:Real, 3})
mcmc_count = dims[2] * dims[3] # total number of samples from posterior
log_count = log(mcmc_count)

@tullio pointwise_naive[i] := exp(log_likelihood[i, j, k] - log_count) |> log
pointwise_naive = similar(log_likelihood, data_size)
@tullio pointwise_naive[i] = exp(log_likelihood[i, j, k] - log_count) |> log
return @tullio naive := pointwise_naive[i]
end
17 changes: 8 additions & 9 deletions src/TuringHelpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ const TURING_MODEL_ARG = """


"""
$(TYPEDSIGNATURES) -> Array
-> Array
Compute pointwise log-likelihoods from a Turing model.
Expand Down Expand Up @@ -38,7 +38,7 @@ end


"""
$(TYPEDSIGNATURES) -> PsisLoo
psis_loo(model::DynamicPPL.Model, chains::Chains, args...; kwargs...) -> PsisLoo
Use Pareto-Smoothed Importance Sampling to calculate the leave-one-out cross validation
score from a `chains` object and a Turing model.
Expand All @@ -59,7 +59,7 @@ end


"""
$(TYPEDSIGNATURES) -> PsisLoo
psis_loo(model::DynamicPPL.Model, chains::Chains, psis::Psis) -> PsisLoo
Use Pareto-Smoothed Importance Sampling to calculate the leave-one-out cross validation
score from a `Chains` object, a Turing model, and a precalculated `Psis` object.
Expand All @@ -68,19 +68,18 @@ score from a `Chains` object, a Turing model, and a precalculated `Psis` object.
- $CHAINS_ARG
- $TURING_MODEL_ARG
- $ARGS [`psis`](@ref).
- $KWARGS [`psis`](@ref).
- `psis`: A `Psis` object containing the results of Pareto smoothed importance sampling.
See also: [`psis`](@ref), [`psis_loo`](@ref), [`PsisLoo`](@ref).
"""
function loo_from_psis(model::DynamicPPL.Model, chains::Chains, args...; kwargs...)
function loo_from_psis(model::DynamicPPL.Model, chains::Chains, psis::Psis)
pointwise_log_likes = pointwise_log_likelihoods(model, chains)
return loo_from_psis(pointwise_log_likes, args...; kwargs...)
return loo_from_psis(pointwise_log_likes, psis)
end


"""
$(TYPEDSIGNATURES) -> Psis
psis(model::DynamicPPL.Model, chains::Chains, args...; kwargs...) -> Psis
Generate samples using Pareto smoothed importance sampling (PSIS).
Expand Down
3 changes: 1 addition & 2 deletions test/tests/BasicTests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ import RData
log_lik_mat = reshape(log_lik_arr, 32, 1000)
chain_index = vcat(fill(1, 500), fill(2, 500))
matrix_psis = psis(log_lik_mat; chain_index=chain_index)
log_psis = psis(log_lik_arr; log_weights=true)

jul_loo = psis_loo(log_lik_arr)
r_eff_loo = psis_loo(log_lik_arr; r_eff=r_eff)
Expand All @@ -56,7 +55,7 @@ import RData
# RMSE less than .2% when using InferenceDiagnostics' ESS
@test sqrt(mean((jul_psis.weights ./ r_weights .- 1) .^ 2)) 0.002
# Max difference is 1%
@test maximum(log_psis.weights .- log.(r_weights)) 0.01
@test maximum(log.(jul_psis.weights) .- log.(r_weights)) 0.01


## Test difference in loo pointwise results
Expand Down

0 comments on commit 425404f

Please sign in to comment.