diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md new file mode 100644 index 00000000..f4224779 --- /dev/null +++ b/.github/CONTRIBUTING.md @@ -0,0 +1,53 @@ +# Contributing to BERTopic + +Hi! Thank you for considering contributing to BERTopic. With the modular nature of BERTopic, many new add-ons, backends, representation models, sub-models, and LLMs, can quickly be added to keep up with the incredibly fast-pacing field. + +Whether contributions are new features, better documentation, bug fixes, or improvement on the repository itself, anything is appreciated! + +## πŸ“š Guidelines + +### πŸ€– Contributing Code + +To contribute to this project, we follow an `issue -> pull request` approach for main features and bug fixes. This means that any new feature, bug fix, or anything else that touches on code directly needs to start from an issue first. That way, the main discussion about what needs to be added/fixed can be done in the issue before creating a pull request. This makes sure that we are on the same page before you start coding your pull request. If you start working on an issue, please assign it to yourself but do so after there is an agreement with the maintainer, [@MaartenGr](https://github.com/MaartenGr). + +When there is agreement on the assigned approach, a pull request can be created in which the fix/feature can be added. This follows a ["fork and pull request"](https://docs.github.com/en/get-started/quickstart/contributing-to-projects) workflow. +Please do not try to push directly to this repo unless you are a maintainer. + +There are exceptions to the `issue -> pull request` approach that are typically small changes that do not need agreements, such as: +* Documentation +* Spelling/grammar issues +* Docstrings +* etc. + +There is a large focus on documentation in this repository, so please make sure to add extensive descriptions of features when creating the pull request. + +Note that the main focus of pull requests and code should be: +* Easy readability +* Clear communication +* Sufficient documentation + +## πŸš€ Quick Start + +To start contributing, make sure to first start from a fresh environment. Using an environment manager, such as `conda` or `pyenv` helps in making sure that your code is reproducible and tracks the versions you have in your environment. + +If you are using conda, you can approach it as follows: + +1. Create and activate a new conda environment (e.g., `conda create -n bertopic python=3.9`) +2. Install requirements (e.g., `pip install .[dev]`) + * This makes sure to also install documentation and testing packages +3. (Optional) Run `make docs` to build your documentation +4. (Optional) Run `make test` to run the unit tests and `make coverage` to check the coverage of unit tests + +❗Note: Unit testing the package can take quite some time since it needs to run several variants of the BERTopic pipeline. + +## πŸ€“ Collaborative Efforts + +When you run into any issue with the above or need help to start with a pull request, feel free to reach out in the issues! As with all repositories, this one has its particularities as a result of the maintainer's view. Each repository is quite different and so will their processes. + +## πŸ† Recognition + +If your contribution has made its way into a new release of BERTopic, you will be given credit in the changelog of the new release! Regardless of the size of the contribution, any help is greatly appreciated. + +## 🎈 Release + +BERTopic tries to mostly follow [semantic versioning](https://semver.org/) for its new releases. Even though BERTopic has been around for a few years now, it is still pre-1.0 software. With the rapid chances in the field and as a way to keep up, this versioning is on purpose. Backwards-compatibility is taken into account but integrating new features and thereby keeping up with the field takes priority. Especially since BERTopic focuses on modularity, flexibility is necessary. diff --git a/.gitignore b/.gitignore index c1f2cc10..e526ce54 100644 --- a/.gitignore +++ b/.gitignore @@ -26,6 +26,9 @@ share/python-wheels/ .installed.cfg *.egg MANIFEST +model_dir +model_dir/ +test # PyInstaller # Usually these files are written by a python script from a template diff --git a/Makefile b/Makefile index d0bdae9c..114b01a9 100644 --- a/Makefile +++ b/Makefile @@ -1,12 +1,17 @@ test: pytest +coverage: + pytest --cov + install: python -m pip install -e . install-test: - python -m pip install -e ".[test]" - python -m pip install -e "." + python -m pip install -e ".[dev]" + +docs: + mkdocs serve pypi: python setup.py sdist diff --git a/README.md b/README.md index 9acacb05..bf6fdc1b 100644 --- a/README.md +++ b/README.md @@ -13,18 +13,29 @@ BERTopic is a topic modeling technique that leverages πŸ€— transformers and c-TF-IDF to create dense clusters allowing for easily interpretable topics whilst keeping important words in the topic descriptions. -BERTopic supports -[**guided**](https://maartengr.github.io/BERTopic/getting_started/guided/guided.html), -[**supervised**](https://maartengr.github.io/BERTopic/getting_started/supervised/supervised.html), -[**semi-supervised**](https://maartengr.github.io/BERTopic/getting_started/semisupervised/semisupervised.html), -[**manual**](https://maartengr.github.io/BERTopic/getting_started/manual/manual.html), -[**long-document**](https://maartengr.github.io/BERTopic/getting_started/distribution/distribution.html), -[**hierarchical**](https://maartengr.github.io/BERTopic/getting_started/hierarchicaltopics/hierarchicaltopics.html), -[**class-based**](https://maartengr.github.io/BERTopic/getting_started/topicsperclass/topicsperclass.html), -[**dynamic**](https://maartengr.github.io/BERTopic/getting_started/topicsovertime/topicsovertime.html), -[**online**](https://maartengr.github.io/BERTopic/getting_started/online/online.html), -[**multimodal**](https://maartengr.github.io/BERTopic/getting_started/multimodal/multimodal.html), and -[**multi-aspect**](https://maartengr.github.io/BERTopic/getting_started/multiaspect/multiaspect.html) topic modeling. It even supports visualizations similar to LDAvis! +BERTopic supports all kinds of topic modeling techniques: + + + + + + + + + + + + + + + + + + + + + +
GuidedSupervisedSemi-supervised
ManualMulti-topic distributionsHierarchical
Class-basedDynamicOnline/Incremental
MultimodalMulti-aspectText Generation/LLM
Corresponding medium posts can be found [here](https://towardsdatascience.com/topic-modeling-with-bert-779f7db187e6?source=friends_link&sk=0b5a470c006d1842ad4c8a3057063a99), [here](https://towardsdatascience.com/interactive-topic-modeling-with-bertopic-1ea55e7d73d8?sk=03c2168e9e74b6bda2a1f3ed953427e4) and [here](https://towardsdatascience.com/using-whisper-and-bertopic-to-model-kurzgesagts-videos-7d8a63139bdf?sk=b1e0fd46f70cb15e8422b4794a81161d). For a more detailed overview, you can read the [paper](https://arxiv.org/abs/2203.05794) or see a [brief overview](https://maartengr.github.io/BERTopic/algorithm/algorithm.html). @@ -39,13 +50,10 @@ pip install bertopic If you want to install BERTopic with other embedding models, you can choose one of the following: ```bash -# Embedding models -pip install bertopic[flair] -pip install bertopic[gensim] -pip install bertopic[spacy] -pip install bertopic[use] +# Choose an embedding backend +pip install bertopic[flair, gensim, spacy, use] -# Vision topic modeling +# Topic modeling with images pip install bertopic[vision] ``` @@ -61,6 +69,7 @@ with one of the examples below: | Advanced Customization in BERTopic | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1ClTYut039t-LDtlcd-oQAdXWgcsSGTw9?usp=sharing) | | (semi-)Supervised Topic Modeling with BERTopic | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1bxizKzv5vfxJEB29sntU__ZC7PBSIPaQ?usp=sharing) | | Dynamic Topic Modeling with Trump's Tweets | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1un8ooI-7ZNlRoK0maVkYhmNRl0XGK88f?usp=sharing) | +| Topic Modeling on Large Data | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1W7aEdDPxC29jP99GGZphUlqjMFFVKtBC?usp=sharing) | | Topic Modeling arXiv Abstracts | [![Kaggle](https://img.shields.io/static/v1?style=for-the-badge&message=Kaggle&color=222222&logo=Kaggle&logoColor=20BEFF&label=)](https://www.kaggle.com/maartengr/topic-modeling-arxiv-abstract-with-bertopic) | @@ -122,8 +131,7 @@ Think! It's the SCSI card doing... 49 49_windows_drive_dos_file windows - dr 1) I have an old Jasmine drive... 49 49_windows_drive_dos_file windows - drive - docs... 0.038983 ... ``` -> πŸ”₯ **Tip** -> Use `BERTopic(language="multilingual")` to select a model that supports 50+ languages. +**`πŸ”₯ Tip`**: Use `BERTopic(language="multilingual")` to select a model that supports 50+ languages. ## Fine-tune Topic Representations @@ -137,8 +145,20 @@ representation_model = KeyBERTInspired() topic_model = BERTopic(representation_model=representation_model) ``` -> πŸ”₯ **Tip** -> Instead of iterating over all of these different topic representations, you can model them simultaneously with [multi-aspect topic representations](https://maartengr.github.io/BERTopic/getting_started/multiaspect/multiaspect.html) in BERTopic. +However, you might want to use something more powerful to describe your clusters. You can even use ChatGPT or other models from OpenAI to generate labels, summaries, phrases, keywords, and more: + +```python +import openai +from bertopic.representation import OpenAI + +# Fine-tune topic representations with GPT +openai.api_key = "sk-..." +representation_model = OpenAI(model="gpt-3.5-turbo", chat=True) +topic_model = BERTopic(representation_model=representation_model) +``` + +**`πŸ”₯ Tip`**: Instead of iterating over all of these different topic representations, you can model them simultaneously with [multi-aspect topic representations](https://maartengr.github.io/BERTopic/getting_started/multiaspect/multiaspect.html) in BERTopic. + ## Visualizations After having trained our BERTopic model, we can iteratively go through hundreds of topics to get a good @@ -153,7 +173,7 @@ topic_model.visualize_topics() ## Modularity -By default, the main steps for topic modeling with BERTopic are sentence-transformers, UMAP, HDBSCAN, and c-TF-IDF run in sequence. However, it assumes some independence between these steps which makes BERTopic quite modular. In other words, BERTopic not only allows you to build your own topic model but to explore several topic modeling techniques on top of your customized topic model: +By default, the [main steps](https://maartengr.github.io/BERTopic/algorithm/algorithm.html) for topic modeling with BERTopic are sentence-transformers, UMAP, HDBSCAN, and c-TF-IDF run in sequence. However, it assumes some independence between these steps which makes BERTopic quite modular. In other words, BERTopic not only allows you to build your own topic model but to explore several topic modeling techniques on top of your customized topic model: https://user-images.githubusercontent.com/25746895/218420473-4b2bb539-9dbe-407a-9674-a8317c7fb3bf.mp4 @@ -166,7 +186,6 @@ You can swap out any of these models or even remove them entirely. The following 5. [Weight](https://maartengr.github.io/BERTopic/getting_started/ctfidf/ctfidf.html) tokens 6. [Represent topics](https://maartengr.github.io/BERTopic/getting_started/representation/representation.html) with one or [multiple](https://maartengr.github.io/BERTopic/getting_started/multiaspect/multiaspect.html) representations -To find more about the underlying algorithm and assumptions [here](https://maartengr.github.io/BERTopic/algorithm/algorithm.html). ## Functionality BERTopic has many functions that quickly can become overwhelming. To alleviate this issue, you will find an overview @@ -228,12 +247,14 @@ There are many different use cases in which topic modeling can be used. As such, | [Semi-supervised Topic Modeling](https://maartengr.github.io/BERTopic/getting_started/semisupervised/semisupervised.html) | `.fit(docs, y=y)` | | [Supervised Topic Modeling](https://maartengr.github.io/BERTopic/getting_started/supervised/supervised.html) | `.fit(docs, y=y)` | | [Manual Topic Modeling](https://maartengr.github.io/BERTopic/getting_started/manual/manual.html) | `.fit(docs, y=y)` | +| [Multimodal Topic Modeling](https://maartengr.github.io/BERTopic/getting_started/multimodal/multimodal.html) | ``.fit(docs, images=images)`` | | [Topic Modeling per Class](https://maartengr.github.io/BERTopic/getting_started/topicsperclass/topicsperclass.html) | `.topics_per_class(docs, classes)` | | [Dynamic Topic Modeling](https://maartengr.github.io/BERTopic/getting_started/topicsovertime/topicsovertime.html) | `.topics_over_time(docs, timestamps)` | | [Hierarchical Topic Modeling](https://maartengr.github.io/BERTopic/getting_started/hierarchicaltopics/hierarchicaltopics.html) | `.hierarchical_topics(docs)` | | [Guided Topic Modeling](https://maartengr.github.io/BERTopic/getting_started/guided/guided.html) | `BERTopic(seed_topic_list=seed_topic_list)` | + ### Visualizations Evaluating topic models can be rather difficult due to the somewhat subjective nature of evaluation. Visualizing different aspects of the topic model helps in understanding the model and makes it easier diff --git a/bertopic/__init__.py b/bertopic/__init__.py index 5f1610ed..533092fa 100644 --- a/bertopic/__init__.py +++ b/bertopic/__init__.py @@ -1,6 +1,6 @@ from bertopic._bertopic import BERTopic -__version__ = "0.14.1" +__version__ = "0.15.0" __all__ = [ "BERTopic", diff --git a/bertopic/_bertopic.py b/bertopic/_bertopic.py index 2a6c56fe..0135364f 100644 --- a/bertopic/_bertopic.py +++ b/bertopic/_bertopic.py @@ -12,6 +12,7 @@ import math import joblib import inspect +import collections import numpy as np import pandas as pd import scipy.sparse as sp @@ -3004,8 +3005,13 @@ def load(cls, topics, params, tensors, ctfidf_tensors, ctfidf_config, images = save_utils.load_files_from_hf(path) else: raise ValueError("Make sure to either pass a valid directory or HF model.") + topic_model = _create_model_from_files(topics, params, tensors, ctfidf_tensors, ctfidf_config, images) + + # Replace embedding model if one is specifically chosen + if embedding_model is not None and type(topic_model.embedding_model) == BaseEmbedder: + topic_model.embedding_model = select_backend(embedding_model) - return _create_model_from_files(topics, params, tensors, ctfidf_tensors, ctfidf_config, images) + return topic_model def push_to_hf_hub( self, @@ -3510,8 +3516,7 @@ def _update_topic_size(self, documents: pd.DataFrame): Arguments: documents: Updated dataframe with documents and their corresponding IDs and newly added Topics """ - sizes = documents.groupby(['Topic']).count().sort_values("ID", ascending=False).reset_index() - self.topic_sizes_ = dict(zip(sizes.Topic, sizes.Document)) + self.topic_sizes_ = collections.Counter(documents.Topic.values.tolist()) self.topics_ = documents.Topic.astype(int).tolist() def _extract_words_per_topic(self, diff --git a/bertopic/_save_utils.py b/bertopic/_save_utils.py index e1e50aef..39e20d41 100644 --- a/bertopic/_save_utils.py +++ b/bertopic/_save_utils.py @@ -266,7 +266,11 @@ def generate_readme(model, repo_id: str): params = "\n".join([f"* {param}: {value}" for param, value in params.items()]) topics = sorted(list(set(model.topics_))) nr_topics = str(len(set(model.topics_))) - nr_documents = str(model.c_tf_idf_.shape[1]) + + if model.topic_sizes_ is not None: + nr_documents = str(sum(model.topic_sizes_.values())) + else: + nr_documents = "" # Topic information topic_keywords = [" - ".join(list(zip(*model.get_topic(topic)))[0][:5]) for topic in topics] @@ -290,7 +294,7 @@ def generate_readme(model, repo_id: str): if not has_visual_aspect: model_card = model_card.replace("{PIPELINE_TAG}", "text-classification") else: - model_card = model_card.replace("pipeline_tag: {PIPELINE_TAG} /n","") # TODO add proper tag for this instance + model_card = model_card.replace("pipeline_tag: {PIPELINE_TAG}\n","") # TODO add proper tag for this instance return model_card diff --git a/bertopic/representation/_openai.py b/bertopic/representation/_openai.py index f8cc7eb9..d2e61c46 100644 --- a/bertopic/representation/_openai.py +++ b/bertopic/representation/_openai.py @@ -4,6 +4,7 @@ from scipy.sparse import csr_matrix from typing import Mapping, List, Tuple, Any from bertopic.representation._base import BaseRepresentation +from bertopic.representation._utils import retry_with_exponential_backoff DEFAULT_PROMPT = """ @@ -69,6 +70,11 @@ class OpenAI(BaseRepresentation): inserted. delay_in_seconds: The delay in seconds between consecutive prompts in order to prevent RateLimitErrors. + exponential_backoff: Retry requests with a random exponential backoff. + A short sleep is used when a rate limit error is hit, + then the requests is retried. Increase the sleep length + if errors are hit until 10 unsuccesfull requests. + If True, overrides `delay_in_seconds`. chat: Set this to True if a GPT-3.5 model is used. See: https://platform.openai.com/docs/models/gpt-3-5 nr_docs: The number of documents to pass to OpenAI if a prompt @@ -116,6 +122,7 @@ def __init__(self, prompt: str = None, generator_kwargs: Mapping[str, Any] = {}, delay_in_seconds: float = None, + exponential_backoff: bool = False, chat: bool = False, nr_docs: int = 4, diversity: float = None @@ -129,6 +136,7 @@ def __init__(self, self.default_prompt_ = DEFAULT_CHAT_PROMPT if chat else DEFAULT_PROMPT self.delay_in_seconds = delay_in_seconds + self.exponential_backoff = exponential_backoff self.chat = chat self.nr_docs = nr_docs self.diversity = diversity @@ -176,10 +184,16 @@ def extract_topics(self, {"role": "user", "content": prompt} ] kwargs = {"model": self.model, "messages": messages, **self.generator_kwargs} - response = openai.ChatCompletion.create(**kwargs) + if self.exponential_backoff: + response = chat_completions_with_backoff(**kwargs) + else: + response = openai.ChatCompletion.create(**kwargs) label = response["choices"][0]["message"]["content"].strip().replace("topic: ", "") else: - response = openai.Completion.create(model=self.model, prompt=prompt, **self.generator_kwargs) + if self.exponential_backoff: + response = completions_with_backoff(model=self.model, prompt=prompt, **self.generator_kwargs) + else: + response = openai.Completion.create(model=self.model, prompt=prompt, **self.generator_kwargs) label = response["choices"][0]["text"].strip() updated_topics[topic] = [(label, 1)] @@ -212,3 +226,11 @@ def _replace_documents(prompt, docs): to_replace += f"- {doc[:255]}\n" prompt = prompt.replace("[DOCUMENTS]", to_replace) return prompt + + +def completions_with_backoff(**kwargs): + return retry_with_exponential_backoff(openai.Completion.create, errors=(openai.error.RateLimitError,))(**kwargs) + + +def chat_completions_with_backoff(**kwargs): + return retry_with_exponential_backoff(openai.ChatCompletion.create, errors=(openai.error.RateLimitError,))(**kwargs) diff --git a/bertopic/representation/_utils.py b/bertopic/representation/_utils.py index e69de29b..accd9a1a 100644 --- a/bertopic/representation/_utils.py +++ b/bertopic/representation/_utils.py @@ -0,0 +1,45 @@ +import random +import time + +def retry_with_exponential_backoff( + func, + initial_delay: float = 1, + exponential_base: float = 2, + jitter: bool = True, + max_retries: int = 10, + errors: tuple = None, +): + """Retry a function with exponential backoff.""" + + def wrapper(*args, **kwargs): + # Initialize variables + num_retries = 0 + delay = initial_delay + + # Loop until a successful response or max_retries is hit or an exception is raised + while True: + try: + return func(*args, **kwargs) + + # Retry on specific errors + except errors as e: + # Increment retries + num_retries += 1 + + # Check if max retries has been reached + if num_retries > max_retries: + raise Exception( + f"Maximum number of retries ({max_retries}) exceeded." + ) + + # Increment the delay + delay *= exponential_base * (1 + jitter * random.random()) + + # Sleep for the delay + time.sleep(delay) + + # Raise exceptions for any errors not specified + except Exception as e: + raise e + + return wrapper \ No newline at end of file diff --git a/docs/changelog.md b/docs/changelog.md index fbfbd655..9e39e8aa 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -5,6 +5,220 @@ hide: # Changelog + +## **Version 0.15.0** +*Release date: 29 May, 2023* + +

Highlights:

+ +* [**Multimodal**](https://maartengr.github.io/BERTopic/getting_started/multimodal/multimodal.html) Topic Modeling + * Train your topic modeling on text, images, or images and text! + * Use the `bertopic.backend.MultiModalBackend` to embed images, text, both or even caption images! +* [**Multi-Aspect**](https://maartengr.github.io/BERTopic/getting_started/multiaspect/multiaspect.html) Topic Modeling + * Create multiple topic representations simultaneously +* Improved [**Serialization**](https://maartengr.github.io/BERTopic/getting_started/serialization/serialization.html) options + * Push your model to the HuggingFace Hub with `.push_to_hf_hub` + * Safer, smaller and more flexible serialization options with `safetensors` + * Thanks to a great collaboration with HuggingFace and the authors of [BERTransfer](https://github.com/opinionscience/BERTransfer)! +* Added new embedding models + * OpenAI: `bertopic.backend.OpenAIBackend` + * Cohere: `bertopic.backend.CohereBackend` +* Added example of [summarizing topics](https://maartengr.github.io/BERTopic/getting_started/representation/representation.html#summarization) with OpenAI's GPT-models +* Added `nr_docs` and `diversity` parameters to OpenAI and Cohere representation models +* Use `custom_labels="Aspect1"` to use the aspect labels for visualizations instead +* Added cuML support for probability calculation in `.transform` +* Updated **topic embeddings** + * Centroids by default and c-TF-IDF weighted embeddings for `partial_fit` and `.update_topics` +* Added `exponential_backoff` parameter to `OpenAI` model + +

Fixes:

+ +* Fixed custom prompt not working in `TextGeneration` +* Fixed [#1142](https://github.com/MaartenGr/BERTopic/pull/1142) +* Add additional logic to handle cupy arrays by [@metasyn](https://github.com/metasyn) in [#1179](https://github.com/MaartenGr/BERTopic/pull/1179) +* Fix hierarchy viz and handle any form of distance matrix by [@elashrry](https://github.com/elashrry) in [#1173](https://github.com/MaartenGr/BERTopic/pull/1173) +* Updated languages list by [@sam9111](https://github.com/sam9111) in [#1099](https://github.com/MaartenGr/BERTopic/pull/1099) +* Added level_scale argument to visualize_hierarchical_documents by [@zilch42](https://github.com/zilch42) in [#1106](https://github.com/MaartenGr/BERTopic/pull/1106) +* Fix inconsistent naming by [@rolanderdei](https://github.com/rolanderdei) in [#1073](https://github.com/MaartenGr/BERTopic/pull/1073) + +

Multimodal Topic Modeling

+ +With v0.15, we can now perform multimodal topic modeling in BERTopic! The most basic example of multimodal topic modeling in BERTopic is when you have images that accompany your documents. This means that it is expected that each document has an image and vice versa. Instagram pictures, for example, almost always have some descriptions to them. + +
+ ![Image title](getting_started/multimodal/images_and_text.svg) +
+
+ +In this example, we are going to use images from `flickr` that each have a caption accociated to it: + +```python +# NOTE: This requires the `datasets` package which you can +# install with `pip install datasets` +from datasets import load_dataset + +ds = load_dataset("maderix/flickr_bw_rgb") +images = ds["train"]["image"] +docs = ds["train"]["caption"] +``` + +The `docs` variable contains the captions for each image in `images`. We can now use these variables to run our multimodal example: + +```python +from bertopic import BERTopic +from bertopic.representation import VisualRepresentation + +# Additional ways of representing a topic +visual_model = VisualRepresentation() + +# Make sure to add the `visual_model` to a dictionary +representation_model = { + "Visual_Aspect": visual_model, +} +topic_model = BERTopic(representation_model=representation_model, verbose=True) +``` + +We can now access our image representations for each topic with `topic_model.topic_aspects_["Visual_Aspect"]`. +If you want an overview of the topic images together with their textual representations in jupyter, you can run the following: + +```python +import base64 +from io import BytesIO +from IPython.display import HTML + +def image_base64(im): + if isinstance(im, str): + im = get_thumbnail(im) + with BytesIO() as buffer: + im.save(buffer, 'jpeg') + return base64.b64encode(buffer.getvalue()).decode() + + +def image_formatter(im): + return f'' + +# Extract dataframe +df = topic_model.get_topic_info().drop("Representative_Docs", 1).drop("Name", 1) + +# Visualize the images +HTML(df.to_html(formatters={'Visual_Aspect': image_formatter}, escape=False)) +``` + +![images_and_text](https://github.com/MaartenGr/BERTopic/assets/25746895/3a741e2b-5810-4865-9664-0c6bb24ca3f9) + + +

Multi-aspect Topic Modeling

+ +In this new release, we introduce `multi-aspect topic modeling`! During the `.fit` or `.fit_transform` stages, you can now get multiple representations of a single topic. In practice, it works by generating and storing all kinds of different topic representations (see image below). + +
+ ![Image title](getting_started/multiaspect/multiaspect.svg) +
+
+ +The approach is rather straightforward. We might want to represent our topics using a `PartOfSpeech` representation model but we might also want to try out `KeyBERTInspired` and compare those representation models. We can do this as follows: + +```python +from bertopic.representation import KeyBERTInspired +from bertopic.representation import PartOfSpeech +from bertopic.representation import MaximalMarginalRelevance +from sklearn.datasets import fetch_20newsgroups + +# Documents to train on +docs = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))['data'] + +# The main representation of a topic +main_representation = KeyBERTInspired() + +# Additional ways of representing a topic +aspect_model1 = PartOfSpeech("en_core_web_sm") +aspect_model2 = [KeyBERTInspired(top_n_words=30), MaximalMarginalRelevance(diversity=.5)] + +# Add all models together to be run in a single `fit` +representation_model = { + "Main": main_representation, + "Aspect1": aspect_model1, + "Aspect2": aspect_model2 +} +topic_model = BERTopic(representation_model=representation_model).fit(docs) +``` + +As show above, to perform multi-aspect topic modeling, we make sure that `representation_model` is a dictionary where each representation model pipeline is defined. +The main pipeline, that is used in most visualization options, is defined with the `"Main"` key. All other aspects can be defined however you want. In the example above, the two additional aspects that we are interested in are defined as `"Aspect1"` and `"Aspect2"`. + +After we have fitted our model, we can access all representations with `topic_model.get_topic_info()`: + + +
+ +As you can see, there are a number of different representations for our topics that we can inspect. All aspects are found in `topic_model.topic_aspects_`. + + +

Serialization

+ +Saving, loading, and sharing a BERTopic model can be done in several ways. With this new release, it is now advised to go with `.safetensors` as that allows for a small, safe, and fast method for saving your BERTopic model. However, other formats, such as `.pickle` and pytorch `.bin` are also possible. + +The methods are used as follows: + +```python +topic_model = BERTopic().fit(my_docs) + +# Method 1 - safetensors +embedding_model = "sentence-transformers/all-MiniLM-L6-v2" +topic_model.save("path/to/my/model_dir", serialization="safetensors", save_ctfidf=True, save_embedding_model=embedding_model) + +# Method 2 - pytorch +embedding_model = "sentence-transformers/all-MiniLM-L6-v2" +topic_model.save("path/to/my/model_dir", serialization="pytorch", save_ctfidf=True, save_embedding_model=embedding_model) + +# Method 3 - pickle +topic_model.save("my_model", serialization="pickle") +``` + +Saving the topic modeling with `.safetensors` or `pytorch` has a number of advantages: + +* `.safetensors` is a relatively **safe format** +* The resulting model can be **very small** (often < 20MB>) since no sub-models need to be saved +* Although version control is important, there is a bit more **flexibility** with respect to specific versions of packages +* More easily used in **production** +* **Share** models with the HuggingFace Hub + +

+ +

+ +The above image, a model trained on 100,000 documents, demonstrates the differences in sizes comparing `safetensors`, `pytorch`, and `pickle`. The difference in sizes can mostly be explained due to the efficient saving procedure and that the clustering and dimensionality reductions are not saved in safetensors/pytorch since inference can be done based on the topic embeddings. + + + + +

HuggingFace Hub

+ +When you have created a BERTopic model, you can easily share it with other through the HuggingFace Hub. First, you need to log in to your HuggingFace account: + +```python +from huggingface_hub import login +login() +``` + +When you have logged in to your HuggingFace account, you can save and upload the model as follows: + +```python +from bertopic import BERTopic + +# Train model +topic_model = BERTopic().fit(my_docs) + +# Push to HuggingFace Hub +topic_model.push_to_hf_hub( + repo_id="MaartenGr/BERTopic_ArXiv", + save_ctfidf=True +) + +# Load from HuggingFace +loaded_model = BERTopic.load("MaartenGr/BERTopic_ArXiv") +``` + ## **Version 0.14.1** *Release date: 2 March, 2023* diff --git a/docs/getting_started/quickstart/quickstart.md b/docs/getting_started/quickstart/quickstart.md index 78c4c903..67d12217 100644 --- a/docs/getting_started/quickstart/quickstart.md +++ b/docs/getting_started/quickstart/quickstart.md @@ -10,13 +10,10 @@ You may want to install more depending on the transformers and language backends The possible installations are: ```bash -# Embedding models -pip install bertopic[flair] -pip install bertopic[gensim] -pip install bertopic[spacy] -pip install bertopic[use] +# Choose an embedding backend +pip install bertopic[flair, gensim, spacy, use] -# Vision topic modeling +# Topic modeling with images pip install bertopic[vision] ``` @@ -93,6 +90,18 @@ representation_model = KeyBERTInspired() topic_model = BERTopic(representation_model=representation_model) ``` +However, you might want to use something more powerful to describe your clusters. You can even use ChatGPT or other models from OpenAI to generate labels, summaries, phrases, keywords, and more: + +```python +import openai +from bertopic.representation import OpenAI + +# Fine-tune topic representations with GPT +openai.api_key = "sk-..." +representation_model = OpenAI(model="gpt-3.5-turbo", chat=True) +topic_model = BERTopic(representation_model=representation_model) +``` + !!! tip "Multi-aspect Topic Modeling" Instead of iterating over all of these different topic representations, you can model them simultaneously with [multi-aspect topic representations](https://maartengr.github.io/BERTopic/getting_started/multiaspect/multiaspect.html) in BERTopic. @@ -127,12 +136,16 @@ Method 3 allows for saving the entire topic model but has several drawbacks: These methods have a number of advantages: * `.safetensors` is a relatively **safe format** -* The resulting model can be **very small** (often < 20MB>) since no sub-models need to be saved +* The resulting model can be **very small** (often < 20MB) since no sub-models need to be saved * Although version control is important, there is a bit more **flexibility** with respect to specific versions of packages * More easily used in **production** * **Share** models with the HuggingFace Hub +!!! Tip "Tip" + For more detail about how to load in a custom vectorizer, representation model, and more, it is highly advised to checkout the [serialization](https://maartengr.github.io/BERTopic/getting_started/serialization/serialization.html) page. It contains more examples, details, and some tips and tricks for loading and saving your environment. + + The methods are as used as follows: ```python diff --git a/docs/getting_started/representation/llm.md b/docs/getting_started/representation/llm.md new file mode 100644 index 00000000..6fa26da5 --- /dev/null +++ b/docs/getting_started/representation/llm.md @@ -0,0 +1,262 @@ +As we have seen in the [previous section](https://maartengr.github.io/BERTopic/getting_started/representation/representation.html), the topics that you get from BERTopic can be fine-tuned using a number of approaches. Here, we are going to focus on text generation Large Language Models such as ChatGPT, GPT-4, and open-source solution. + +Using these techniques, we can further fine-tune topics to generate labels, summaries, poems of topics, and more. To do so, we first generate a set of keywords and documents that describe a topic best using BERTopic's c-TF-IDF calculate. Then, these candidate keywords and documents are passed to the text generation model and asked to generate output that fits the topic best. + +A huge benefit of this is that we can describe a topic with only a few documents and we therefore do not need to pass all documents to the text generation model. Not only speeds this the generation of topic labels up significantly, you also do not need a massive amount of credits when using an external API, such as Cohere or OpenAI. + + +## **Prompts** + +In most of the examples below, we use certain tags to customize our prompts. There are currently two tags, namely `"[KEYWORDS]"` and `"[DOCUMENTS]"`. +These tags indicate where in the prompt they are to be replaced with a topics keywords and top 4 most representative documents respectively. +For example, if we have the following prompt: + +```python +prompt = """ +I have topic that contains the following documents: \n[DOCUMENTS] +The topic is described by the following keywords: [KEYWORDS] + +Based on the above information, can you give a short label of the topic? +""" +``` + +then that will be rendered as follows: + +```python +""" +I have a topic that contains the following documents: +- Our videos are also made possible by your support on patreon.co. +- If you want to help us make more videos, you can do so on patreon.com or get one of our posters from our shop. +- If you want to help us make more videos, you can do so there. +- And if you want to support us in our endeavor to survive in the world of online video, and make more videos, you can do so on patreon.com. + +The topic is described by the following keywords: videos video you our support want this us channel patreon make on we if facebook to patreoncom can for and more watch + +Based on the above information, can you give a short label of the topic? +""" +``` + +!!! tip Tip + You can access the default prompts of these models with `representation_model.default_prompt_` + +## **πŸ€— Transformers** + +Nearly every week, there are new and improved models released on the πŸ€— [Model Hub](https://huggingface.co/models) that, with some creativity, allow for +further fine-tuning of our c-TF-IDF based topics. These models range from text generation to zero-classification. In BERTopic, wrappers around these +methods are created as a way to support whatever might be released in the future. + +Using a GPT-like model from the huggingface hub is rather straightforward: + +```python +from bertopic.representation import TextGeneration +from bertopic import BERTopic + +# Create your representation model +representation_model = TextGeneration('gpt2') + +# Use the representation model in BERTopic on top of the default pipeline +topic_model = BERTopic(representation_model=representation_model) +``` + +GPT2, however, is not the most accurate model out there on HuggingFace models. You can get +much better results with a `flan-T5` like model: + +```python +from transformers import pipeline +from bertopic.representation import TextGeneration + +prompt = "I have a topic described by the following keywords: [KEYWORDS]. Based on the previous keywords, what is this topic about?"" + +# Create your representation model +generator = pipeline('text2text-generation', model='google/flan-t5-base') +representation_model = TextGeneration(generator) +``` + +
+
+--8<-- "docs/getting_started/representation/hf.svg" +
+
+ +As can be seen from the example above, if you would like to use a `text2text-generation` model, you will to +pass a `transformers.pipeline` with the `"text2text-generation"` parameter. Moreover, you can use a custom prompt and decide where the keywords should +be inserted by using the `[KEYWORDS]` or documents with the `[DOCUMENTS]` tag: + + +## **OpenAI** + +Instead of using a language model from πŸ€— transformers, we can use external APIs instead that +do the work for you. Here, we can use [OpenAI](https://openai.com/api/) to extract our topic labels from the candidate documents and keywords. +To use this, you will need to install openai first: + +```bash +pip install openai +``` + +Then, get yourself an API key and use OpenAI's API as follows: + +```python +import openai +from bertopic.representation import OpenAI +from bertopic import BERTopic + +# Create your representation model +openai.api_key = MY_API_KEY +representation_model = OpenAI() + +# Use the representation model in BERTopic on top of the default pipeline +topic_model = BERTopic(representation_model=representation_model) +``` + +
+
+--8<-- "docs/getting_started/representation/openai.svg" +
+
+ +You can also use a custom prompt: + +```python +prompt = "I have the following documents: [DOCUMENTS] \nThese documents are about the following topic: '" +representation_model = OpenAI(prompt=prompt) +``` + +### **ChatGPT** + +Within OpenAI's API, the ChatGPT models use a different API structure compared to the GPT-3 models. +In order to use ChatGPT with BERTopic, we need to define the model and make sure to enable `chat`: + +```python +representation_model = OpenAI(model="gpt-3.5-turbo", delay_in_seconds=10, chat=True) +``` + +Prompting with ChatGPT is very satisfying and is customizable as follows: + +```python +prompt = """ +I have a topic that contains the following documents: +[DOCUMENTS] +The topic is described by the following keywords: [KEYWORDS] + +Based on the information above, extract a short topic label in the following format: +topic: +""" +``` + +!!! note + Whenever you create a custom prompt, it is important to add + ``` + Based on the information above, extract a short topic label in the following format: + topic: + ``` + at the end of your prompt as BERTopic extracts everything that comes after `topic: `. Having + said that, if `topic: ` is not in the output, then it will simply extract the entire response, so + feel free to experiment with the prompts. + +### **Summarization** + +Due to the structure of the prompts in OpenAI's chat models, we can extract different types of topic representations from their GPT models. +Instead of extracting a topic label, we can instead ask it to extract a short description of the topic instead: + +```python +summarization_prompt = """ +I have a topic that is described by the following keywords: [KEYWORDS] +In this topic, the following documents are a small but representative subset of all documents in the topic: +[DOCUMENTS] + +Based on the information above, please give a description of this topic in the following format: +topic: +""" + +representation_model = OpenAI(model="gpt-3.5-turbo", chat=True, prompt=summarization_prompt, nr_docs=5, delay_in_seconds=3) +``` + +The above is not constrained to just creating a short description or summary of the topic, we can extract labels, keywords, poems, example documents, extensitive descriptions, and more using this method! +If you want to have multiple representations of a single topic, it might be worthwhile to also check out [**multi-aspect**](https://maartengr.github.io/BERTopic/getting_started/multiaspect/multiaspect.html) topic modeling with BERTopic. + + +## **LangChain** + +[Langchain](https://github.com/hwchase17/langchain) is a package that helps users with chaining large language models. +In BERTopic, we can leverage this package in order to more efficiently combine external knowledge. Here, this +external knowledge are the most representative documents in each topic. + +To use langchain, you will need to install the langchain package first. Additionally, you will need an underlying LLM to support langchain, +like openai: + +```bash +pip install langchain, openai +``` + +Then, you can create your chain as follows: + +```python +from langchain.chains.question_answering import load_qa_chain +from langchain.llms import OpenAI +chain = load_qa_chain(OpenAI(temperature=0, openai_api_key=my_openai_api_key), chain_type="stuff") +``` + +Finally, you can pass the chain to BERTopic as follows: + +```python +from bertopic.representation import LangChain + +# Create your representation model +representation_model = LangChain(chain) + +# Use the representation model in BERTopic on top of the default pipeline +topic_model = BERTopic(representation_model=representation_model) +``` + +You can also use a custom prompt: + +```python +prompt = "What are these documents about? Please give a single label." +representation_model = LangChain(chain, prompt=prompt) +``` + +!!! note Note + The prompt does not make use of `[KEYWORDS]` and `[DOCUMENTS]` tags as + the documents are already used within langchain's `load_qa_chain`. + +## **Cohere** + +Instead of using a language model from πŸ€— transformers, we can use external APIs instead that +do the work for you. Here, we can use [Cohere](https://docs.cohere.ai/) to extract our topic labels from the candidate documents and keywords. +To use this, you will need to install cohere first: + +```bash +pip install cohere +``` + +Then, get yourself an API key and use Cohere's API as follows: + +```python +import cohere +from bertopic.representation import Cohere +from bertopic import BERTopic + +# Create your representation model +co = cohere.Client(my_api_key) +representation_model = Cohere(co) + +# Use the representation model in BERTopic on top of the default pipeline +topic_model = BERTopic(representation_model=representation_model) +``` + +
+
+--8<-- "docs/getting_started/representation/cohere.svg" +
+
+ +You can also use a custom prompt: + +```python +prompt = """ +I have topic that contains the following documents: [DOCUMENTS] +The topic is described by the following keywords: [KEYWORDS]. +Based on the above information, can you give a short label of the topic? +""" +representation_model = Cohere(co, prompt=prompt) +``` \ No newline at end of file diff --git a/docs/getting_started/representation/representation.md b/docs/getting_started/representation/representation.md index faf00e1a..b1fbd247 100644 --- a/docs/getting_started/representation/representation.md +++ b/docs/getting_started/representation/representation.md @@ -163,267 +163,6 @@ topic_model = BERTopic(representation_model=representation_model)
-## **Text Generation & Prompts** - -Text generation models, like GPT-3 and the well-known ChatGPT, are becoming more and more capable of generating sensible output. -For that purpose, a number of models are exposed in BERTopic that allow topic labels to be created based on candidate documents and keywords -for each topic. These candidate documents and keywords are created from BERTopic's c-TF-IDF calculation. A huge benefit of this is that we can -describe a topic with only a few documents and we therefore do not need to pass all documents to the text generation model. Not only speeds -this the generation of topic labels up significantly, you also do not need a massive amount of credits when using an external API, such as Cohere or OpenAI. - -In most of the examples below, we use certain tags to customize our prompts. There are currently two tags, namely `"[KEYWORDS]"` and `"[DOCUMENTS]"`. -These tags indicate where in the prompt they are to be replaced with a topics keywords and top 4 most representative documents respectively. -For example, if we have the following prompt: - -```python -prompt = """ -I have topic that contains the following documents: \n[DOCUMENTS] -The topic is described by the following keywords: [KEYWORDS] - -Based on the above information, can you give a short label of the topic? -""" -``` - -then that will be rendered as follows: - -```python -""" -I have a topic that contains the following documents: -- Our videos are also made possible by your support on patreon.co. -- If you want to help us make more videos, you can do so on patreon.com or get one of our posters from our shop. -- If you want to help us make more videos, you can do so there. -- And if you want to support us in our endeavor to survive in the world of online video, and make more videos, you can do so on patreon.com. - -The topic is described by the following keywords: videos video you our support want this us channel patreon make on we if facebook to patreoncom can for and more watch - -Based on the above information, can you give a short label of the topic? -""" -``` - -!!! tip Tip - You can access the default prompts of these models with `representation_model.default_prompt_` - -### **πŸ€— Transformers** - -Nearly every week, there are new and improved models released on the πŸ€— [Model Hub](https://huggingface.co/models) that, with some creativity, allow for -further fine-tuning of our c-TF-IDF based topics. These models range from text generation to zero-classification. In BERTopic, wrappers around these -methods are created as a way to support whatever might be released in the future. - -Using a GPT-like model from the huggingface hub is rather straightforward: - -```python -from bertopic.representation import TextGeneration -from bertopic import BERTopic - -# Create your representation model -representation_model = TextGeneration('gpt2') - -# Use the representation model in BERTopic on top of the default pipeline -topic_model = BERTopic(representation_model=representation_model) -``` - -GPT2, however, is not the most accurate model out there on HuggingFace models. You can get -much better results with a `flan-T5` like model: - -```python -from transformers import pipeline -from bertopic.representation import TextGeneration - -prompt = "I have a topic described by the following keywords: [KEYWORDS]. Based on the previous keywords, what is this topic about?"" - -# Create your representation model -generator = pipeline('text2text-generation', model='google/flan-t5-base') -representation_model = TextGeneration(generator) -``` - -
-
---8<-- "docs/getting_started/representation/hf.svg" -
-
- -As can be seen from the example above, if you would like to use a `text2text-generation` model, you will to -pass a `transformers.pipeline` with the `"text2text-generation"` parameter. Moreover, you can use a custom prompt and decide where the keywords should -be inserted by using the `[KEYWORDS]` or documents with the `[DOCUMENTS]` tag: - -### **Cohere** - -Instead of using a language model from πŸ€— transformers, we can use external APIs instead that -do the work for you. Here, we can use [Cohere](https://docs.cohere.ai/) to extract our topic labels from the candidate documents and keywords. -To use this, you will need to install cohere first: - -```bash -pip install cohere -``` - -Then, get yourself an API key and use Cohere's API as follows: - -```python -import cohere -from bertopic.representation import Cohere -from bertopic import BERTopic - -# Create your representation model -co = cohere.Client(my_api_key) -representation_model = Cohere(co) - -# Use the representation model in BERTopic on top of the default pipeline -topic_model = BERTopic(representation_model=representation_model) -``` - -
-
---8<-- "docs/getting_started/representation/cohere.svg" -
-
- -You can also use a custom prompt: - -```python -prompt = """ -I have topic that contains the following documents: [DOCUMENTS] -The topic is described by the following keywords: [KEYWORDS]. -Based on the above information, can you give a short label of the topic? -""" -representation_model = Cohere(co, prompt=prompt) -``` - -### **OpenAI** - -Instead of using a language model from πŸ€— transformers, we can use external APIs instead that -do the work for you. Here, we can use [OpenAI](https://openai.com/api/) to extract our topic labels from the candidate documents and keywords. -To use this, you will need to install openai first: - -```bash -pip install openai -``` - -Then, get yourself an API key and use OpenAI's API as follows: - -```python -import openai -from bertopic.representation import OpenAI -from bertopic import BERTopic - -# Create your representation model -openai.api_key = MY_API_KEY -representation_model = OpenAI() - -# Use the representation model in BERTopic on top of the default pipeline -topic_model = BERTopic(representation_model=representation_model) -``` - -
-
---8<-- "docs/getting_started/representation/openai.svg" -
-
- -You can also use a custom prompt: - -```python -prompt = "I have the following documents: [DOCUMENTS] \nThese documents are about the following topic: '" -representation_model = OpenAI(prompt=prompt) -``` - -#### **ChatGPT** - -Within OpenAI's API, the ChatGPT models use a different API structure compared to the GPT-3 models. -In order to use ChatGPT with BERTopic, we need to define the model and make sure to enable `chat`: - -```python -representation_model = OpenAI(model="gpt-3.5-turbo", delay_in_seconds=10, chat=True) -``` - -Prompting with ChatGPT is very satisfying and is customizable as follows: - -```python -prompt = """ -I have a topic that contains the following documents: -[DOCUMENTS] -The topic is described by the following keywords: [KEYWORDS] - -Based on the information above, extract a short topic label in the following format: -topic: -""" -``` - -!!! note - Whenever you create a custom prompt, it is important to add - ``` - Based on the information above, extract a short topic label in the following format: - topic: - ``` - at the end of your prompt as BERTopic extracts everything that comes after `topic: `. Having - said that, if `topic: ` is not in the output, then it will simply extract the entire response, so - feel free to experiment with the prompts. - -#### **Summarization** - -Due to the structure of the prompts in OpenAI's chat models, we can extract different types of topic representations from their GPT models. -Instead of extracting a topic label, we can instead ask it to extract a short description of the topic instead: - -```python -summarization_prompt = """ -I have a topic that is described by the following keywords: [KEYWORDS] -In this topic, the following documents are a small but representative subset of all documents in the topic: -[DOCUMENTS] - -Based on the information above, please give a description of this topic in the following format: -topic: -""" - -representation_model = OpenAI(model="gpt-3.5-turbo", chat=True, prompt=summarization_prompt, nr_docs=5, delay_in_seconds=3) -``` - -The above is not constrained to just creating a short description or summary of the topic, we can extract labels, keywords, poems, example documents, extensitive descriptions, and more using this method! -If you want to have multiple representations of a single topic, it might be worthwhile to also check out [**multi-aspect**](https://maartengr.github.io/BERTopic/getting_started/multiaspect/multiaspect.html) topic modeling with BERTopic. - - -### **LangChain** - -[Langchain](https://github.com/hwchase17/langchain) is a package that helps users with chaining large language models. -In BERTopic, we can leverage this package in order to more efficiently combine external knowledge. Here, this -external knowledge are the most representative documents in each topic. - -To use langchain, you will need to install the langchain package first. Additionally, you will need an underlying LLM to support langchain, -like openai: - -```bash -pip install langchain, openai -``` - -Then, you can create your chain as follows: - -```python -from langchain.chains.question_answering import load_qa_chain -from langchain.llms import OpenAI -chain = load_qa_chain(OpenAI(temperature=0, openai_api_key=my_openai_api_key), chain_type="stuff") -``` - -Finally, you can pass the chain to BERTopic as follows: - -```python -from bertopic.representation import LangChain - -# Create your representation model -representation_model = LangChain(chain) - -# Use the representation model in BERTopic on top of the default pipeline -topic_model = BERTopic(representation_model=representation_model) -``` - -You can also use a custom prompt: - -```python -prompt = "What are these documents about? Please give a single label." -representation_model = LangChain(chain, prompt=prompt) -``` - -!!! note Note - The prompt does not make use of `[KEYWORDS]` and `[DOCUMENTS]` tags as - the documents are already used within langchain's `load_qa_chain`. - ## **Chain Models** All of the above models can make use of the candidate topics, as generated by c-TF-IDF, to further fine-tune the topic representations. For example, `MaximalMarginalRelevance` takes the keywords in the candidate topics and re-ranks them. Similarly, the keywords in the candidate topic can be used as the input for GPT-prompts in `OpenAI`. @@ -459,9 +198,6 @@ from bertopic.representation._base import BaseRepresentation class CustomRepresentationModel(BaseRepresentation): - def __init__(self): - pass - def extract_topics(self, topic_model, documents, c_tf_idf, topics ) -> Mapping[str, List[Tuple[str, float]]]: """ Extract topics @@ -494,8 +230,16 @@ topic_model = BERTopic(representation_model=representation_model) There are a few things to take into account when creating your custom model: * It needs to have the exact same parameter input: `topic_model`, `documents`, `c_tf_idf`, `topics`. -* You can change the `__init__` however you want, it does not influence the underlying structure * Make sure that `updated_topics` has the exact same structure as `topics`: - * For example: `updated_topics = {"1", [("space", 0.9), ("nasa", 0.7)], "2": [("science", 0.66), ("article", 0.6)]`} - * Thus, it is a dictionary where each topic is represented by a list of keyword,value tuples. -* Lastly, make sure that `updated_topics` contains at least 5 keywords, even if they are empty: `[("", 0), ("", 0), ...]` \ No newline at end of file + +```python +updated_topics = { + "1", [("space", 0.9), ("nasa", 0.7)], + "2": [("science", 0.66), ("article", 0.6)] +} +``` + +!!! Tip + You can change the `__init__` however you want, it does not influence the underlying structure. This + also means that you can save data/embeddings/representations/sentiment in your custom representation + model. diff --git a/docs/getting_started/serialization/serialization.md b/docs/getting_started/serialization/serialization.md index 00fd07b5..cb15cd76 100644 --- a/docs/getting_started/serialization/serialization.md +++ b/docs/getting_started/serialization/serialization.md @@ -53,6 +53,12 @@ Saving the topic modeling with `.safetensors` or `pytorch` has a number of advan * More easily used in **production** * **Share** models with the HuggingFace Hub +

+ +

+ +The above image, a model trained on 100,000 documents, demonstrates the differences in sizes comparing `safetensors`, `pytorch`, and `pickle`. The difference in sizes can mostly be explained due to the efficient saving procedure and that the clustering and dimensionality reductions are not saved in safetensors/pytorch since inference can be done based on the topic embeddings. + ## **HuggingFace Hub** diff --git a/docs/getting_started/serialization/serialization.png b/docs/getting_started/serialization/serialization.png new file mode 100644 index 00000000..38b40a64 Binary files /dev/null and b/docs/getting_started/serialization/serialization.png differ diff --git a/docs/getting_started/visualization/visualize_documents.md b/docs/getting_started/visualization/visualize_documents.md new file mode 100644 index 00000000..2e87aac6 --- /dev/null +++ b/docs/getting_started/visualization/visualize_documents.md @@ -0,0 +1,102 @@ +Using the `.visualize_topics`, we can visualize the topics and get insight into their relationships. However, +you might want a more fine-grained approach where we can visualize the documents inside the topics to see +if they were assigned correctly or whether they make sense. To do so, we can use the `topic_model.visualize_documents()` +function. This function recalculates the document embeddings and reduces them to 2-dimensional space for easier visualization +purposes. This process can be quite expensive, so it is advised to adhere to the following pipeline: + +```python +from sklearn.datasets import fetch_20newsgroups +from sentence_transformers import SentenceTransformer +from bertopic import BERTopic +from umap import UMAP + +# Prepare embeddings +docs = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))['data'] +sentence_model = SentenceTransformer("all-MiniLM-L6-v2") +embeddings = sentence_model.encode(docs, show_progress_bar=False) + +# Train BERTopic +topic_model = BERTopic().fit(docs, embeddings) + +# Run the visualization with the original embeddings +topic_model.visualize_documents(docs, embeddings=embeddings) + +# Reduce dimensionality of embeddings, this step is optional but much faster to perform iteratively: +reduced_embeddings = UMAP(n_neighbors=10, n_components=2, min_dist=0.0, metric='cosine').fit_transform(embeddings) +topic_model.visualize_documents(docs, reduced_embeddings=reduced_embeddings) +``` + + + + +!!! note + The visualization above was generated with the additional parameter `hide_document_hover=True` which disables the + option to hover over the individual points and see the content of the documents. This was done for demonstration purposes + as saving all those documents in the visualization can be quite expensive and result in large files. However, + it might be interesting to set `hide_document_hover=False` in order to hover over the points and see the content of the documents. + +### **Custom Hover** + +When you visualize the documents, you might not always want to see the complete document over hover. Many documents have shorter information that might be more interesting to visualize, such as its title. To create the hover based on a documents' title instead of its content, you can simply pass a variable (`titles`) containing the title for each document: + +```python +topic_model.visualize_documents(titles, reduced_embeddings=reduced_embeddings) +``` + +## **Visualize Probablities or Distribution** + +We can generate the topic-document probability matrix by simply setting `calculate_probabilities=True` if a HDBSCAN model is used: + +```python +from bertopic import BERTopic +topic_model = BERTopic(calculate_probabilities=True) +topics, probs = topic_model.fit_transform(docs) +``` + +The resulting `probs` variable contains the soft-clustering as done through HDBSCAN. + +If a non-HDBSCAN model is used, we can estimate the topic distributions after training our model: + +```python +from bertopic import BERTopic + +topic_model = BERTopic() +topics, _ = topic_model.fit_transform(docs) +topic_distr, _ = topic_model.approximate_distribution(docs, min_similarity=0) +``` + +Then, we either pass the `probs` or `topic_distr` variable to `.visualize_distribution` to visualize either the probability distributions or the topic distributions: + +```python +# To visualize the probabilities of topic assignment +topic_model.visualize_distribution(probs[0]) + +# To visualize the topic distributions in a document +topic_model.visualize_distribution(topic_distr[0]) +``` + + + +Although a topic distribution is nice, we may want to see how each token contributes to a specific topic. To do so, we need to first calculate topic distributions on a token level and then visualize the results: + +```python +# Calculate the topic distributions on a token-level +topic_distr, topic_token_distr = topic_model.approximate_distribution(docs, calculate_tokens=True) + +# Visualize the token-level distributions +df = topic_model.visualize_approximate_distribution(docs[1], topic_token_distr[1]) +df +``` + +

+ +

+ +!!! note + To get the stylized dataframe for `.visualize_approximate_distribution` you will need to have Jinja installed. If you do not have this installed, an unstylized dataframe will be returned instead. You can install Jinja via `pip install jinja2` + +!!! note + The distribution of the probabilities does not give an indication to + the distribution of the frequencies of topics across a document. It merely shows + how confident BERTopic is that certain topics can be found in a document. + diff --git a/docs/getting_started/visualization/visualize_hierarchy.md b/docs/getting_started/visualization/visualize_hierarchy.md new file mode 100644 index 00000000..ca3784b9 --- /dev/null +++ b/docs/getting_started/visualization/visualize_hierarchy.md @@ -0,0 +1,360 @@ +The topics that you create can be hierarchically reduced. In order to understand the potential hierarchical +structure of the topics, we can use `scipy.cluster.hierarchy` to create clusters and visualize how +they relate to one another. This might help to select an appropriate `nr_topics` when reducing the number +of topics that you have created. To visualize this hierarchy, run the following: + +```python +topic_model.visualize_hierarchy() +``` + + + +!!! note + Do note that this is not the actual procedure of `.reduce_topics()` when `nr_topics` is set to + auto since HDBSCAN is used to automatically extract topics. The visualization above closely resembles + the actual procedure of `.reduce_topics()` when any number of `nr_topics` is selected. + +### **Hierarchical labels** + +Although visualizing this hierarchy gives us information about the structure, it would be helpful to see what happens +to the topic representations when merging topics. To do so, we first need to calculate the representations of the +hierarchical topics: + + +First, we train a basic BERTopic model: + +```python +from bertopic import BERTopic +from sklearn.datasets import fetch_20newsgroups + +docs = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))["data"] +topic_model = BERTopic(verbose=True) +topics, probs = topic_model.fit_transform(docs) +hierarchical_topics = topic_model.hierarchical_topics(docs) +``` + +To visualize these results, we simply need to pass the resulting `hierarchical_topics` to our `.visualize_hierarchy` function: + +```python +topic_model.visualize_hierarchy(hierarchical_topics=hierarchical_topics) +``` + + + +If you **hover** over the black circles, you will see the topic representation at that level of the hierarchy. These representations +help you understand the effect of merging certain topics. Some might be logical to merge whilst others might not. Moreover, +we can now see which sub-topics can be found within certain larger themes. + +### **Text-based topic tree** + +Although this gives a nice overview of the potential hierarchy, hovering over all black circles can be tiresome. Instead, we can +use `topic_model.get_topic_tree` to create a text-based representation of this hierarchy. Although the general structure is more difficult +to view, we can see better which topics could be logically merged: + +```python +>>> tree = topic_model.get_topic_tree(hierarchical_topics) +>>> print(tree) +. +└─atheists_atheism_god_moral_atheist + β”œβ”€atheists_atheism_god_atheist_argument + β”‚ β”œβ”€β– β”€β”€atheists_atheism_god_atheist_argument ── Topic: 21 + β”‚ └─■──br_god_exist_genetic_existence ── Topic: 124 + └─■──moral_morality_objective_immoral_morals ── Topic: 29 +``` + +
+ Click here to view the full tree. + + ```bash + . + β”œβ”€people_armenian_said_god_armenians + β”‚ β”œβ”€god_jesus_jehovah_lord_christ + β”‚ β”‚ β”œβ”€god_jesus_jehovah_lord_christ + β”‚ β”‚ β”‚ β”œβ”€jehovah_lord_mormon_mcconkie_god + β”‚ β”‚ β”‚ β”‚ β”œβ”€β– β”€β”€ra_satan_thou_god_lucifer ── Topic: 94 + β”‚ β”‚ β”‚ β”‚ └─■──jehovah_lord_mormon_mcconkie_unto ── Topic: 78 + β”‚ β”‚ β”‚ └─jesus_mary_god_hell_sin + β”‚ β”‚ β”‚ β”œβ”€jesus_hell_god_eternal_heaven + β”‚ β”‚ β”‚ β”‚ β”œβ”€hell_jesus_eternal_god_heaven + β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€β– β”€β”€jesus_tomb_disciples_resurrection_john ── Topic: 69 + β”‚ β”‚ β”‚ β”‚ β”‚ └─■──hell_eternal_god_jesus_heaven ── Topic: 53 + β”‚ β”‚ β”‚ β”‚ └─■──aaron_baptism_sin_law_god ── Topic: 89 + β”‚ β”‚ β”‚ └─■──mary_sin_maria_priest_conception ── Topic: 56 + β”‚ β”‚ └─■──marriage_married_marry_ceremony_marriages ── Topic: 110 + β”‚ └─people_armenian_armenians_said_mr + β”‚ β”œβ”€people_armenian_armenians_said_israel + β”‚ β”‚ β”œβ”€god_homosexual_homosexuality_atheists_sex + β”‚ β”‚ β”‚ β”œβ”€homosexual_homosexuality_sex_gay_homosexuals + β”‚ β”‚ β”‚ β”‚ β”œβ”€β– β”€β”€kinsey_sex_gay_men_sexual ── Topic: 44 + β”‚ β”‚ β”‚ β”‚ └─homosexuality_homosexual_sin_homosexuals_gay + β”‚ β”‚ β”‚ β”‚ β”œβ”€β– β”€β”€gay_homosexual_homosexuals_sexual_cramer ── Topic: 50 + β”‚ β”‚ β”‚ β”‚ └─■──homosexuality_homosexual_sin_paul_sex ── Topic: 27 + β”‚ β”‚ β”‚ └─god_atheists_atheism_moral_atheist + β”‚ β”‚ β”‚ β”œβ”€islam_quran_judas_islamic_book + β”‚ β”‚ β”‚ β”‚ β”œβ”€β– β”€β”€jim_context_challenges_articles_quote ── Topic: 36 + β”‚ β”‚ β”‚ β”‚ └─islam_quran_judas_islamic_book + β”‚ β”‚ β”‚ β”‚ β”œβ”€β– β”€β”€islam_quran_islamic_rushdie_muslims ── Topic: 31 + β”‚ β”‚ β”‚ β”‚ └─■──judas_scripture_bible_books_greek ── Topic: 33 + β”‚ β”‚ β”‚ └─atheists_atheism_god_moral_atheist + β”‚ β”‚ β”‚ β”œβ”€atheists_atheism_god_atheist_argument + β”‚ β”‚ β”‚ β”‚ β”œβ”€β– β”€β”€atheists_atheism_god_atheist_argument ── Topic: 21 + β”‚ β”‚ β”‚ β”‚ └─■──br_god_exist_genetic_existence ── Topic: 124 + β”‚ β”‚ β”‚ └─■──moral_morality_objective_immoral_morals ── Topic: 29 + β”‚ β”‚ └─armenian_armenians_people_israel_said + β”‚ β”‚ β”œβ”€armenian_armenians_israel_people_jews + β”‚ β”‚ β”‚ β”œβ”€tax_rights_government_income_taxes + β”‚ β”‚ β”‚ β”‚ β”œβ”€β– β”€β”€rights_right_slavery_slaves_residence ── Topic: 106 + β”‚ β”‚ β”‚ β”‚ └─tax_government_taxes_income_libertarians + β”‚ β”‚ β”‚ β”‚ β”œβ”€β– β”€β”€government_libertarians_libertarian_regulation_party ── Topic: 58 + β”‚ β”‚ β”‚ β”‚ └─■──tax_taxes_income_billion_deficit ── Topic: 41 + β”‚ β”‚ β”‚ └─armenian_armenians_israel_people_jews + β”‚ β”‚ β”‚ β”œβ”€gun_guns_militia_firearms_amendment + β”‚ β”‚ β”‚ β”‚ β”œβ”€β– β”€β”€blacks_penalty_death_cruel_punishment ── Topic: 55 + β”‚ β”‚ β”‚ β”‚ └─■──gun_guns_militia_firearms_amendment ── Topic: 7 + β”‚ β”‚ β”‚ └─armenian_armenians_israel_jews_turkish + β”‚ β”‚ β”‚ β”œβ”€β– β”€β”€israel_israeli_jews_arab_jewish ── Topic: 4 + β”‚ β”‚ β”‚ └─■──armenian_armenians_turkish_armenia_azerbaijan ── Topic: 15 + β”‚ β”‚ └─stephanopoulos_president_mr_myers_ms + β”‚ β”‚ β”œβ”€β– β”€β”€serbs_muslims_stephanopoulos_mr_bosnia ── Topic: 35 + β”‚ β”‚ └─■──myers_stephanopoulos_president_ms_mr ── Topic: 87 + β”‚ └─batf_fbi_koresh_compound_gas + β”‚ β”œβ”€β– β”€β”€reno_workers_janet_clinton_waco ── Topic: 77 + β”‚ └─batf_fbi_koresh_gas_compound + β”‚ β”œβ”€batf_koresh_fbi_warrant_compound + β”‚ β”‚ β”œβ”€β– β”€β”€batf_warrant_raid_compound_fbi ── Topic: 42 + β”‚ β”‚ └─■──koresh_batf_fbi_children_compound ── Topic: 61 + β”‚ └─■──fbi_gas_tear_bds_building ── Topic: 23 + └─use_like_just_dont_new + β”œβ”€game_team_year_games_like + β”‚ β”œβ”€game_team_games_25_year + β”‚ β”‚ β”œβ”€game_team_games_25_season + β”‚ β”‚ β”‚ β”œβ”€window_printer_use_problem_mhz + β”‚ β”‚ β”‚ β”‚ β”œβ”€mhz_wire_simms_wiring_battery + β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€simms_mhz_battery_cpu_heat + β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€simms_pds_simm_vram_lc + β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€β– β”€β”€pds_nubus_lc_slot_card ── Topic: 119 + β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ └─■──simms_simm_vram_meg_dram ── Topic: 32 + β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ └─mhz_battery_cpu_heat_speed + β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€mhz_cpu_speed_heat_fan + β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€mhz_cpu_speed_heat_fan + β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€β– β”€β”€fan_cpu_heat_sink_fans ── Topic: 92 + β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ └─■──mhz_speed_cpu_fpu_clock ── Topic: 22 + β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ └─■──monitor_turn_power_computer_electricity ── Topic: 91 + β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ └─battery_batteries_concrete_duo_discharge + β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€β– β”€β”€duo_battery_apple_230_problem ── Topic: 121 + β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ └─■──battery_batteries_concrete_discharge_temperature ── Topic: 75 + β”‚ β”‚ β”‚ β”‚ β”‚ └─wire_wiring_ground_neutral_outlets + β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€wire_wiring_ground_neutral_outlets + β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€wire_wiring_ground_neutral_outlets + β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€β– β”€β”€leds_uv_blue_light_boards ── Topic: 66 + β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ └─■──wire_wiring_ground_neutral_outlets ── Topic: 120 + β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ └─scope_scopes_phone_dial_number + β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€β– β”€β”€dial_number_phone_line_output ── Topic: 93 + β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ └─■──scope_scopes_motorola_generator_oscilloscope ── Topic: 113 + β”‚ β”‚ β”‚ β”‚ β”‚ └─celp_dsp_sampling_antenna_digital + β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€β– β”€β”€antenna_antennas_receiver_cable_transmitter ── Topic: 70 + β”‚ β”‚ β”‚ β”‚ β”‚ └─■──celp_dsp_sampling_speech_voice ── Topic: 52 + β”‚ β”‚ β”‚ β”‚ └─window_printer_xv_mouse_windows + β”‚ β”‚ β”‚ β”‚ β”œβ”€window_xv_error_widget_problem + β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€error_symbol_undefined_xterm_rx + β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€β– β”€β”€symbol_error_undefined_doug_parse ── Topic: 63 + β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ └─■──rx_remote_server_xdm_xterm ── Topic: 45 + β”‚ β”‚ β”‚ β”‚ β”‚ └─window_xv_widget_application_expose + β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€window_widget_expose_application_event + β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€β– β”€β”€gc_mydisplay_draw_gxxor_drawing ── Topic: 103 + β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ └─■──window_widget_application_expose_event ── Topic: 25 + β”‚ β”‚ β”‚ β”‚ β”‚ └─xv_den_polygon_points_algorithm + β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€β– β”€β”€den_polygon_points_algorithm_polygons ── Topic: 28 + β”‚ β”‚ β”‚ β”‚ β”‚ └─■──xv_24bit_image_bit_images ── Topic: 57 + β”‚ β”‚ β”‚ β”‚ └─printer_fonts_print_mouse_postscript + β”‚ β”‚ β”‚ β”‚ β”œβ”€printer_fonts_print_font_deskjet + β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€β– β”€β”€scanner_logitech_grayscale_ocr_scanman ── Topic: 108 + β”‚ β”‚ β”‚ β”‚ β”‚ └─printer_fonts_print_font_deskjet + β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€β– β”€β”€printer_print_deskjet_hp_ink ── Topic: 18 + β”‚ β”‚ β”‚ β”‚ β”‚ └─■──fonts_font_truetype_tt_atm ── Topic: 49 + β”‚ β”‚ β”‚ β”‚ └─mouse_ghostscript_midi_driver_postscript + β”‚ β”‚ β”‚ β”‚ β”œβ”€ghostscript_midi_postscript_files_file + β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€β– β”€β”€ghostscript_postscript_pageview_ghostview_dsc ── Topic: 104 + β”‚ β”‚ β”‚ β”‚ β”‚ └─midi_sound_file_windows_driver + β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€β– β”€β”€location_mar_file_host_rwrr ── Topic: 83 + β”‚ β”‚ β”‚ β”‚ β”‚ └─■──midi_sound_driver_blaster_soundblaster ── Topic: 98 + β”‚ β”‚ β”‚ β”‚ └─■──mouse_driver_mice_ball_problem ── Topic: 68 + β”‚ β”‚ β”‚ └─game_team_games_25_season + β”‚ β”‚ β”‚ β”œβ”€1st_sale_condition_comics_hulk + β”‚ β”‚ β”‚ β”‚ β”œβ”€sale_condition_offer_asking_cd + β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€condition_stereo_amp_speakers_asking + β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€β– β”€β”€miles_car_amfm_toyota_cassette ── Topic: 62 + β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ └─■──amp_speakers_condition_stereo_audio ── Topic: 24 + β”‚ β”‚ β”‚ β”‚ β”‚ └─games_sale_pom_cds_shipping + β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€pom_cds_sale_shipping_cd + β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€β– β”€β”€size_shipping_sale_condition_mattress ── Topic: 100 + β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ └─■──pom_cds_cd_sale_picture ── Topic: 37 + β”‚ β”‚ β”‚ β”‚ β”‚ └─■──games_game_snes_sega_genesis ── Topic: 40 + β”‚ β”‚ β”‚ β”‚ └─1st_hulk_comics_art_appears + β”‚ β”‚ β”‚ β”‚ β”œβ”€1st_hulk_comics_art_appears + β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€lens_tape_camera_backup_lenses + β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€β– β”€β”€tape_backup_tapes_drive_4mm ── Topic: 107 + β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ └─■──lens_camera_lenses_zoom_pouch ── Topic: 114 + β”‚ β”‚ β”‚ β”‚ β”‚ └─1st_hulk_comics_art_appears + β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€β– β”€β”€1st_hulk_comics_art_appears ── Topic: 105 + β”‚ β”‚ β”‚ β”‚ β”‚ └─■──books_book_cover_trek_chemistry ── Topic: 125 + β”‚ β”‚ β”‚ β”‚ └─tickets_hotel_ticket_voucher_package + β”‚ β”‚ β”‚ β”‚ β”œβ”€β– β”€β”€hotel_voucher_package_vacation_room ── Topic: 74 + β”‚ β”‚ β”‚ β”‚ └─■──tickets_ticket_june_airlines_july ── Topic: 84 + β”‚ β”‚ β”‚ └─game_team_games_season_hockey + β”‚ β”‚ β”‚ β”œβ”€game_hockey_team_25_550 + β”‚ β”‚ β”‚ β”‚ β”œβ”€β– β”€β”€espn_pt_pts_game_la ── Topic: 17 + β”‚ β”‚ β”‚ β”‚ └─■──team_25_game_hockey_550 ── Topic: 2 + β”‚ β”‚ β”‚ └─■──year_game_hit_baseball_players ── Topic: 0 + β”‚ β”‚ └─bike_car_greek_insurance_msg + β”‚ β”‚ β”œβ”€car_bike_insurance_cars_engine + β”‚ β”‚ β”‚ β”œβ”€car_insurance_cars_radar_engine + β”‚ β”‚ β”‚ β”‚ β”œβ”€insurance_health_private_care_canada + β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€β– β”€β”€insurance_health_private_care_canada ── Topic: 99 + β”‚ β”‚ β”‚ β”‚ β”‚ └─■──insurance_car_accident_rates_sue ── Topic: 82 + β”‚ β”‚ β”‚ β”‚ └─car_cars_radar_engine_detector + β”‚ β”‚ β”‚ β”‚ β”œβ”€car_radar_cars_detector_engine + β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€β– β”€β”€radar_detector_detectors_ka_alarm ── Topic: 39 + β”‚ β”‚ β”‚ β”‚ β”‚ └─car_cars_mustang_ford_engine + β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€β– β”€β”€clutch_shift_shifting_transmission_gear ── Topic: 88 + β”‚ β”‚ β”‚ β”‚ β”‚ └─■──car_cars_mustang_ford_v8 ── Topic: 14 + β”‚ β”‚ β”‚ β”‚ └─oil_diesel_odometer_diesels_car + β”‚ β”‚ β”‚ β”‚ β”œβ”€odometer_oil_sensor_car_drain + β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€β– β”€β”€odometer_sensor_speedo_gauge_mileage ── Topic: 96 + β”‚ β”‚ β”‚ β”‚ β”‚ └─■──oil_drain_car_leaks_taillights ── Topic: 102 + β”‚ β”‚ β”‚ β”‚ └─■──diesel_diesels_emissions_fuel_oil ── Topic: 79 + β”‚ β”‚ β”‚ └─bike_riding_ride_bikes_motorcycle + β”‚ β”‚ β”‚ β”œβ”€bike_ride_riding_bikes_lane + β”‚ β”‚ β”‚ β”‚ β”œβ”€β– β”€β”€bike_ride_riding_lane_car ── Topic: 11 + β”‚ β”‚ β”‚ β”‚ └─■──bike_bikes_miles_honda_motorcycle ── Topic: 19 + β”‚ β”‚ β”‚ └─■──countersteering_bike_motorcycle_rear_shaft ── Topic: 46 + β”‚ β”‚ └─greek_msg_kuwait_greece_water + β”‚ β”‚ β”œβ”€greek_msg_kuwait_greece_water + β”‚ β”‚ β”‚ β”œβ”€greek_msg_kuwait_greece_dog + β”‚ β”‚ β”‚ β”‚ β”œβ”€greek_msg_kuwait_greece_dog + β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€greek_kuwait_greece_turkish_greeks + β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€β– β”€β”€greek_greece_turkish_greeks_cyprus ── Topic: 71 + β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ └─■──kuwait_iraq_iran_gulf_arabia ── Topic: 76 + β”‚ β”‚ β”‚ β”‚ β”‚ └─msg_dog_drugs_drug_food + β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€dog_dogs_cooper_trial_weaver + β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€β– β”€β”€clinton_bush_quayle_reagan_panicking ── Topic: 101 + β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ └─dog_dogs_cooper_trial_weaver + β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€β– β”€β”€cooper_trial_weaver_spence_witnesses ── Topic: 90 + β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ └─■──dog_dogs_bike_trained_springer ── Topic: 67 + β”‚ β”‚ β”‚ β”‚ β”‚ └─msg_drugs_drug_food_chinese + β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€β– β”€β”€msg_food_chinese_foods_taste ── Topic: 30 + β”‚ β”‚ β”‚ β”‚ β”‚ └─■──drugs_drug_marijuana_cocaine_alcohol ── Topic: 72 + β”‚ β”‚ β”‚ β”‚ └─water_theory_universe_science_larsons + β”‚ β”‚ β”‚ β”‚ β”œβ”€water_nuclear_cooling_steam_dept + β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€β– β”€β”€rocketry_rockets_engines_nuclear_plutonium ── Topic: 115 + β”‚ β”‚ β”‚ β”‚ β”‚ └─water_cooling_steam_dept_plants + β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€β– β”€β”€water_dept_phd_environmental_atmospheric ── Topic: 97 + β”‚ β”‚ β”‚ β”‚ β”‚ └─■──cooling_water_steam_towers_plants ── Topic: 109 + β”‚ β”‚ β”‚ β”‚ └─theory_universe_larsons_larson_science + β”‚ β”‚ β”‚ β”‚ β”œβ”€β– β”€β”€theory_universe_larsons_larson_science ── Topic: 54 + β”‚ β”‚ β”‚ β”‚ └─■──oort_cloud_grbs_gamma_burst ── Topic: 80 + β”‚ β”‚ β”‚ └─helmet_kirlian_photography_lock_wax + β”‚ β”‚ β”‚ β”œβ”€helmet_kirlian_photography_leaf_mask + β”‚ β”‚ β”‚ β”‚ β”œβ”€kirlian_photography_leaf_pictures_deleted + β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€deleted_joke_stuff_maddi_nickname + β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€β– β”€β”€joke_maddi_nickname_nicknames_frank ── Topic: 43 + β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ └─■──deleted_stuff_bookstore_joke_motto ── Topic: 81 + β”‚ β”‚ β”‚ β”‚ β”‚ └─■──kirlian_photography_leaf_pictures_aura ── Topic: 85 + β”‚ β”‚ β”‚ β”‚ └─helmet_mask_liner_foam_cb + β”‚ β”‚ β”‚ β”‚ β”œβ”€β– β”€β”€helmet_liner_foam_cb_helmets ── Topic: 112 + β”‚ β”‚ β”‚ β”‚ └─■──mask_goalies_77_santore_tl ── Topic: 123 + β”‚ β”‚ β”‚ └─lock_wax_paint_plastic_ear + β”‚ β”‚ β”‚ β”œβ”€β– β”€β”€lock_cable_locks_bike_600 ── Topic: 117 + β”‚ β”‚ β”‚ └─wax_paint_ear_plastic_skin + β”‚ β”‚ β”‚ β”œβ”€β– β”€β”€wax_paint_plastic_scratches_solvent ── Topic: 65 + β”‚ β”‚ β”‚ └─■──ear_wax_skin_greasy_acne ── Topic: 116 + β”‚ β”‚ └─m4_mp_14_mw_mo + β”‚ β”‚ β”œβ”€m4_mp_14_mw_mo + β”‚ β”‚ β”‚ β”œβ”€β– β”€β”€m4_mp_14_mw_mo ── Topic: 111 + β”‚ β”‚ β”‚ └─■──test_ensign_nameless_deane_deanebinahccbrandeisedu ── Topic: 118 + β”‚ β”‚ └─■──ites_cheek_hello_hi_ken ── Topic: 3 + β”‚ └─space_medical_health_disease_cancer + β”‚ β”œβ”€medical_health_disease_cancer_patients + β”‚ β”‚ β”œβ”€β– β”€β”€cancer_centers_center_medical_research ── Topic: 122 + β”‚ β”‚ └─health_medical_disease_patients_hiv + β”‚ β”‚ β”œβ”€patients_medical_disease_candida_health + β”‚ β”‚ β”‚ β”œβ”€β– β”€β”€candida_yeast_infection_gonorrhea_infections ── Topic: 48 + β”‚ β”‚ β”‚ └─patients_disease_cancer_medical_doctor + β”‚ β”‚ β”‚ β”œβ”€β– β”€β”€hiv_medical_cancer_patients_doctor ── Topic: 34 + β”‚ β”‚ β”‚ └─■──pain_drug_patients_disease_diet ── Topic: 26 + β”‚ β”‚ └─■──health_newsgroup_tobacco_vote_votes ── Topic: 9 + β”‚ └─space_launch_nasa_shuttle_orbit + β”‚ β”œβ”€space_moon_station_nasa_launch + β”‚ β”‚ β”œβ”€β– β”€β”€sky_advertising_billboard_billboards_space ── Topic: 59 + β”‚ β”‚ └─■──space_station_moon_redesign_nasa ── Topic: 16 + β”‚ └─space_mission_hst_launch_orbit + β”‚ β”œβ”€space_launch_nasa_orbit_propulsion + β”‚ β”‚ β”œβ”€β– β”€β”€space_launch_nasa_propulsion_astronaut ── Topic: 47 + β”‚ β”‚ └─■──orbit_km_jupiter_probe_earth ── Topic: 86 + β”‚ └─■──hst_mission_shuttle_orbit_arrays ── Topic: 60 + └─drive_file_key_windows_use + β”œβ”€key_file_jpeg_encryption_image + β”‚ β”œβ”€key_encryption_clipper_chip_keys + β”‚ β”‚ β”œβ”€β– β”€β”€key_clipper_encryption_chip_keys ── Topic: 1 + β”‚ β”‚ └─■──entry_file_ripem_entries_key ── Topic: 73 + β”‚ └─jpeg_image_file_gif_images + β”‚ β”œβ”€motif_graphics_ftp_available_3d + β”‚ β”‚ β”œβ”€motif_graphics_openwindows_ftp_available + β”‚ β”‚ β”‚ β”œβ”€β– β”€β”€openwindows_motif_xview_windows_mouse ── Topic: 20 + β”‚ β”‚ β”‚ └─■──graphics_widget_ray_3d_available ── Topic: 95 + β”‚ β”‚ └─■──3d_machines_version_comments_contact ── Topic: 38 + β”‚ └─jpeg_image_gif_images_format + β”‚ β”œβ”€β– β”€β”€gopher_ftp_files_stuffit_images ── Topic: 51 + β”‚ └─■──jpeg_image_gif_format_images ── Topic: 13 + └─drive_db_card_scsi_windows + β”œβ”€db_windows_dos_mov_os2 + β”‚ β”œβ”€β– β”€β”€copy_protection_program_software_disk ── Topic: 64 + β”‚ └─■──db_windows_dos_mov_os2 ── Topic: 8 + └─drive_card_scsi_drives_ide + β”œβ”€drive_scsi_drives_ide_disk + β”‚ β”œβ”€β– β”€β”€drive_scsi_drives_ide_disk ── Topic: 6 + β”‚ └─■──meg_sale_ram_drive_shipping ── Topic: 12 + └─card_modem_monitor_video_drivers + β”œβ”€β– β”€β”€card_monitor_video_drivers_vga ── Topic: 5 + └─■──modem_port_serial_irq_com ── Topic: 10 + ``` +
+ +## **Visualize Hierarchical Documents** +We can extend the previous method by calculating the topic representation at different levels of the hierarchy and +plotting them on a 2D plane. To do so, we first need to calculate the hierarchical topics: + +```python +from sklearn.datasets import fetch_20newsgroups +from sentence_transformers import SentenceTransformer +from bertopic import BERTopic +from umap import UMAP + +# Prepare embeddings +docs = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))['data'] +sentence_model = SentenceTransformer("all-MiniLM-L6-v2") +embeddings = sentence_model.encode(docs, show_progress_bar=False) + +# Train BERTopic and extract hierarchical topics +topic_model = BERTopic().fit(docs, embeddings) +hierarchical_topics = topic_model.hierarchical_topics(docs) +``` +Then, we can visualize the hierarchical documents by either supplying it with our embeddings or by +reducing their dimensionality ourselves: + +```python +# Run the visualization with the original embeddings +topic_model.visualize_hierarchical_documents(docs, hierarchical_topics, embeddings=embeddings) + +# Reduce dimensionality of embeddings, this step is optional but much faster to perform iteratively: +reduced_embeddings = UMAP(n_neighbors=10, n_components=2, min_dist=0.0, metric='cosine').fit_transform(embeddings) +topic_model.visualize_hierarchical_documents(docs, hierarchical_topics, reduced_embeddings=reduced_embeddings) +``` + + + +!!! note + The visualization above was generated with the additional parameter `hide_document_hover=True` which disables the + option to hover over the individual points and see the content of the documents. This makes the resulting visualization + smaller and fit into your RAM. However, it might be interesting to set `hide_document_hover=False` to hover + over the points and see the content of the documents. diff --git a/docs/getting_started/visualization/visualize_terms.md b/docs/getting_started/visualization/visualize_terms.md new file mode 100644 index 00000000..f7d40e2d --- /dev/null +++ b/docs/getting_started/visualization/visualize_terms.md @@ -0,0 +1,46 @@ +We can visualize the selected terms for a few topics by creating bar charts out of the c-TF-IDF scores +for each topic representation. Insights can be gained from the relative c-TF-IDF scores between and within +topics. Moreover, you can easily compare topic representations to each other. +To visualize this hierarchy, run the following: + +```python +topic_model.visualize_barchart() +``` + + + + +## **Visualize Term Score Decline** +Topics are represented by a number of words starting with the best representative word. +Each word is represented by a c-TF-IDF score. The higher the score, the more representative a word +to the topic is. Since the topic words are sorted by their c-TF-IDF score, the scores slowly decline +with each word that is added. At some point adding words to the topic representation only marginally +increases the total c-TF-IDF score and would not be beneficial for its representation. + +To visualize this effect, we can plot the c-TF-IDF scores for each topic by the term rank of each word. +In other words, the position of the words (term rank), where the words with +the highest c-TF-IDF score will have a rank of 1, will be put on the x-axis. Whereas the y-axis +will be populated by the c-TF-IDF scores. The result is a visualization that shows you the decline +of c-TF-IDF score when adding words to the topic representation. It allows you, using the elbow method, +the select the best number of words in a topic. + +To visualize the c-TF-IDF score decline, run the following: + +```python +topic_model.visualize_term_rank() +``` + + + +To enable the log scale on the y-axis for a better view of individual topics, run the following: + +```python +topic_model.visualize_term_rank(log_scale=True) +``` + + + +This visualization was heavily inspired by the "Term Probability Decline" visualization found in an +analysis by the amazing [tmtoolkit](https://tmtoolkit.readthedocs.io/). +Reference to that specific analysis can be found +[here](https://wzbsocialsciencecenter.github.io/tm_corona/tm_analysis.html). diff --git a/docs/getting_started/visualization/visualize_topics.md b/docs/getting_started/visualization/visualize_topics.md new file mode 100644 index 00000000..69500098 --- /dev/null +++ b/docs/getting_started/visualization/visualize_topics.md @@ -0,0 +1,119 @@ +Visualizing BERTopic and its derivatives is important in understanding the model, how it works, and more importantly, where it works. +Since topic modeling can be quite a subjective field it is difficult for users to validate their models. Looking at the topics and seeing +if they make sense is an important factor in alleviating this issue. + +## **Visualize Topics** +After having trained our `BERTopic` model, we can iteratively go through hundreds of topics to get a good +understanding of the topics that were extracted. However, that takes quite some time and lacks a global representation. +Instead, we can visualize the topics that were generated in a way very similar to +[LDAvis](https://github.com/cpsievert/LDAvis). + +We embed our c-TF-IDF representation of the topics in 2D using Umap and then visualize the two dimensions using +plotly such that we can create an interactive view. + +First, we need to train our model: + +```python +from bertopic import BERTopic +from sklearn.datasets import fetch_20newsgroups + +docs = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))['data'] +topic_model = BERTopic() +topics, probs = topic_model.fit_transform(docs) +``` + +Then, we can call `.visualize_topics` to create a 2D representation of your topics. The resulting graph is a +plotly interactive graph which can be converted to HTML: + +```python +topic_model.visualize_topics() +``` + + + +You can use the slider to select the topic which then lights up red. If you hover over a topic, then general +information is given about the topic, including the size of the topic and its corresponding words. + + +## **Visualize Topic Similarity** +Having generated topic embeddings, through both c-TF-IDF and embeddings, we can create a similarity +matrix by simply applying cosine similarities through those topic embeddings. The result will be a +matrix indicating how similar certain topics are to each other. +To visualize the heatmap, run the following: + +```python +topic_model.visualize_heatmap() +``` + + + + +!!! note + You can set `n_clusters` in `visualize_heatmap` to order the topics by their similarity. + This will result in blocks being formed in the heatmap indicating which clusters of topics are + similar to each other. This step is very much recommended as it will make reading the heatmap easier. + +## **Visualize Topics over Time** +After creating topics over time with Dynamic Topic Modeling, we can visualize these topics by +leveraging the interactive abilities of Plotly. Plotly allows us to show the frequency +of topics over time whilst giving the option of hovering over the points to show the time-specific topic representations. +Simply call `.visualize_topics_over_time` with the newly created topics over time: + + +```python +import re +import pandas as pd +from bertopic import BERTopic + +# Prepare data +trump = pd.read_csv('https://drive.google.com/uc?export=download&id=1xRKHaP-QwACMydlDnyFPEaFdtskJuBa6') +trump.text = trump.apply(lambda row: re.sub(r"http\S+", "", row.text).lower(), 1) +trump.text = trump.apply(lambda row: " ".join(filter(lambda x:x[0]!="@", row.text.split())), 1) +trump.text = trump.apply(lambda row: " ".join(re.sub("[^a-zA-Z]+", " ", row.text).split()), 1) +trump = trump.loc[(trump.isRetweet == "f") & (trump.text != ""), :] +timestamps = trump.date.to_list() +tweets = trump.text.to_list() + +# Create topics over time +model = BERTopic(verbose=True) +topics, probs = model.fit_transform(tweets) +topics_over_time = model.topics_over_time(tweets, timestamps) +``` + +Then, we visualize some interesting topics: + +```python +model.visualize_topics_over_time(topics_over_time, topics=[9, 10, 72, 83, 87, 91]) +``` + + +## **Visualize Topics per Class** +You might want to extract and visualize the topic representation per class. For example, if you have +specific groups of users that might approach topics differently, then extracting them would help understanding +how these users talk about certain topics. In other words, this is simply creating a topic representation for +certain classes that you might have in your data. + +First, we need to train our model: + +```python +from bertopic import BERTopic +from sklearn.datasets import fetch_20newsgroups + +# Prepare data and classes +data = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes')) +docs = data["data"] +classes = [data["target_names"][i] for i in data["target"]] + +# Create topic model and calculate topics per class +topic_model = BERTopic() +topics, probs = topic_model.fit_transform(docs) +topics_per_class = topic_model.topics_per_class(docs, classes=classes) +``` + +Then, we visualize the topic representation of major topics per class: + +```python +topic_model.visualize_topics_per_class(topics_per_class) +``` + + diff --git a/docs/index.md b/docs/index.md index b381a299..f33f7737 100644 --- a/docs/index.md +++ b/docs/index.md @@ -10,19 +10,29 @@ hide: BERTopic is a topic modeling technique that leverages πŸ€— transformers and c-TF-IDF to create dense clusters allowing for easily interpretable topics whilst keeping important words in the topic descriptions. -BERTopic supports [**guided**](https://maartengr.github.io/BERTopic/getting_started/guided/guided.html), -[**supervised**](https://maartengr.github.io/BERTopic/getting_started/supervised/supervised.html), -[**semi-supervised**](https://maartengr.github.io/BERTopic/getting_started/semisupervised/semisupervised.html), -[**manual**](https://maartengr.github.io/BERTopic/getting_started/manual/manual.html), -[**long-document**](https://maartengr.github.io/BERTopic/getting_started/distribution/distribution.html), -[**hierarchical**](https://maartengr.github.io/BERTopic/getting_started/hierarchicaltopics/hierarchicaltopics.html), -[**class-based**](https://maartengr.github.io/BERTopic/getting_started/topicsperclass/topicsperclass.html), -[**dynamic**](https://maartengr.github.io/BERTopic/getting_started/topicsovertime/topicsovertime.html), -[**online**](https://maartengr.github.io/BERTopic/getting_started/online/online.html), -[**multimodal**](https://maartengr.github.io/BERTopic/getting_started/multimodal/multimodal.html), and -[**multi-aspect**](https://maartengr.github.io/BERTopic/getting_started/multiaspect/multiaspect.html) topic modeling. - -It even supports visualizations similar to LDAvis! +BERTopic supports all kinds of topic modeling techniques: + + + + + + + + + + + + + + + + + + + + + +
GuidedSupervisedSemi-supervised
ManualMulti-topic distributionsHierarchical
Class-basedDynamicOnline/Incremental
MultimodalMulti-aspectText Generation
Corresponding medium posts can be found [here](https://towardsdatascience.com/topic-modeling-with-bert-779f7db187e6?source=friends_link&sk=0b5a470c006d1842ad4c8a3057063a99), [here](https://towardsdatascience.com/interactive-topic-modeling-with-bertopic-1ea55e7d73d8?sk=03c2168e9e74b6bda2a1f3ed953427e4) and [here](https://towardsdatascience.com/using-whisper-and-bertopic-to-model-kurzgesagts-videos-7d8a63139bdf?sk=b1e0fd46f70cb15e8422b4794a81161d). For a more detailed overview, you can read the [paper](https://arxiv.org/abs/2203.05794) or see a [brief overview](https://maartengr.github.io/BERTopic/algorithm/algorithm.html). @@ -38,13 +48,10 @@ You may want to install more depending on the transformers and language backends The possible installations are: ```bash -# Embedding models -pip install bertopic[flair] -pip install bertopic[gensim] -pip install bertopic[spacy] -pip install bertopic[use] +# Choose an embedding backend +pip install bertopic[flair, gensim, spacy, use] -# Vision topic modeling +# Topic modeling with images pip install bertopic[vision] ``` @@ -122,6 +129,18 @@ representation_model = KeyBERTInspired() topic_model = BERTopic(representation_model=representation_model) ``` +However, you might want to use something more powerful to describe your clusters. You can even use ChatGPT or other models from OpenAI to generate labels, summaries, phrases, keywords, and more: + +```python +import openai +from bertopic.representation import OpenAI + +# Fine-tune topic representations with GPT +openai.api_key = "sk-..." +representation_model = OpenAI(model="gpt-3.5-turbo", chat=True) +topic_model = BERTopic(representation_model=representation_model) +``` + !!! tip "Multi-aspect Topic Modeling" Instead of iterating over all of these different topic representations, you can model them simultaneously with [multi-aspect topic representations](https://maartengr.github.io/BERTopic/getting_started/multiaspect/multiaspect.html) in BERTopic. @@ -204,11 +223,13 @@ There are many different use cases in which topic modeling can be used. As such, | [Semi-supervised Topic Modeling](https://maartengr.github.io/BERTopic/getting_started/semisupervised/semisupervised.html) | `.fit(docs, y=y)` | | [Supervised Topic Modeling](https://maartengr.github.io/BERTopic/getting_started/supervised/supervised.html) | `.fit(docs, y=y)` | | [Manual Topic Modeling](https://maartengr.github.io/BERTopic/getting_started/manual/manual.html) | `.fit(docs, y=y)` | +| [Multimodal Topic Modeling](https://maartengr.github.io/BERTopic/getting_started/multimodal/multimodal.html) | ``.fit(docs, images=images)`` | | [Topic Modeling per Class](https://maartengr.github.io/BERTopic/getting_started/topicsperclass/topicsperclass.html) | `.topics_per_class(docs, classes)` | | [Dynamic Topic Modeling](https://maartengr.github.io/BERTopic/getting_started/topicsovertime/topicsovertime.html) | `.topics_over_time(docs, timestamps)` | | [Hierarchical Topic Modeling](https://maartengr.github.io/BERTopic/getting_started/hierarchicaltopics/hierarchicaltopics.html) | `.hierarchical_topics(docs)` | | [Guided Topic Modeling](https://maartengr.github.io/BERTopic/getting_started/guided/guided.html) | `BERTopic(seed_topic_list=seed_topic_list)` | + ### Visualizations Evaluating topic models can be rather difficult due to the somewhat subjective nature of evaluation. Visualizing different aspects of the topic model helps in understanding the model and makes it easier diff --git a/mkdocs.yml b/mkdocs.yml index 2566e904..1df1a9c3 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -11,32 +11,41 @@ nav: - Home: index.md - The Algorithm: algorithm/algorithm.md - Getting Started: - - getting_started/quickstart/quickstart.md - - Topic Visualization: getting_started/visualization/visualization.md - - Topic Reduction: getting_started/topicreduction/topicreduction.md - - Topic Representation: getting_started/topicrepresentation/topicrepresentation.md - - Search Topics: getting_started/search/search.md - - Parameter tuning: getting_started/parameter tuning/parametertuning.md - - Outlier reduction: getting_started/outlier_reduction/outlier_reduction.md + - Quick Start: getting_started/quickstart/quickstart.md - Serialization: getting_started/serialization/serialization.md - - Tips & Tricks: getting_started/tips_and_tricks/tips_and_tricks.md + - Search Topics: getting_started/search/search.md + - In-depth: + - Visualizations: + - Topics: getting_started/visualization/visualize_topics.md + - Documents: getting_started/visualization/visualize_documents.md + - Terms: getting_started/visualization/visualize_terms.md + - Hierarchy: getting_started/visualization/visualize_hierarchy.md + - Update Topics: + - Topic Reduction: getting_started/topicreduction/topicreduction.md + - Update Topic Representations: getting_started/topicrepresentation/topicrepresentation.md + - Outlier reduction: getting_started/outlier_reduction/outlier_reduction.md + - Parameter tuning: getting_started/parameter tuning/parametertuning.md + - Tips & Tricks: getting_started/tips_and_tricks/tips_and_tricks.md - Sub-models: - - Embeddings: getting_started/embeddings/embeddings.md - - Dimensionality Reduction: getting_started/dim_reduction/dim_reduction.md - - Clustering: getting_started/clustering/clustering.md - - Vectorizers: getting_started/vectorizers/vectorizers.md - - c-TF-IDF: getting_started/ctfidf/ctfidf.md - - (Optional) Representation: getting_started/representation/representation.md - - (Optional) Multi-Aspect: getting_started/multiaspect/multiaspect.md + - 1. Embeddings: getting_started/embeddings/embeddings.md + - 2. Dimensionality Reduction: getting_started/dim_reduction/dim_reduction.md + - 3. Clustering: getting_started/clustering/clustering.md + - 4. Vectorizers: getting_started/vectorizers/vectorizers.md + - 5. c-TF-IDF: getting_started/ctfidf/ctfidf.md + - 6. Fine-tune Topics: + - 6A. Representation Models: getting_started/representation/representation.md + - 6B. LLM & Generative AI: getting_started/representation/llm.md + - 6C. Multiple Representations: getting_started/multiaspect/multiaspect.md - Variations: - Dynamic Topic Modeling: getting_started/topicsovertime/topicsovertime.md - - Guided Topic Modeling: getting_started/guided/guided.md - Hierarchical Topic Modeling: getting_started/hierarchicaltopics/hierarchicaltopics.md - - Manual Topic Modeling: getting_started/manual/manual.md - Multimodal Topic Modeling: getting_started/multimodal/multimodal.md - Online Topic Modeling: getting_started/online/online.md - - Semi-supervised Topic Modeling: getting_started/semisupervised/semisupervised.md - - Supervised Topic Modeling: getting_started/supervised/supervised.md + - (semi)-supervised: + - Semi-supervised Topic Modeling: getting_started/semisupervised/semisupervised.md + - Supervised Topic Modeling: getting_started/supervised/supervised.md + - Manual Topic Modeling: getting_started/manual/manual.md + - Guided Topic Modeling: getting_started/guided/guided.md - Topic Distributions: getting_started/distribution/distribution.md - Topics per Class: getting_started/topicsperclass/topicsperclass.md - FAQ: faq.md diff --git a/setup.py b/setup.py index 64a84336..03fdbcb6 100644 --- a/setup.py +++ b/setup.py @@ -43,12 +43,13 @@ ] vision_packages = [ - "Pillow>=9.2.0" + "Pillow>=9.2.0", + "accelerate>=0.19.0" # To prevent "cannot import name 'PartialState' from 'accelerate'" ] extra_packages = flair_packages + spacy_packages + use_packages + gensim_packages -dev_packages = docs_packages + test_packages + extra_packages +dev_packages = docs_packages + test_packages with open("README.md", "r") as fh: long_description = fh.read() @@ -56,7 +57,7 @@ setup( name="bertopic", packages=find_packages(exclude=["notebooks", "docs"]), - version="0.14.1", + version="0.15.0", author="Maarten P. Grootendorst", author_email="maartengrootendorst@gmail.com", description="BERTopic performs topic Modeling with state-of-the-art transformer models.",