Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Random updates #152

Draft
wants to merge 32 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
bc9c80a
small updates
cscherrer Sep 27, 2024
9e1d332
bump version
cscherrer Sep 27, 2024
875cb79
improving tests
cscherrer Sep 27, 2024
fc8737c
use ConstantRNGs
cscherrer Sep 27, 2024
e88e952
proxyfuns
cscherrer Sep 27, 2024
a653520
Merge remote-tracking branch 'origin/cs-constantrngs' into cs-dev
cscherrer Sep 27, 2024
20ea0ce
Merge branch 'cs-proxyfuns' into cs-dev
cscherrer Sep 27, 2024
b85d408
update
cscherrer Oct 1, 2024
80f715e
Merge remote-tracking branch 'origin/main' into cs-dev
cscherrer Oct 2, 2024
a715ae7
format
cscherrer Oct 16, 2024
c868118
drop old comment
cscherrer Oct 16, 2024
d908436
drop splat
cscherrer Oct 25, 2024
645c899
more
cscherrer Nov 1, 2024
d1e12f3
Merge branch 'implicit-maps' into cs-dev
cscherrer Nov 4, 2024
86a3f1c
update test_smf
cscherrer Nov 7, 2024
efea9d3
update test_interface
cscherrer Nov 7, 2024
5d02508
bugfix
cscherrer Nov 7, 2024
d182868
update interface
cscherrer Nov 7, 2024
8cd0f9f
debug
cscherrer Nov 7, 2024
e0f5b06
make rand default to Random.default_rng()
cscherrer Nov 15, 2024
87c9e8f
update test_smf
cscherrer Nov 15, 2024
d1e829b
get tests passing
cscherrer Nov 15, 2024
da6cbe6
structarrays
cscherrer Jan 9, 2025
c2dd263
simplify StructArrays methods
cscherrer Jan 10, 2025
6febce0
mass interface
cscherrer Jan 14, 2025
3221d7c
use root measure for base of superposition over common type
cscherrer Jan 14, 2025
07b911b
update
cscherrer Jan 14, 2025
585c0d1
drop JET tests
cscherrer Jan 15, 2025
015094b
simplify version constraints
cscherrer Jan 15, 2025
4e6f8bf
fix unbound type parameters
cscherrer Jan 15, 2025
0f94d22
fix unbound type parameter
cscherrer Jan 16, 2025
331f776
give up on unbound type params
cscherrer Jan 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tricks = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775"

Expand All @@ -39,7 +41,7 @@ MeasureBaseChainRulesCoreExt = "ChainRulesCore"

[compat]
ChainRulesCore = "1"
ChangesOfVariables = "0.1.3"
ChangesOfVariables = "0.1"
Compat = "3.35, 4"
ConstantRNGs = "0.1.1"
ConstructionBase = "1.3"
Expand All @@ -56,13 +58,15 @@ LogarithmicNumbers = "1"
MappedArrays = "0.4"
NaNMath = "0.3, 1"
PrettyPrinting = "0.3, 0.4"
PropertyFunctions = "0.2.2"
PropertyFunctions = "0.2"
Random = "1"
Reexport = "1"
SpecialFunctions = "2"
Static = "0.8, 1"
StaticArrays = "1.5"
Statistics = "1"
StatsBase = "0.34"
StructArrays = "0.7"
Test = "1"
Tricks = "0.1"
julia = "1.10"
1 change: 0 additions & 1 deletion src/MeasureBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ include("smf.jl")
include("getdof.jl")
include("transport.jl")
include("schema.jl")
include("splat.jl")
include("proxies.jl")
include("kernel.jl")
include("parameterized.jl")
Expand Down
2 changes: 1 addition & 1 deletion src/combinators/conditional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ condition(μ, constraint) = ConditionalMeasure(μ, constraint)
# end
# end

