diff --git a/.github/workflows/test-python.yml b/.github/workflows/test-python.yml index ff995d1c..c1f3e46d 100644 --- a/.github/workflows/test-python.yml +++ b/.github/workflows/test-python.yml @@ -14,6 +14,7 @@ jobs: env: AWS_ACCESS_KEY_ID: ci AWS_SECRET_ACCESS_KEY: ci + SKIP_WEAVIATE_SETUP: 'True' steps: - uses: actions/checkout@v3 - uses: actions/setup-python@v4 diff --git a/Makefile b/Makefile index 41f41972..5af49d7f 100644 --- a/Makefile +++ b/Makefile @@ -1,58 +1,60 @@ -ifndef VERBOSE -.SILENT: -endif -ENV=dev - -help: - echo "make build | build the SAM project" - echo "make serve | run the SAM server locally" - echo "make clean | remove all installed dependencies and build artifacts" - echo "make deps | install all dependencies" - echo "make link | create hard links to allow for hot reloading of a built project" - echo "make secrets | symlink secrets files from ../tfvars" - echo "make style | run all style checks" - echo "make test | run all tests" - echo "make cover | run all tests with coverage" - echo "make env ENV=[env] | activate env.\$$ENV.json file (default: dev)" - echo "make deps-node | install node dependencies" - echo "make deps-python | install python dependencies" - echo "make style-node | run node code style check" - echo "make style-python | run python code style check" - echo "make test-node | run node tests" - echo "make test-python | run python tests" - echo "make cover-node | run node tests with coverage" - echo "make cover-python | run python tests with coverage" -.aws-sam/build.toml: ./template.yaml node/package-lock.json node/src/package-lock.json python/requirements.txt python/src/requirements.txt - sam build --cached --parallel -deps-node: - cd node && npm ci -cover-node: - cd node && npm run test:coverage -style-node: - cd node && npm run prettier -test-node: - cd node && npm run test -deps-python: - cd chat/src && pip install -r requirements.txt -cover-python: - cd chat/src && coverage run --include='src/**/*' -m unittest -v && coverage report -style-python: - cd chat && ruff check . -test-python: - cd chat && python -m unittest -v -build: .aws-sam/build.toml -link: build - cd chat/src && for src in *.py **/*.py; do for target in $$(find ../../.aws-sam/build -maxdepth 1 -type d); do if [[ -f $$target/$$src ]]; then ln -f $$src $$target/$$src; fi; done; done - cd node/src && for src in *.js *.json **/*.js **/*.json; do for target in $$(find ../../.aws-sam/build -maxdepth 1 -type d); do if [[ -f $$target/$$src ]]; then ln -f $$src $$target/$$src; fi; done; done -serve: link - sam local start-api --host 0.0.0.0 --log-file dc-api.log -deps: deps-node deps-python -style: style-node style-python -test: test-node test-python -cover: cover-node cover-python -env: - ln -fs ./env.${ENV}.json ./env.json -secrets: - ln -s ../tfvars/dc-api/* . -clean: - rm -rf .aws-sam node/node_modules node/src/node_modules python/**/__pycache__ python/.coverage python/.ruff_cache +ifndef VERBOSE +.SILENT: +endif +ENV=dev + +help: + echo "make build | build the SAM project" + echo "make serve | run the SAM server locally" + echo "make clean | remove all installed dependencies and build artifacts" + echo "make deps | install all dependencies" + echo "make link | create hard links to allow for hot reloading of a built project" + echo "make secrets | symlink secrets files from ../tfvars" + echo "make style | run all style checks" + echo "make test | run all tests" + echo "make cover | run all tests with coverage" + echo "make env ENV=[env] | activate env.\$$ENV.json file (default: dev)" + echo "make deps-node | install node dependencies" + echo "make deps-python | install python dependencies" + echo "make style-node | run node code style check" + echo "make style-python | run python code style check" + echo "make test-node | run node tests" + echo "make test-python | run python tests" + echo "make cover-node | run node tests with coverage" + echo "make cover-python | run python tests with coverage" +.aws-sam/build.toml: ./template.yaml node/package-lock.json node/src/package-lock.json chat/dependencies/requirements.txt chat/src/requirements.txt + sam build --cached --parallel +deps-node: + cd node && npm ci +cover-node: + cd node && npm run test:coverage +style-node: + cd node && npm run prettier +test-node: + cd node && npm run test +deps-python: + cd chat/src && pip install -r requirements.txt +cover-python: deps-python + cd chat && export SKIP_WEAVIATE_SETUP=True && coverage run -m unittest -v && coverage report +style-python: deps-python + cd chat && ruff check . +test-python: deps-python + cd chat && export SKIP_WEAVIATE_SETUP=True && PYTHONPATH=src:test && python -m unittest discover -v +python-version: + cd chat && python --version +build: .aws-sam/build.toml +link: build + cd chat/src && for src in *.py **/*.py; do for target in $$(find ../../.aws-sam/build -maxdepth 1 -type d); do if [[ -f $$target/$$src ]]; then ln -f $$src $$target/$$src; fi; done; done + cd node/src && for src in *.js *.json **/*.js **/*.json; do for target in $$(find ../../.aws-sam/build -maxdepth 1 -type d); do if [[ -f $$target/$$src ]]; then ln -f $$src $$target/$$src; fi; done; done +serve: link + sam local start-api --host 0.0.0.0 --log-file dc-api.log +deps: deps-node deps-python +style: style-node style-python +test: test-node test-python +cover: cover-node cover-python +env: + ln -fs ./env.${ENV}.json ./env.json +secrets: + ln -s ../tfvars/dc-api/* . +clean: + rm -rf .aws-sam node/node_modules node/src/node_modules python/**/__pycache__ python/.coverage python/.ruff_cache \ No newline at end of file diff --git a/chat/dependencies/requirements.txt b/chat/dependencies/requirements.txt index 68ec56c3..6bee442a 100644 --- a/chat/dependencies/requirements.txt +++ b/chat/dependencies/requirements.txt @@ -1,7 +1,6 @@ +boto3~=1.34.13 langchain~=0.0.208 -nbformat~=5.9.0 openai~=0.27.8 -pandas~=2.0.2 pyjwt~=2.6.0 python-dotenv~=1.0.0 tiktoken~=0.4.0 diff --git a/chat/src/event_config.py b/chat/src/event_config.py new file mode 100644 index 00000000..8aa6a194 --- /dev/null +++ b/chat/src/event_config.py @@ -0,0 +1,178 @@ +import os +import json + +from dataclasses import dataclass, field +from langchain.chains.qa_with_sources import load_qa_with_sources_chain +from langchain.prompts import PromptTemplate +from setup import ( + weaviate_client, + weaviate_vector_store, + openai_chat_client, +) +from typing import List +from handlers.streaming_socket_callback_handler import StreamingSocketCallbackHandler +from helpers.apitoken import ApiToken +from helpers.prompts import document_template, prompt_template +from websocket import Websocket + + +CHAIN_TYPE = "stuff" +DOCUMENT_VARIABLE_NAME = "context" +INDEX_NAME = "DCWork" +K_VALUE = 10 +TEXT_KEY = "title" +MAX_K = 100 +TEMPERATURE = 0.2 +VERSION = "2023-07-01-preview" + + +@dataclass +class EventConfig: + """ + The EventConfig class represents the configuration for an event. + Default values are set for the following properties which can be overridden in the payload message. + """ + + api_token: ApiToken = field(init=False) + attributes: List[str] = field(init=False) + azure_endpoint: str = field(init=False) + debug_mode: bool = field(init=False) + deployment_name: str = field(init=False) + document_prompt: PromptTemplate = field(init=False) + event: dict = field(default_factory=dict) + index_name: str = field(init=False) + is_logged_in: bool = field(init=False) + k: int = field(init=False) + openai_api_version: str = field(init=False) + payload: dict = field(default_factory=dict) + prompt_text: str = field(init=False) + prompt: PromptTemplate = field(init=False) + question: str = field(init=False) + ref: str = field(init=False) + request_context: dict = field(init=False) + temperature: float = field(init=False) + socket: Websocket = field(init=False, default=None) + text_key: str = field(init=False) + + def __post_init__(self): + self.payload = json.loads(self.event.get("body", "{}")) + + self.api_token = ApiToken(signed_token=self.payload.get("auth")) + self.azure_resource_name = self.payload.get( + "azure_resource_name", os.environ.get("AZURE_OPENAI_RESOURCE_NAME") + ) + if not self.azure_resource_name: + raise EnvironmentError( + "Either payload must contain 'azure_resource_name' or environment variable 'AZURE_OPENAI_RESOURCE_NAME' must be set" + ) + + self.azure_endpoint = f"https://{self.azure_resource_name}.openai.azure.com/" + self.debug_mode = self._is_debug_mode_enabled() + self.deployment_name = self.payload.get( + "deployment_name", os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT_ID") + ) + self.index_name = self.payload.get("index", self.payload.get("index", INDEX_NAME)) + self.is_logged_in = self.api_token.is_logged_in() + self.k = min(self.payload.get("k", K_VALUE), MAX_K) + self.openai_api_version = self.payload.get("openai_api_version", VERSION) + self.prompt_text = ( + self.payload.get("prompt", prompt_template()) + if self.api_token.is_superuser() + else prompt_template() + ) + self.request_context = self.event.get("requestContext", {}) + self.question = self.payload.get("question") + self.ref = self.payload.get("ref") + self.temperature = self.payload.get("temperature", TEMPERATURE) + self.text_key = self.payload.get("text_key", TEXT_KEY) + + self.attributes = [ + item + for item in self._get_request_attributes() + if item not in [self.text_key, "source"] + ] + self.document_prompt = PromptTemplate( + template=document_template(self.attributes), + input_variables=["page_content", "source"] + self.attributes, + ) + self.prompt = PromptTemplate( + template=self.prompt_text, input_variables=["question", "context"] + ) + + def debug_message(self): + return { + "type": "debug", + "message": { + "azure_endpoint": self.azure_endpoint, + "deployment_name": self.deployment_name, + "index": self.index_name, + "k": self.k, + "openai_api_version": self.openai_api_version, + "prompt": self.prompt_text, + "question": self.question, + "ref": self.ref, + "temperature": self.temperature, + "text_key": self.text_key, + }, + } + + def setup_websocket(self): + connection_id = self.request_context.get("connectionId") + endpoint_url = f'https://{self.request_context.get("domainName")}/{self.request_context.get("stage")}' + self.socket = Websocket(endpoint_url=endpoint_url, connection_id=connection_id, ref=self.ref) + + def setup_llm_request(self): + self._setup_vector_store() + self._setup_chat_client() + self._setup_chain() + + def _setup_vector_store(self): + self.weaviate = weaviate_vector_store( + index_name=self.index_name, + text_key=self.text_key, + attributes=self.attributes + ["source"], + ) + + def _setup_chat_client(self): + self.client = openai_chat_client( + deployment_name=self.deployment_name, + openai_api_base=self.azure_endpoint, + openai_api_version=self.openai_api_version, + callbacks=[StreamingSocketCallbackHandler(self.socket, self.debug_mode)], + streaming=True, + ) + + def _setup_chain(self): + self.chain = load_qa_with_sources_chain( + self.client, + chain_type=CHAIN_TYPE, + prompt=self.prompt, + document_prompt=self.document_prompt, + document_variable_name=DOCUMENT_VARIABLE_NAME, + verbose=self._to_bool(os.getenv("VERBOSE")), + ) + + def _is_debug_mode_enabled(self): + debug = self.payload.get("debug", False) + return debug and self.api_token.is_superuser() + + def _get_request_attributes(self): + request_attributes = self.payload.get("attributes", None) + if request_attributes is not None: + return request_attributes + + if os.getenv("SKIP_WEAVIATE_SETUP"): + return [] + + client = weaviate_client() + schema = client.schema.get(self.index_name) + names = [prop["name"] for prop in schema.get("properties")] + return names + + def _to_bool(self, val): + """Converts a value to boolean. If the value is a string, it considers + "", "no", "false", "0" as False. Otherwise, it returns the boolean of the value. + """ + if isinstance(val, str): + return val.lower() not in ["", "no", "false", "0"] + return bool(val) diff --git a/chat/src/handlers/__init__.py b/chat/src/handlers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/chat/src/handlers/chat.py b/chat/src/handlers/chat.py index 76ed50e9..c13be6ab 100644 --- a/chat/src/handlers/chat.py +++ b/chat/src/handlers/chat.py @@ -1,170 +1,26 @@ -import boto3 -import json -import os -import setup -import tiktoken -from helpers.apitoken import ApiToken -from helpers.prompts import document_template, prompt_template -from langchain.callbacks.base import BaseCallbackHandler -from langchain.chains.qa_with_sources import load_qa_with_sources_chain -from langchain.prompts import PromptTemplate -from openai.error import InvalidRequestError +import traceback +from event_config import EventConfig +from helpers.response import prepare_response -DEFAULT_INDEX = "DCWork" -DEFAULT_KEY = "title" -DEFAULT_K = 10 -MAX_K = 100 - -class Websocket: - def __init__(self, endpoint_url, connection_id, ref): - self.client = boto3.client("apigatewaymanagementapi", endpoint_url=endpoint_url) - self.connection_id = connection_id - self.ref = ref - - def send(self, data): - data["ref"] = self.ref - data_as_bytes = bytes(json.dumps(data), "utf-8") - self.client.post_to_connection( - Data=data_as_bytes, ConnectionId=self.connection_id - ) - - -class StreamingSocketCallbackHandler(BaseCallbackHandler): - def __init__(self, socket: Websocket, debug_mode: bool): - self.socket = socket - self.debug_mode = debug_mode - - def on_llm_new_token(self, token: str, **kwargs): - if not self.debug_mode: - self.socket.send({"token": token}) - - -def handler(event, context): +def handler(event, _context): try: - payload = json.loads(event.get("body", "{}")) - - request_context = event.get("requestContext", {}) - connection_id = request_context.get("connectionId") - endpoint_url = f'https://{request_context.get("domainName")}/{request_context.get("stage")}' - ref = payload.get("ref") - socket = Websocket( - connection_id=connection_id, endpoint_url=endpoint_url, ref=ref - ) + config = EventConfig(event) + config.setup_websocket() - api_token = ApiToken(signed_token=payload.get("auth")) - if not api_token.is_logged_in(): - socket.send({"statusCode": 401, "body": "Unauthorized"}) + if not config.is_logged_in: + config.socket.send({"type": "error", "message": "Unauthorized"}) return {"statusCode": 401, "body": "Unauthorized"} - debug_mode = payload.get("debug", False) and api_token.is_superuser() - - question = payload.get("question") - index_name = payload.get("index", payload.get("index", DEFAULT_INDEX)) - print(f"Searching index {index_name}") - text_key = payload.get("text_key", DEFAULT_KEY) - attributes = [ - item - for item in get_attributes( - index_name, payload if api_token.is_superuser() else {} - ) - if item not in [text_key, "source"] - ] - - weaviate = setup.weaviate_vector_store( - index_name=index_name, text_key=text_key, attributes=attributes + ["source"] - ) - - client = setup.openai_chat_client( - callbacks=[StreamingSocketCallbackHandler(socket, debug_mode)], - streaming=True, - ) - - prompt_text = ( - payload.get("prompt", prompt_template()) - if api_token.is_superuser() - else prompt_template() - ) - prompt = PromptTemplate( - template=prompt_text, input_variables=["question", "context"] - ) + if config.debug_mode: + config.socket.send(config.debug_message()) - document_prompt = PromptTemplate( - template=document_template(attributes), - input_variables=["page_content", "source"] + attributes, - ) - - k = min(payload.get("k", DEFAULT_K), MAX_K) - docs = weaviate.similarity_search(question, k=k, additional="certainty") - chain = load_qa_with_sources_chain( - client, - chain_type="stuff", - prompt=prompt, - document_prompt=document_prompt, - document_variable_name="context", - verbose=to_bool(os.getenv("VERBOSE")), - ) - - try: - for doc in docs: - doc.metadata['full_text'] = '' - doc_response = [doc.__dict__ for doc in docs] - original_question = {"question": question, "source_documents": doc_response} - socket.send(original_question) - response = chain({"question": question, "input_documents": docs}) - if debug_mode: - final_response = { - "answer": response["output_text"], - "attributes": attributes, - "isSuperuser": api_token.is_superuser(), - "prompt": prompt_text, - "ref": ref, - "k": k, - "original_question": original_question, - "token_counts": { - "question": count_tokens(question), - "answer": count_tokens(response["output_text"]), - "prompt": count_tokens(prompt_text), - "source_documents": count_tokens(doc_response), - }, - } - else: - final_response = {"answer": response["output_text"], "ref": ref} - except InvalidRequestError as err: - final_response = { - "question": question, - "error": str(err), - "source_documents": [], - } - - socket.send(final_response) + config.setup_llm_request() + final_response = prepare_response(config) + config.socket.send(final_response) return {"statusCode": 200} except Exception as err: + error_message = traceback.format_exc() + config.socket.send(error_message) print(event) - raise err - - -def get_attributes(index, payload): - request_attributes = payload.get("attributes", None) - if request_attributes is not None: - return request_attributes - - client = setup.weaviate_client() - schema = client.schema.get(index) - names = [prop["name"] for prop in schema.get("properties")] - print(f"Retrieved attributes: {names}") - return names - - -def count_tokens(val): - encoding = tiktoken.encoding_for_model("gpt-4") - token_integers = encoding.encode(str(val)) - num_tokens = len(token_integers) - - return num_tokens - - -def to_bool(val): - if isinstance(val, str): - return val.lower() not in ["", "no", "false", "0"] - return bool(val) + raise err \ No newline at end of file diff --git a/chat/src/handlers/streaming_socket_callback_handler.py b/chat/src/handlers/streaming_socket_callback_handler.py new file mode 100644 index 00000000..b4feed92 --- /dev/null +++ b/chat/src/handlers/streaming_socket_callback_handler.py @@ -0,0 +1,14 @@ +from langchain.callbacks.base import BaseCallbackHandler +from websocket import Websocket +import os + +class StreamingSocketCallbackHandler(BaseCallbackHandler): + def __init__(self, socket: Websocket, debug_mode: bool): + self.socket = socket + self.debug_mode = debug_mode + + def on_llm_new_token(self, token: str, **kwargs): + if os.getenv("SKIP_WEAVIATE_SETUP"): + return token + elif not self.debug_mode: + self.socket.send({"token": token}) diff --git a/chat/src/helpers/__init__.py b/chat/src/helpers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/chat/src/helpers/apitoken.py b/chat/src/helpers/apitoken.py index f8685357..46c97263 100644 --- a/chat/src/helpers/apitoken.py +++ b/chat/src/helpers/apitoken.py @@ -2,30 +2,34 @@ import jwt import os + class ApiToken: - @classmethod - def empty_token(cls): - time = int(datetime.now().timestamp()) - return { - 'iss': os.getenv('DC_API_ENDPOINT'), - 'exp': datetime.fromtimestamp(time + 12 * 60 * 60).timestamp(), # 12 hours - 'iat': time, - 'entitlements': [], - 'isLoggedIn': False, - } - - def __init__(self, signed_token=None): - if signed_token is None: - self.token = ApiToken.empty_token() - else: - try: - secret = os.getenv("API_TOKEN_SECRET") - self.token = jwt.decode(signed_token, secret, algorithms=["HS256"]) - except Exception: - self.token = ApiToken.empty_token() + @classmethod + def empty_token(cls): + time = int(datetime.now().timestamp()) + return { + "iss": os.getenv("DC_API_ENDPOINT"), + "exp": datetime.fromtimestamp(time + 12 * 60 * 60).timestamp(), # 12 hours + "iat": time, + "entitlements": [], + "isLoggedIn": False, + } + + def __init__(self, signed_token=None): + if signed_token is None: + self.token = ApiToken.empty_token() + else: + try: + secret = os.getenv("API_TOKEN_SECRET") + self.token = jwt.decode(signed_token, secret, algorithms=["HS256"]) + except Exception: + self.token = ApiToken.empty_token() + + def __str__(self): + return f"ApiToken(token={self.token})" + + def is_logged_in(self): + return self.token.get("isLoggedIn", False) - def is_logged_in(self): - return self.token.get("isLoggedIn", False) - - def is_superuser(self): - return self.token.get("isSuperUser", False) + def is_superuser(self): + return self.token.get("isSuperUser", False) diff --git a/chat/src/helpers/metrics.py b/chat/src/helpers/metrics.py new file mode 100644 index 00000000..168cd02f --- /dev/null +++ b/chat/src/helpers/metrics.py @@ -0,0 +1,18 @@ +import tiktoken + + +def token_usage(config, response, original_question): + return { + "question": count_tokens(config.question), + "answer": count_tokens(response["output_text"]), + "prompt": count_tokens(config.prompt_text), + "source_documents": count_tokens(original_question["source_documents"]), + } + + +def count_tokens(val): + encoding = tiktoken.encoding_for_model("gpt-4") + token_integers = encoding.encode(str(val)) + num_tokens = len(token_integers) + + return num_tokens diff --git a/chat/src/helpers/prompts.py b/chat/src/helpers/prompts.py index 8624d385..32ffbc46 100644 --- a/chat/src/helpers/prompts.py +++ b/chat/src/helpers/prompts.py @@ -1,5 +1,7 @@ -# ruff: noqa: E501 -def prompt_template(): +from typing import List, Optional + + +def prompt_template() -> str: return """Please answer the question based on the documents provided, and include some details about why the documents might be relevant to the particular question: Documents: @@ -10,7 +12,9 @@ def prompt_template(): """ -def document_template(attributes): +def document_template(attributes: Optional[List[str]] = None) -> str: + if attributes is None: + attributes = [] lines = ( ["Content: {page_content}", "Metadata:"] + [f" {attribute}: {{{attribute}}}" for attribute in attributes] diff --git a/chat/src/helpers/response.py b/chat/src/helpers/response.py new file mode 100644 index 00000000..655c4c46 --- /dev/null +++ b/chat/src/helpers/response.py @@ -0,0 +1,49 @@ +from helpers.metrics import token_usage +from openai.error import InvalidRequestError + + +def base_response(config, response): + return {"answer": response["output_text"], "ref": config.ref} + + +def debug_response(config, response, original_question): + response_base = base_response(config, response) + debug_info = { + "attributes": config.attributes.append("source"), + "is_superuser": config.api_token.is_superuser(), + "prompt": config.prompt_text, + "k": config.k, + "token_counts": token_usage(config, response, original_question), + } + return {**response_base, **debug_info} + + +def get_and_send_original_question(config, docs): + doc_response = [doc.__dict__ for doc in docs] + original_question = { + "question": config.question, + "source_documents": doc_response, + } + config.socket.send(original_question) + return original_question + + +def prepare_response(config): + try: + docs = config.weaviate.similarity_search( + config.question, k=config.k, additional="certainty" + ) + original_question = get_and_send_original_question(config, docs) + response = config.chain({"question": config.question, "input_documents": docs}) + + if config.debug_mode: + prepared_response = debug_response(config, response, original_question) + else: + prepared_response = base_response(config, response) + except InvalidRequestError as err: + prepared_response = { + "question": config.question, + "error": str(err), + "source_documents": [], + } + return prepared_response diff --git a/chat/src/helpers/utils.py b/chat/src/helpers/utils.py new file mode 100644 index 00000000..d0d243d4 --- /dev/null +++ b/chat/src/helpers/utils.py @@ -0,0 +1,7 @@ +def to_bool(val): + """Converts a value to boolean. If the value is a string, it considers + "", "no", "false", "0" as False. Otherwise, it returns the boolean of the value. + """ + if isinstance(val, str): + return val.lower() not in ["", "no", "false", "0"] + return bool(val) diff --git a/chat/src/requirements.txt b/chat/src/requirements.txt index aa6d612d..8cb0270e 100644 --- a/chat/src/requirements.txt +++ b/chat/src/requirements.txt @@ -1,8 +1,7 @@ # Runtime Dependencies +boto3~=1.34.13 langchain~=0.0.208 -nbformat~=5.9.0 openai~=0.27.8 -pandas~=2.0.2 pyjwt~=2.6.0 python-dotenv~=1.0.0 tiktoken~=0.4.0 diff --git a/chat/src/setup.py b/chat/src/setup.py index da9dbbf1..cc70c653 100644 --- a/chat/src/setup.py +++ b/chat/src/setup.py @@ -3,34 +3,57 @@ from typing import List import os import weaviate +import boto3 + def openai_chat_client(**kwargs): - deployment = os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT_ID") - key = os.getenv("AZURE_OPENAI_API_KEY") - resource = os.getenv("AZURE_OPENAI_RESOURCE_NAME") - version = "2023-07-01-preview" - - return AzureChatOpenAI(deployment_name=deployment, - openai_api_key=key, - openai_api_base=f"https://{resource}.openai.azure.com/", - openai_api_version=version, - **kwargs) - + return AzureChatOpenAI( + openai_api_key=os.getenv("AZURE_OPENAI_API_KEY"), + **kwargs, + ) + def weaviate_client(): - weaviate_url = os.environ['WEAVIATE_URL'] - weaviate_api_key = os.environ['WEAVIATE_API_KEY'] - auth_config = weaviate.AuthApiKey(api_key=weaviate_api_key) + if os.getenv("SKIP_WEAVIATE_SETUP"): + return None + + weaviate_url = os.environ.get("WEAVIATE_URL") + try: + if weaviate_url is None: + raise EnvironmentError( + "WEAVIATE_URL is not set in the environment variables" + ) + + weaviate_api_key = os.environ.get("WEAVIATE_API_KEY") + if weaviate_api_key is None: + raise EnvironmentError( + "WEAVIATE_API_KEY is not set in the environment variables" + ) + + auth_config = weaviate.AuthApiKey(api_key=weaviate_api_key) + + client = weaviate.Client(url=weaviate_url, auth_client_secret=auth_config) + except Exception as e: + print(f"An error occurred: {e}") + client = None + return client - return weaviate.Client( - url=weaviate_url, - auth_client_secret=auth_config - ) def weaviate_vector_store(index_name: str, text_key: str, attributes: List[str] = []): - client = weaviate_client() + if os.getenv("SKIP_WEAVIATE_SETUP"): + return None + + client = weaviate_client() + + return Weaviate( + client=client, index_name=index_name, text_key=text_key, attributes=attributes + ) + - return Weaviate(client=client, - index_name=index_name, - text_key=text_key, - attributes=attributes) +def websocket_client(endpoint_url: str): + endpoint_url = endpoint_url or os.getenv("APIGATEWAY_URL") + try: + client = boto3.client("apigatewaymanagementapi", endpoint_url=endpoint_url) + return client + except Exception as e: + raise e \ No newline at end of file diff --git a/chat/src/websocket.py b/chat/src/websocket.py new file mode 100644 index 00000000..dc81179a --- /dev/null +++ b/chat/src/websocket.py @@ -0,0 +1,16 @@ +import json +from setup import websocket_client + +class Websocket: + def __init__(self, client=None, endpoint_url=None, connection_id=None, ref=None): + self.client = client or websocket_client(endpoint_url) + self.connection_id = connection_id + self.ref = ref if ref else {} + + def send(self, data): + if isinstance(data, str): + data = {"message": data} + data["ref"] = self.ref + data_as_bytes = bytes(json.dumps(data), "utf-8") + self.client.post_to_connection(Data=data_as_bytes, ConnectionId=self.connection_id) + return data diff --git a/chat/test/handlers/test_streaming_socket_callback_handler.py b/chat/test/handlers/test_streaming_socket_callback_handler.py new file mode 100644 index 00000000..b814a8ff --- /dev/null +++ b/chat/test/handlers/test_streaming_socket_callback_handler.py @@ -0,0 +1,26 @@ +# ruff: noqa: E402 +import sys +sys.path.append('./src') + +from unittest import TestCase +from handlers.streaming_socket_callback_handler import ( + StreamingSocketCallbackHandler, +) +from websocket import Websocket + + + +class MockClient: + def post_to_connection(self, Data, ConnectionId): + return Data + +class TestMyStreamingSocketCallbackHandler(TestCase): + def test_on_new_llm_token(self): + handler = StreamingSocketCallbackHandler(Websocket(client=MockClient()), False) + result = handler.on_llm_new_token(token="test") + self.assertEqual(result, "test") + self.assertFalse(handler.debug_mode) + + def test_debug_mode(self): + handler = StreamingSocketCallbackHandler(Websocket(client=MockClient()), debug_mode=True) + self.assertTrue(handler.debug_mode) diff --git a/chat/test/helpers/test_apitoken.py b/chat/test/helpers/test_apitoken.py index 51b93629..a330f56a 100644 --- a/chat/test/helpers/test_apitoken.py +++ b/chat/test/helpers/test_apitoken.py @@ -1,30 +1,44 @@ +# ruff: noqa: E402 import os -from src.helpers.apitoken import ApiToken +import sys +sys.path.append('./src') + +from helpers.apitoken import ApiToken from test.fixtures.apitoken import SUPER_TOKEN, TEST_SECRET, TEST_TOKEN from unittest import mock, TestCase -@mock.patch.dict( - os.environ, - { - "API_TOKEN_SECRET": TEST_SECRET - } -) + + + +@mock.patch.dict(os.environ, {"API_TOKEN_SECRET": TEST_SECRET}) class TestFunction(TestCase): - def test_empty_token(self): - subject = ApiToken() - self.assertFalse(subject.is_logged_in()) - - def test_valid_token(self): - subject = ApiToken(TEST_TOKEN) - self.assertTrue(subject.is_logged_in()) - self.assertFalse(subject.is_superuser()) - - def test_superuser_token(self): - subject = ApiToken(SUPER_TOKEN) - self.assertTrue(subject.is_logged_in()) - self.assertTrue(subject.is_superuser()) - - def test_invalid_token(self): - subject = ApiToken("INVALID_TOKEN") - self.assertFalse(subject.is_logged_in()) - \ No newline at end of file + def test_empty_token(self): + subject = ApiToken() + self.assertIsInstance(subject, ApiToken) + self.assertFalse(subject.is_logged_in()) + + def test_valid_token(self): + subject = ApiToken(TEST_TOKEN) + self.assertIsInstance(subject, ApiToken) + self.assertTrue(subject.is_logged_in()) + self.assertFalse(subject.is_superuser()) + + def test_superuser_token(self): + subject = ApiToken(SUPER_TOKEN) + self.assertIsInstance(subject, ApiToken) + self.assertTrue(subject.is_logged_in()) + self.assertTrue(subject.is_superuser()) + + def test_invalid_token(self): + subject = ApiToken("INVALID_TOKEN") + self.assertIsInstance(subject, ApiToken) + self.assertFalse(subject.is_logged_in()) + + def test_empty_token_class_method(self): + empty_token = ApiToken.empty_token() + self.assertIsInstance(empty_token, dict) + self.assertFalse(empty_token["isLoggedIn"]) + + def test_str_method(self): + subject = ApiToken(TEST_TOKEN) + self.assertEqual(str(subject), f"ApiToken(token={subject.token})") diff --git a/chat/test/helpers/test_metrics.py b/chat/test/helpers/test_metrics.py new file mode 100644 index 00000000..651043eb --- /dev/null +++ b/chat/test/helpers/test_metrics.py @@ -0,0 +1,64 @@ +# ruff: noqa: E402 +import json +import os +import sys +sys.path.append('./src') + +from unittest import TestCase, mock +from helpers.metrics import count_tokens, token_usage +from event_config import EventConfig + + + +@mock.patch.dict( + os.environ, + { + "AZURE_OPENAI_RESOURCE_NAME": "test", + "WEAVIATE_URL": "http://test", + "WEAVIATE_API_KEY": "test" + }, +) +class TestMetrics(TestCase): + def test_token_usage(self): + original_question = { + "question": "What is your name?", + "source_documents": [], + } + event = { + "body": json.dumps({ + "deployment_name": "test", + "index": "test", + "k": 1, + "openai_api_version": "2019-05-06", + "prompt": "This is a test prompt.", + "question": original_question, + "ref": "test", + "temperature": 0.5, + "text_key": "text", + "auth": "test123" + }) + } + config = EventConfig(event=event) + + response = { + "output_text": "This is a test response.", + } + + result = token_usage(config, response, original_question) + + expected_result = { + "answer": 6, + "prompt": 36, + "question": 15, + "source_documents": 1, + } + + self.assertEqual(result, expected_result) + + def test_count_tokens(self): + val = "Hello, world!" + expected_result = 4 + + result = count_tokens(val) + + self.assertEqual(result, expected_result) diff --git a/chat/test/helpers/test_prompts.py b/chat/test/helpers/test_prompts.py new file mode 100644 index 00000000..9508f32a --- /dev/null +++ b/chat/test/helpers/test_prompts.py @@ -0,0 +1,33 @@ +# ruff: noqa: E402 +import sys +sys.path.append('./src') + +from helpers.prompts import prompt_template, document_template +from unittest import TestCase + + +class TestPromptTemplate(TestCase): + def test_prompt_template(self): + prompt = prompt_template() + assert isinstance(prompt, str) + assert len(prompt) > 0 + + +class TestDocumentTemplate(TestCase): + def test_empty_attributes(self): + self.assertEqual( + document_template(), + "Content: {page_content}\nMetadata:\nSource: {source}", + ) + + def test_single_attribute(self): + self.assertEqual( + document_template(["title"]), + "Content: {page_content}\nMetadata:\n title: {title}\nSource: {source}", + ) + + def test_multiple_attributes(self): + self.assertEqual( + document_template(["title", "author", "subject", "description"]), + "Content: {page_content}\nMetadata:\n title: {title}\n author: {author}\n subject: {subject}\n description: {description}\nSource: {source}", + ) diff --git a/chat/test/test_event_config.py b/chat/test/test_event_config.py new file mode 100644 index 00000000..0fae8072 --- /dev/null +++ b/chat/test/test_event_config.py @@ -0,0 +1,98 @@ +# ruff: noqa: E402 +import json +import os +import sys +sys.path.append('./src') + +from event_config import EventConfig +from unittest import TestCase, mock + + +class TestEventConfigWithoutAzureResource(TestCase): + def test_requires_an_azure_resource(self): + with self.assertRaises(EnvironmentError): + EventConfig() + + +@mock.patch.dict( + os.environ, + { + "AZURE_OPENAI_RESOURCE_NAME": "test", + }, +) +class TestEventConfig(TestCase): + def test_fetches_attributes_from_vector_database(self): + os.environ.pop("AZURE_OPENAI_RESOURCE_NAME", None) + with self.assertRaises(EnvironmentError): + EventConfig() + + def test_defaults(self): + actual = EventConfig(event={"body": json.dumps({"attributes": ["title"]})}) + expected_defaults = {"azure_endpoint": "https://test.openai.azure.com/"} + self.assertEqual(actual.azure_endpoint, expected_defaults["azure_endpoint"]) + + def test_overrides(self): + actual = EventConfig( + event={ + "body": json.dumps( + { + "azure_resource_name": "new_name_for_test", + "attributes": ["subject", "date_created"], + "index": "testIndex", + "k": 100, + "openai_api_version": "2024-01-01", + "question": "test question", + "ref": "test ref", + "temperature": 0.9, + "text_key": "accession_number", + } + ) + } + ) + expected_overrides = { + "attributes": ["subject", "date_created"], + "azure_endpoint": "https://new_name_for_test.openai.azure.com/", + "index_name": "testIndex", + "k": 100, + "openai_api_version": "2024-01-01", + "question": "test question", + "ref": "test ref", + "temperature": 0.9, + "text_key": "accession_number", + } + self.assertEqual(actual.azure_endpoint, expected_overrides["azure_endpoint"]) + self.assertEqual(actual.index_name, expected_overrides["index_name"]) + self.assertEqual(actual.attributes, expected_overrides["attributes"]) + self.assertEqual(actual.k, expected_overrides["k"]) + self.assertEqual( + actual.openai_api_version, expected_overrides["openai_api_version"] + ) + self.assertEqual(actual.question, expected_overrides["question"]) + self.assertEqual(actual.ref, expected_overrides["ref"]) + self.assertEqual(actual.temperature, expected_overrides["temperature"]) + self.assertEqual(actual.text_key, expected_overrides["text_key"]) + + def test_text_key_removed_from_attributes_list(self): + actual = EventConfig( + event={ + "body": json.dumps( + { + "attributes": ["title", "description"], + "text_key": "description", + } + ) + } + ) + self.assertNotIn(actual.text_key, actual.attributes) + + def test_source_removed_from_attributes_list(self): + actual = EventConfig(event={"body": json.dumps({"attributes": ["source"]})}) + self.assertNotIn("source", actual.attributes) + + def test_debug_message(self): + self.assertEqual( + EventConfig( + event={"body": json.dumps({"attributes": ["source"]})} + ).debug_message()["type"], + "debug", + ) diff --git a/chat/test/test_websocket.py b/chat/test/test_websocket.py new file mode 100644 index 00000000..4d4d8b76 --- /dev/null +++ b/chat/test/test_websocket.py @@ -0,0 +1,18 @@ +# ruff: noqa: E402 +import sys +sys.path.append('./src') + +from unittest import TestCase +from websocket import Websocket + + +class MockClient: + def post_to_connection(self, Data, ConnectionId): + return Data + +class TestWebsocket(TestCase): + def test_post_to_connection(self): + websocket = Websocket(client=MockClient(), connection_id="test_connection_id", ref="test_ref") + message = "test_message" + expected = {"message": "test_message", "ref": "test_ref"} + self.assertEqual(websocket.send(message), expected) \ No newline at end of file