Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend OpenAI finish_reason handling #1985

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
25 changes: 18 additions & 7 deletions bertopic/representation/_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from scipy.sparse import csr_matrix
from typing import Mapping, List, Tuple, Any, Union, Callable
from bertopic.representation._base import BaseRepresentation
from bertopic.representation._utils import retry_with_exponential_backoff, truncate_document
from bertopic.representation._utils import retry_with_exponential_backoff, truncate_document, MyLogger

logger = MyLogger("WARNING")
soonernotfaster marked this conversation as resolved.
Show resolved Hide resolved

DEFAULT_PROMPT = """
This is a list of texts where each collection of texts describe a topic. After each collection of texts, the name of the topic they represent is mentioned as a short-highly-descriptive title
Expand Down Expand Up @@ -37,7 +38,7 @@
Topic name:"""

DEFAULT_CHAT_PROMPT = """
I have a topic that contains the following documents:
I have a topic that contains the following documents:
soonernotfaster marked this conversation as resolved.
Show resolved Hide resolved
[DOCUMENTS]
The topic is described by the following keywords: [KEYWORDS]

Expand Down Expand Up @@ -193,7 +194,7 @@ def extract_topics(self,
updated_topics: Updated topic representations
"""
# Extract the top n representative documents per topic
repr_docs_mappings, _, _, _ = topic_model._extract_representative_docs(c_tf_idf, documents, topics, 500, self.nr_docs, self.diversity)
repr_docs_mappings, _, _, repr_doc_ids = topic_model._extract_representative_docs(c_tf_idf, documents, topics, 500, self.nr_docs, self.diversity)

# Generate using OpenAI's Language Model
updated_topics = {}
Expand All @@ -217,11 +218,21 @@ def extract_topics(self,
else:
response = self.client.chat.completions.create(**kwargs)

# Check whether content was actually generated
# Adresses #1570 for potential issues with OpenAI's content filter
if hasattr(response.choices[0].message, "content"):
label = response.choices[0].message.content.strip().replace("topic: ", "")
choice = response.choices[0]

if choice.finish_reason == "stop":
label = choice.message.content.strip().replace("topic: ", "")
elif choice.finish_reason == "length":
logger.warn(f"Extracing Topics - Length limit reached for doc_ids ({repr_doc_ids})")
if hasattr(response.choices[0].message, "content"):
soonernotfaster marked this conversation as resolved.
Show resolved Hide resolved
label = choice.message.content.strip().replace("topic: ", "")
else:
label = "Incomple output due to token limit being reached"
soonernotfaster marked this conversation as resolved.
Show resolved Hide resolved
elif choice.finish_reason == "content_filter":
logger.warn(f"Extracing Topics - Content filtered for doc_ids ({repr_doc_ids})")
soonernotfaster marked this conversation as resolved.
Show resolved Hide resolved
label = "Output content filtered by OpenAI"
else:
logger.warn(f"Extracing Topics - No label due to finish_reason {choice.finish_reason} for doc_ids ({repr_doc_ids})")
soonernotfaster marked this conversation as resolved.
Show resolved Hide resolved
label = "No label returned"
else:
if self.exponential_backoff:
Expand Down