Skip to content

Commit

Permalink
Add variate transport
Browse files Browse the repository at this point in the history
  • Loading branch information
oschulz authored Jun 19, 2022
1 parent bf7eae6 commit 648376d
Show file tree
Hide file tree
Showing 20 changed files with 742 additions and 40 deletions.
11 changes: 9 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
name = "MeasureBase"
uuid = "fa1605e6-acd5-459c-a1e6-7e635759db14"
authors = ["Chad Scherrer <[email protected]> and contributors"]
version = "0.10.0"
version = "0.11.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
Expand All @@ -24,11 +27,14 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tricks = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775"

[compat]
ChainRulesCore = "1"
ChangesOfVariables = "0.1.3"
Compat = "3.35, 4"
ConstructionBase = "1.3"
DensityInterface = "0.4"
FillArrays = "0.12, 0.13"
IfElse = "0.1"
InverseFunctions = "0.1.7"
IrrationalConstants = "0.1"
LogExpFunctions = "0.3"
LogarithmicNumbers = "1"
Expand All @@ -42,6 +48,7 @@ julia = "1.3"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"

[targets]
test = ["Aqua"]
test = ["Aqua", "ChainRulesTestUtils"]
29 changes: 14 additions & 15 deletions src/MeasureBase.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
module MeasureBase

using Base: @propagate_inbounds

using Random
import Random: rand!
import Random: gentype
Expand All @@ -11,13 +13,17 @@ import DensityInterface: densityof
import DensityInterface: DensityKind
using DensityInterface

using InverseFunctions
using ChangesOfVariables

import Base.iterate
import ConstructionBase
using ConstructionBase: constructorof

using PrettyPrinting
const Pretty = PrettyPrinting

using ChainRulesCore
using FillArrays
using Static

Expand All @@ -32,20 +38,11 @@ export logdensity_def
export basemeasure
export basekernel
export productmeasure

"""
inssupport(m, x)
insupport(m)
`insupport(m,x)` computes whether `x` is in the support of `m`.
`insupport(m)` returns a function, and satisfies
insupport(m)(x) == insupport(m, x)
"""
function insupport end

export insupport
export getdof
export transport_to

include("insupport.jl")

abstract type AbstractMeasure end

Expand All @@ -63,7 +60,7 @@ gentype(μ::AbstractMeasure) = typeof(testvalue(μ))
# gentype(μ::AbstractMeasure) = gentype(basemeasure(μ))

using NaNMath
using LogExpFunctions: logsumexp
using LogExpFunctions: logsumexp, logistic, logit

@deprecate instance_type(x) Core.Typeof(x) false

Expand Down Expand Up @@ -94,6 +91,8 @@ using Compat

using IrrationalConstants

include("getdof.jl")
include("transport.jl")
include("schema.jl")
include("splat.jl")
include("proxies.jl")
Expand Down Expand Up @@ -125,9 +124,9 @@ include("combinators/powerweighted.jl")
include("combinators/conditional.jl")

include("standard/stdmeasure.jl")
include("standard/stdnormal.jl")
include("standard/stduniform.jl")
include("standard/stdexponential.jl")
include("standard/stdlogistic.jl")

include("rand.jl")

Expand Down
28 changes: 28 additions & 0 deletions src/combinators/power.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,13 @@ end
end
end

@inline function logdensity_def(
d::PowerMeasure{M,NTuple{N, Base.OneTo{StaticInt{0}}}},
x,
) where {M,N}
static(0.0)
end

@inline function insupport::PowerMeasure, x)
p = μ.parent
all(x) do xj
Expand All @@ -100,3 +107,24 @@ end
dynamic(insupport(p, xj))
end
end


@inline getdof::PowerMeasure) = getdof.parent) * prod(map(length, μ.axes))

@inline getdof(::PowerMeasure{<:Any, NTuple{N,Base.OneTo{StaticInt{0}}}}) where N = static(0)


@propagate_inbounds function checked_var::PowerMeasure, x::AbstractArray{<:Any})
@boundscheck begin
sz_μ = map(length, μ.axes)
sz_x = size(x)
if sz_μ != sz_x
throw(ArgumentError("Size of variate doesn't match size of power measure"))
end
end
return x
end

