Skip to content

Commit

Permalink
Limit the maximum length of the chat response
Browse files Browse the repository at this point in the history
- Add language to prompt to encourage concise answers
- Limit responses to 1000 tokens by default, with `max_tokens` superuser override
- Add `{"end": {"reason": reason}}` message
- Update chat/response to use latest langchain and LCEL
- Turn streaming on or off independent of debug mode
- Split runtime and dev dependencies into separate `requirements.txt` files to keep the `sam sync` layer size below 10MB
- Update `chat/template.yml` and `Makefile` to support rapid iterations without separate dependency layer

Co-Authored-By: Brendan Quinn <[email protected]>
  • Loading branch information
mbklein and bmquinn committed Jun 26, 2024
1 parent 5ded219 commit 1907b58
Show file tree
Hide file tree
Showing 15 changed files with 182 additions and 140 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test-python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
with:
python-version: '3.9'
cache-dependency-path: chat/src/requirements.txt
- run: pip install -r requirements.txt
- run: pip install -r requirements.txt && pip install -r requirements-dev.txt
working-directory: ./chat/src
- name: Check code style
run: ruff check .
Expand Down
4 changes: 3 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ help:
echo "make cover-python | run python tests with coverage"
.aws-sam/build.toml: ./template.yaml node/package-lock.json node/src/package-lock.json chat/dependencies/requirements.txt chat/src/requirements.txt
sed -Ei.orig 's/^(\s+)#\*\s/\1/' template.yaml
sed -Ei.orig 's/^(\s+)#\*\s/\1/' chat/template.yaml
sam build --cached --parallel
mv template.yaml.orig template.yaml
mv chat/template.yaml.orig chat/template.yaml
deps-node:
cd node/src ;\
npm list >/dev/null 2>&1 ;\
Expand All @@ -48,7 +50,7 @@ style-node: deps-node
test-node: deps-node
cd node && npm run test
deps-python:
cd chat/src && pip install -r requirements.txt
cd chat/src && pip install -r requirements.txt && pip install -r requirements-dev.txt
cover-python: deps-python
cd chat && export SKIP_WEAVIATE_SETUP=True && coverage run --source=src -m unittest -v && coverage report --skip-empty
cover-html-python: deps-python
Expand Down
13 changes: 7 additions & 6 deletions chat/dependencies/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
boto3~=1.34.13
boto3~=1.34
honeybadger
langchain
langchain-community
openai~=0.27.8
langchain~=0.2
langchain-aws~=0.1
langchain-openai~=0.1
openai~=1.35
opensearch-py
pyjwt~=2.6.0
python-dotenv~=1.0.0
requests
requests-aws4auth
tiktoken~=0.4.0
wheel~=0.40.0
tiktoken~=0.7
wheel~=0.40
40 changes: 15 additions & 25 deletions chat/src/event_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import json

from dataclasses import dataclass, field
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
from langchain.prompts import PromptTemplate

