Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deploy to production #253

Merged
merged 3 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion chat/src/event_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
K_VALUE = 40
MAX_K = 100
MAX_TOKENS = 1000
SIZE = 5
SIZE = 20
TEMPERATURE = 0.2
TEXT_KEY = "id"
VERSION = "2024-02-01"
Expand Down
43 changes: 43 additions & 0 deletions chat/src/handlers/chat_sync.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@

import json
import logging
import os
from http_event_config import HTTPEventConfig
from helpers.http_response import HTTPResponse
from honeybadger import honeybadger

honeybadger.configure()
logging.getLogger('honeybadger').addHandler(logging.StreamHandler())

RESPONSE_TYPES = {
"base": ["answer", "ref"],
"debug": ["answer", "attributes", "azure_endpoint", "deployment_name", "is_superuser", "k", "openai_api_version", "prompt", "question", "ref", "temperature", "text_key", "token_counts"],
"log": ["answer", "deployment_name", "is_superuser", "k", "openai_api_version", "prompt", "question", "ref", "size", "source_documents", "temperature", "token_counts"],
"error": ["question", "error", "source_documents"]
}

def handler(event, context):
print(f'Event: {event}')

config = HTTPEventConfig(event)

if not config.is_logged_in:
return {"statusCode": 401, "body": "Unauthorized"}

if config.question is None or config.question == "":
return {"statusCode": 400, "body": "Question cannot be blank"}

if not os.getenv("SKIP_WEAVIATE_SETUP"):
config.setup_llm_request()
response = HTTPResponse(config)
final_response = response.prepare_response()
if "error" in final_response:
logging.error(f'Error: {final_response["error"]}')
return {"statusCode": 500, "body": "Internal Server Error"}
else:
return {"statusCode": 200, "body": json.dumps(reshape_response(final_response, 'debug' if config.debug_mode else 'base'))}

return {"statusCode": 200}

def reshape_response(response, type):
return {k: response[k] for k in RESPONSE_TYPES[type]}
60 changes: 60 additions & 0 deletions chat/src/helpers/http_response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from helpers.metrics import debug_response
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableLambda, RunnablePassthrough

def extract_prompt_value(v):
if isinstance(v, list):
return [extract_prompt_value(item) for item in v]
elif isinstance(v, dict) and 'label' in v:
return [v.get('label')]
else:
return v

class HTTPResponse:
def __init__(self, config):
self.config = config
self.store = {}

def debug_response_passthrough(self):
return RunnableLambda(lambda x: debug_response(self.config, x, self.original_question))

def original_question_passthrough(self):
def get_and_send_original_question(docs):
source_documents = []
for doc in docs["context"]:
doc.metadata = {key: extract_prompt_value(doc.metadata.get(key)) for key in self.config.attributes if key in doc.metadata}
source_document = doc.metadata.copy()
source_document["content"] = doc.page_content
source_documents.append(source_document)

original_question = {
"question": self.config.question,
"source_documents": source_documents,
}

self.original_question = original_question
return docs

return RunnablePassthrough(get_and_send_original_question)

def prepare_response(self):
try:
retriever = self.config.opensearch.as_retriever(search_type="similarity", search_kwargs={"k": self.config.k, "size": self.config.size, "_source": {"excludes": ["embedding"]}})
chain = (
{"context": retriever, "question": RunnablePassthrough()}
| self.original_question_passthrough()
| self.config.prompt
| self.config.client
| StrOutputParser()
| self.debug_response_passthrough()
)
response = chain.invoke(self.config.question)
except Exception as err:
response = {
"question": self.config.question,
"error": str(err),
"source_documents": [],
}
return response


4 changes: 2 additions & 2 deletions chat/src/helpers/hybrid_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ def filter(query: dict):
}
}

