diff --git a/ext/MeasureBaseChainRulesCoreExt.jl b/ext/MeasureBaseChainRulesCoreExt.jl index 57ed25fa..9fb3d347 100644 --- a/ext/MeasureBaseChainRulesCoreExt.jl +++ b/ext/MeasureBaseChainRulesCoreExt.jl @@ -44,11 +44,16 @@ ChainRulesCore.rrule(::typeof(checked_arg), ν, x) = checked_arg(ν, x), _checke # = return type inference ==================================================== -using MeasureBase: logdensityof_rt +using MeasureBase: logdensityof_rt, strict_logdensityof_rt _logdensityof_rt_pullback(::Any) = (NoTangent(), NoTangent(), ZeroTangent()) function ChainRulesCore.rrule(::typeof(logdensityof_rt), target, v) logdensityof_rt(target, v), _logdensityof_rt_pullback end +_strict_logdensityof_rt_pullback(::Any) = (NoTangent(), NoTangent(), ZeroTangent()) +function ChainRulesCore.rrule(::typeof(strict_logdensityof_rt), target, v) + strict_logdensityof_rt(target, v), _strict_logdensityof_rt_pullback +end + end # module MeasureBaseChainRulesCoreExt diff --git a/src/combinators/half.jl b/src/combinators/half.jl index 24063b76..eede8c04 100644 --- a/src/combinators/half.jl +++ b/src/combinators/half.jl @@ -19,8 +19,8 @@ function Base.rand(rng::AbstractRNG, ::Type{T}, μ::Half) where {T} return abs(rand(rng, T, unhalf(μ))) end -function logdensityof(μ::Half, x) - ld = logdensityof(unhalf(μ), x) - loghalf +function strict_logdensityof(μ::Half, x) + ld = strict_logdensityof(unhalf(μ), x) - loghalf return x ≥ 0 ? ld : oftype(ld, -Inf) end diff --git a/src/combinators/likelihood.jl b/src/combinators/likelihood.jl index 6dfd164f..f93b3829 100644 --- a/src/combinators/likelihood.jl +++ b/src/combinators/likelihood.jl @@ -1,6 +1,11 @@ export AbstractLikelihood, Likelihood abstract type AbstractLikelihood end +(lik::AbstractLikelihood)(p) = exp(ULogarithmic, logdensityof(lik.k(p), lik.x)) + +DensityInterface.DensityKind(::AbstractLikelihood) = IsDensity() + +Base.:∘(::typeof(log), lik::AbstractLikelihood) = logdensityof(lik) # @inline function logdensityof(ℓ::AbstractLikelihood, p) # t() = dynamic(unsafe_logdensityof(ℓ, p)) @@ -11,6 +16,7 @@ abstract type AbstractLikelihood end # insupport(ℓ::AbstractLikelihood, p) = insupport(ℓ.k(p), ℓ.x) @doc raw""" + Likelihood(k::Base.Callable, x) Likelihood(k::AbstractTransitionKernel, x) "Observe" a value `x`, yielding a function from the parameters to ℝ. @@ -117,14 +123,11 @@ struct Likelihood{K,X} <: AbstractLikelihood x::X Likelihood(k::K, x::X) where {K<:AbstractTransitionKernel,X} = new{K,X}(k, x) + Likelihood(::Type{K}, x::X) where {K,X} = new{Type{K},X}(K, x) Likelihood(k::K, x::X) where {K<:Function,X} = new{K,X}(k, x) Likelihood(μ, x) = Likelihood(kernel(μ), x) end -(lik::AbstractLikelihood)(p) = exp(ULogarithmic, logdensityof(lik.k(p), lik.x)) - -DensityInterface.DensityKind(::AbstractLikelihood) = IsDensity() - function Pretty.quoteof(ℓ::Likelihood) k = Pretty.quoteof(ℓ.k) x = Pretty.quoteof(ℓ.x) diff --git a/src/combinators/power.jl b/src/combinators/power.jl index e6397c3f..d29fe4f5 100644 --- a/src/combinators/power.jl +++ b/src/combinators/power.jl @@ -78,7 +78,7 @@ params(d::PowerMeasure) = params(first(marginals(d))) basemeasure(d.parent)^d.axes end -for func in [:logdensityof, :logdensity_def] +for func in [:strict_logdensityof, :logdensity_def] @eval @inline function $func(d::PowerMeasure{M}, x) where {M} parent = d.parent sum(x) do xj diff --git a/src/combinators/product.jl b/src/combinators/product.jl index 0290419d..bc2213b4 100644 --- a/src/combinators/product.jl +++ b/src/combinators/product.jl @@ -72,7 +72,7 @@ function _rand_product( end |> collect end -for func in [:logdensityof, :logdensity_def] +for func in [:strict_logdensityof, :logdensity_def] @eval @inline function $func(d::AbstractProductMeasure, x) mapreduce($func, +, marginals(d), x) end @@ -82,7 +82,7 @@ struct ProductMeasure{M} <: AbstractProductMeasure marginals::M end -@inline function logdensity_rel(μ::ProductMeasure, ν::ProductMeasure, x) +@inline function strict_logdensity_rel(μ::ProductMeasure, ν::ProductMeasure, x) mapreduce(logdensity_rel, +, marginals(μ), marginals(ν), x) end @@ -109,7 +109,7 @@ end return q end -for func in [:logdensityof, :logdensity_def] +for func in [:strict_logdensityof, :logdensity_def] # For tuples, `mapreduce` has trouble with type inference @eval @inline function $func(d::ProductMeasure{T}, x) where {T<:Tuple} ℓs = map($func, marginals(d), x) diff --git a/src/combinators/spikemixture.jl b/src/combinators/spikemixture.jl index e39d4230..39751b03 100644 --- a/src/combinators/spikemixture.jl +++ b/src/combinators/spikemixture.jl @@ -21,7 +21,7 @@ end SpikeMixture(basemeasure(μ.m), static(1.0), static(1.0)) end -for func in [:logdensityof, :logdensity_def] +for func in [:strict_logdensityof, :logdensity_def] @eval @inline function $func(μ::SpikeMixture, x) # NOTE: We could instead write this as # R1 = typeof(log(one(μ.s))) diff --git a/src/combinators/transformedmeasure.jl b/src/combinators/transformedmeasure.jl index c9db7a6b..081d1816 100644 --- a/src/combinators/transformedmeasure.jl +++ b/src/combinators/transformedmeasure.jl @@ -103,7 +103,7 @@ function Pretty.tile(ν::PushforwardMeasure) end # TODO: THIS IS ALMOST CERTAINLY WRONG -# @inline function logdensity_rel( +# @inline function strict_logdensity_rel( # ν::PushforwardMeasure{FF1,IF1,M1,<:AdaptRootMeasure}, # β::PushforwardMeasure{FF2,IF2,M2,<:AdaptRootMeasure}, # y, @@ -111,7 +111,7 @@ end # x = β.inv_f(y) # f = ν.inv_f ∘ β.f # inv_f = β.inv_f ∘ ν.f -# logdensity_rel(pushfwd(f, inv_f, ν.origin, AdaptRootMeasure()), β.origin, x) +# strict_logdensity_rel(pushfwd(f, inv_f, ν.origin, AdaptRootMeasure()), β.origin, x) # end # TODO: Would profit from custom pullback: @@ -132,7 +132,7 @@ function _combine_logd_with_ladj(logd_orig::Real, ladj::Real) end end -function logdensityof( +function strict_logdensityof( @nospecialize(μ::_NonBijectivePusfwdMeasure{M,<:PushfwdRootMeasure}), @nospecialize(v::Any) ) where {M} @@ -143,7 +143,7 @@ function logdensityof( ) end -function logdensityof( +function strict_logdensityof( @nospecialize(μ::_NonBijectivePusfwdMeasure{M,<:AdaptRootMeasure}), @nospecialize(v::Any) ) where {M} @@ -154,7 +154,7 @@ function logdensityof( ) end -for func in [:logdensityof, :logdensity_def] +for func in [:strict_logdensityof, :logdensity_def] @eval function $func(ν::PushforwardMeasure{F,I,M,<:AdaptRootMeasure}, y) where {F,I,M} f_inv = unwrap(ν.finv) x, inv_ladj = with_logabsdet_jacobian(f_inv, y) @@ -222,25 +222,52 @@ To manually specify an inverse, call function pushfwd end export pushfwd -@inline pushfwd(f, μ) = _pushfwd_impl(f, μ, AdaptRootMeasure()) -@inline pushfwd(f, μ, style::AdaptRootMeasure) = _pushfwd_impl(f, μ, style) -@inline pushfwd(f, μ, style::PushfwdRootMeasure) = _pushfwd_impl(f, μ, style) +@inline pushfwd(f, μ) = _pushfwd_impl1(f, μ, AdaptRootMeasure()) +@inline pushfwd(f, μ, style::AdaptRootMeasure) = _pushfwd_impl1(f, μ, style) +@inline pushfwd(f, μ, style::PushfwdRootMeasure) = _pushfwd_impl1(f, μ, style) -_pushfwd_impl(f, μ, style) = PushforwardMeasure(f, inverse(f), μ, style) +_pushfwd_impl1(f, μ, style::PushFwdStyle) = _pushfwd_impl2(f, inverse(f), μ, style) +_pushfwd_impl1(::typeof(identity), μ, ::AdaptRootMeasure) = μ +_pushfwd_impl1(::typeof(identity), μ, ::PushfwdRootMeasure) = μ -function _pushfwd_impl( +_pushfwd_impl2(f, finv, μ, style::PushFwdStyle) = PushforwardMeasure(f, finv, μ, style) + +function _pushfwd_impl2( f, + finv, μ::PushforwardMeasure{F,I,M,S}, style::S, ) where {F,I,M,S<:PushFwdStyle} orig_μ = μ.origin new_f = fcomp(f, μ.f) - new_f_inv = fcomp(μ.finv, inverse(f)) + new_f_inv = fcomp(μ.finv, finv) PushforwardMeasure(new_f, new_f_inv, orig_μ, style) end -_pushfwd_impl(::typeof(identity), μ, ::AdaptRootMeasure) = μ -_pushfwd_impl(::typeof(identity), μ, ::PushfwdRootMeasure) = μ +struct _CurriedPushfwd{F,I,S<:PushFwdStyle} <: Function + f::F + finv::I + style::S + + function _CurriedPushfwd{F,I,S}(f::F, finv::I, style::S) where {F,I,S<:PushFwdStyle} + new{F,I,S}(f, finv, style) + end + + function _CurriedPushfwd(f, finv, style::S) where {S<:PushFwdStyle} + new{Core.Typeof(f),Core.Typeof(finv),S}(f, finv, style) + end +end + +@inline (cf::_CurriedPushfwd{F,FI})(μ) where {F,FI} = + _pushfwd_impl2(cf.f, cf.finv, μ, cf.style) + +@inline pushfwd(f) = _curried_pushfwd_impl(f, AdaptRootMeasure()) +@inline pushfwd(f, style::AdaptRootMeasure) = _curried_pushfwd_impl(f, style) +@inline pushfwd(f, style::PushfwdRootMeasure) = _curried_pushfwd_impl(f, style) + +_curried_pushfwd_impl(f, style::PushFwdStyle) = _CurriedPushfwd(f, inverse(f), style) +@inline _curried_pushfwd_impl(::typeof(identity), ::AdaptRootMeasure) = identity +@inline _curried_pushfwd_impl(::typeof(identity), ::PushfwdRootMeasure) = identity ############################################################################### # pullback @@ -267,8 +294,16 @@ export pullbck @inline pullbck(f, μ, style::AdaptRootMeasure) = _pullback_impl(f, μ, style) @inline pullbck(f, μ, style::PushfwdRootMeasure) = _pullback_impl(f, μ, style) -function _pullback_impl(f, μ, style = AdaptRootMeasure()) - pushfwd(inverse(f), μ, style) -end +_pullback_impl(f, μ, style::PushFwdStyle) = _pushfwd_impl2(inverse(f), f, μ, style) +_pullback_impl(::typeof(identity), μ, ::AdaptRootMeasure) = μ +_pullback_impl(::typeof(identity), μ, ::PushfwdRootMeasure) = μ + +@inline pullbck(f) = _curried_pullbck_impl(f, AdaptRootMeasure()) +@inline pullbck(f, style::AdaptRootMeasure) = _curried_pullbck_impl(f, style) +@inline pullbck(f, style::PushfwdRootMeasure) = _curried_pullbck_impl(f, style) + +_curried_pullbck_impl(f, style::PushFwdStyle) = _CurriedPushfwd(inverse(f), f, style) +@inline _curried_pullbck_impl(::typeof(identity), ::AdaptRootMeasure) = identity +@inline _curried_pullbck_impl(::typeof(identity), ::PushfwdRootMeasure) = identity @deprecate pullback(f, μ, style::PushFwdStyle = AdaptRootMeasure()) pullbck(f, μ, style) diff --git a/src/density-core.jl b/src/density-core.jl index 6ac3d01e..121188f3 100644 --- a/src/density-core.jl +++ b/src/density-core.jl @@ -26,11 +26,15 @@ To compute log-density relative to `basemeasure(m)` or *define* a log-density `logdensity_def`. To compute a log-density relative to a specific base-measure, see -`logdensity_rel`. +`logdensity_rel`. + +# Implementation + +Do not specialize `logdensityof` directly for subtypes of `AbstractMeasure`, +specialize `MeasureBase.logdensity_def` and `MeasureBase.strict_logdensityof` instead. """ -@inline function logdensityof(μ::AbstractMeasure, x) - result = dynamic(unsafe_logdensityof(μ, x)) - _checksupport(insupport(μ, x), result) +@inline function logdensityof(μ::AbstractMeasure, x) #!!!!!!!!!!!!!!!!! + strict_logdensityof(μ, x) end @inline function logdensityof_rt(::T, ::U) where {T,U} @@ -41,6 +45,24 @@ _checksupport(cond, result) = ifelse(cond == true, result, oftype(result, -Inf)) export unsafe_logdensityof +""" + MeasureBase.strict_logdensityof(μ, x) + +Compute the log-density of the measure `μ` at `x` relative to `rootmeasure(m)`. +In contrast to [`logdensityof(μ, x)`](@ref), this will not take implicit pushforwards +of `μ` (depending on the type of `x`) into account. +""" +function strict_logdensityof end + +@inline function strict_logdensityof(μ, x) + result = dynamic(unsafe_logdensityof(μ, x)) + _checksupport(insupport(μ, x), result) +end + +@inline function strict_logdensityof_rt(::T, ::U) where {T,U} + Core.Compiler.return_type(strict_logdensityof, Tuple{T,U}) +end + # https://discourse.julialang.org/t/counting-iterations-to-a-type-fixpoint/75876/10?u=cscherrer """ unsafe_logdensityof(m, x) @@ -68,14 +90,27 @@ See also `logdensityof`. end """ - logdensity_rel(m1, m2, x) + logdensity_rel(μ, ν, x) -Compute the log-density of `m1` relative to `m2` at `x`. This function checks -whether `x` is in the support of `m1` or `m2` (or both, or neither). If `x` is +Compute the log-density of `μ` relative to `ν` at `x`. This function checks +whether `x` is in the support of `μ` or `ν` (or both, or neither). If `x` is known to be in the support of both, it can be more efficient to call -`unsafe_logdensity_rel`. +`unsafe_logdensity_rel`. +""" +function logdensity_rel(μ, ν, x) + strict_logdensity_rel(μ, ν, x) +end + """ -@inline function logdensity_rel(μ::M, ν::N, x::X) where {M,N,X} + MeasureBase.strict_logdensity_rel(μ, ν, x) + +Compute the log-density of `μ` relative to `ν` at `x`. In contrast to +[`logdensity_rel(μ, ν, x)`](@ref), this will not take implicit pushforwards +of `μ` and `ν` (depending on the type of `x`) into account. +""" +function strict_logdensity_rel end + +@inline function strict_logdensity_rel(μ::M, ν::N, x::X) where {M,N,X} T = unstatic( promote_type( return_type(logdensity_def, (μ, x)), diff --git a/src/density.jl b/src/density.jl index a79021de..1a4e797a 100644 --- a/src/density.jl +++ b/src/density.jl @@ -163,10 +163,10 @@ logdensity_def(μ::DensityMeasure, x) = logdensityof(μ.f, x) density_def(μ::DensityMeasure, x) = densityof(μ.f, x) -function logdensityof(μ::DensityMeasure, x::Any) +function strict_logdensityof(μ::DensityMeasure, x::Any) integrand, μ_base = μ.f, μ.base - base_logval = logdensityof(μ_base, x) + base_logval = strict_logdensityof(μ_base, x) T = typeof(base_logval) U = logdensityof_rt(integrand, x) diff --git a/src/kernel.jl b/src/kernel.jl index d6667c7b..b90d0063 100644 --- a/src/kernel.jl +++ b/src/kernel.jl @@ -1,6 +1,7 @@ export AbstractTransitionKernel, GenericTransitionKernel, TypedTransitionKernel, ParameterizedTransitionKernel +# ToDo (breaking): A transition kernel should be a Function, not an AbstractMeasure. abstract type AbstractTransitionKernel <: AbstractMeasure end struct GenericTransitionKernel{F} <: AbstractTransitionKernel diff --git a/src/primitive.jl b/src/primitive.jl index 85cf2beb..a28b8e44 100644 --- a/src/primitive.jl +++ b/src/primitive.jl @@ -19,8 +19,8 @@ basemeasure(μ::PrimitiveMeasure) = μ @inline basemeasure_depth(::PrimitiveMeasure) = static(0) -@inline logdensityof(::PrimitiveMeasure, x::Real) = zero(float(typeof(x))) -@inline logdensityof(::PrimitiveMeasure, x) = false +@inline strict_logdensityof(::PrimitiveMeasure, x::Real) = zero(float(typeof(x))) +@inline strict_logdensityof(::PrimitiveMeasure, x) = false logdensity_def(::PrimitiveMeasure, x) = static(0.0) diff --git a/src/primitives/counting.jl b/src/primitives/counting.jl index c61d0624..dd6be543 100644 --- a/src/primitives/counting.jl +++ b/src/primitives/counting.jl @@ -12,14 +12,14 @@ struct Counting{T} <: AbstractMeasure Counting(supp) = new{Core.Typeof(supp)}(supp) end -@inline function logdensityof(μ::Counting, x::Real) +@inline function strict_logdensityof(μ::Counting, x::Real) R = float(typeof(x)) insupport(μ, x) ? zero(R) : R(-Inf) end -@inline logdensityof(μ::Counting, x) = insupport(μ, x) ? 0.0 : -Inf +@inline strict_logdensityof(μ::Counting, x) = insupport(μ, x) ? 0.0 : -Inf -@inline logdensity_def(μ::Counting, x) = logdensityof(μ, x) +@inline logdensity_def(μ::Counting, x) = strict_logdensityof(μ, x) basemeasure(::Counting) = CountingBase() diff --git a/src/primitives/dirac.jl b/src/primitives/dirac.jl index 01297486..0bfd1430 100644 --- a/src/primitives/dirac.jl +++ b/src/primitives/dirac.jl @@ -20,12 +20,12 @@ basemeasure(d::Dirac) = CountingBase() massof(::Dirac) = static(1.0) -function logdensityof(μ::Dirac, x::Real) +function strict_logdensityof(μ::Dirac, x::Real) R = float(typeof(x)) insupport(μ, x) ? zero(R) : R(-Inf) end -logdensityof(μ::Dirac, x) = insupport(μ, x) ? 0.0 : -Inf +strict_logdensityof(μ::Dirac, x) = insupport(μ, x) ? 0.0 : -Inf logdensity_def(::Dirac, x::Real) = zero(float(typeof(x))) logdensity_def(::Dirac, x) = 0.0 diff --git a/src/primitives/lebesgue.jl b/src/primitives/lebesgue.jl index 3846eaf5..223a65fe 100644 --- a/src/primitives/lebesgue.jl +++ b/src/primitives/lebesgue.jl @@ -63,12 +63,12 @@ insupport(μ::Lebesgue, x) = x ∈ μ.support insupport(::Lebesgue{RealNumbers}, ::Real) = true -@inline function logdensityof(μ::Lebesgue, x::Real) +@inline function strict_logdensityof(μ::Lebesgue, x::Real) R = float(typeof(x)) insupport(μ, x) ? zero(R) : R(-Inf) end -@inline logdensityof(μ::Lebesgue, x) = insupport(μ, x) ? 0.0 : -Inf +@inline strict_logdensityof(μ::Lebesgue, x) = insupport(μ, x) ? 0.0 : -Inf massof(::Lebesgue{RealNumbers}, s::Interval) = width(s) diff --git a/src/standard/stdexponential.jl b/src/standard/stdexponential.jl index c985c224..ae267e5b 100644 --- a/src/standard/stdexponential.jl +++ b/src/standard/stdexponential.jl @@ -4,7 +4,7 @@ export StdExponential insupport(::StdExponential, x) = x ≥ zero(x) -@inline function logdensityof(::StdExponential, x) +@inline function strict_logdensityof(::StdExponential, x) R = float(typeof(x)) x ≥ zero(R) ? convert(R, -x) : R(-Inf) end diff --git a/src/standard/stdlogistic.jl b/src/standard/stdlogistic.jl index 58a1ba67..f7881e5d 100644 --- a/src/standard/stdlogistic.jl +++ b/src/standard/stdlogistic.jl @@ -4,9 +4,9 @@ export StdLogistic @inline insupport(d::StdLogistic, x) = true -@inline logdensityof(::StdLogistic, x) = (u = -abs(x); u - 2 * log1pexp(u)) +@inline strict_logdensityof(::StdLogistic, x) = (u = -abs(x); u - 2 * log1pexp(u)) -@inline logdensity_def(::StdLogistic, x) = logdensityof(StdLogistic(), x) +@inline logdensity_def(::StdLogistic, x) = strict_logdensityof(StdLogistic(), x) @inline basemeasure(::StdLogistic) = LebesgueBase() @inline transport_def(::StdUniform, μ::StdLogistic, x) = logistic(x) diff --git a/src/standard/stdnormal.jl b/src/standard/stdnormal.jl index 057a8629..077b9bb8 100644 --- a/src/standard/stdnormal.jl +++ b/src/standard/stdnormal.jl @@ -7,7 +7,7 @@ export StdNormal @inline insupport(::StdNormal, x) = true -@inline logdensityof(::StdNormal, x) = (-x^2 - log2π) / 2 +@inline strict_logdensityof(::StdNormal, x) = (-x^2 - log2π) / 2 @inline logdensity_def(::StdNormal, x) = -x^2 / 2 @inline basemeasure(::StdNormal) = WeightedMeasure(static(-0.5 * log2π), LebesgueBase()) diff --git a/src/standard/stduniform.jl b/src/standard/stduniform.jl index 7bbe15ed..4a3a3268 100644 --- a/src/standard/stduniform.jl +++ b/src/standard/stduniform.jl @@ -4,7 +4,7 @@ export StdUniform insupport(::StdUniform, x) = zero(x) ≤ x ≤ one(x) -@inline function logdensityof(::StdUniform, x) +@inline function strict_logdensityof(::StdUniform, x) R = float(typeof(x)) zero(x) ≤ x ≤ one(x) ? zero(R) : R(-Inf) end diff --git a/src/transport.jl b/src/transport.jl index b0c8ed41..8d59ef5d 100644 --- a/src/transport.jl +++ b/src/transport.jl @@ -243,8 +243,8 @@ end function ChangesOfVariables.with_logabsdet_jacobian(f::TransportFunction, x) y = f(x) - logpdf_src = logdensityof(f.μ, x) - logpdf_trg = logdensityof(f.ν, y) + logpdf_src = strict_logdensityof(f.μ, x) + logpdf_trg = strict_logdensityof(f.ν, y) ladj = logpdf_src - logpdf_trg # If logpdf_src and logpdf_trg are -Inf setting lafj to zero is safe: fixed_ladj = logpdf_src == logpdf_trg == -Inf ? zero(ladj) : ladj diff --git a/test/Project.toml b/test/Project.toml index 376c1b05..05fbef8f 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -7,6 +7,7 @@ DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" +FunctionChains = "8e6b2b91-af83-483e-ba35-d00930e4cf9b" InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/test/combinators/transformedmeasure.jl b/test/combinators/transformedmeasure.jl index 28ddbb50..de2a76de 100644 --- a/test/combinators/transformedmeasure.jl +++ b/test/combinators/transformedmeasure.jl @@ -12,6 +12,7 @@ import InverseFunctions: inverse, FunctionWithInverse, setinverse using IrrationalConstants: invsqrt2, sqrt2 import ChangesOfVariables: with_logabsdet_jacobian using MeasureBase.Interface: transport_to, test_transport +using FunctionChains: fchain Φ(z) = erfc(-z * invsqrt2) / 2 Φinv(p) = -erfcinv(2 * p) * sqrt2 @@ -106,7 +107,7 @@ using ChangesOfVariables # Test basic pushforward construction μ = StdNormal() f = exp - ν = pushfwd(f, μ) + ν = @inferred pushfwd(f, μ) @test ν isa PushforwardMeasure @@ -167,10 +168,29 @@ using ChangesOfVariables # Test pullback pb = pullbck(f, ν) @test pb isa PushforwardMeasure + @test pb.origin === μ + @test pb.f === fchain(exp, log) @test logdensityof(pb, y) ≈ logdensityof(μ, y) # Test deprecated pullback @test_deprecated pullback(f, μ) + + # Test identity specializations + for stylearg in [(), (AdaptRootMeasure(),), (PushfwdRootMeasure(),)] + @test @inferred(pushfwd(identity, μ, stylearg...)) === μ + @test @inferred(pullbck(identity, ν, stylearg...)) === ν + end + + # Test curried pushfwd and pullback + for stylearg in [(), (AdaptRootMeasure(),), (PushfwdRootMeasure(),)] + @test @inferred(pushfwd(f, stylearg...)(μ)) === pushfwd(f, μ, stylearg...) + @test @inferred(pullbck(f, stylearg...)(ν)) === pullbck(f, ν, stylearg...) + + @test @inferred(pushfwd(identity, stylearg...)) === identity + @test @inferred(pullbck(identity, stylearg...)) === identity + end + + @test pushfwd(identity) === identity end @testset "PushFwdStyle types" begin