Skip to content

Commit

Permalink
Allow for overriding parameters to the LLM with default configuration…
Browse files Browse the repository at this point in the history
… in place

- Large refactor of configuration handling, adds the ability to override many more parameters via websocket messages
- Tests passing in dev environment using the Makefile and Github actions
- Allow for skipping weaviate setup in Github actions via environment variable
  • Loading branch information
bmquinn committed Feb 2, 2024
1 parent 0dd7f7f commit e50173f
Show file tree
Hide file tree
Showing 22 changed files with 721 additions and 298 deletions.
1 change: 1 addition & 0 deletions .github/workflows/test-python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ jobs:
env:
AWS_ACCESS_KEY_ID: ci
AWS_SECRET_ACCESS_KEY: ci
SKIP_WEAVIATE_SETUP: 'True'
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
Expand Down
118 changes: 60 additions & 58 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,58 +1,60 @@
ifndef VERBOSE
.SILENT:
endif
ENV=dev

help:
echo "make build | build the SAM project"
echo "make serve | run the SAM server locally"
echo "make clean | remove all installed dependencies and build artifacts"
echo "make deps | install all dependencies"
echo "make link | create hard links to allow for hot reloading of a built project"
echo "make secrets | symlink secrets files from ../tfvars"
echo "make style | run all style checks"
echo "make test | run all tests"
echo "make cover | run all tests with coverage"
echo "make env ENV=[env] | activate env.\$$ENV.json file (default: dev)"
echo "make deps-node | install node dependencies"
echo "make deps-python | install python dependencies"
echo "make style-node | run node code style check"
echo "make style-python | run python code style check"
echo "make test-node | run node tests"
echo "make test-python | run python tests"
echo "make cover-node | run node tests with coverage"
echo "make cover-python | run python tests with coverage"
.aws-sam/build.toml: ./template.yaml node/package-lock.json node/src/package-lock.json python/requirements.txt python/src/requirements.txt
sam build --cached --parallel
deps-node:
cd node && npm ci
cover-node:
cd node && npm run test:coverage
style-node:
cd node && npm run prettier
test-node:
cd node && npm run test
deps-python:
cd chat/src && pip install -r requirements.txt
cover-python:
cd chat/src && coverage run --include='src/**/*' -m unittest -v && coverage report
style-python:
cd chat && ruff check .
test-python:
cd chat && python -m unittest -v
build: .aws-sam/build.toml
link: build
cd chat/src && for src in *.py **/*.py; do for target in $$(find ../../.aws-sam/build -maxdepth 1 -type d); do if [[ -f $$target/$$src ]]; then ln -f $$src $$target/$$src; fi; done; done
cd node/src && for src in *.js *.json **/*.js **/*.json; do for target in $$(find ../../.aws-sam/build -maxdepth 1 -type d); do if [[ -f $$target/$$src ]]; then ln -f $$src $$target/$$src; fi; done; done
serve: link
sam local start-api --host 0.0.0.0 --log-file dc-api.log
deps: deps-node deps-python
style: style-node style-python
test: test-node test-python
cover: cover-node cover-python
env:
ln -fs ./env.${ENV}.json ./env.json
secrets:
ln -s ../tfvars/dc-api/* .
clean:
rm -rf .aws-sam node/node_modules node/src/node_modules python/**/__pycache__ python/.coverage python/.ruff_cache
ifndef VERBOSE
.SILENT:
endif
ENV=dev

