From 7b7bbe5f4ffde9791a044939af2fc3e66c4b5386 Mon Sep 17 00:00:00 2001 From: Chris Lloyd Date: Sat, 4 Nov 2023 15:23:26 +0000 Subject: [PATCH] Support saving/loading models that use OnlineCountVectorizer --- bertopic/vectorizers/_online_cv.py | 39 ++++++++++++++++++++++-- tests/test_vectorizers/test_online_cv.py | 13 ++++++++ 2 files changed, 50 insertions(+), 2 deletions(-) diff --git a/bertopic/vectorizers/_online_cv.py b/bertopic/vectorizers/_online_cv.py index f7244bd1..370af337 100644 --- a/bertopic/vectorizers/_online_cv.py +++ b/bertopic/vectorizers/_online_cv.py @@ -71,10 +71,45 @@ class OnlineCountVectorizer(CountVectorizer): def __init__(self, decay: float = None, delete_min_df: float = None, - **kwargs): + input="content", + encoding="utf-8", + decode_error="strict", + strip_accents=None, + lowercase=True, + preprocessor=None, + tokenizer=None, + stop_words=None, + token_pattern=r"(?u)\b\w\w+\b", + ngram_range=(1, 1), + analyzer="word", + max_df=1.0, + min_df=1, + max_features=None, + vocabulary=None, + binary=False, + dtype=np.int64, + ): self.decay = decay self.delete_min_df = delete_min_df - super(OnlineCountVectorizer, self).__init__(**kwargs) + super(OnlineCountVectorizer, self).__init__( + input=input, + encoding=encoding, + decode_error=decode_error, + strip_accents=strip_accents, + lowercase=lowercase, + preprocessor=preprocessor, + tokenizer=tokenizer, + stop_words=stop_words, + token_pattern=token_pattern, + ngram_range=ngram_range, + analyzer=analyzer, + max_df=max_df, + min_df=min_df, + max_features=max_features, + vocabulary=vocabulary, + binary=binary, + dtype=dtype + ) def partial_fit(self, raw_documents: List[str]) -> None: """ Perform a partial fit and update vocabulary with OOV tokens diff --git a/tests/test_vectorizers/test_online_cv.py b/tests/test_vectorizers/test_online_cv.py index d7ab677e..1d3c63ea 100644 --- a/tests/test_vectorizers/test_online_cv.py +++ b/tests/test_vectorizers/test_online_cv.py @@ -1,6 +1,8 @@ import copy import pytest + +from bertopic import BERTopic from bertopic.vectorizers import OnlineCountVectorizer @@ -28,3 +30,14 @@ def test_clean_bow(model, request): assert original_shape[0] == topic_model.vectorizer_model.X_.shape[0] assert original_shape[1] > topic_model.vectorizer_model.X_.shape[1] + +@pytest.mark.parametrize('model', [('online_topic_model')]) +def test_load_save(model, request): + topic_model = copy.deepcopy(request.getfixturevalue(model)) + original_vectorizer = topic_model.vectorizer_model + topic_model.save("test") + loaded_model = BERTopic.load("test") + loaded_vectorizer = loaded_model.vectorizer_model + + assert isinstance(loaded_vectorizer, OnlineCountVectorizer) + assert loaded_vectorizer.decay == original_vectorizer.decay