function Base.:|(
function condition(
μ::ProductMeasure{NamedTuple{M,T}},
constraint::NamedTuple{N},
) where {M,T,N}
Expand Down
2 changes: 1 addition & 1 deletion src/combinators/smart-constructors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ end
productmeasure(nt::NamedTuple) = ProductMeasure(nt)
productmeasure(tup::Tuple) = ProductMeasure(tup)

productmeasure(f, param_maps, pars) = ProductMeasure(kernel(f, param_maps), pars)
productmeasure(f, param_maps, pars) = productmeasure(kernel(f, param_maps), pars)

function productmeasure(k::ParameterizedTransitionKernel, pars)
productmeasure(k.suff, k.param_maps, pars)
Expand Down
13 changes: 12 additions & 1 deletion src/combinators/superpose.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,18 @@ end
function basemeasure(μ::SuperpositionMeasure{Tuple{A,B}}) where {A,B}
superpose(map(basemeasure, μ.components)...)
end
basemeasure(μ::SuperpositionMeasure) = superpose(map(basemeasure, μ.components))

# basemeasure(μ::SuperpositionMeasure) = superpose(map(basemeasure, μ.components))

function basemeasure(μ::SuperpositionMeasure{NTuple{N, M}}) where {N, M<:AbstractMeasure}
rootmeasure(first(μ.components))
end

function logdensity_def(μ::SuperpositionMeasure{NTuple{N, M}}, x) where {N, M<:AbstractMeasure}
ℓs = (logdensityof(c, x) for c in μ.components)
logsumexp(ℓs)
end


# TODO: Fix `rand` method (this one is wrong)
# function Base.rand(μ::SuperpositionMeasure{X,N}) where {X,N}
Expand Down
10 changes: 5 additions & 5 deletions src/density-core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ To compute a log-density relative to a specific base-measure, see
"""
@inline function logdensityof(μ::AbstractMeasure, x)
result = dynamic(unsafe_logdensityof(μ, x))
_checksupport(insupport(μ, x), result)
# _checksupport(insupport(μ, x), result)
result
end

@inline function logdensityof_rt(::T, ::U) where {T,U}
Expand Down Expand Up @@ -125,7 +126,7 @@ See also `logdensity_rel`.
end

# Note that this method assumes `μ` and `ν` to have the same type
function logdensity_def(μ::T, ν::T, x) where {T}
function logdensity_def(μ::T, ν::T, x) where {T<:AbstractMeasure}
if μ === ν
return zero(logdensity_def(μ, x))
else
Expand Down Expand Up @@ -167,6 +168,5 @@ end

@inline density_rel(μ, ν, x) = exp(logdensity_rel(μ, ν, x))

# TODO: Do we need this method?
density_def(μ, ν::AbstractMeasure, x) = exp(logdensity_def(μ, ν, x))
density_def(μ, x) = exp(logdensity_def(μ, x))
density_def(μ::AbstractMeasure, ν::AbstractMeasure, x) = exp(logdensity_def(μ, ν, x))
density_def(μ::AbstractMeasure, x) = exp(logdensity_def(μ, x))
4 changes: 3 additions & 1 deletion src/domains.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ struct ZeroSet{F,G} <: AbstractDomain
end

# Based on some quick tests, but may need some adjustment
Base.in(x::AbstractArray{T}, z::ZeroSet) where {T} = abs(z.f(x)) < ldexp(eps(float(T)), 6)
function Base.in(x::AbstractArray{T}, z::ZeroSet) where {T<:Real}
abs(z.f(x)) < ldexp(eps(float(T)), 6)
end

###########################################################
# CodimOne
Expand Down
29 changes: 18 additions & 11 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@ function dynamic_basemeasure_depth(μ::M) where {M}
return depth
end

function test_interface(μ::M) where {M}
function test_interface(μ::M, f = identity) where {M}
@eval begin
μ = $μ
f = $f
@testset "$μ" begin
μ = $μ

Expand All @@ -64,14 +65,15 @@ function test_interface(μ::M) where {M}
# testvalue, logdensityof

x = @inferred testvalue(Float64, μ)
x = f(x)
β = @inferred basemeasure(μ, x)

ℓμ = @inferred logdensityof(μ, x)
ℓβ = @inferred logdensityof(β, x)

@test ℓμ ≈ logdensity_def(μ, x) + ℓβ

@test logdensity_def(μ, testvalue(Float64, μ)) isa Real
@test logdensity_def(μ, x) isa Real
end
end
end
Expand Down Expand Up @@ -106,7 +108,7 @@ function test_transport(ν, μ)
end
end

function test_smf(μ, n = 100)
function test_smf(μ, n = 100, k=10)
@testset "smf($μ)" begin
# Get `n` sorted uniforms in O(n) time
p = rand(n)
Expand All @@ -121,16 +123,21 @@ function test_smf(μ, n = 100)
@test issorted(x)
@test all(istrue ∘ insupport(μ), x)

@test all((Finv ∘ F).(x) .≈ x)

for j in 1:n
a = rand()
b = rand()
a, b = minmax(a, b)
x = Finv(a)
y = Finv(b)
@test μ(Interval{:open,:closed}(x, y)) ≈ (F(y) - F(x))
for (xj, pj) in zip(x, p)
# Ideally this would be exactly zero, but in practice we need to allow for
# numerical errors in the implementation of `smf` and `invsmf`.

# Numerical errors are surprisingly large:
# smf(Beta(α = 0.18454471614214718, β = 0.0648526227363212)): Test Failed at /home/chad/git/MeasureBase.jl/src/interface.jl:130
# Expression: F(xj) - pj ≥ -1.0e-10
# Evaluated: -0.00010369484104344462 ≥ -1.0e-10
@test F(xj) - pj ≥ -1e-3
end

p .= F.(x)

@test all(Finv.(p) .≈ x)
end
end

Expand Down
1 change: 1 addition & 0 deletions src/mass-interface.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

import LinearAlgebra: normalize

import Base
Expand Down
24 changes: 23 additions & 1 deletion src/parameterized.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ end
function Pretty.tile(d::ParameterizedMeasure)
result = Pretty.literal(nameof(typeof(d)))
par = getfield(d, :par)
result *= Pretty.literal(sprint(show, par; context = :compact => true))
result *= Pretty.tile(par)
result
end

Expand Down Expand Up @@ -67,6 +67,28 @@ function ConstructionBase.setproperties(
return constructorof(P)(merge(params(d), nt))
end

###############################################################################
# StructArrays

# TODO: Can this be in an extension?


import StructArrays

function StructArrays.staticschema(::Type{P}) where {N,P<:ParameterizedMeasure{N}}
par_type = fieldtype(P, :par)
types = Tuple{(fieldtype(par_type, n) for n in N)...}
NamedTuple{N,types}
end

function StructArrays.component(m::ParameterizedMeasure, key::Symbol)
getfield(getfield(m, :par), key)
end

function StructArrays.createinstance(::Type{P}, args...) where {N,P<:ParameterizedMeasure{N}}
constructorof(P)(NamedTuple{N}(args))
end

###############################################################################
# params

Expand Down
9 changes: 1 addition & 8 deletions src/primitives/dirac.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,8 @@ basemeasure(d::Dirac) = CountingBase()

massof(::Dirac) = static(1.0)

function logdensityof(μ::Dirac, x::Real)
R = float(typeof(x))
insupport(μ, x) ? zero(R) : R(-Inf)
end

logdensityof(μ::Dirac, x) = insupport(μ, x) ? 0.0 : -Inf

logdensity_def(::Dirac, x::Real) = zero(float(typeof(x)))
logdensity_def(::Dirac, x) = 0.0
logdensity_def(μ::Dirac, x) = insupport(μ, x) ? 0.0 : -Inf

Base.rand(::Random.AbstractRNG, T::Type, μ::Dirac) = μ.x

Expand Down
4 changes: 3 additions & 1 deletion src/primitives/lebesgue.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ gentype(::Lebesgue) = Float64

Lebesgue() = Lebesgue(ℝ)

testvalue(::Type{T}, d::Lebesgue) where {T} = testvalue(T, d.support)::T
function Base.rand(rng::ConstantRNG, ::Type{T}, d::Lebesgue) where {T}
zero(T)
end

proxy(d::Lebesgue) = restrict(in(d.support), LebesgueBase())
proxy(::Lebesgue{MeasureBase.RealNumbers}) = LebesgueBase()
Expand Down
26 changes: 24 additions & 2 deletions src/rand.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,35 @@
import Base

Base.rand(d::AbstractMeasure) = rand(Random.GLOBAL_RNG, Float64, d)
Base.rand(d::AbstractMeasure) = rand(Random.default_rng(), Float64, d)

Base.rand(T::Type, μ::AbstractMeasure) = rand(Random.GLOBAL_RNG, T, μ)
Base.rand(T::Type, μ::AbstractMeasure) = rand(Random.default_rng(), T, μ)

@nospecialize
function Base.rand(rng::AbstractRNG, ::Type{T}, d::M) where {T,M<:AbstractMeasure}
@error "No method defined for rand(::AbstractRNG, ::Type{T}, $M)"
end
@specialize

Base.rand(rng::AbstractRNG, d::AbstractMeasure) = rand(rng, Float64, d)

@inline Random.rand!(d::AbstractMeasure, args...) = rand!(GLOBAL_RNG, d, args...)

@inline function Base.rand(
rng::AbstractRNG,
::Type{T},
d::ProductMeasure{A},
) where {T,A<:AbstractArray}
mar = marginals(d)
elT = typeof(rand(rng, T, first(mar)))

sz = size(mar)
x = Array{elT,length(sz)}(undef, sz)
@inbounds @simd for j in eachindex(mar)
x[j] = rand(rng, T, mar[j])
end
x
end

# TODO: Make this work
# function Base.rand(rng::AbstractRNG, ::Type{T}, d::AbstractMeasure) where {T}
# x = testvalue(d)
Expand Down
11 changes: 0 additions & 11 deletions src/splat.jl

This file was deleted.

14 changes: 13 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ testvalue(::Type{T}) where {T} = zero(T)

export rootmeasure

basemeasure(μ, x) = basemeasure(μ)
@inline basemeasure(μ, x) = basemeasure(μ)

"""
rootmeasure(μ::AbstractMeasure)
Expand Down Expand Up @@ -180,3 +180,15 @@ isapproxzero(A::AbstractArray) = all(isapproxzero, A)

isapproxone(x::T) where {T<:Real} = x ≈ one(T)
isapproxone(A::AbstractArray) = all(isapproxone, A)

import Statistics
import StatsBase

using Statistics
using StatsBase: entropy

StatsBase.entropy(m::AbstractMeasure, b::Real) = entropy(proxy(m), b)
Statistics.mean(m::AbstractMeasure) = mean(proxy(m))
Statistics.std(m::AbstractMeasure) = std(proxy(m))
Statistics.var(m::AbstractMeasure) = var(proxy(m))
Statistics.quantile(m::AbstractMeasure, q) = quantile(proxy(m), q)
6 changes: 5 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@ import LogarithmicNumbers
using MeasureBase
using MeasureBase: test_interface, test_smf

include("test_aqua.jl")
using Aqua
@testset "Code quality (Aqua.jl)" begin
Aqua.test_all(MeasureBase, ambiguities = false)
# Aqua.test_ambiguities(MeasureBase)
end

include("static.jl")

Expand Down
Loading