def hybrid_query(query: str, model_id: str, vector_field: str = "embedding", k: int = 10, **kwargs: Any):
def hybrid_query(query: str, model_id: str, vector_field: str = "embedding", k: int = 40, **kwargs: Any):
result = {
"size": kwargs.get("size", 5),
"size": kwargs.get("size", 20),
"query": {
"hybrid": {
"queries": [
Expand Down
13 changes: 10 additions & 3 deletions chat/src/helpers/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,20 @@ def get_and_send_original_question(docs):
source_document = doc.metadata.copy()
source_document["content"] = doc.page_content
source_documents.append(source_document)


socket_message = {
"question": self.config.question,
"source_documents": source_documents[:5]
}
self.config.socket.send(socket_message)

original_question = {
"question": self.config.question,
"source_documents": source_documents,
"source_documents": source_documents
}
self.config.socket.send(original_question)
self.original_question = original_question

docs["source_documents"] = source_documents
return docs

return RunnablePassthrough(get_and_send_original_question)
Expand Down
188 changes: 188 additions & 0 deletions chat/src/http_event_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
import os
import json

from dataclasses import dataclass, field

from langchain_core.prompts import ChatPromptTemplate
from setup import (
opensearch_client,
opensearch_vector_store,
openai_chat_client,
)
from typing import List
from helpers.apitoken import ApiToken
from helpers.prompts import document_template, prompt_template

CHAIN_TYPE = "stuff"
DOCUMENT_VARIABLE_NAME = "context"
K_VALUE = 40
MAX_K = 100
MAX_TOKENS = 1000
SIZE = 5
TEMPERATURE = 0.2
TEXT_KEY = "id"
VERSION = "2024-02-01"

@dataclass
class HTTPEventConfig:
"""
The EventConfig class represents the configuration for an event.
Default values are set for the following properties which can be overridden in the payload message.
"""

DEFAULT_ATTRIBUTES = ["accession_number", "alternate_title", "api_link", "canonical_link", "caption", "collection",
"contributor", "date_created", "date_created_edtf", "description", "genre", "id", "identifier",
"keywords", "language", "notes", "physical_description_material", "physical_description_size",
"provenance", "publisher", "rights_statement", "subject", "table_of_contents", "thumbnail",
"title", "visibility", "work_type"]

api_token: ApiToken = field(init=False)
attributes: List[str] = field(init=False)
azure_endpoint: str = field(init=False)
azure_resource_name: str = field(init=False)
debug_mode: bool = field(init=False)
deployment_name: str = field(init=False)
document_prompt: ChatPromptTemplate = field(init=False)
event: dict = field(default_factory=dict)
is_logged_in: bool = field(init=False)
k: int = field(init=False)
max_tokens: int = field(init=False)
openai_api_version: str = field(init=False)
payload: dict = field(default_factory=dict)
prompt_text: str = field(init=False)
prompt: ChatPromptTemplate = field(init=False)
question: str = field(init=False)
ref: str = field(init=False)
request_context: dict = field(init=False)
temperature: float = field(init=False)
size: int = field(init=False)
stream_response: bool = field(init=False)
text_key: str = field(init=False)

def __post_init__(self):
self.payload = json.loads(self.event.get("body", "{}"))
self.api_token = ApiToken(signed_token=self.payload.get("auth"))
self.attributes = self._get_attributes()
self.azure_endpoint = self._get_azure_endpoint()
self.azure_resource_name = self._get_azure_resource_name()
self.debug_mode = self._is_debug_mode_enabled()
self.deployment_name = self._get_deployment_name()
self.is_logged_in = self.api_token.is_logged_in()
self.k = self._get_k()
self.max_tokens = min(self.payload.get("max_tokens", MAX_TOKENS), MAX_TOKENS)
self.openai_api_version = self._get_openai_api_version()
self.prompt_text = self._get_prompt_text()
self.request_context = self.event.get("requestContext", {})
self.question = self.payload.get("question")
self.ref = self.payload.get("ref")
self.size = self._get_size()
self.stream_response = self.payload.get("stream_response", not self.debug_mode)
self.temperature = self._get_temperature()
self.text_key = self._get_text_key()
self.document_prompt = self._get_document_prompt()
self.prompt = ChatPromptTemplate.from_template(self.prompt_text)

def _get_payload_value_with_superuser_check(self, key, default):
if self.api_token.is_superuser():
return self.payload.get(key, default)
else:
return default

def _get_attributes_function(self):
try:
opensearch = opensearch_client()
mapping = opensearch.indices.get_mapping(index="dc-v2-work")
return list(next(iter(mapping.values()))['mappings']['properties'].keys())
except StopIteration:
return []

def _get_attributes(self):
return self._get_payload_value_with_superuser_check("attributes", self.DEFAULT_ATTRIBUTES)

def _get_azure_endpoint(self):
default = f"https://{self._get_azure_resource_name()}.openai.azure.com/"
return self._get_payload_value_with_superuser_check("azure_endpoint", default)

def _get_azure_resource_name(self):
azure_resource_name = self._get_payload_value_with_superuser_check(
"azure_resource_name", os.environ.get("AZURE_OPENAI_RESOURCE_NAME")
)
if not azure_resource_name:
raise EnvironmentError(
"Either payload must contain 'azure_resource_name' or environment variable 'AZURE_OPENAI_RESOURCE_NAME' must be set"
)
return azure_resource_name

def _get_deployment_name(self):
return self._get_payload_value_with_superuser_check(
"deployment_name", os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT_ID")
)

def _get_k(self):
value = self._get_payload_value_with_superuser_check("k", K_VALUE)
return min(value, MAX_K)

def _get_openai_api_version(self):
return self._get_payload_value_with_superuser_check(
"openai_api_version", VERSION
)

def _get_prompt_text(self):
return self._get_payload_value_with_superuser_check("prompt", prompt_template())

def _get_size(self):
return self._get_payload_value_with_superuser_check("size", SIZE)

def _get_temperature(self):
return self._get_payload_value_with_superuser_check("temperature", TEMPERATURE)

def _get_text_key(self):
return self._get_payload_value_with_superuser_check("text_key", TEXT_KEY)

def _get_document_prompt(self):
return ChatPromptTemplate.from_template(document_template(self.attributes))

def debug_message(self):
return {
"type": "debug",
"message": {
"attributes": self.attributes,
"azure_endpoint": self.azure_endpoint,
"deployment_name": self.deployment_name,
"k": self.k,
"openai_api_version": self.openai_api_version,
"prompt": self.prompt_text,
"question": self.question,
"ref": self.ref,
"size": self.ref,
"temperature": self.temperature,
"text_key": self.text_key,
},
}

def setup_llm_request(self):
self._setup_vector_store()
self._setup_chat_client()

def _setup_vector_store(self):
self.opensearch = opensearch_vector_store()

def _setup_chat_client(self):
self.client = openai_chat_client(
azure_deployment=self.deployment_name,
azure_endpoint=self.azure_endpoint,
openai_api_version=self.openai_api_version,
max_tokens=self.max_tokens
)

def _is_debug_mode_enabled(self):
debug = self.payload.get("debug", False)
return debug and self.api_token.is_superuser()

def _to_bool(self, val):
"""Converts a value to boolean. If the value is a string, it considers
"", "no", "false", "0" as False. Otherwise, it returns the boolean of the value.
"""
if isinstance(val, str):
return val.lower() not in ["", "no", "false", "0"]
return bool(val)
42 changes: 42 additions & 0 deletions chat/template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,48 @@ Resources:
Resource: !Sub "${ChatMetricsLog.Arn}:*"
#* Metadata:
#* BuildMethod: nodejs20.x
ChatSyncFunction:
Type: AWS::Serverless::Function
Properties:
CodeUri: ./src
Runtime: python3.10
Architectures:
- x86_64
#* Layers:
#* - !Ref ChatDependencies
MemorySize: 1024
Handler: handlers/chat_sync.handler
Timeout: 300
Environment:
Variables:
API_TOKEN_SECRET: !Ref ApiTokenSecret
AZURE_OPENAI_API_KEY: !Ref AzureOpenaiApiKey
AZURE_OPENAI_LLM_DEPLOYMENT_ID: !Ref AzureOpenaiLlmDeploymentId
AZURE_OPENAI_RESOURCE_NAME: !Ref AzureOpenaiResourceName
ENV_PREFIX: !Ref EnvironmentPrefix
HONEYBADGER_API_KEY: !Ref HoneybadgerApiKey
HONEYBADGER_ENVIRONMENT: !Ref HoneybadgerEnv
HONEYBADGER_REVISION: !Ref HoneybadgerRevision
METRICS_LOG_GROUP: !Ref ChatMetricsLog
OPENSEARCH_ENDPOINT: !Ref OpenSearchEndpoint
OPENSEARCH_MODEL_ID: !Ref OpenSearchModelId
FunctionUrlConfig:
AuthType: NONE
Policies:
- Statement:
- Effect: Allow
Action:
- 'es:ESHttpGet'
- 'es:ESHttpPost'
Resource: '*'
# - Statement:
# - Effect: Allow
# Action:
# - logs:CreateLogStream
# - logs:PutLogEvents
# Resource: !Sub "${ChatMetricsLog.Arn}:*"
#* Metadata:
#* BuildMethod: nodejs20.x
ChatMetricsLog:
Type: AWS::Logs::LogGroup
Properties:
Expand Down
Loading
Loading