Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backmerge main into staging #211

Merged
merged 6 commits into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion chat/dependencies/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
boto3~=1.34.13
langchain~=0.1.8
langchain
langchain-community
openai~=0.27.8
opensearch-py
Expand Down
46 changes: 43 additions & 3 deletions chat/src/handlers/chat.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
import boto3
import json
import os
import sys
import traceback
from datetime import datetime
from event_config import EventConfig
from helpers.response import prepare_response

def handler(event, _context):
RESPONSE_TYPES = {
"base": ["answer", "ref"],
"debug": ["answer", "attributes", "azure_endpoint", "deployment_name", "is_superuser", "k", "openai_api_version", "prompt", "question", "ref", "temperature", "text_key", "token_counts"],
"log": ["answer", "is_superuser", "k", "openai_api_version", "prompt", "question", "ref", "temperature", "token_counts"]
}

def handler(event, context):
try:
config = EventConfig(event)
socket = event.get('socket', None)
Expand All @@ -14,17 +23,48 @@ def handler(event, _context):
config.socket.send({"type": "error", "message": "Unauthorized"})
return {"statusCode": 401, "body": "Unauthorized"}

debug_message = config.debug_message()
if config.debug_mode:
config.socket.send(config.debug_message())
config.socket.send(debug_message)

if not os.getenv("SKIP_WEAVIATE_SETUP"):
config.setup_llm_request()
final_response = prepare_response(config)
config.socket.send(final_response)
config.socket.send(reshape_response(final_response, 'debug' if config.debug_mode else 'base'))

log_group = os.getenv('METRICS_LOG_GROUP')
log_stream = context.log_stream_name
if log_group and ensure_log_stream_exists(log_group, log_stream):
log_client = boto3.client('logs')
log_message = reshape_response(final_response, 'log')
log_events = [
{
'timestamp': timestamp(),
'message': json.dumps(log_message)
}
]
log_client.put_log_events(logGroupName=log_group, logStreamName=log_stream, logEvents=log_events)
return {"statusCode": 200}

except Exception:
exc_info = sys.exc_info()
err_text = ''.join(traceback.format_exception(*exc_info))
print(err_text)
return {"statusCode": 500, "body": f'Unhandled error:\n{err_text}'}

def reshape_response(response, type):
return {k: response[k] for k in RESPONSE_TYPES[type]}

def ensure_log_stream_exists(log_group, log_stream):
log_client = boto3.client('logs')
try:
log_client.create_log_stream(logGroupName=log_group, logStreamName=log_stream)
return True
except log_client.exceptions.ResourceAlreadyExistsException:
return True
except Exception:
print(f'Could not create log stream: {log_group}:{log_stream}')
return False

def timestamp():
return round(datetime.timestamp(datetime.now()) * 1000)
4 changes: 3 additions & 1 deletion chat/src/helpers/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@


