From c6cf4d42d08fb87fef758d8b265a89fd7b6c07b0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 1 Jul 2024 18:01:58 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mess.py | 27 +++++++++++++++++---------- tweetopic/_btm.py | 43 ++++++++++++++++++++++++++++++++----------- tweetopic/_dmm.py | 4 +++- tweetopic/_doc.py | 2 +- tweetopic/btm.py | 18 +++++++++++------- tweetopic/func.py | 5 +++-- 6 files changed, 67 insertions(+), 32 deletions(-) diff --git a/mess.py b/mess.py index 3d544f4..490631e 100644 --- a/mess.py +++ b/mess.py @@ -17,10 +17,14 @@ from tqdm import trange from tweetopic._doc import init_doc_words -from tweetopic.bayesian.dmm import (BayesianDMM, posterior_predictive, - predict_doc, sparse_multinomial_logpdf, - symmetric_dirichlet_logpdf, - symmetric_dirichlet_multinomial_logpdf) +from tweetopic.bayesian.dmm import ( + BayesianDMM, + posterior_predictive, + predict_doc, + sparse_multinomial_logpdf, + symmetric_dirichlet_logpdf, + symmetric_dirichlet_multinomial_logpdf, +) from tweetopic.bayesian.sampling import batch_data, sample_nuts from tweetopic.func import spread @@ -58,23 +62,26 @@ def logprior_fn(params): def loglikelihood_fn(params, data): doc_likelihood = jax.vmap( - partial(sparse_multinomial_logpdf, component=params["component"]) + partial(sparse_multinomial_logpdf, component=params["component"]), ) return jnp.sum( doc_likelihood( unique_words=data["doc_unique_words"], unique_word_counts=data["doc_unique_word_counts"], - ) + ), ) logdensity_fn(position) logdensity_fn = lambda params: logprior_fn(params) + loglikelihood_fn( - params, data + params, + data, ) grad_estimator = blackjax.sgmcmc.gradients.grad_estimator( - logprior_fn, loglikelihood_fn, data_size=n_documents + logprior_fn, + loglikelihood_fn, + data_size=n_documents, ) rng_key = jax.random.PRNGKey(0) batch_key, warmup_key, sampling_key = jax.random.split(rng_key, 3) @@ -88,8 +95,8 @@ def loglikelihood_fn(params, data): ) position = dict( component=jnp.array( - transform(stats.dirichlet.mean(alpha=np.full(n_features, alpha))) - ) + transform(stats.dirichlet.mean(alpha=np.full(n_features, alpha))), + ), ) samples, states = sample_nuts(position, logdensity_fn) diff --git a/tweetopic/_btm.py b/tweetopic/_btm.py index b485fec..c336f38 100644 --- a/tweetopic/_btm.py +++ b/tweetopic/_btm.py @@ -1,4 +1,4 @@ -"""Module for utility functions for fitting BTMs""" +"""Module for utility functions for fitting BTMs.""" import random from typing import Dict, Tuple, TypeVar @@ -12,7 +12,8 @@ @njit def doc_unique_biterms( - doc_unique_words: np.ndarray, doc_unique_word_counts: np.ndarray + doc_unique_words: np.ndarray, + doc_unique_word_counts: np.ndarray, ) -> Dict[Tuple[int, int], int]: (n_max_unique_words,) = doc_unique_words.shape biterm_counts = dict() @@ -43,7 +44,7 @@ def doc_unique_biterms( @njit def nb_add_counter(dest: Dict[T, int], source: Dict[T, int]): - """Adds one counter dict to another in place with Numba""" + """Adds one counter dict to another in place with Numba.""" for key in source: if key in dest: dest[key] += source[key] @@ -53,17 +54,20 @@ def nb_add_counter(dest: Dict[T, int], source: Dict[T, int]): @njit def corpus_unique_biterms( - doc_unique_words: np.ndarray, doc_unique_word_counts: np.ndarray + doc_unique_words: np.ndarray, + doc_unique_word_counts: np.ndarray, ) -> Dict[Tuple[int, int], int]: n_documents, _ = doc_unique_words.shape biterm_counts = doc_unique_biterms( - doc_unique_words[0], doc_unique_word_counts[0] + doc_unique_words[0], + doc_unique_word_counts[0], ) for i_doc in range(1, n_documents): doc_unique_words_i = doc_unique_words[i_doc] doc_unique_word_counts_i = doc_unique_word_counts[i_doc] doc_biterms = doc_unique_biterms( - doc_unique_words_i, doc_unique_word_counts_i + doc_unique_words_i, + doc_unique_word_counts_i, ) nb_add_counter(biterm_counts, doc_biterms) return biterm_counts @@ -71,7 +75,7 @@ def corpus_unique_biterms( @njit def compute_biterm_set( - biterm_counts: Dict[Tuple[int, int], int] + biterm_counts: Dict[Tuple[int, int], int], ) -> np.ndarray: return np.array(list(biterm_counts.keys())) @@ -116,7 +120,12 @@ def add_biterm( topic_biterm_count: np.ndarray, ) -> None: add_remove_biterm( - True, i_biterm, i_topic, biterms, topic_word_count, topic_biterm_count + True, + i_biterm, + i_topic, + biterms, + topic_word_count, + topic_biterm_count, ) @@ -129,7 +138,12 @@ def remove_biterm( topic_biterm_count: np.ndarray, ) -> None: add_remove_biterm( - False, i_biterm, i_topic, biterms, topic_word_count, topic_biterm_count + False, + i_biterm, + i_topic, + biterms, + topic_word_count, + topic_biterm_count, ) @@ -147,7 +161,11 @@ def init_components( i_topic = random.randint(0, n_components - 1) biterm_topic_assignments[i_biterm] = i_topic add_biterm( - i_biterm, i_topic, biterms, topic_word_count, topic_biterm_count + i_biterm, + i_topic, + biterms, + topic_word_count, + topic_biterm_count, ) return biterm_topic_assignments, topic_word_count, topic_biterm_count @@ -448,7 +466,10 @@ def predict_docs( ) biterms = doc_unique_biterms(words, word_counts) prob_topic_given_document( - pred, biterms, topic_distribution, topic_word_distribution + pred, + biterms, + topic_distribution, + topic_word_distribution, ) predictions[i_doc, :] = pred return predictions diff --git a/tweetopic/_dmm.py b/tweetopic/_dmm.py index 21f59f6..728565c 100644 --- a/tweetopic/_dmm.py +++ b/tweetopic/_dmm.py @@ -1,4 +1,6 @@ -"""Module containing tools for fitting a Dirichlet Multinomial Mixture Model.""" +"""Module containing tools for fitting a Dirichlet Multinomial Mixture +Model.""" + from __future__ import annotations from math import exp, log diff --git a/tweetopic/_doc.py b/tweetopic/_doc.py index 657c6dc..e66e65a 100644 --- a/tweetopic/_doc.py +++ b/tweetopic/_doc.py @@ -11,7 +11,7 @@ def init_doc_words( n_docs, _ = doc_term_matrix.shape doc_unique_words = np.zeros((n_docs, max_unique_words)).astype(np.uint32) doc_unique_word_counts = np.zeros((n_docs, max_unique_words)).astype( - np.uint32 + np.uint32, ) for i_doc in range(n_docs): unique_words = doc_term_matrix[i_doc].rows[0] # type: ignore diff --git a/tweetopic/btm.py b/tweetopic/btm.py index c8ee3ab..bf8a988 100644 --- a/tweetopic/btm.py +++ b/tweetopic/btm.py @@ -9,16 +9,19 @@ import sklearn from numpy.typing import ArrayLike -from tweetopic._btm import (compute_biterm_set, corpus_unique_biterms, - fit_model, predict_docs) +from tweetopic._btm import ( + compute_biterm_set, + corpus_unique_biterms, + fit_model, + predict_docs, +) from tweetopic._doc import init_doc_words from tweetopic.exceptions import NotFittedException from tweetopic.utils import set_numba_seed class BTM(sklearn.base.TransformerMixin, sklearn.base.BaseEstimator): - """Implementation of the Biterm Topic Model with Gibbs Sampling - solver. + """Implementation of the Biterm Topic Model with Gibbs Sampling solver. Parameters ---------- @@ -144,7 +147,9 @@ def fit(self, X: Union[spr.spmatrix, ArrayLike], y: None = None): X.tolil(), max_unique_words=max_unique_words, ) - biterms = corpus_unique_biterms(doc_unique_words, doc_unique_word_counts) + biterms = corpus_unique_biterms( + doc_unique_words, doc_unique_word_counts + ) biterm_set = compute_biterm_set(biterms) self.topic_distribution, self.components_ = fit_model( n_iter=self.n_iterations, @@ -159,8 +164,7 @@ def fit(self, X: Union[spr.spmatrix, ArrayLike], y: None = None): # TODO: Something goes terribly wrong here, fix this def transform(self, X: Union[spr.spmatrix, ArrayLike]) -> np.ndarray: - """Predicts probabilities for each document belonging to each - topic. + """Predicts probabilities for each document belonging to each topic. Parameters ---------- diff --git a/tweetopic/func.py b/tweetopic/func.py index cfd2029..934eb5f 100644 --- a/tweetopic/func.py +++ b/tweetopic/func.py @@ -1,11 +1,12 @@ """Utility functions for use in the library.""" + from functools import wraps from typing import Callable def spread(fn: Callable): - """Creates a new function from the given function so that it takes one - dict (PyTree) and spreads the arguments.""" + """Creates a new function from the given function so that it takes one dict + (PyTree) and spreads the arguments.""" @wraps(fn) def inner(kwargs):