-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Big configuration refactor, adds overrides
- Loading branch information
Showing
9 changed files
with
286 additions
and
247 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
import os | ||
import json | ||
|
||
from handlers.streaming_socket_callback_handler import StreamingSocketCallbackHandler | ||
from helpers.apitoken import ApiToken | ||
from helpers.prompts import document_template, prompt_template | ||
from helpers.utils import to_bool | ||
from langchain.chains.qa_with_sources import load_qa_with_sources_chain | ||
from langchain.prompts import PromptTemplate | ||
from setup import ( | ||
weaviate_client, | ||
weaviate_vector_store, | ||
openai_chat_client, | ||
) | ||
from websocket import Websocket | ||
|
||
CHAIN_TYPE = "stuff" | ||
DOCUMENT_VARIABLE_NAME = "context" | ||
INDEX_NAME = "DCWork" | ||
K_VALUE = 10 | ||
TEXT_KEY = "title" | ||
MAX_K = 100 | ||
TEMPERATURE = 0.2 | ||
VERSION = "2023-07-01-preview" | ||
|
||
|
||
class EventConfig: | ||
""" | ||
The EventConfig class represents the configuration for an event. Default values are set for the following properties which can be overridden in the payload message: | ||
- attributes: Document attributes sent to the LLM. | ||
- auth: Authentication token for the connection. | ||
- azure_endpoint: Full URI for the Azure OpenAI endpoint. | ||
- debug: Debug mode status (requires a superuser token). | ||
- deployment_name: Name of the Azure AI deployment. | ||
- index_name: Name of the vector database index. | ||
- k: The number of documents retreived in vector database searches. | ||
- message: Type of socket communication. | ||
- openai_api_version: Version of the Azure AI model. | ||
- prompt_text*: System prompt (the string must contain both page_content and source: "{page_content} {source}"). | ||
- question: User prompt typically sent via frontend input. | ||
- ref: Reference for uniquely identifying the request. | ||
- text_key: Attribute used to name each document. | ||
* requires debug mode to be enabled | ||
""" | ||
|
||
def __init__(self, event): | ||
self.payload = json.loads(event.get("body", "{}")) | ||
self.api_token = ApiToken(signed_token=self.payload.get("auth")) | ||
self.azure_endpoint = self.payload.get( | ||
"azure_endpoint", | ||
f"https://{os.getenv('AZURE_OPENAI_RESOURCE_NAME')}.openai.azure.com/", | ||
) | ||
self.debug_mode = self._is_debug_mode_enabled() | ||
self.deployment_name = self.payload.get( | ||
"deployment_name", os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT_ID") | ||
) | ||
self.index_name = self.payload.get( | ||
"index", self.payload.get("index", INDEX_NAME) | ||
) | ||
self.is_logged_in = self.api_token.is_logged_in() | ||
self.k = min(self.payload.get("k", K_VALUE), MAX_K) | ||
self.openai_api_version = self.payload.get("openai_api_version", VERSION) | ||
self.prompt_text = ( | ||
self.payload.get("prompt", prompt_template()) | ||
if self.api_token.is_superuser() | ||
else prompt_template() | ||
) | ||
self.request_context = event.get("requestContext", {}) | ||
self.question = self.payload.get("question") | ||
self.ref = self.payload.get("ref") | ||
self.temperature = self.payload.get("temperature", TEMPERATURE) | ||
self.text_key = self.payload.get("text_key", TEXT_KEY) | ||
|
||
self.attributes = [ | ||
item | ||
for item in self._get_request_attributes() | ||
if item not in [self.text_key, "source"] | ||
] | ||
self.document_prompt = PromptTemplate( | ||
template=document_template(self.attributes), | ||
input_variables=["page_content", "source"] + self.attributes, | ||
) | ||
self.prompt = PromptTemplate( | ||
template=self.prompt_text, input_variables=["question", "context"] | ||
) | ||
|
||
def debug_message(self): | ||
return { | ||
"type": "debug", | ||
"message": { | ||
"azure_endpoint": self.azure_endpoint, | ||
"deployment_name": self.deployment_name, | ||
"index": self.index_name, | ||
"k": self.k, | ||
"openai_api_version": self.openai_api_version, | ||
"prompt": self.prompt_text, | ||
"question": self.question, | ||
"ref": self.ref, | ||
"temperature": self.temperature, | ||
"text_key": self.text_key, | ||
}, | ||
} | ||
|
||
def setup_websocket(self): | ||
connection_id = self.request_context.get("connectionId") | ||
endpoint_url = f'https://{self.request_context.get("domainName")}/{self.request_context.get("stage")}' | ||
self.socket = Websocket( | ||
connection_id=connection_id, endpoint_url=endpoint_url, ref=self.ref | ||
) | ||
|
||
def setup_llm_request(self): | ||
self._setup_vector_store() | ||
self._setup_chat_client() | ||
self._setup_chain() | ||
|
||
def _setup_vector_store(self): | ||
self.weaviate = weaviate_vector_store( | ||
index_name=self.index_name, | ||
text_key=self.text_key, | ||
attributes=self.attributes + ["source"], | ||
) | ||
|
||
def _setup_chat_client(self): | ||
self.client = openai_chat_client( | ||
deployment_name=self.deployment_name, | ||
openai_api_base=self.azure_endpoint, | ||
openai_api_version=self.openai_api_version, | ||
callbacks=[StreamingSocketCallbackHandler(self.socket, self.debug_mode)], | ||
streaming=True, | ||
) | ||
|
||
def _setup_chain(self): | ||
self.chain = load_qa_with_sources_chain( | ||
self.client, | ||
chain_type=CHAIN_TYPE, | ||
prompt=self.prompt, | ||
document_prompt=self.document_prompt, | ||
document_variable_name=DOCUMENT_VARIABLE_NAME, | ||
verbose=self._to_bool(os.getenv("VERBOSE")), | ||
) | ||
|
||
def _is_debug_mode_enabled(self): | ||
debug = self.payload.get("debug", False) | ||
return debug and self.api_token.is_superuser() | ||
|
||
def _get_request_attributes(self): | ||
request_attributes = self.payload.get("attributes", None) | ||
if request_attributes is not None: | ||
return request_attributes | ||
|
||
client = weaviate_client() | ||
schema = client.schema.get(self.index_name) | ||
names = [prop["name"] for prop in schema.get("properties")] | ||
return names | ||
|
||
def _to_bool(self, val): | ||
"""Converts a value to boolean. If the value is a string, it considers | ||
"", "no", "false", "0" as False. Otherwise, it returns the boolean of the value. | ||
""" | ||
if isinstance(val, str): | ||
return val.lower() not in ["", "no", "false", "0"] | ||
return bool(val) |
Oops, something went wrong.