def token_usage(config, response, original_question):
return {
data = {
"question": count_tokens(config.question),
"answer": count_tokens(response["output_text"]),
"prompt": count_tokens(config.prompt_text),
"source_documents": count_tokens(original_question["source_documents"]),
}
data["total"] = sum(data.values())
return data


def count_tokens(val):
Expand Down
16 changes: 4 additions & 12 deletions chat/src/helpers/response.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,22 @@
from helpers.metrics import token_usage
from openai.error import InvalidRequestError

def base_response(config, response):
return {"answer": response["output_text"], "ref": config.ref}


def debug_response(config, response, original_question):
response_base = base_response(config, response)
debug_info = {
return {
"answer": response["output_text"],
"attributes": config.attributes,
"azure_endpoint": config.azure_endpoint,
"deployment_name": config.deployment_name,
"is_superuser": config.api_token.is_superuser(),
"k": config.k,
"openai_api_version": config.openai_api_version,
"prompt": config.prompt_text,
"question": config.question,
"ref": config.ref,
"temperature": config.temperature,
"text_key": config.text_key,
"token_counts": token_usage(config, response, original_question),
}
return {**response_base, **debug_info}


def get_and_send_original_question(config, docs):
doc_response = []
Expand Down Expand Up @@ -63,10 +58,7 @@ def prepare_response(config):
original_question = get_and_send_original_question(config, docs)
response = config.chain({"question": config.question, "input_documents": docs})

if config.debug_mode:
prepared_response = debug_response(config, response, original_question)
else:
prepared_response = base_response(config, response)
prepared_response = debug_response(config, response, original_question)
except InvalidRequestError as err:
prepared_response = {
"question": config.question,
Expand Down
2 changes: 1 addition & 1 deletion chat/src/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Runtime Dependencies
boto3~=1.34.13
langchain~=0.1.8
langchain
langchain-community
openai~=0.27.8
opensearch-py
Expand Down
13 changes: 13 additions & 0 deletions chat/template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ Resources:
AZURE_OPENAI_LLM_DEPLOYMENT_ID: !Ref AzureOpenaiLlmDeploymentId
AZURE_OPENAI_RESOURCE_NAME: !Ref AzureOpenaiResourceName
ENV_PREFIX: !Ref EnvironmentPrefix
METRICS_LOG_GROUP: !Ref ChatMetricsLog
OPENSEARCH_ENDPOINT: !Ref OpenSearchEndpoint
OPENSEARCH_MODEL_ID: !Ref OpenSearchModelId
Policies:
Expand All @@ -218,8 +219,20 @@ Resources:
- 'es:ESHttpGet'
- 'es:ESHttpPost'
Resource: '*'
- Statement:
- Effect: Allow
Action:
- logs:CreateLogStream
- logs:PutLogEvents
Resource: !Sub "${ChatMetricsLog.Arn}:*"
Metadata:
BuildMethod: nodejs18.x
ChatMetricsLog:
Type: AWS::Logs::LogGroup
Properties:
LogGroupName: !Sub "/nul/${AWS::StackName}/ChatFunction-Metrics"
LogGroupClass: STANDARD
RetentionInDays: 90
Deployment:
Type: AWS::ApiGatewayV2::Deployment
DependsOn:
Expand Down
11 changes: 7 additions & 4 deletions chat/test/handlers/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ def post_to_connection(self, Data, ConnectionId):
self.received_data = Data
return Data

class MockContext:
def __init__(self):
self.log_stream_name = 'test'

@mock.patch.dict(
os.environ,
Expand All @@ -33,13 +36,13 @@ def post_to_connection(self, Data, ConnectionId):
class TestHandler(TestCase):
def test_handler_unauthorized(self):
event = {"socket": Websocket(client=MockClient(), endpoint_url="test", connection_id="test", ref="test")}
self.assertEqual(handler(event, {}), {'body': 'Unauthorized', 'statusCode': 401})
self.assertEqual(handler(event, MockContext()), {'body': 'Unauthorized', 'statusCode': 401})

@patch.object(ApiToken, 'is_logged_in')
def test_handler_success(self, mock_is_logged_in):
mock_is_logged_in.return_value = True
event = {"socket": Websocket(client=MockClient(), endpoint_url="test", connection_id="test", ref="test")}
self.assertEqual(handler(event, {}), {'statusCode': 200})
self.assertEqual(handler(event, MockContext()), {'statusCode': 200})

@patch.object(ApiToken, 'is_logged_in')
@patch.object(ApiToken, 'is_superuser')
Expand All @@ -51,7 +54,7 @@ def test_handler_debug_mode(self, mock_is_debug_enabled, mock_is_logged_in, mock
mock_client = MockClient()
mock_websocket = Websocket(client=mock_client, endpoint_url="test", connection_id="test", ref="test")
event = {"socket": mock_websocket, "debug": True}
handler(event, {})
handler(event, MockContext())
response = json.loads(mock_client.received_data)
self.assertEqual(response["type"], "debug")

Expand All @@ -65,7 +68,7 @@ def test_handler_debug_mode_for_superusers_only(self, mock_is_debug_enabled, moc
mock_client = MockClient()
mock_websocket = Websocket(client=mock_client, endpoint_url="test", connection_id="test", ref="test")
event = {"socket": mock_websocket, "debug": True}
handler(event, {})
handler(event, MockContext())
response = json.loads(mock_client.received_data)
self.assertEqual(response["type"], "error")

Expand Down
1 change: 1 addition & 0 deletions chat/test/helpers/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def test_token_usage(self):
"prompt": 302,
"question": 15,
"source_documents": 1,
"total": 350
}

self.assertEqual(result, expected_result)
Expand Down
Loading