Skip to content

Commit

Permalink
support for ArrayOfUnivariateDistribution
Browse files Browse the repository at this point in the history
  • Loading branch information
dmetivie committed Nov 5, 2024
1 parent 173c6ea commit ca79999
Showing 1 changed file with 26 additions and 27 deletions.
53 changes: 26 additions & 27 deletions src/that_should_be_in_Distributions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,52 +17,51 @@ function fit_mle(g::D, args...) where {D<:Distribution}
fit_mle(typeof(g), args...)
end

fit_mle(d::T, x::AbstractArray{<:Integer}) where {T<:Binomial} =
fit_mle(T, suffstats(T, ntrials(d), x))
fit_mle(d::T, x::AbstractArray{<:Integer}) where {T<:Categorical} =
fit_mle(T, ncategories(d), x)
fit_mle(d::T, x::AbstractArray{<:Integer}) where {T<:Binomial} = fit_mle(T, suffstats(T, ntrials(d), x))
fit_mle(d::T, x::AbstractArray{<:Integer}) where {T<:Categorical} = fit_mle(T, ncategories(d), x)

## * `fit_mle` for `product_distribution`

#TODO: add deprecation notice!
#TODO! but currently still have `product_distribution([d1, d2]) ≠ product_distribution(d1, d2)` (first is still `Product` while second is `Distributions.ProductDistribution`)
#TODO! open issue in `Distributions.jl`

"""
fit_mle(g::Product, x::AbstractMatrix)
fit_mle(g::Product, x::AbstractMatrix, γ::AbstractVector)
The `fit_mle` for multivariate Product distributions `g` is the `product_distribution` of `fit_mle` of each components of `g`.
Product is meant to be depreacated in next version of `Distribution.jl`. Use the analog `VectorOfUnivariateDistribution` type instead.
Product is meant to be depreacated in next versions of `Distribution.jl`. Use the analog `VectorOfUnivariateDistribution` type instead.
"""
function fit_mle(g::Product, x::AbstractMatrix, args...)
d = size(x, 1)
length(g) == d ||
throw(DimensionMismatch("The dimensions of g and x are inconsistent."))
return product_distribution([
fit_mle(g.v[s], y, args...) for (s, y) in enumerate(eachrow(x))
])
length(g) == d || throw(DimensionMismatch("The dimensions of g and x are inconsistent."))
return product_distribution([fit_mle(g.v[s], y, args...) for (s, y) in enumerate(eachrow(x))])
end

params(g::Product) = params.(g.v)

#! `ArrayOfUnivariateDistribution` is not released yet
# params(d::ArrayOfUnivariateDistribution) = params.(d.dists)
params(d::ArrayOfUnivariateDistribution) = params.(d.dists) #

# #### Fitting
# promote_sample(::Type{dT}, x::AbstractArray{T}) where {T<:Real, dT<:Real} = T <: dT ? x : convert.(dT, x)
#### Fitting
promote_sample(::Type{dT}, x::AbstractArray{T}) where {T<:Real,dT<:Real} = T <: dT ? x : convert.(dT, x)

# """
# fit_mle(dists::ArrayOfUnivariateDistribution, x::AbstractArray)
# fit_mle(dists::ArrayOfUnivariateDistribution, x::AbstractArray, γ::AbstractVector)
"""
fit_mle(dists::ArrayOfUnivariateDistribution, x::AbstractArray)
fit_mle(dists::ArrayOfUnivariateDistribution, x::AbstractArray, γ::AbstractVector)
# The `fit_mle` for a `ArrayOfUnivariateDistribution` distributions `dists` is the `product_distribution` of `fit_mle` of each components of `dists`.
# """
# function fit_mle(dists::VectorOfUnivariateDistribution, x::AbstractMatrix{<:Real}, args...)
# length(dists) == size(x, 1) || throw(DimensionMismatch("The dimensions of dists and x are inconsistent."))
# return product_distribution([fit_mle(d, promote_sample(eltype(d), x[s, :]), args...) for (s, d) in enumerate(dists.dists)])
# end
The `fit_mle` for a `ArrayOfUnivariateDistribution` distributions `dists` is the `product_distribution` of `fit_mle` of each components of `dists`.
`VectorOfUnivariateDistribution` should act like old `Product` while `ArrayOfUnivariateDistribution` are not really tested yet.
"""
function fit_mle(dists::VectorOfUnivariateDistribution, x::AbstractMatrix{<:Real}, args...)
length(dists) == size(x, 1) || throw(DimensionMismatch("The dimensions of dists and x are inconsistent."))
return product_distribution([fit_mle(d, promote_sample(eltype(d), x[s, :]), args...) for (s, d) in enumerate(dists.dists)]...)
end

# function fit_mle(dists::ArrayOfUnivariateDistribution, x::AbstractArray, args...)
# size(dists) == size(first(x)) || throw(DimensionMismatch("The dimensions of dists and x are inconsistent."))
# return product_distribution([fit_mle(d, promote_sample(eltype(d), [x[i][s] for i in eachindex(x)]), args...) for (s, d) in enumerate(dists.dists)])
# end
function fit_mle(dists::ArrayOfUnivariateDistribution, x::AbstractArray, args...)
size(dists) == size(first(x)) || throw(DimensionMismatch("The dimensions of dists and x are inconsistent."))
return product_distribution([fit_mle(d, promote_sample(eltype(d), [x[i][s] for i in eachindex(x)]), args...) for (s, d) in enumerate(dists.dists)]...)
end


## * New `fit_mle` * ##
Expand Down

0 comments on commit ca79999

Please sign in to comment.