help:
echo "make build | build the SAM project"
echo "make serve | run the SAM server locally"
echo "make clean | remove all installed dependencies and build artifacts"
echo "make deps | install all dependencies"
echo "make link | create hard links to allow for hot reloading of a built project"
echo "make secrets | symlink secrets files from ../tfvars"
echo "make style | run all style checks"
echo "make test | run all tests"
echo "make cover | run all tests with coverage"
echo "make env ENV=[env] | activate env.\$$ENV.json file (default: dev)"
echo "make deps-node | install node dependencies"
echo "make deps-python | install python dependencies"
echo "make style-node | run node code style check"
echo "make style-python | run python code style check"
echo "make test-node | run node tests"
echo "make test-python | run python tests"
echo "make cover-node | run node tests with coverage"
echo "make cover-python | run python tests with coverage"
.aws-sam/build.toml: ./template.yaml node/package-lock.json node/src/package-lock.json chat/dependencies/requirements.txt chat/src/requirements.txt
sam build --cached --parallel
deps-node:
cd node && npm ci
cover-node:
cd node && npm run test:coverage
style-node:
cd node && npm run prettier
test-node:
cd node && npm run test
deps-python:
cd chat/src && pip install -r requirements.txt
cover-python: deps-python
cd chat && export SKIP_WEAVIATE_SETUP=True && coverage run -m unittest -v && coverage report
style-python: deps-python
cd chat && ruff check .
test-python: deps-python
cd chat && export SKIP_WEAVIATE_SETUP=True && PYTHONPATH=src:test && python -m unittest discover -v
python-version:
cd chat && python --version
build: .aws-sam/build.toml
link: build
cd chat/src && for src in *.py **/*.py; do for target in $$(find ../../.aws-sam/build -maxdepth 1 -type d); do if [[ -f $$target/$$src ]]; then ln -f $$src $$target/$$src; fi; done; done
cd node/src && for src in *.js *.json **/*.js **/*.json; do for target in $$(find ../../.aws-sam/build -maxdepth 1 -type d); do if [[ -f $$target/$$src ]]; then ln -f $$src $$target/$$src; fi; done; done
serve: link
sam local start-api --host 0.0.0.0 --log-file dc-api.log
deps: deps-node deps-python
style: style-node style-python
test: test-node test-python
cover: cover-node cover-python
env:
ln -fs ./env.${ENV}.json ./env.json
secrets:
ln -s ../tfvars/dc-api/* .
clean:
rm -rf .aws-sam node/node_modules node/src/node_modules python/**/__pycache__ python/.coverage python/.ruff_cache
3 changes: 1 addition & 2 deletions chat/dependencies/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
boto3~=1.34.13
langchain~=0.0.208
nbformat~=5.9.0
openai~=0.27.8
pandas~=2.0.2
pyjwt~=2.6.0
python-dotenv~=1.0.0
tiktoken~=0.4.0
Expand Down
178 changes: 178 additions & 0 deletions chat/src/event_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import os
import json

from dataclasses import dataclass, field
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
from langchain.prompts import PromptTemplate
from setup import (
weaviate_client,
weaviate_vector_store,
openai_chat_client,
)
from typing import List
from handlers.streaming_socket_callback_handler import StreamingSocketCallbackHandler
from helpers.apitoken import ApiToken
from helpers.prompts import document_template, prompt_template
from websocket import Websocket


CHAIN_TYPE = "stuff"
DOCUMENT_VARIABLE_NAME = "context"
INDEX_NAME = "DCWork"
K_VALUE = 10
TEXT_KEY = "title"
MAX_K = 100
TEMPERATURE = 0.2
VERSION = "2023-07-01-preview"


@dataclass
class EventConfig:
"""
The EventConfig class represents the configuration for an event.
Default values are set for the following properties which can be overridden in the payload message.
"""

api_token: ApiToken = field(init=False)
attributes: List[str] = field(init=False)
azure_endpoint: str = field(init=False)
debug_mode: bool = field(init=False)
deployment_name: str = field(init=False)
document_prompt: PromptTemplate = field(init=False)
event: dict = field(default_factory=dict)
index_name: str = field(init=False)
is_logged_in: bool = field(init=False)
k: int = field(init=False)
openai_api_version: str = field(init=False)
payload: dict = field(default_factory=dict)
prompt_text: str = field(init=False)
prompt: PromptTemplate = field(init=False)
question: str = field(init=False)
ref: str = field(init=False)
request_context: dict = field(init=False)
temperature: float = field(init=False)
socket: Websocket = field(init=False, default=None)
text_key: str = field(init=False)

def __post_init__(self):
self.payload = json.loads(self.event.get("body", "{}"))

