Skip to content

Commit

Permalink
Fixed embedding parameter not working (#36)
Browse files Browse the repository at this point in the history
  • Loading branch information
MaartenGr authored Jan 10, 2021
1 parent c271ec6 commit 8813b4d
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 65 deletions.
7 changes: 3 additions & 4 deletions bertopic/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from bertopic._bertopic import BERTopic
from bertopic._ctfidf import ClassTFIDF
from bertopic._embeddings import languages, embedding_models
from bertopic._embeddings import languages

__version__ = "0.4.1"
__version__ = "0.4.2"

__all__ = [
"BERTopic",
"ClassTFIDF",
"languages",
"embedding_models",
"languages"
]
16 changes: 5 additions & 11 deletions bertopic/_bertopic.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
# BERTopic
from ._ctfidf import ClassTFIDF
from ._utils import MyLogger, check_documents_type, check_embeddings_shape, check_is_fitted
from ._embeddings import languages, embedding_models
from ._embeddings import languages
from ._mmr import mmr

# Additional dependencies
Expand Down Expand Up @@ -884,6 +884,10 @@ def _select_embedding_model(self) -> SentenceTransformer:
if self.custom_embeddings and self.allow_st_model:
return SentenceTransformer("xlm-r-bert-base-nli-stsb-mean-tokens")

# Select embedding model based on specific sentence transformer model
elif self.embedding_model:
return SentenceTransformer(self.embedding_model)

# Select embedding model based on language
elif self.language:
if self.language.lower() in ["English", "english", "en"]:
Expand All @@ -901,16 +905,6 @@ def _select_embedding_model(self) -> SentenceTransformer:
"Else, please select a language from the following list:\n"
f"{languages}")

# Select embedding model based on specific sentence transformer model
elif self.embedding_model:
if self.embedding_model in embedding_models:
return SentenceTransformer(self.embedding_model)
else:
raise ValueError("Please select an embedding model from the following list:\n"
f"{embedding_models}\n\n"
f"For more information about the models, see:\n"
f"https://www.sbert.net/docs/pretrained_models.html")

return SentenceTransformer("xlm-r-bert-base-nli-stsb-mean-tokens")

def _reduce_topics(self, documents: pd.DataFrame) -> pd.DataFrame:
Expand Down
44 changes: 0 additions & 44 deletions bertopic/_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,3 @@
# All models, as of 13/12/2020, pretrained for sentence transformers
embedding_models = ['LaBSE',
'average_word_embeddings_glove.6B.300d',
'average_word_embeddings_glove.840B.300d',
'average_word_embeddings_komninos',
'average_word_embeddings_levy_dependency',
'bert-base-nli-cls-token',
'bert-base-nli-max-tokens',
'bert-base-nli-mean-tokens',
'bert-base-nli-stsb-mean-tokens',
'bert-base-nli-stsb-wkpooling',
'bert-base-nli-wkpooling',
'bert-base-wikipedia-sections-mean-tokens',
'bert-large-nli-cls-token',
'bert-large-nli-max-tokens',
'bert-large-nli-mean-tokens',
'bert-large-nli-stsb-mean-tokens',
'distilbert-base-nli-max-tokens',
'distilbert-base-nli-mean-tokens',
'distilbert-base-nli-stsb-mean-tokens',
'distilbert-base-nli-stsb-quora-ranking',
'distilbert-base-nli-stsb-wkpooling',
'distilbert-base-nli-wkpooling',
'distilbert-multilingual-nli-stsb-quora-ranking',
'distilroberta-base-msmarco-v1',
'distilroberta-base-msmarco-v2',
'distilroberta-base-paraphrase-v1',
'distiluse-base-multilingual-cased-v1',
'distiluse-base-multilingual-cased-v2',
'distiluse-base-multilingual-cased',
'roberta-base-nli-mean-tokens',
'roberta-base-nli-stsb-mean-tokens',
'roberta-base-nli-stsb-wkpooling',
'roberta-base-nli-wkpooling',
'roberta-large-nli-mean-tokens',
'roberta-large-nli-stsb-mean-tokens',
'xlm-r-100langs-bert-base-nli-mean-tokens',
'xlm-r-100langs-bert-base-nli-stsb-mean-tokens',
'xlm-r-base-en-ko-nli-ststb',
'xlm-r-bert-base-nli-mean-tokens',
'xlm-r-bert-base-nli-stsb-mean-tokens',
'xlm-r-distilroberta-base-paraphrase-v1',
'xlm-r-large-en-ko-nli-ststb']

languages = ['afrikaans', 'albanian', 'amharic', 'arabic', 'armenian', 'assamese',
'azerbaijani', 'basque', 'belarusian', 'bengali', 'bengali romanize',
'bosnian', 'breton', 'bulgarian', 'burmese', 'burmese zawgyi font', 'catalan',
Expand Down
9 changes: 9 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
## **Version 0.4.2**
*Release date: 10 Januari, 2021*

**Fixes**:

* Selecting `embedding_model` did not work when `language` was also used. This led to the user needing
to set `language` to None before being able to use `embedding_model`. Fixed by using `embedding_model` when
`language` is used (as a default parameter).

## **Version 0.4.1**
*Release date: 07 Januari, 2021*

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
setuptools.setup(
name="bertopic",
packages=["bertopic"],
version="0.4.1",
version="0.4.2",
author="Maarten Grootendorst",
author_email="[email protected]",
description="BERTopic performs topic Modeling with state-of-the-art transformer models.",
Expand Down
6 changes: 1 addition & 5 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,6 @@ def test_load_model():

def test_extract_incorrect_embeddings():
""" Test if errors are raised when loading incorrect model """
with pytest.raises(ValueError):
model = BERTopic(language=None, embedding_model='not_a_model')
model._extract_embeddings(["Some document"])

with pytest.raises(ValueError):
model = BERTopic(language="Unknown language")
model._extract_embeddings(["Some document"])
Expand All @@ -98,7 +94,7 @@ def test_extract_incorrect_embeddings():
def test_extract_embeddings():
""" Test if correct model is loaded and embeddings match the sentence-transformers version """
docs = ["some document"]
model = BERTopic(language=None, embedding_model="distilbert-base-nli-stsb-mean-tokens")
model = BERTopic(embedding_model="distilbert-base-nli-stsb-mean-tokens")
bertopic_embeddings = model._extract_embeddings(docs)

assert isinstance(bertopic_embeddings, np.ndarray)
Expand Down

0 comments on commit 8813b4d

Please sign in to comment.