function checked_var::PowerMeasure, x::Any)
throw(ArgumentError("Size of variate doesn't match size of power measure"))
end
91 changes: 91 additions & 0 deletions src/combinators/transformedmeasure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,94 @@ function params(::AbstractTransformedMeasure) end
function paramnames(::AbstractTransformedMeasure) end

function parent(::AbstractTransformedMeasure) end


export PushforwardMeasure

"""
struct PushforwardMeasure{FF,IF,MU,VC<:TransformVolCorr} <: AbstractPushforward
f :: FF
inv_f :: IF
origin :: MU
volcorr :: VC
end
"""
struct PushforwardMeasure{FF,IF,M,VC<:TransformVolCorr} <: AbstractPushforward
f::FF
inv_f::IF
origin::M
volcorr::VC
end

gettransform::PushforwardMeasure) = ν.f
parent::PushforwardMeasure) = ν.origin


function Pretty.tile::PushforwardMeasure)
Pretty.list_layout(Pretty.tile.([ν.f, ν.inv_f, ν.origin]); prefix = :PushforwardMeasure)
end


@inline function logdensity_def::PushforwardMeasure{FF,IF,M,<:WithVolCorr}, y) where {FF,IF,M}
x_orig, inv_ladj = with_logabsdet_jacobian.inv_f, y)
logd_orig = logdensity_def.origin, x_orig)
logd = float(logd_orig + inv_ladj)
neginf = oftype(logd, -Inf)
return ifelse(
# Zero density wins against infinite volume:
(isnan(logd) && logd_orig == -Inf && inv_ladj == +Inf) ||
# Maybe also for (logd_orig == -Inf) && isfinite(inv_ladj) ?
# Return constant -Inf to prevent problems with ForwardDiff:
(isfinite(logd_orig) && (inv_ladj == -Inf)),
neginf,
logd
)
end

@inline function logdensity_def::PushforwardMeasure{FF,IF,M,<:NoVolCorr}, y) where {FF,IF,M}
x_orig = to_origin(ν, y)
return logdensity_def.origin, x_orig)
end


insupport::PushforwardMeasure, y) = insupport(transport_origin(ν), to_origin(ν, y))

testvalue::PushforwardMeasure) = from_origin(ν, testvalue(transport_origin(ν)))

@inline function basemeasure::PushforwardMeasure)
PushforwardMeasure.f, ν.inv_f, basemeasure(transport_origin(ν)), NoVolCorr())
end


_pushfwd_dof(::Type{MU}, ::Type, dof) where MU = NoDOF{MU}()
_pushfwd_dof(::Type{MU}, ::Type{<:Tuple{Any,Real}}, dof) where MU = dof

# Assume that DOF are preserved if with_logabsdet_jacobian is functional:
@inline function getdof::MU) where {MU<:PushforwardMeasure}
T = Core.Compiler.return_type(testvalue, Tuple{typeof.origin)})
R = Core.Compiler.return_type(with_logabsdet_jacobian, Tuple{typeof.f), T})
_pushfwd_dof(MU, R, getdof.origin))
end

# Bypass `checked_var`, would require potentially costly transformation:
@inline checked_var(::PushforwardMeasure, x) = x


@inline transport_origin::PushforwardMeasure) = ν.origin
@inline from_origin::PushforwardMeasure, x) = ν.f(x)
@inline to_origin::PushforwardMeasure, y) = ν.inv_f(y)

function Base.rand(rng::AbstractRNG, ::Type{T}, ν::PushforwardMeasure) where T
return from_origin(ν, rand(rng, T, transport_origin(ν)))
end


export pushfwd

"""
pushfwd(f, μ, volcorr = WithVolCorr())
Return the [pushforward measure](https://en.wikipedia.org/wiki/Pushforward_measure)
from `μ` the [measurable function](https://en.wikipedia.org/wiki/Measurable_function) `f`.
"""
pushfwd(f, μ, volcorr = WithVolCorr()) = PushforwardMeasure(f, inverse(f), μ, volcorr)
4 changes: 4 additions & 0 deletions src/combinators/weighted.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,7 @@ Base.:*(m::AbstractMeasure, k::Real) = k * m
gentype::WeightedMeasure) = gentype.base)