from langchain_core.prompts import ChatPromptTemplate
from setup import (
opensearch_client,
opensearch_vector_store,
Expand All @@ -19,6 +19,7 @@
DOCUMENT_VARIABLE_NAME = "context"
K_VALUE = 5
MAX_K = 100
MAX_TOKENS = 1000
TEMPERATURE = 0.2
TEXT_KEY = "id"
VERSION = "2024-02-01"
Expand All @@ -42,19 +43,21 @@ class EventConfig:
azure_resource_name: str = field(init=False)
debug_mode: bool = field(init=False)
deployment_name: str = field(init=False)
document_prompt: PromptTemplate = field(init=False)
document_prompt: ChatPromptTemplate = field(init=False)
event: dict = field(default_factory=dict)
is_logged_in: bool = field(init=False)
k: int = field(init=False)
max_tokens: int = field(init=False)
openai_api_version: str = field(init=False)
payload: dict = field(default_factory=dict)
prompt_text: str = field(init=False)
prompt: PromptTemplate = field(init=False)
prompt: ChatPromptTemplate = field(init=False)
question: str = field(init=False)
ref: str = field(init=False)
request_context: dict = field(init=False)
temperature: float = field(init=False)
socket: Websocket = field(init=False, default=None)
stream_response: bool = field(init=False)
text_key: str = field(init=False)

def __post_init__(self):
Expand All @@ -67,17 +70,17 @@ def __post_init__(self):
self.deployment_name = self._get_deployment_name()
self.is_logged_in = self.api_token.is_logged_in()
self.k = self._get_k()
self.max_tokens = min(self.payload.get("max_tokens", MAX_TOKENS), MAX_TOKENS)
self.openai_api_version = self._get_openai_api_version()
self.prompt_text = self._get_prompt_text()
self.request_context = self.event.get("requestContext", {})
self.question = self.payload.get("question")
self.ref = self.payload.get("ref")
self.stream_response = self.payload.get("stream_response", not self.debug_mode)
self.temperature = self._get_temperature()
self.text_key = self._get_text_key()
self.document_prompt = self._get_document_prompt()
self.prompt = PromptTemplate(
template=self.prompt_text, input_variables=["question", "context"]
)
self.prompt = ChatPromptTemplate.from_template(self.prompt_text)

def _get_payload_value_with_superuser_check(self, key, default):
if self.api_token.is_superuser():
Expand Down Expand Up @@ -134,10 +137,7 @@ def _get_text_key(self):
return self._get_payload_value_with_superuser_check("text_key", TEXT_KEY)

def _get_document_prompt(self):
return PromptTemplate(
template=document_template(self.attributes),
input_variables=["title", "id"] + self.attributes,
)
return ChatPromptTemplate.from_template(document_template(self.attributes))

def debug_message(self):
return {
Expand Down Expand Up @@ -170,28 +170,18 @@ def setup_websocket(self, socket=None):
def setup_llm_request(self):
self._setup_vector_store()
self._setup_chat_client()
self._setup_chain()

def _setup_vector_store(self):
self.opensearch = opensearch_vector_store()

def _setup_chat_client(self):
self.client = openai_chat_client(
deployment_name=self.deployment_name,
openai_api_base=self.azure_endpoint,
azure_deployment=self.deployment_name,
azure_endpoint=self.azure_endpoint,
openai_api_version=self.openai_api_version,
callbacks=[StreamingSocketCallbackHandler(self.socket, self.debug_mode)],
callbacks=[StreamingSocketCallbackHandler(self.socket, stream=self.stream_response)],
streaming=True,
)

def _setup_chain(self):
self.chain = load_qa_with_sources_chain(
self.client,
chain_type=CHAIN_TYPE,
prompt=self.prompt,
document_prompt=self.document_prompt,
document_variable_name=DOCUMENT_VARIABLE_NAME,
verbose=self._to_bool(os.getenv("VERBOSE")),
max_tokens=self.max_tokens
)

def _is_debug_mode_enabled(self):
Expand Down
5 changes: 3 additions & 2 deletions chat/src/handlers/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
from datetime import datetime
from event_config import EventConfig
from helpers.response import prepare_response
from helpers.response import Response
from honeybadger import honeybadger

honeybadger.configure()
Expand Down Expand Up @@ -35,7 +35,8 @@ def handler(event, context):

if not os.getenv("SKIP_WEAVIATE_SETUP"):
config.setup_llm_request()
final_response = prepare_response(config)
response = Response(config)
final_response = response.prepare_response()
config.socket.send(reshape_response(final_response, 'debug' if config.debug_mode else 'base'))

log_group = os.getenv('METRICS_LOG_GROUP')
Expand Down
17 changes: 14 additions & 3 deletions chat/src/handlers/streaming_socket_callback_handler.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
from langchain.callbacks.base import BaseCallbackHandler
from websocket import Websocket
from typing import Any
from langchain_core.outputs.llm_result import LLMResult

class StreamingSocketCallbackHandler(BaseCallbackHandler):
def __init__(self, socket: Websocket, debug_mode: bool):
def __init__(self, socket: Websocket, stream: bool = True):
self.socket = socket
self.debug_mode = debug_mode
self.stream = stream

def on_llm_new_token(self, token: str, **kwargs):
if self.socket and not self.debug_mode:
if len(token) > 0 and self.socket and self.stream:
return self.socket.send({"token": token})

def on_llm_end(self, response: LLMResult, **kwargs: Any):
try:
finish_reason = response.generations[0][0].generation_info["finish_reason"]
if self.socket:
return self.socket.send({"end": {"reason": finish_reason}})
except Exception as err:
finish_reason = f'Unknown ({str(err)})'
print(f"Stream ended: {finish_reason}")
2 changes: 1 addition & 1 deletion chat/src/helpers/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
def token_usage(config, response, original_question):
data = {
"question": count_tokens(config.question),
"answer": count_tokens(response["output_text"]),
"answer": count_tokens(response),
"prompt": count_tokens(config.prompt_text),
"source_documents": count_tokens(original_question["source_documents"]),
}
Expand Down
15 changes: 7 additions & 8 deletions chat/src/helpers/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,15 @@


def prompt_template() -> str:
return """Please provide an answer to the question based on the documents provided. Include specific details from the documents that support your answer. Each document is identified by a 'title' and a unique 'source' UUID:
return """Please provide a brief answer to the question based on the documents provided. Include specific details from the documents that support your answer. Keep your answer concise. Each document is identified by a 'title' and a unique 'source' UUID:
Documents:
{context}
Answer in raw markdown. When referencing a document by title, link to it using its UUID like this: [title](https://dc.library.northwestern.edu/items/UUID). For example: [Judy Collins, Jackson Hole Folk Festival](https://dc.library.northwestern.edu/items/f1ca513b-7d13-4af6-ad7b-8c7ffd1d3a37). Suggest keyword searches using this format: [keyword](https://dc.library.northwestern.edu/search?q=keyword). Offer a variety of search terms that cover different aspects of the topic. Include as many direct links to Digital Collections searches as necessary for a thorough study. The `collection` field contains information about the collection the document belongs to. In the summary, mention the top 1 or 2 collections, explain why they are relevant and link to them using the collection title and id: [collection['title']](https://dc.library.northwestern.edu/collections/collection['id']), for example [World War II Poster Collection](https://dc.library.northwestern.edu/collections/faf4f60e-78e0-4fbf-96ce-4ca8b4df597a):
Question:
{question}
"""
Documents:
{context}
Answer in raw markdown. When referencing a document by title, link to it using its UUID like this: [title](https://dc.library.northwestern.edu/items/UUID). For example: [Judy Collins, Jackson Hole Folk Festival](https://dc.library.northwestern.edu/items/f1ca513b-7d13-4af6-ad7b-8c7ffd1d3a37). Suggest keyword searches using this format: [keyword](https://dc.library.northwestern.edu/search?q=keyword). Offer a variety of search terms that cover different aspects of the topic. Include as many direct links to Digital Collections searches as necessary for a thorough study. The `collection` field contains information about the collection the document belongs to. In the summary, mention the top 1 or 2 collections, explain why they are relevant and link to them using the collection title and id: [collection['title']](https://dc.library.northwestern.edu/collections/collection['id']), for example [World War II Poster Collection](https://dc.library.northwestern.edu/collections/faf4f60e-78e0-4fbf-96ce-4ca8b4df597a):
Question:
{question}
"""

def document_template(attributes: Optional[List[str]] = None) -> str:
if attributes is None:
Expand Down
130 changes: 73 additions & 57 deletions chat/src/helpers/response.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,6 @@
from helpers.metrics import token_usage
from openai.error import InvalidRequestError

def debug_response(config, response, original_question):
return {
"answer": response["output_text"],
"attributes": config.attributes,
"azure_endpoint": config.azure_endpoint,
"deployment_name": config.deployment_name,
"is_superuser": config.api_token.is_superuser(),
"k": config.k,
"openai_api_version": config.openai_api_version,
"prompt": config.prompt_text,
"question": config.question,
"ref": config.ref,
"temperature": config.temperature,
"text_key": config.text_key,
"token_counts": token_usage(config, response, original_question),
}

def get_and_send_original_question(config, docs):
doc_response = []
for doc in docs:
doc_dict = doc.__dict__
metadata = doc_dict.get('metadata', {})
new_doc = {key: extract_prompt_value(metadata.get(key)) for key in config.attributes if key in metadata}
doc_response.append(new_doc)

original_question = {
"question": config.question,
"source_documents": doc_response,
}
config.socket.send(original_question)
return original_question
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableLambda, RunnablePassthrough

def extract_prompt_value(v):
if isinstance(v, list):
Expand All @@ -40,29 +9,76 @@ def extract_prompt_value(v):
return [v.get('label')]
else:
return v

def prepare_response(config):
try:
subquery = {
"match": {
"all_titles": {
"query": config.question,
"operator": "AND",
"analyzer": "english"
}

class Response:
def __init__(self, config):
self.config = config
self.store = {}

def debug_response_passthrough(self):
def debug_response(config, response, original_question):
return {
"answer": response,
"attributes": config.attributes,
"azure_endpoint": config.azure_endpoint,
"deployment_name": config.deployment_name,
"is_superuser": config.api_token.is_superuser(),
"k": config.k,
"openai_api_version": config.openai_api_version,
"prompt": config.prompt_text,
"question": config.question,
"ref": config.ref,
"temperature": config.temperature,
"text_key": config.text_key,
"token_counts": token_usage(config, response, original_question),
}
}
docs = config.opensearch.similarity_search(
query=config.question, k=config.k, subquery=subquery, _source={"excludes": ["embedding"]}
)
original_question = get_and_send_original_question(config, docs)
response = config.chain({"question": config.question, "input_documents": docs})

prepared_response = debug_response(config, response, original_question)
except InvalidRequestError as err:
prepared_response = {
"question": config.question,
"error": str(err),
"source_documents": [],
}
return prepared_response
return RunnableLambda(lambda x: debug_response(self.config, x, self.original_question))

def original_question_passthrough(self):
def get_and_send_original_question(docs):
source_documents = []
for doc in docs["context"]:
doc.metadata = {key: extract_prompt_value(doc.metadata.get(key)) for key in self.config.attributes if key in doc.metadata}
source_document = doc.metadata.copy()
source_document["content"] = doc.page_content
source_documents.append(source_document)

original_question = {
"question": self.config.question,
"source_documents": source_documents,
}
self.config.socket.send(original_question)
self.original_question = original_question
return docs

return RunnablePassthrough(get_and_send_original_question)

def prepare_response(self):
try:
subquery = {
"match": {
"all_titles": {
"query": self.config.question,
"operator": "AND",
"analyzer": "english"
}
}
}
retriever = self.config.opensearch.as_retriever(search_type="similarity", search_kwargs={"k": self.config.k, "subquery": subquery, "_source": {"excludes": ["embedding"]}})
chain = (
{"context": retriever, "question": RunnablePassthrough()}
| self.original_question_passthrough()
| self.config.prompt
| self.config.client
| StrOutputParser()
| self.debug_response_passthrough()
)
response = chain.invoke(self.config.question)
except Exception as err:
response = {
"question": self.config.question,
"error": str(err),
"source_documents": [],
}
return response
3 changes: 3 additions & 0 deletions chat/src/requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Dev/Test Dependencies
ruff~=0.1.0
coverage~=7.3.2
Loading

0 comments on commit 1907b58

Please sign in to comment.