From 79fb1f8fa3f09a78f92ed62efb55aa6fe00a74ef Mon Sep 17 00:00:00 2001 From: Abhinav Natarajan Date: Sat, 14 Jan 2023 21:43:23 +0000 Subject: [PATCH] added signature for sampleK --- Project.toml | 2 +- docs/src/changelog.md | 7 +++++++ src/prior.jl | 13 ++++++++----- 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index c0f0c88..565bb51 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "RedClust" uuid = "bf1adee6-87fe-4679-8d23-51fe99940a25" authors = ["Abhinav Natarajan "] -version = "1.0.1" +version = "1.1.0" [deps] Clustering = "aaaa29a8-35af-508c-8bc3-b662a17a0fe5" diff --git a/docs/src/changelog.md b/docs/src/changelog.md index aae62a7..e52aad0 100644 --- a/docs/src/changelog.md +++ b/docs/src/changelog.md @@ -1,5 +1,12 @@ # Changelog +## [1.1.0] +### Added +- added function signature for [`sampleK`](@ref) to accept individual parameters. + +## [1.0.1] +Documentation now updated to use Literate.jl to generate the example. + ## [1.0.0] ### Added diff --git a/src/prior.jl b/src/prior.jl index 5335c75..0c66105 100644 --- a/src/prior.jl +++ b/src/prior.jl @@ -68,11 +68,12 @@ function fitprior( input = dissM clustfn = kmedoids end - @inbounds for k in ProgressBar(1:(Kmax-Kmin+1), output_stream = ostream) temp = clustfn(input, k; maxiter=1000) objective[k] = temp.totalcost - !temp.converged && @warn "Clustering did not converge at K = $k" + if !temp.converged + @warn "Clustering did not converge at K = $k" + end end elbow = detectknee(Kmin:Kmax, objective)[1] K = elbow @@ -167,10 +168,11 @@ end """ sampleK(params::PriorHyperparamsList, numsamples::Int, n::Int)::Vector{Int} + sampleK(η::Real, σ::Real, u::Real, v::Real, numsamples::Int, n::Int) Returns a vector of length `numsamples` containing samples of ``K`` (number of clusters) generated from its marginal prior predictive distribution inferred from `params`. The parameter `n` is the number of observations in the model. """ -function sampleK(params::PriorHyperparamsList, numsamples::Int, n::Int)::Vector{Int} +function sampleK(η::Real, σ::Real, u::Real, v::Real, numsamples::Int, n::Int) # Input validation if n < 1 throw(ArgumentError("n must be a positive integer.")) @@ -179,8 +181,8 @@ function sampleK(params::PriorHyperparamsList, numsamples::Int, n::Int)::Vector{ throw(ArgumentError("numsamples must be a positive integer.")) end samples = Vector{Int}(undef, numsamples) - rdist = Gamma(params.η, 1/params.σ) - pdist = Beta(params.u, params.v) + rdist = Gamma(η, 1/σ) + pdist = Beta(u, v) K = 1:(n-1) logprobs = zeros(n) @inbounds for i = 1:numsamples @@ -192,6 +194,7 @@ function sampleK(params::PriorHyperparamsList, numsamples::Int, n::Int)::Vector{ end return samples end +sampleK(params::PriorHyperparamsList, numsamples::Int, n::Int)::Vector{Int} = sampleK(params.η, params.σ, params.u, params.v, numsamples, n) function detectknee( xvalues::AbstractVector{<:Real},