From 28b55391fe8334476d981e97fe163ae1a5b83a22 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Mon, 23 Aug 2021 13:38:37 +0200 Subject: [PATCH 1/2] Categorical distribution --- src/parameterized/categorical.jl | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 src/parameterized/categorical.jl diff --git a/src/parameterized/categorical.jl b/src/parameterized/categorical.jl new file mode 100644 index 00000000..4e2abc4d --- /dev/null +++ b/src/parameterized/categorical.jl @@ -0,0 +1,27 @@ +# Categorical distribution + +# REFERENCES +# https://juliastats.org/Distributions.jl/stable/univariate/#Distributions.Categorical +# https://juliastats.org/Distributions.jl/stable/univariate/#Distributions.DiscreteNonParametric + +export Categorical + +@parameterized Categorical(p) ≪ CountingMeasure(ℤ[0:∞]) + +ncategories(d::Categorical) = length(d.p) + +(d::Categorical ≪ ::CountingMeasure{IntegerRange{a,b}}) where {a,b} = a ≤ 1 && b ≥ ncategories(d) + +(::CountingMeasure{IntegerRange{a,b}} ≪ ::Categorical) where {a,b} = a ≥ 1 && b ≤ ncategories(d) + +############################################################################### +@kwstruct Categorical(p) + +logdensity(d::Categorical{(:p)}, y) = log(d.p[y]) + +# Very inefficient because of the heavy implementation of Dists.DiscreteNonParametric +distproxy(d::Categorical{(:p)}) = Dists.Categorical(d.p) + +Base.rand(rng::AbstractRNG, T::Type, d::Categorical{(:p)}) = rand(rng, distproxy(d)) + +asparams(::Type{<:Categorical}, ::Val{:p}) = as𝕀 From 2f7e03c39e85414cff13e20bb0bb9b8e45fc92a7 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 26 Aug 2021 09:43:50 +0200 Subject: [PATCH 2/2] Add logp parametrization for Categorical --- src/parameterized/categorical.jl | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/parameterized/categorical.jl b/src/parameterized/categorical.jl index 4e2abc4d..f4ccde37 100644 --- a/src/parameterized/categorical.jl +++ b/src/parameterized/categorical.jl @@ -19,9 +19,22 @@ ncategories(d::Categorical) = length(d.p) logdensity(d::Categorical{(:p)}, y) = log(d.p[y]) -# Very inefficient because of the heavy implementation of Dists.DiscreteNonParametric +# The implementation of Dists.DiscreteNonParametric has heavy argument checks +# But I think since the values of Categorical are 1:n the sortperm has no effect +# So it might be OK distproxy(d::Categorical{(:p)}) = Dists.Categorical(d.p) Base.rand(rng::AbstractRNG, T::Type, d::Categorical{(:p)}) = rand(rng, distproxy(d)) asparams(::Type{<:Categorical}, ::Val{:p}) = as𝕀 + +############################################################################### +@kwstruct Categorical(logp) + +logdensity(d::Categorical{(:logp)}, y) = d.logp[y] + +distproxy(d::Categorical{(:logp)}) = Dists.Categorical(exp.(d.logp)) # inefficient + +Base.rand(rng::AbstractRNG, T::Type, d::Categorical{(:logp)}) = rand(rng, distproxy(d)) + +asparams(::Type{<:Categorical}, ::Val{:logp}) = asℝ