Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Move from implicitly mapped measures and kernels to data tagged as mapped #155

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion ext/MeasureBaseChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,16 @@

# = return type inference ====================================================

using MeasureBase: logdensityof_rt
using MeasureBase: logdensityof_rt, strict_logdensityof_rt

_logdensityof_rt_pullback(::Any) = (NoTangent(), NoTangent(), ZeroTangent())
function ChainRulesCore.rrule(::typeof(logdensityof_rt), target, v)
logdensityof_rt(target, v), _logdensityof_rt_pullback
end

_strict_logdensityof_rt_pullback(::Any) = (NoTangent(), NoTangent(), ZeroTangent())
function ChainRulesCore.rrule(::typeof(strict_logdensityof_rt), target, v)
strict_logdensityof_rt(target, v), _strict_logdensityof_rt_pullback

Check warning on line 56 in ext/MeasureBaseChainRulesCoreExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MeasureBaseChainRulesCoreExt.jl#L54-L56

Added lines #L54 - L56 were not covered by tests
end

end # module MeasureBaseChainRulesCoreExt
4 changes: 2 additions & 2 deletions src/combinators/half.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ function Base.rand(rng::AbstractRNG, ::Type{T}, μ::Half) where {T}
return abs(rand(rng, T, unhalf(μ)))
end

function logdensityof(μ::Half, x)
ld = logdensityof(unhalf(μ), x) - loghalf
function strict_logdensityof(μ::Half, x)
ld = strict_logdensityof(unhalf(μ), x) - loghalf
Comment on lines +22 to +23
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to have another look at this - should it just be logdensity_def?

If we need to merge before this is resolved, let's add a #TODO comment

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this is part of specializing logdensityof (e242314) to avoid the multi-step logdensity_def machinery if users don't need logdensity_rel (until we revamp that machinery to make it more type-stable and Zygote-friendly).

return x ≥ 0 ? ld : oftype(ld, -Inf)
end

Expand Down
11 changes: 7 additions & 4 deletions src/combinators/likelihood.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
export AbstractLikelihood, Likelihood

abstract type AbstractLikelihood end
(lik::AbstractLikelihood)(p) = exp(ULogarithmic, logdensityof(lik.k(p), lik.x))

Check warning on line 4 in src/combinators/likelihood.jl

View check run for this annotation

Codecov / codecov/patch

src/combinators/likelihood.jl#L4

Added line #L4 was not covered by tests

DensityInterface.DensityKind(::AbstractLikelihood) = IsDensity()

Check warning on line 6 in src/combinators/likelihood.jl

View check run for this annotation

Codecov / codecov/patch

src/combinators/likelihood.jl#L6

Added line #L6 was not covered by tests

Base.:∘(::typeof(log), lik::AbstractLikelihood) = logdensityof(lik)

Check warning on line 8 in src/combinators/likelihood.jl

View check run for this annotation

Codecov / codecov/patch

src/combinators/likelihood.jl#L8

Added line #L8 was not covered by tests

# @inline function logdensityof(ℓ::AbstractLikelihood, p)
# t() = dynamic(unsafe_logdensityof(ℓ, p))
Expand All @@ -11,6 +16,7 @@
# insupport(ℓ::AbstractLikelihood, p) = insupport(ℓ.k(p), ℓ.x)