insupport::WeightedMeasure, x) = insupport.base, x)

transport_origin::WeightedMeasure) = ν.base
to_origin(::WeightedMeasure, y) = y
from_origin(::WeightedMeasure, x) = x
77 changes: 77 additions & 0 deletions src/getdof.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""
MeasureBase.NoDOF{MU}
Indicates that there is no way to compute degrees of freedom of a measure
of type `MU` with the given information, e.g. because the DOF are not
a global property of the measure.
"""
struct NoDOF{MU} end


"""
getdof(μ)
Returns the effective number of degrees of freedom of variates of
measure `μ`.
The effective NDOF my differ from the length of the variates. For example,
the effective NDOF for a Dirichlet distribution with variates of length `n`
is `n - 1`.
Also see [`check_dof`](@ref).
"""
function getdof end

# Prevent infinite recursion:
@inline _default_getdof(::Type{MU}, ::MU) where MU = NoDOF{MU}
@inline _default_getdof(::Type{MU}, mu_base) where MU = getdof(mu_base)

@inline getdof::MU) where MU = _default_getdof(MU, basemeasure(μ))


"""
MeasureBase.check_dof(ν, μ)::Nothing
Check if `ν` and `μ` have the same effective number of degrees of freedom
according to [`MeasureBase.getdof`](@ref).
"""
function check_dof end

function check_dof(ν, μ)
n_ν = getdof(ν)
n_μ = getdof(μ)
if n_ν != n_μ
throw(ArgumentError("Measure ν of type $(nameof(typeof(ν))) has $(n_ν) DOF but μ of type $(nameof(typeof(μ))) has $(n_μ) DOF"))
end
return nothing
end

_check_dof_pullback(ΔΩ) = NoTangent(), NoTangent(), NoTangent()
ChainRulesCore.rrule(::typeof(check_dof), ν, μ) = check_dof(ν, μ), _check_dof_pullback


"""
MeasureBase.NoVarCheck{MU,T}
Indicates that there is no way to check of a values of type `T` are
variate of measures of type `MU`.
"""
struct NoVarCheck{MU,T} end


"""
MeasureBase.checked_var(μ::MU, x::T)::T
Return `x` if `x` is a valid variate of `μ`, throw an `ArgumentError` if not,
return `NoVarCheck{MU,T}()` if not check can be performed.
"""
function checked_var end

# Prevent infinite recursion:
@propagate_inbounds _default_checked_var(::Type{MU}, ::MU, ::T) where {MU,T} = NoVarCheck{MU,T}
@propagate_inbounds _default_checked_var(::Type{MU}, mu_base, x) where MU = checked_var(mu_base, x)

@propagate_inbounds checked_var(mu::MU, x) where MU = _default_checked_var(MU, basemeasure(mu), x)

_checked_var_pullback(ΔΩ) = NoTangent(), NoTangent(), ΔΩ
ChainRulesCore.rrule(::typeof(checked_var), ν, x) = checked_var(ν, x), _checked_var_pullback
32 changes: 32 additions & 0 deletions src/insupport.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""
inssupport(m, x)
insupport(m)
`insupport(m,x)` computes whether `x` is in the support of `m`.
`insupport(m)` returns a function, and satisfies
insupport(m)(x) == insupport(m, x)
"""
function insupport end


"""
MeasureBase.require_insupport(μ, x)::Nothing
Checks if `x` is in the support of distribution/measure `μ`, throws an
`ArgumentError` if not.
"""
function require_insupport end

_require_insupport_pullback(ΔΩ) = NoTangent(), ZeroTangent()
function ChainRulesCore.rrule(::typeof(require_insupport), μ, x)
return require_insupport(μ, x), _require_insupport_pullback
end

function require_insupport(μ, x)
if !insupport(μ, x)
throw(ArgumentError("x is not within the support of μ"))
end
return nothing
end
Loading

2 comments on commit 648376d

@oschulz
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/62671

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.11.0 -m "<description of version>" 648376d4549bc76540daf206c737713d6d7b4cb8
git push origin v0.11.0

Please sign in to comment.