Skip to content

Commit

Permalink
added signature for sampleK
Browse files Browse the repository at this point in the history
  • Loading branch information
abhinavnatarajan committed Jan 14, 2023
1 parent efda04e commit 79fb1f8
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 6 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "RedClust"
uuid = "bf1adee6-87fe-4679-8d23-51fe99940a25"
authors = ["Abhinav Natarajan <[email protected]>"]
version = "1.0.1"
version = "1.1.0"

[deps]
Clustering = "aaaa29a8-35af-508c-8bc3-b662a17a0fe5"
Expand Down
7 changes: 7 additions & 0 deletions docs/src/changelog.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# Changelog

## [1.1.0]
### Added
- added function signature for [`sampleK`](@ref) to accept individual parameters.

## [1.0.1]
Documentation now updated to use Literate.jl to generate the example.

## [1.0.0]

### Added
Expand Down
13 changes: 8 additions & 5 deletions src/prior.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,12 @@ function fitprior(
input = dissM
clustfn = kmedoids
end

@inbounds for k in ProgressBar(1:(Kmax-Kmin+1), output_stream = ostream)
temp = clustfn(input, k; maxiter=1000)
objective[k] = temp.totalcost
!temp.converged && @warn "Clustering did not converge at K = $k"
if !temp.converged
@warn "Clustering did not converge at K = $k"
end
end
elbow = detectknee(Kmin:Kmax, objective)[1]
K = elbow
Expand Down Expand Up @@ -167,10 +168,11 @@ end

"""
sampleK(params::PriorHyperparamsList, numsamples::Int, n::Int)::Vector{Int}
sampleK(η::Real, σ::Real, u::Real, v::Real, numsamples::Int, n::Int)
Returns a vector of length `numsamples` containing samples of ``K`` (number of clusters) generated from its marginal prior predictive distribution inferred from `params`. The parameter `n` is the number of observations in the model.
"""
function sampleK(params::PriorHyperparamsList, numsamples::Int, n::Int)::Vector{Int}
function sampleK(η::Real, σ::Real, u::Real, v::Real, numsamples::Int, n::Int)
# Input validation
if n < 1
throw(ArgumentError("n must be a positive integer."))
Expand All @@ -179,8 +181,8 @@ function sampleK(params::PriorHyperparamsList, numsamples::Int, n::Int)::Vector{
throw(ArgumentError("numsamples must be a positive integer."))
end
samples = Vector{Int}(undef, numsamples)
rdist = Gamma(params.η, 1/params.σ)
pdist = Beta(params.u, params.v)
rdist = Gamma(η, 1/σ)
pdist = Beta(u, v)
K = 1:(n-1)
logprobs = zeros(n)
@inbounds for i = 1:numsamples
Expand All @@ -192,6 +194,7 @@ function sampleK(params::PriorHyperparamsList, numsamples::Int, n::Int)::Vector{
end
return samples
end
sampleK(params::PriorHyperparamsList, numsamples::Int, n::Int)::Vector{Int} = sampleK(params.η, params.σ, params.u, params.v, numsamples, n)

function detectknee(
xvalues::AbstractVector{<:Real},
Expand Down

2 comments on commit 79fb1f8

@abhinavnatarajan
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/77131

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.1.0 -m "<description of version>" 79fb1f8fa3f09a78f92ed62efb55aa6fe00a74ef
git push origin v1.1.0

Please sign in to comment.