@doc raw"""
Likelihood(k::Base.Callable, x)
Likelihood(k::AbstractTransitionKernel, x)

"Observe" a value `x`, yielding a function from the parameters to ℝ.
Expand Down Expand Up @@ -117,14 +123,11 @@
x::X

Likelihood(k::K, x::X) where {K<:AbstractTransitionKernel,X} = new{K,X}(k, x)
Likelihood(::Type{K}, x::X) where {K,X} = new{Type{K},X}(K, x)

Check warning on line 126 in src/combinators/likelihood.jl

View check run for this annotation

Codecov / codecov/patch

src/combinators/likelihood.jl#L126

Added line #L126 was not covered by tests
Likelihood(k::K, x::X) where {K<:Function,X} = new{K,X}(k, x)
Likelihood(μ, x) = Likelihood(kernel(μ), x)
end

(lik::AbstractLikelihood)(p) = exp(ULogarithmic, logdensityof(lik.k(p), lik.x))

DensityInterface.DensityKind(::AbstractLikelihood) = IsDensity()

function Pretty.quoteof(ℓ::Likelihood)
k = Pretty.quoteof(ℓ.k)
x = Pretty.quoteof(ℓ.x)
Expand Down
2 changes: 1 addition & 1 deletion src/combinators/power.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ params(d::PowerMeasure) = params(first(marginals(d)))
basemeasure(d.parent)^d.axes
end

for func in [:logdensityof, :logdensity_def]
for func in [:strict_logdensityof, :logdensity_def]
@eval @inline function $func(d::PowerMeasure{M}, x) where {M}
parent = d.parent
sum(x) do xj
Expand Down
6 changes: 3 additions & 3 deletions src/combinators/product.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
end |> collect
end

for func in [:logdensityof, :logdensity_def]
for func in [:strict_logdensityof, :logdensity_def]
@eval @inline function $func(d::AbstractProductMeasure, x)
mapreduce($func, +, marginals(d), x)
end
Expand All @@ -82,7 +82,7 @@
marginals::M
end

@inline function logdensity_rel(μ::ProductMeasure, ν::ProductMeasure, x)
@inline function strict_logdensity_rel(μ::ProductMeasure, ν::ProductMeasure, x)

Check warning on line 85 in src/combinators/product.jl

View check run for this annotation

Codecov / codecov/patch

src/combinators/product.jl#L85

Added line #L85 was not covered by tests
mapreduce(logdensity_rel, +, marginals(μ), marginals(ν), x)
end

Expand All @@ -109,7 +109,7 @@
return q
end

for func in [:logdensityof, :logdensity_def]
for func in [:strict_logdensityof, :logdensity_def]
# For tuples, `mapreduce` has trouble with type inference
@eval @inline function $func(d::ProductMeasure{T}, x) where {T<:Tuple}
ℓs = map($func, marginals(d), x)
Expand Down
2 changes: 1 addition & 1 deletion src/combinators/spikemixture.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ end
SpikeMixture(basemeasure(μ.m), static(1.0), static(1.0))
end

for func in [:logdensityof, :logdensity_def]
for func in [:strict_logdensityof, :logdensity_def]
@eval @inline function $func(μ::SpikeMixture, x)
# NOTE: We could instead write this as
# R1 = typeof(log(one(μ.s)))
Expand Down
67 changes: 51 additions & 16 deletions src/combinators/transformedmeasure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,15 @@
end

# TODO: THIS IS ALMOST CERTAINLY WRONG
# @inline function logdensity_rel(
# @inline function strict_logdensity_rel(
# ν::PushforwardMeasure{FF1,IF1,M1,<:AdaptRootMeasure},
# β::PushforwardMeasure{FF2,IF2,M2,<:AdaptRootMeasure},
# y,
# ) where {FF1,IF1,M1,FF2,IF2,M2}
# x = β.inv_f(y)
# f = ν.inv_f ∘ β.f
# inv_f = β.inv_f ∘ ν.f
# logdensity_rel(pushfwd(f, inv_f, ν.origin, AdaptRootMeasure()), β.origin, x)
# strict_logdensity_rel(pushfwd(f, inv_f, ν.origin, AdaptRootMeasure()), β.origin, x)
# end

# TODO: Would profit from custom pullback:
Expand All @@ -132,7 +132,7 @@
end
end

function logdensityof(
function strict_logdensityof(
@nospecialize(μ::_NonBijectivePusfwdMeasure{M,<:PushfwdRootMeasure}),
@nospecialize(v::Any)
) where {M}
Expand All @@ -143,7 +143,7 @@
)
end

function logdensityof(
function strict_logdensityof(
@nospecialize(μ::_NonBijectivePusfwdMeasure{M,<:AdaptRootMeasure}),
@nospecialize(v::Any)
) where {M}
Expand All @@ -154,7 +154,7 @@
)
end

for func in [:logdensityof, :logdensity_def]
for func in [:strict_logdensityof, :logdensity_def]
@eval function $func(ν::PushforwardMeasure{F,I,M,<:AdaptRootMeasure}, y) where {F,I,M}
f_inv = unwrap(ν.finv)
x, inv_ladj = with_logabsdet_jacobian(f_inv, y)
Expand Down Expand Up @@ -222,25 +222,52 @@
function pushfwd end
export pushfwd

@inline pushfwd(f, μ) = _pushfwd_impl(f, μ, AdaptRootMeasure())
@inline pushfwd(f, μ, style::AdaptRootMeasure) = _pushfwd_impl(f, μ, style)
@inline pushfwd(f, μ, style::PushfwdRootMeasure) = _pushfwd_impl(f, μ, style)
@inline pushfwd(f, μ) = _pushfwd_impl1(f, μ, AdaptRootMeasure())
@inline pushfwd(f, μ, style::AdaptRootMeasure) = _pushfwd_impl1(f, μ, style)
@inline pushfwd(f, μ, style::PushfwdRootMeasure) = _pushfwd_impl1(f, μ, style)

_pushfwd_impl(f, μ, style) = PushforwardMeasure(f, inverse(f), μ, style)
_pushfwd_impl1(f, μ, style::PushFwdStyle) = _pushfwd_impl2(f, inverse(f), μ, style)
_pushfwd_impl1(::typeof(identity), μ, ::AdaptRootMeasure) = μ
_pushfwd_impl1(::typeof(identity), μ, ::PushfwdRootMeasure) = μ

function _pushfwd_impl(
_pushfwd_impl2(f, finv, μ, style::PushFwdStyle) = PushforwardMeasure(f, finv, μ, style)

function _pushfwd_impl2(
f,
finv,
μ::PushforwardMeasure{F,I,M,S},
style::S,
) where {F,I,M,S<:PushFwdStyle}
orig_μ = μ.origin
new_f = fcomp(f, μ.f)
new_f_inv = fcomp(μ.finv, inverse(f))
new_f_inv = fcomp(μ.finv, finv)
PushforwardMeasure(new_f, new_f_inv, orig_μ, style)
end

_pushfwd_impl(::typeof(identity), μ, ::AdaptRootMeasure) = μ
_pushfwd_impl(::typeof(identity), μ, ::PushfwdRootMeasure) = μ
struct _CurriedPushfwd{F,I,S<:PushFwdStyle} <: Function
f::F
finv::I
style::S

function _CurriedPushfwd{F,I,S}(f::F, finv::I, style::S) where {F,I,S<:PushFwdStyle}
new{F,I,S}(f, finv, style)

Check warning on line 253 in src/combinators/transformedmeasure.jl

View check run for this annotation

Codecov / codecov/patch

src/combinators/transformedmeasure.jl#L252-L253

Added lines #L252 - L253 were not covered by tests
end

function _CurriedPushfwd(f, finv, style::S) where {S<:PushFwdStyle}
new{Core.Typeof(f),Core.Typeof(finv),S}(f, finv, style)
end
end

@inline (cf::_CurriedPushfwd{F,FI})(μ) where {F,FI} =
_pushfwd_impl2(cf.f, cf.finv, μ, cf.style)

@inline pushfwd(f) = _curried_pushfwd_impl(f, AdaptRootMeasure())
@inline pushfwd(f, style::AdaptRootMeasure) = _curried_pushfwd_impl(f, style)
@inline pushfwd(f, style::PushfwdRootMeasure) = _curried_pushfwd_impl(f, style)

_curried_pushfwd_impl(f, style::PushFwdStyle) = _CurriedPushfwd(f, inverse(f), style)
@inline _curried_pushfwd_impl(::typeof(identity), ::AdaptRootMeasure) = identity
@inline _curried_pushfwd_impl(::typeof(identity), ::PushfwdRootMeasure) = identity

Check warning on line 270 in src/combinators/transformedmeasure.jl

View check run for this annotation

Codecov / codecov/patch

src/combinators/transformedmeasure.jl#L269-L270

Added lines #L269 - L270 were not covered by tests

###############################################################################
# pullback
Expand All @@ -267,8 +294,16 @@
@inline pullbck(f, μ, style::AdaptRootMeasure) = _pullback_impl(f, μ, style)
@inline pullbck(f, μ, style::PushfwdRootMeasure) = _pullback_impl(f, μ, style)

function _pullback_impl(f, μ, style = AdaptRootMeasure())
pushfwd(inverse(f), μ, style)
end
_pullback_impl(f, μ, style::PushFwdStyle) = _pushfwd_impl2(inverse(f), f, μ, style)
_pullback_impl(::typeof(identity), μ, ::AdaptRootMeasure) = μ
_pullback_impl(::typeof(identity), μ, ::PushfwdRootMeasure) = μ

@inline pullbck(f) = _curried_pullbck_impl(f, AdaptRootMeasure())
@inline pullbck(f, style::AdaptRootMeasure) = _curried_pullbck_impl(f, style)
@inline pullbck(f, style::PushfwdRootMeasure) = _curried_pullbck_impl(f, style)

_curried_pullbck_impl(f, style::PushFwdStyle) = _CurriedPushfwd(inverse(f), f, style)
@inline _curried_pullbck_impl(::typeof(identity), ::AdaptRootMeasure) = identity
@inline _curried_pullbck_impl(::typeof(identity), ::PushfwdRootMeasure) = identity

Check warning on line 307 in src/combinators/transformedmeasure.jl

View check run for this annotation

Codecov / codecov/patch

src/combinators/transformedmeasure.jl#L306-L307

Added lines #L306 - L307 were not covered by tests

@deprecate pullback(f, μ, style::PushFwdStyle = AdaptRootMeasure()) pullbck(f, μ, style)
53 changes: 44 additions & 9 deletions src/density-core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,15 @@
`logdensity_def`.

To compute a log-density relative to a specific base-measure, see
`logdensity_rel`.
`logdensity_rel`.

# Implementation

Do not specialize `logdensityof` directly for subtypes of `AbstractMeasure`,
specialize `MeasureBase.logdensity_def` and `MeasureBase.strict_logdensityof` instead.
"""
@inline function logdensityof(μ::AbstractMeasure, x)
result = dynamic(unsafe_logdensityof(μ, x))
_checksupport(insupport(μ, x), result)
@inline function logdensityof(μ::AbstractMeasure, x) #!!!!!!!!!!!!!!!!!
strict_logdensityof(μ, x)
end

