-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Allow for overriding parameters to the LLM with default configuration…
… in place - Large refactor of configuration handling, adds the ability to override many more parameters via websocket messages - Tests passing in dev environment using the Makefile and Github actions - Allow for skipping weaviate setup in Github actions via environment variable
- Loading branch information
Showing
22 changed files
with
721 additions
and
298 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,58 +1,60 @@ | ||
ifndef VERBOSE | ||
.SILENT: | ||
endif | ||
ENV=dev | ||
|
||
help: | ||
echo "make build | build the SAM project" | ||
echo "make serve | run the SAM server locally" | ||
echo "make clean | remove all installed dependencies and build artifacts" | ||
echo "make deps | install all dependencies" | ||
echo "make link | create hard links to allow for hot reloading of a built project" | ||
echo "make secrets | symlink secrets files from ../tfvars" | ||
echo "make style | run all style checks" | ||
echo "make test | run all tests" | ||
echo "make cover | run all tests with coverage" | ||
echo "make env ENV=[env] | activate env.\$$ENV.json file (default: dev)" | ||
echo "make deps-node | install node dependencies" | ||
echo "make deps-python | install python dependencies" | ||
echo "make style-node | run node code style check" | ||
echo "make style-python | run python code style check" | ||
echo "make test-node | run node tests" | ||
echo "make test-python | run python tests" | ||
echo "make cover-node | run node tests with coverage" | ||
echo "make cover-python | run python tests with coverage" | ||
.aws-sam/build.toml: ./template.yaml node/package-lock.json node/src/package-lock.json python/requirements.txt python/src/requirements.txt | ||
sam build --cached --parallel | ||
deps-node: | ||
cd node && npm ci | ||
cover-node: | ||
cd node && npm run test:coverage | ||
style-node: | ||
cd node && npm run prettier | ||
test-node: | ||
cd node && npm run test | ||
deps-python: | ||
cd chat/src && pip install -r requirements.txt | ||
cover-python: | ||
cd chat/src && coverage run --include='src/**/*' -m unittest -v && coverage report | ||
style-python: | ||
cd chat && ruff check . | ||
test-python: | ||
cd chat && python -m unittest -v | ||
build: .aws-sam/build.toml | ||
link: build | ||
cd chat/src && for src in *.py **/*.py; do for target in $$(find ../../.aws-sam/build -maxdepth 1 -type d); do if [[ -f $$target/$$src ]]; then ln -f $$src $$target/$$src; fi; done; done | ||
cd node/src && for src in *.js *.json **/*.js **/*.json; do for target in $$(find ../../.aws-sam/build -maxdepth 1 -type d); do if [[ -f $$target/$$src ]]; then ln -f $$src $$target/$$src; fi; done; done | ||
serve: link | ||
sam local start-api --host 0.0.0.0 --log-file dc-api.log | ||
deps: deps-node deps-python | ||
style: style-node style-python | ||
test: test-node test-python | ||
cover: cover-node cover-python | ||
env: | ||
ln -fs ./env.${ENV}.json ./env.json | ||
secrets: | ||
ln -s ../tfvars/dc-api/* . | ||
clean: | ||
rm -rf .aws-sam node/node_modules node/src/node_modules python/**/__pycache__ python/.coverage python/.ruff_cache | ||
ifndef VERBOSE | ||
.SILENT: | ||
endif | ||
ENV=dev | ||
|
||
help: | ||
echo "make build | build the SAM project" | ||
echo "make serve | run the SAM server locally" | ||
echo "make clean | remove all installed dependencies and build artifacts" | ||
echo "make deps | install all dependencies" | ||
echo "make link | create hard links to allow for hot reloading of a built project" | ||
echo "make secrets | symlink secrets files from ../tfvars" | ||
echo "make style | run all style checks" | ||
echo "make test | run all tests" | ||
echo "make cover | run all tests with coverage" | ||
echo "make env ENV=[env] | activate env.\$$ENV.json file (default: dev)" | ||
echo "make deps-node | install node dependencies" | ||
echo "make deps-python | install python dependencies" | ||
echo "make style-node | run node code style check" | ||
echo "make style-python | run python code style check" | ||
echo "make test-node | run node tests" | ||
echo "make test-python | run python tests" | ||
echo "make cover-node | run node tests with coverage" | ||
echo "make cover-python | run python tests with coverage" | ||
.aws-sam/build.toml: ./template.yaml node/package-lock.json node/src/package-lock.json chat/dependencies/requirements.txt chat/src/requirements.txt | ||
sam build --cached --parallel | ||
deps-node: | ||
cd node && npm ci | ||
cover-node: | ||
cd node && npm run test:coverage | ||
style-node: | ||
cd node && npm run prettier | ||
test-node: | ||
cd node && npm run test | ||
deps-python: | ||
cd chat/src && pip install -r requirements.txt | ||
cover-python: deps-python | ||
cd chat && export SKIP_WEAVIATE_SETUP=True && coverage run -m unittest -v && coverage report | ||
style-python: deps-python | ||
cd chat && ruff check . | ||
test-python: deps-python | ||
cd chat && export SKIP_WEAVIATE_SETUP=True && PYTHONPATH=src:test && python -m unittest discover -v | ||
python-version: | ||
cd chat && python --version | ||
build: .aws-sam/build.toml | ||
link: build | ||
cd chat/src && for src in *.py **/*.py; do for target in $$(find ../../.aws-sam/build -maxdepth 1 -type d); do if [[ -f $$target/$$src ]]; then ln -f $$src $$target/$$src; fi; done; done | ||
cd node/src && for src in *.js *.json **/*.js **/*.json; do for target in $$(find ../../.aws-sam/build -maxdepth 1 -type d); do if [[ -f $$target/$$src ]]; then ln -f $$src $$target/$$src; fi; done; done | ||
serve: link | ||
sam local start-api --host 0.0.0.0 --log-file dc-api.log | ||
deps: deps-node deps-python | ||
style: style-node style-python | ||
test: test-node test-python | ||
cover: cover-node cover-python | ||
env: | ||
ln -fs ./env.${ENV}.json ./env.json | ||
secrets: | ||
ln -s ../tfvars/dc-api/* . | ||
clean: | ||
rm -rf .aws-sam node/node_modules node/src/node_modules python/**/__pycache__ python/.coverage python/.ruff_cache |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
import os | ||
import json | ||
|
||
from dataclasses import dataclass, field | ||
from langchain.chains.qa_with_sources import load_qa_with_sources_chain | ||
from langchain.prompts import PromptTemplate | ||
from setup import ( | ||
weaviate_client, | ||
weaviate_vector_store, | ||
openai_chat_client, | ||
) | ||
from typing import List | ||
from handlers.streaming_socket_callback_handler import StreamingSocketCallbackHandler | ||
from helpers.apitoken import ApiToken | ||
from helpers.prompts import document_template, prompt_template | ||
from websocket import Websocket | ||
|
||
|
||
CHAIN_TYPE = "stuff" | ||
DOCUMENT_VARIABLE_NAME = "context" | ||
INDEX_NAME = "DCWork" | ||
K_VALUE = 10 | ||
TEXT_KEY = "title" | ||
MAX_K = 100 | ||
TEMPERATURE = 0.2 | ||
VERSION = "2023-07-01-preview" | ||
|
||
|
||
@dataclass | ||
class EventConfig: | ||
""" | ||
The EventConfig class represents the configuration for an event. | ||
Default values are set for the following properties which can be overridden in the payload message. | ||
""" | ||
|
||
api_token: ApiToken = field(init=False) | ||
attributes: List[str] = field(init=False) | ||
azure_endpoint: str = field(init=False) | ||
debug_mode: bool = field(init=False) | ||
deployment_name: str = field(init=False) | ||
document_prompt: PromptTemplate = field(init=False) | ||
event: dict = field(default_factory=dict) | ||
index_name: str = field(init=False) | ||
is_logged_in: bool = field(init=False) | ||
k: int = field(init=False) | ||
openai_api_version: str = field(init=False) | ||
payload: dict = field(default_factory=dict) | ||
prompt_text: str = field(init=False) | ||
prompt: PromptTemplate = field(init=False) | ||
question: str = field(init=False) | ||
ref: str = field(init=False) | ||
request_context: dict = field(init=False) | ||
temperature: float = field(init=False) | ||
socket: Websocket = field(init=False, default=None) | ||
text_key: str = field(init=False) | ||
|
||
def __post_init__(self): | ||
self.payload = json.loads(self.event.get("body", "{}")) | ||
|
||
self.api_token = ApiToken(signed_token=self.payload.get("auth")) | ||
self.azure_resource_name = self.payload.get( | ||
"azure_resource_name", os.environ.get("AZURE_OPENAI_RESOURCE_NAME") | ||
) | ||
if not self.azure_resource_name: | ||
raise EnvironmentError( | ||
"Either payload must contain 'azure_resource_name' or environment variable 'AZURE_OPENAI_RESOURCE_NAME' must be set" | ||
) | ||
|
||
self.azure_endpoint = f"https://{self.azure_resource_name}.openai.azure.com/" | ||
self.debug_mode = self._is_debug_mode_enabled() | ||
self.deployment_name = self.payload.get( | ||
"deployment_name", os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT_ID") | ||
) | ||
self.index_name = self.payload.get("index", self.payload.get("index", INDEX_NAME)) | ||
self.is_logged_in = self.api_token.is_logged_in() | ||
self.k = min(self.payload.get("k", K_VALUE), MAX_K) | ||
self.openai_api_version = self.payload.get("openai_api_version", VERSION) | ||
self.prompt_text = ( | ||
self.payload.get("prompt", prompt_template()) | ||
if self.api_token.is_superuser() | ||
else prompt_template() | ||
) | ||
self.request_context = self.event.get("requestContext", {}) | ||
self.question = self.payload.get("question") | ||
self.ref = self.payload.get("ref") | ||
self.temperature = self.payload.get("temperature", TEMPERATURE) | ||
self.text_key = self.payload.get("text_key", TEXT_KEY) | ||
|
||
self.attributes = [ | ||
item | ||
for item in self._get_request_attributes() | ||
if item not in [self.text_key, "source"] | ||
] | ||
self.document_prompt = PromptTemplate( | ||
template=document_template(self.attributes), | ||
input_variables=["page_content", "source"] + self.attributes, | ||
) | ||
self.prompt = PromptTemplate( | ||
template=self.prompt_text, input_variables=["question", "context"] | ||
) | ||
|
||
def debug_message(self): | ||
return { | ||
"type": "debug", | ||
"message": { | ||
"azure_endpoint": self.azure_endpoint, | ||
"deployment_name": self.deployment_name, | ||
"index": self.index_name, | ||
"k": self.k, | ||
"openai_api_version": self.openai_api_version, | ||
"prompt": self.prompt_text, | ||
"question": self.question, | ||
"ref": self.ref, | ||
"temperature": self.temperature, | ||
"text_key": self.text_key, | ||
}, | ||
} | ||
|
||
def setup_websocket(self): | ||
connection_id = self.request_context.get("connectionId") | ||
endpoint_url = f'https://{self.request_context.get("domainName")}/{self.request_context.get("stage")}' | ||
self.socket = Websocket(endpoint_url=endpoint_url, connection_id=connection_id, ref=self.ref) | ||
|
||
def setup_llm_request(self): | ||
self._setup_vector_store() | ||
self._setup_chat_client() | ||
self._setup_chain() | ||
|
||
def _setup_vector_store(self): | ||
self.weaviate = weaviate_vector_store( | ||
index_name=self.index_name, | ||
text_key=self.text_key, | ||
attributes=self.attributes + ["source"], | ||
) | ||
|
||
def _setup_chat_client(self): | ||
self.client = openai_chat_client( | ||
deployment_name=self.deployment_name, | ||
openai_api_base=self.azure_endpoint, | ||
openai_api_version=self.openai_api_version, | ||
callbacks=[StreamingSocketCallbackHandler(self.socket, self.debug_mode)], | ||
streaming=True, | ||
) | ||
|
||
def _setup_chain(self): | ||
self.chain = load_qa_with_sources_chain( | ||
self.client, | ||
chain_type=CHAIN_TYPE, | ||
prompt=self.prompt, | ||
document_prompt=self.document_prompt, | ||
document_variable_name=DOCUMENT_VARIABLE_NAME, | ||
verbose=self._to_bool(os.getenv("VERBOSE")), | ||
) | ||
|
||
def _is_debug_mode_enabled(self): | ||
debug = self.payload.get("debug", False) | ||
return debug and self.api_token.is_superuser() | ||
|
||
def _get_request_attributes(self): | ||
request_attributes = self.payload.get("attributes", None) | ||
if request_attributes is not None: | ||
return request_attributes | ||
|
||
if os.getenv("SKIP_WEAVIATE_SETUP"): | ||
return [] | ||
|
||
client = weaviate_client() | ||
schema = client.schema.get(self.index_name) | ||
names = [prop["name"] for prop in schema.get("properties")] | ||
return names | ||
|
||
def _to_bool(self, val): | ||
"""Converts a value to boolean. If the value is a string, it considers | ||
"", "no", "false", "0" as False. Otherwise, it returns the boolean of the value. | ||
""" | ||
if isinstance(val, str): | ||
return val.lower() not in ["", "no", "false", "0"] | ||
return bool(val) |
Empty file.
Oops, something went wrong.