Skip to content

Commit

Permalink
refactor imports
Browse files Browse the repository at this point in the history
  • Loading branch information
abhinavnatarajan committed Oct 31, 2023
1 parent fcf72ac commit 7f42e07
Show file tree
Hide file tree
Showing 8 changed files with 495 additions and 511 deletions.
94 changes: 56 additions & 38 deletions src/RedClust.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,49 +3,67 @@
## Introduction
RedClust is the main module of `RedClust.jl`, a Julia package for Bayesian clustering of high-dimensional Euclidean data using pairwise dissimilarities instead of the raw observations.
RedClust is the main module of `RedClust.jl`, a Julia package for Bayesian clustering of high-dimensional Euclidean data using pairwise dissimilarities instead of the raw observations.
Use `names(RedClust)` to get the export list of this module, and type `?name` to get help on a specific `name`.
Use `names(RedClust)` to get the export list of this module, and type `?name` to get help on a specific `name`.
See https://abhinavnatarajan.github.io/RedClust.jl/ for detailed documentation.
"""
module RedClust

export
# fit prior
fitprior,
fitprior2,
sampledist,
sampleK,

# MCMC sampler
runsampler,

# utility functions
adjacencymatrix,
sortlabels,
uppertriangle,
generatemixture,
makematrix,

# summary functions
evaluateclustering,
summarise,

# point estimate
getpointestimate,
binderloss,
infodist,

# types
MCMCOptionsList,
PriorHyperparamsList,
MCMCData,
MCMCResult,

# data
example_datasets,
example_dataset
# mcmc.jl
using Clustering: kmeans, kmedoids, mutualinfo, randindex, varinfo, vmeasure
using Dates
using Distances: Euclidean, pairwise
using Distributions: Beta, Dirichlet, Distributions, Gamma, MvNormal, Normal, Uniform,
fit_mle, logpdf, pdf, rate, shape, truncated
using HDF5: close, create_group, h5open
using LinearAlgebra: Diagonal, I
using LoopVectorization: @turbo
using Printf: @sprintf
using ProgressBars: ProgressBar
using Random: AbstractRNG, TaskLocalRNG, rand, seed!
using SpecialFunctions: logbeta, loggamma
using StaticArrays
using StatsBase:
autocor, counts, entropy, levelsmap, mean, mean_and_var, sample, std, var, wsample


export
# fit prior
fitprior,
fitprior2,
sampledist,
sampleK,

# MCMC sampler
runsampler,

# utility functions
adjacencymatrix,
sortlabels,
uppertriangle,
generatemixture,
makematrix,

# summary functions
evaluateclustering,
summarise,

# point estimate
getpointestimate,
binderloss,
infodist,

# types
MCMCOptionsList,
PriorHyperparamsList,
MCMCData,
MCMCResult,

# data
example_datasets,
example_dataset

include("./types.jl")
include("./utils.jl")
Expand All @@ -55,4 +73,4 @@ include("./pointestimate.jl")
include("./summaries.jl")
include("./example_data.jl")

end
end
37 changes: 17 additions & 20 deletions src/example_data.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
using Random: seed!
using HDF5: h5open, close, create_group

const datafilename = joinpath(@__DIR__, "..", "data", "example_datasets.h5")

# function generate_example_data(n::Int)
# seed!(44)
# K = 10 # Number of clusters
# seed!(44)
# K = 10 # Number of clusters
# N = 100 # Number of points
# data_σ = [0.25, 0.2, 0.18][n] # Variance of the normal kernel
# data_dim = [10, 50, 10][n] # Data dimension
Expand All @@ -24,51 +21,51 @@ const datafilename = joinpath(@__DIR__, "..", "data", "example_datasets.h5")
# example["cluster_labels"] = clusts
# example["cluster_weights"] = probs
# example["oracle_coclustering_probabilities"] = oracle_coclustering
# end
# end
# close(datafile)
# end

# save_example_data()

"""
Return a read-only handle to a HDF5 file that contains example datasets from the main paper. You must remember to close the file once you are done reading from it. This function is provided for reproducibility purposes only; it is recommended to read the datasets via the convenience function [`example_dataset`](@ref).
Return a read-only handle to a HDF5 file that contains example datasets from the main paper. You must remember to close the file once you are done reading from it. This function is provided for reproducibility purposes only; it is recommended to read the datasets via the convenience function [`example_dataset`](@ref).
"""
function example_datasets()
h5open(datafilename, "r")
end

@doc raw"""
example_dataset(n::Int)
example_dataset(n::Int)
Returns a named tuple containing the dataset from the n``^{\mathrm{th}}`` simulated example in the main paper. This dataset was generated using the following code in Julia v1.8.1:
```julia
using RedClust
using Random: seed!
seed!(44)
K = 10 # Number of clusters
N = 100 # Number of points
data_σ = [0.25, 0.2, 0.18][n] # Variance of the normal kernel
data_dim = [10, 50, 10][n] # Data dimension
data = generatemixture(N, K; α = 10, σ = data_σ, dim = data_dim);
using RedClust
using Random: seed!
seed!(44)
K = 10 # Number of clusters
N = 100 # Number of points
data_σ = [0.25, 0.2, 0.18][n] # Variance of the normal kernel
data_dim = [10, 50, 10][n] # Data dimension
data = generatemixture(N, K; α = 10, σ = data_σ, dim = data_dim);
```
Note however that the above code may produce different results on your computer because the random number generator in Julia is not meant for reproducible results across different computers, different versions of Julia, or different versions of the Random.jl package, even with appropriate seeding. Therefore the datasets have been included with this package, and it is recommended to access them via this function.
See also [`generatemixture`](@ref).
See also [`generatemixture`](@ref).
"""
function example_dataset(n::Int)
if n [1, 2, 3]
throw(ArgumentError("n must be 1, 2, or 3."))
end
datafile = example_datasets()
egdata = datafile["example" * string(n)]
egdata = datafile["example"*string(n)]
x = read(egdata["points"])
points = [x[:, i] for i in 1:last(size(x))]
clusts = read(egdata["cluster_labels"])
distmatrix = read(egdata["distance_matrix"])
probs = read(egdata["cluster_weights"])
oracle_coclustering = read(egdata["oracle_coclustering_probabilities"])
close(datafile)
data = (points = points, distmatrix = distmatrix, clusts = clusts, probs = probs, oracle_coclustering = oracle_coclustering)
data = (points=points, distmatrix=distmatrix, clusts=clusts, probs=probs, oracle_coclustering=oracle_coclustering)
return data
end
end
Loading

0 comments on commit 7f42e07

Please sign in to comment.