self.api_token = ApiToken(signed_token=self.payload.get("auth"))
self.azure_resource_name = self.payload.get(
"azure_resource_name", os.environ.get("AZURE_OPENAI_RESOURCE_NAME")
)
if not self.azure_resource_name:
raise EnvironmentError(
"Either payload must contain 'azure_resource_name' or environment variable 'AZURE_OPENAI_RESOURCE_NAME' must be set"
)

self.azure_endpoint = f"https://{self.azure_resource_name}.openai.azure.com/"
self.debug_mode = self._is_debug_mode_enabled()
self.deployment_name = self.payload.get(
"deployment_name", os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT_ID")
)
self.index_name = self.payload.get("index", self.payload.get("index", INDEX_NAME))
self.is_logged_in = self.api_token.is_logged_in()
self.k = min(self.payload.get("k", K_VALUE), MAX_K)
self.openai_api_version = self.payload.get("openai_api_version", VERSION)
self.prompt_text = (
self.payload.get("prompt", prompt_template())
if self.api_token.is_superuser()
else prompt_template()
)
self.request_context = self.event.get("requestContext", {})
self.question = self.payload.get("question")
self.ref = self.payload.get("ref")
self.temperature = self.payload.get("temperature", TEMPERATURE)
self.text_key = self.payload.get("text_key", TEXT_KEY)

self.attributes = [
item
for item in self._get_request_attributes()
if item not in [self.text_key, "source"]
]
self.document_prompt = PromptTemplate(
template=document_template(self.attributes),
input_variables=["page_content", "source"] + self.attributes,
)
self.prompt = PromptTemplate(
template=self.prompt_text, input_variables=["question", "context"]
)

def debug_message(self):
return {
"type": "debug",
"message": {
"azure_endpoint": self.azure_endpoint,
"deployment_name": self.deployment_name,
"index": self.index_name,
"k": self.k,
"openai_api_version": self.openai_api_version,
"prompt": self.prompt_text,
"question": self.question,
"ref": self.ref,
"temperature": self.temperature,
"text_key": self.text_key,
},
}

def setup_websocket(self):
connection_id = self.request_context.get("connectionId")
endpoint_url = f'https://{self.request_context.get("domainName")}/{self.request_context.get("stage")}'
self.socket = Websocket(endpoint_url=endpoint_url, connection_id=connection_id, ref=self.ref)

def setup_llm_request(self):
self._setup_vector_store()
self._setup_chat_client()
self._setup_chain()

def _setup_vector_store(self):
self.weaviate = weaviate_vector_store(
index_name=self.index_name,
text_key=self.text_key,
attributes=self.attributes + ["source"],
)

def _setup_chat_client(self):
self.client = openai_chat_client(
deployment_name=self.deployment_name,
openai_api_base=self.azure_endpoint,
openai_api_version=self.openai_api_version,
callbacks=[StreamingSocketCallbackHandler(self.socket, self.debug_mode)],
streaming=True,
)

def _setup_chain(self):
self.chain = load_qa_with_sources_chain(
self.client,
chain_type=CHAIN_TYPE,
prompt=self.prompt,
document_prompt=self.document_prompt,
document_variable_name=DOCUMENT_VARIABLE_NAME,
verbose=self._to_bool(os.getenv("VERBOSE")),
)

def _is_debug_mode_enabled(self):
debug = self.payload.get("debug", False)
return debug and self.api_token.is_superuser()

def _get_request_attributes(self):
request_attributes = self.payload.get("attributes", None)
if request_attributes is not None:
return request_attributes

if os.getenv("SKIP_WEAVIATE_SETUP"):
return []

client = weaviate_client()
schema = client.schema.get(self.index_name)
names = [prop["name"] for prop in schema.get("properties")]
return names

def _to_bool(self, val):
"""Converts a value to boolean. If the value is a string, it considers
"", "no", "false", "0" as False. Otherwise, it returns the boolean of the value.
"""
if isinstance(val, str):
return val.lower() not in ["", "no", "false", "0"]
return bool(val)
Empty file added chat/src/handlers/__init__.py
Empty file.
Loading

0 comments on commit e50173f

Please sign in to comment.