@inline function logdensityof_rt(::T, ::U) where {T,U}
Expand All @@ -41,6 +45,24 @@

export unsafe_logdensityof

"""
MeasureBase.strict_logdensityof(μ, x)

Compute the log-density of the measure `μ` at `x` relative to `rootmeasure(m)`.
In contrast to [`logdensityof(μ, x)`](@ref), this will not take implicit pushforwards
of `μ` (depending on the type of `x`) into account.
"""
function strict_logdensityof end

@inline function strict_logdensityof(μ, x)
result = dynamic(unsafe_logdensityof(μ, x))
_checksupport(insupport(μ, x), result)
end

@inline function strict_logdensityof_rt(::T, ::U) where {T,U}
Core.Compiler.return_type(strict_logdensityof, Tuple{T,U})

Check warning on line 63 in src/density-core.jl

View check run for this annotation

Codecov / codecov/patch

src/density-core.jl#L62-L63

Added lines #L62 - L63 were not covered by tests
end

# https://discourse.julialang.org/t/counting-iterations-to-a-type-fixpoint/75876/10?u=cscherrer
"""
unsafe_logdensityof(m, x)
Expand Down Expand Up @@ -68,14 +90,27 @@
end

"""
logdensity_rel(m1, m2, x)
logdensity_rel(μ, ν, x)

