Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add implicitly mapped measures and kernels #153

Merged
merged 18 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions .github/workflows/Breakage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ jobs:
pkgversion: [latest, stable]

steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4

# Install Julia
- uses: julia-actions/setup-julia@v1
- uses: julia-actions/setup-julia@v2
with:
version: 1
version: '1'
arch: x64
- uses: actions/cache@v1
- uses: actions/cache@v3
env:
cache-name: cache-artifacts
with:
Expand Down Expand Up @@ -79,7 +79,7 @@ jobs:
end;
end'

- uses: actions/upload-artifact@v2
- uses: actions/upload-artifact@v3
with:
name: pr
path: pr/
Expand All @@ -88,9 +88,9 @@ jobs:
needs: break
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4

- uses: actions/download-artifact@v2
- uses: actions/download-artifact@v3
with:
name: pr
path: pr/
Expand Down Expand Up @@ -121,7 +121,7 @@ jobs:
fi
done >> MSG

- uses: actions/upload-artifact@v2
- uses: actions/upload-artifact@v3
with:
name: pr
path: pr/
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
fail-fast: false
matrix:
version:
- '1.6'
- '1.10'
- '1'
- 'pre'
os:
Expand Down
13 changes: 11 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ authors = ["Chad Scherrer <[email protected]>", "Oliver Schulz <oschulz@mp
version = "0.14.12"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
ConstantRNGs = "aa9b60e7-6b1c-4c29-a6e5-e43521412437"
Expand All @@ -22,14 +21,22 @@ LogarithmicNumbers = "aa2f6b4e-9042-5d33-9679-40d3a6b85899"
MappedArrays = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
PrettyPrinting = "54e16d92-306c-5ea0-a30b-337be88ac337"
PropertyFunctions = "09e99361-2bb8-48a2-a80f-de58f0739eb4"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tricks = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

[extensions]
MeasureBaseChainRulesCoreExt = "ChainRulesCore"

[compat]
ChainRulesCore = "1"
ChangesOfVariables = "0.1.3"
Expand All @@ -49,11 +56,13 @@ LogarithmicNumbers = "1"
MappedArrays = "0.4"
NaNMath = "0.3, 1"
PrettyPrinting = "0.3, 0.4"
PropertyFunctions = "0.2.2"
Random = "1"
Reexport = "1"
SpecialFunctions = "2"
Static = "0.8, 1"
StaticArrays = "1.5"
Statistics = "1"
Test = "1"
Tricks = "0.1"
julia = "1.6"
julia = "1.10"
15 changes: 3 additions & 12 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,16 @@ using Documenter
using MeasureBase

# Doctest setup
DocMeta.setdocmeta!(
MeasureBase,
:DocTestSetup,
:(using MeasureBase);
recursive=true,
)
DocMeta.setdocmeta!(MeasureBase, :DocTestSetup, :(using MeasureBase); recursive = true)

makedocs(
sitename = "MeasureBase",
modules = [MeasureBase],
format = Documenter.HTML(
prettyurls = !("local" in ARGS),
canonical = "https://juliamath.github.io/MeasureBase.jl/stable/"
canonical = "https://juliamath.github.io/MeasureBase.jl/stable/",
),
pages = [
"Home" => "index.md",
"API" => "api.md",
"LICENSE" => "LICENSE.md",
],
pages = ["Home" => "index.md", "API" => "api.md", "LICENSE" => "LICENSE.md"],
doctest = ("fixdoctests" in ARGS) ? :fix : true,
linkcheck = !("nonstrict" in ARGS),
warnonly = ("nonstrict" in ARGS),
Expand Down
54 changes: 54 additions & 0 deletions ext/MeasureBaseChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# This file is a part of MeasureBase.jl, licensed under the MIT License (MIT).

module MeasureBaseChainRulesCoreExt

using MeasureBase
using ChainRulesCore: NoTangent, ZeroTangent
import ChainRulesCore

# = utils ====================================================================

using MeasureBase: isneginf, isposinf

_isneginf_pullback(::Any) = (NoTangent(), ZeroTangent())
ChainRulesCore.rrule(::typeof(isneginf), x) = isneginf(x), _logdensityof_rt_pullback

Check warning on line 14 in ext/MeasureBaseChainRulesCoreExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MeasureBaseChainRulesCoreExt.jl#L13-L14

Added lines #L13 - L14 were not covered by tests

_isposinf_pullback(::Any) = (NoTangent(), ZeroTangent())
ChainRulesCore.rrule(::typeof(isposinf), x) = isposinf(x), _isposinf_pullback

Check warning on line 17 in ext/MeasureBaseChainRulesCoreExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MeasureBaseChainRulesCoreExt.jl#L16-L17

Added lines #L16 - L17 were not covered by tests

# = insupport & friends ======================================================

using MeasureBase: check_dof, require_insupport, checked_arg, _checksupport, _origin_depth

@inline function ChainRulesCore.rrule(::typeof(_checksupport), cond, result)
y = _checksupport(cond, result)
function _checksupport_pullback(ȳ)
return NoTangent(), ZeroTangent(), one(ȳ)

Check warning on line 26 in ext/MeasureBaseChainRulesCoreExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MeasureBaseChainRulesCoreExt.jl#L23-L26

Added lines #L23 - L26 were not covered by tests
end
y, _checksupport_pullback

Check warning on line 28 in ext/MeasureBaseChainRulesCoreExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MeasureBaseChainRulesCoreExt.jl#L28

Added line #L28 was not covered by tests
end

_require_insupport_pullback(ΔΩ) = NoTangent(), ZeroTangent()
function ChainRulesCore.rrule(::typeof(require_insupport), μ, x)
return require_insupport(μ, x), _require_insupport_pullback

Check warning on line 33 in ext/MeasureBaseChainRulesCoreExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MeasureBaseChainRulesCoreExt.jl#L31-L33

Added lines #L31 - L33 were not covered by tests
end

_origin_depth_pullback(ΔΩ) = NoTangent(), NoTangent()
ChainRulesCore.rrule(::typeof(_origin_depth), ν) = _origin_depth(ν), _origin_depth_pullback

_check_dof_pullback(ΔΩ) = NoTangent(), NoTangent(), NoTangent()
ChainRulesCore.rrule(::typeof(check_dof), ν, μ) = check_dof(ν, μ), _check_dof_pullback

_checked_arg_pullback(ΔΩ) = NoTangent(), NoTangent(), ΔΩ
ChainRulesCore.rrule(::typeof(checked_arg), ν, x) = checked_arg(ν, x), _checked_arg_pullback

# = return type inference ====================================================

using MeasureBase: logdensityof_rt

_logdensityof_rt_pullback(::Any) = (NoTangent(), NoTangent(), ZeroTangent())
function ChainRulesCore.rrule(::typeof(logdensityof_rt), target, v)
logdensityof_rt(target, v), _logdensityof_rt_pullback

Check warning on line 51 in ext/MeasureBaseChainRulesCoreExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MeasureBaseChainRulesCoreExt.jl#L49-L51

Added lines #L49 - L51 were not covered by tests
end

end # module MeasureBaseChainRulesCoreExt
11 changes: 10 additions & 1 deletion src/MeasureBase.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module MeasureBase

using Base: @propagate_inbounds
using Base: OneTo

using Random
import Random: rand!
Expand All @@ -21,6 +22,7 @@ using DensityInterface: FuncDensity, LogFuncDensity
using DensityInterface

using InverseFunctions
using InverseFunctions: FunctionWithInverse
using ChangesOfVariables
using ConstantRNGs

Expand All @@ -29,14 +31,17 @@ import ConstructionBase
using ConstructionBase: constructorof
using IntervalSets

using StaticArrays:
StaticArray, StaticVector, StaticMatrix, SArray, SVector, SMatrix, SOneTo

using PrettyPrinting
const Pretty = PrettyPrinting

using ChainRulesCore
import FillArrays
using Static
using Static: StaticInteger
using FunctionChains
using PropertyFunctions: PropSelFunction

export gentype
export rebase
Expand Down Expand Up @@ -106,6 +111,7 @@ function logdensity_def end
using Compat

using IrrationalConstants
using IrrationalConstants: loghalf

include("static.jl")
include("smf.jl")
Expand Down Expand Up @@ -139,6 +145,7 @@ include("combinators/restricted.jl")
include("combinators/smart-constructors.jl")
include("combinators/powerweighted.jl")
include("combinators/conditional.jl")
include("combinators/implicitlymapped.jl")

include("standard/stdmeasure.jl")
include("standard/stduniform.jl")
Expand All @@ -147,6 +154,8 @@ include("standard/stdlogistic.jl")
include("standard/stdnormal.jl")
include("combinators/half.jl")

#include("implicitmaps.jl")

include("rand.jl")

include("density.jl")
Expand Down
5 changes: 5 additions & 0 deletions src/combinators/half.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ function Base.rand(rng::AbstractRNG, ::Type{T}, μ::Half) where {T}
return abs(rand(rng, T, unhalf(μ)))
end

function logdensityof(μ::Half, x)
ld = logdensityof(unhalf(μ), x) - loghalf
oschulz marked this conversation as resolved.
Show resolved Hide resolved
return x ≥ 0 ? ld : oftype(ld, -Inf)
end

logdensity_def(μ::Half, x) = logdensity_def(unhalf(μ), x)

@inline function insupport(d::Half, x)
Expand Down
Loading
Loading