Compute the log-density of `m1` relative to `m2` at `x`. This function checks
whether `x` is in the support of `m1` or `m2` (or both, or neither). If `x` is
Compute the log-density of `μ` relative to `ν` at `x`. This function checks
whether `x` is in the support of `μ` or `ν` (or both, or neither). If `x` is
known to be in the support of both, it can be more efficient to call
`unsafe_logdensity_rel`.
`unsafe_logdensity_rel`.
"""
function logdensity_rel(μ, ν, x)
strict_logdensity_rel(μ, ν, x)
end

"""
@inline function logdensity_rel(μ::M, ν::N, x::X) where {M,N,X}
MeasureBase.strict_logdensity_rel(μ, ν, x)

Compute the log-density of `μ` relative to `ν` at `x`. In contrast to
[`logdensity_rel(μ, ν, x)`](@ref), this will not take implicit pushforwards
of `μ` and `ν` (depending on the type of `x`) into account.
"""
function strict_logdensity_rel end

@inline function strict_logdensity_rel(μ::M, ν::N, x::X) where {M,N,X}
T = unstatic(
promote_type(
return_type(logdensity_def, (μ, x)),
Expand Down
4 changes: 2 additions & 2 deletions src/density.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,10 @@ logdensity_def(μ::DensityMeasure, x) = logdensityof(μ.f, x)

density_def(μ::DensityMeasure, x) = densityof(μ.f, x)

function logdensityof(μ::DensityMeasure, x::Any)
function strict_logdensityof(μ::DensityMeasure, x::Any)
integrand, μ_base = μ.f, μ.base

base_logval = logdensityof(μ_base, x)
base_logval = strict_logdensityof(μ_base, x)

T = typeof(base_logval)
U = logdensityof_rt(integrand, x)
Expand Down
1 change: 1 addition & 0 deletions src/kernel.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
export AbstractTransitionKernel,
GenericTransitionKernel, TypedTransitionKernel, ParameterizedTransitionKernel

# ToDo (breaking): A transition kernel should be a Function, not an AbstractMeasure.
abstract type AbstractTransitionKernel <: AbstractMeasure end

struct GenericTransitionKernel{F} <: AbstractTransitionKernel
Expand Down
4 changes: 2 additions & 2 deletions src/primitive.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@

@inline basemeasure_depth(::PrimitiveMeasure) = static(0)

@inline logdensityof(::PrimitiveMeasure, x::Real) = zero(float(typeof(x)))
@inline logdensityof(::PrimitiveMeasure, x) = false
@inline strict_logdensityof(::PrimitiveMeasure, x::Real) = zero(float(typeof(x)))
@inline strict_logdensityof(::PrimitiveMeasure, x) = false

Check warning on line 23 in src/primitive.jl

View check run for this annotation

Codecov / codecov/patch

src/primitive.jl#L23

Added line #L23 was not covered by tests

logdensity_def(::PrimitiveMeasure, x) = static(0.0)

Expand Down
6 changes: 3 additions & 3 deletions src/primitives/counting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
Counting(supp) = new{Core.Typeof(supp)}(supp)
end

@inline function logdensityof(μ::Counting, x::Real)
@inline function strict_logdensityof(μ::Counting, x::Real)
R = float(typeof(x))
insupport(μ, x) ? zero(R) : R(-Inf)
end

@inline logdensityof(μ::Counting, x) = insupport(μ, x) ? 0.0 : -Inf
@inline strict_logdensityof(μ::Counting, x) = insupport(μ, x) ? 0.0 : -Inf

Check warning on line 20 in src/primitives/counting.jl

View check run for this annotation

Codecov / codecov/patch

src/primitives/counting.jl#L20

Added line #L20 was not covered by tests

@inline logdensity_def(μ::Counting, x) = logdensityof(μ, x)
@inline logdensity_def(μ::Counting, x) = strict_logdensityof(μ, x)

basemeasure(::Counting) = CountingBase()

Expand Down
4 changes: 2 additions & 2 deletions src/primitives/dirac.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@

massof(::Dirac) = static(1.0)

function logdensityof(μ::Dirac, x::Real)
function strict_logdensityof(μ::Dirac, x::Real)
R = float(typeof(x))
insupport(μ, x) ? zero(R) : R(-Inf)
end

logdensityof(μ::Dirac, x) = insupport(μ, x) ? 0.0 : -Inf
strict_logdensityof(μ::Dirac, x) = insupport(μ, x) ? 0.0 : -Inf

Check warning on line 28 in src/primitives/dirac.jl

View check run for this annotation

Codecov / codecov/patch

src/primitives/dirac.jl#L28

Added line #L28 was not covered by tests

logdensity_def(::Dirac, x::Real) = zero(float(typeof(x)))
logdensity_def(::Dirac, x) = 0.0
Expand Down
4 changes: 2 additions & 2 deletions src/primitives/lebesgue.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,12 @@

insupport(::Lebesgue{RealNumbers}, ::Real) = true

@inline function logdensityof(μ::Lebesgue, x::Real)
@inline function strict_logdensityof(μ::Lebesgue, x::Real)
R = float(typeof(x))
insupport(μ, x) ? zero(R) : R(-Inf)
end

@inline logdensityof(μ::Lebesgue, x) = insupport(μ, x) ? 0.0 : -Inf
@inline strict_logdensityof(μ::Lebesgue, x) = insupport(μ, x) ? 0.0 : -Inf

Check warning on line 71 in src/primitives/lebesgue.jl

View check run for this annotation

Codecov / codecov/patch

src/primitives/lebesgue.jl#L71

Added line #L71 was not covered by tests

massof(::Lebesgue{RealNumbers}, s::Interval) = width(s)

Expand Down
Loading
Loading