Skip to content

Commit

Permalink
Send chat handler metrics to a dedicated CloudWatch Logs group
Browse files Browse the repository at this point in the history
  • Loading branch information
mbklein committed Apr 10, 2024
1 parent 64630b1 commit e07637c
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 20 deletions.
46 changes: 43 additions & 3 deletions chat/src/handlers/chat.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
import boto3
import json
import os
import sys
import traceback
from datetime import datetime
from event_config import EventConfig
from helpers.response import prepare_response

def handler(event, _context):
RESPONSE_TYPES = {
"base": ["answer", "ref"],
"debug": ["answer", "attributes", "azure_endpoint", "deployment_name", "is_superuser", "k", "openai_api_version", "prompt", "question", "ref", "temperature", "text_key", "token_counts"],
"log": ["answer", "is_superuser", "k", "openai_api_version", "prompt", "question", "ref", "temperature", "token_counts"]
}

def handler(event, context):
try:
config = EventConfig(event)
socket = event.get('socket', None)
Expand All @@ -14,17 +23,48 @@ def handler(event, _context):
config.socket.send({"type": "error", "message": "Unauthorized"})
return {"statusCode": 401, "body": "Unauthorized"}

debug_message = config.debug_message()
if config.debug_mode:
config.socket.send(config.debug_message())
config.socket.send(debug_message)

if not os.getenv("SKIP_WEAVIATE_SETUP"):
config.setup_llm_request()
final_response = prepare_response(config)
config.socket.send(final_response)
config.socket.send(reshape_response(final_response, 'debug' if config.debug_mode else 'base'))

log_group = os.getenv('METRICS_LOG_GROUP')
log_stream = context.log_stream_name
if log_group and ensure_log_stream_exists(log_group, log_stream):
log_client = boto3.client('logs')
log_message = reshape_response(final_response, 'log')
log_events = [
{
'timestamp': timestamp(),
'message': json.dumps(log_message)
}
]
log_client.put_log_events(logGroupName=log_group, logStreamName=log_stream, logEvents=log_events)
return {"statusCode": 200}

except Exception:
exc_info = sys.exc_info()
err_text = ''.join(traceback.format_exception(*exc_info))
print(err_text)
return {"statusCode": 500, "body": f'Unhandled error:\n{err_text}'}

def reshape_response(response, type):
return {k: response[k] for k in RESPONSE_TYPES[type]}

def ensure_log_stream_exists(log_group, log_stream):
log_client = boto3.client('logs')
try:
log_client.create_log_stream(logGroupName=log_group, logStreamName=log_stream)
return True
except log_client.exceptions.ResourceAlreadyExistsException:
return True
except Exception:
print(f'Could not create log stream: {log_group}:{log_stream}')
return False

def timestamp():
return round(datetime.timestamp(datetime.now()) * 1000)
4 changes: 3 additions & 1 deletion chat/src/helpers/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@


def token_usage(config, response, original_question):
return {
data = {
"question": count_tokens(config.question),
"answer": count_tokens(response["output_text"]),
"prompt": count_tokens(config.prompt_text),
"source_documents": count_tokens(original_question["source_documents"]),
}
data["total"] = sum(data.values())
return data


def count_tokens(val):
Expand Down
16 changes: 4 additions & 12 deletions chat/src/helpers/response.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,22 @@
from helpers.metrics import token_usage
from openai.error import InvalidRequestError

def base_response(config, response):
return {"answer": response["output_text"], "ref": config.ref}


def debug_response(config, response, original_question):
response_base = base_response(config, response)
debug_info = {
return {
"answer": response["output_text"],
"attributes": config.attributes,
"azure_endpoint": config.azure_endpoint,
"deployment_name": config.deployment_name,
"is_superuser": config.api_token.is_superuser(),
"k": config.k,
"openai_api_version": config.openai_api_version,
"prompt": config.prompt_text,
"question": config.question,
"ref": config.ref,
"temperature": config.temperature,
"text_key": config.text_key,
"token_counts": token_usage(config, response, original_question),
}
return {**response_base, **debug_info}


def get_and_send_original_question(config, docs):
doc_response = []
Expand Down Expand Up @@ -55,10 +50,7 @@ def prepare_response(config):
original_question = get_and_send_original_question(config, docs)
response = config.chain({"question": config.question, "input_documents": docs})

if config.debug_mode:
prepared_response = debug_response(config, response, original_question)
else:
prepared_response = base_response(config, response)
prepared_response = debug_response(config, response, original_question)
except InvalidRequestError as err:
prepared_response = {
"question": config.question,
Expand Down
13 changes: 13 additions & 0 deletions chat/template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ Resources:
AZURE_OPENAI_LLM_DEPLOYMENT_ID: !Ref AzureOpenaiLlmDeploymentId
AZURE_OPENAI_RESOURCE_NAME: !Ref AzureOpenaiResourceName
ENV_PREFIX: !Ref EnvironmentPrefix
METRICS_LOG_GROUP: !Ref ChatMetricsLog
OPENSEARCH_ENDPOINT: !Ref OpenSearchEndpoint
OPENSEARCH_MODEL_ID: !Ref OpenSearchModelId
Policies:
Expand All @@ -218,8 +219,20 @@ Resources:
- 'es:ESHttpGet'
- 'es:ESHttpPost'
Resource: '*'
- Statement:
- Effect: Allow
Action:
- logs:CreateLogStream
- logs:PutLogEvents
Resource: !Sub "${ChatMetricsLog.Arn}:*"
Metadata:
BuildMethod: nodejs18.x
ChatMetricsLog:
Type: AWS::Logs::LogGroup
Properties:
LogGroupName: !Sub "/nul/${AWS::StackName}/ChatFunction-Metrics"
LogGroupClass: STANDARD
RetentionInDays: 90
Deployment:
Type: AWS::ApiGatewayV2::Deployment
DependsOn:
Expand Down
11 changes: 7 additions & 4 deletions chat/test/handlers/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ def post_to_connection(self, Data, ConnectionId):
self.received_data = Data
return Data

class MockContext:
def __init__(self):
self.log_stream_name = 'test'

@mock.patch.dict(
os.environ,
Expand All @@ -33,13 +36,13 @@ def post_to_connection(self, Data, ConnectionId):
class TestHandler(TestCase):
def test_handler_unauthorized(self):
event = {"socket": Websocket(client=MockClient(), endpoint_url="test", connection_id="test", ref="test")}
self.assertEqual(handler(event, {}), {'body': 'Unauthorized', 'statusCode': 401})
self.assertEqual(handler(event, MockContext()), {'body': 'Unauthorized', 'statusCode': 401})

@patch.object(ApiToken, 'is_logged_in')
def test_handler_success(self, mock_is_logged_in):
mock_is_logged_in.return_value = True
event = {"socket": Websocket(client=MockClient(), endpoint_url="test", connection_id="test", ref="test")}
self.assertEqual(handler(event, {}), {'statusCode': 200})
self.assertEqual(handler(event, MockContext()), {'statusCode': 200})

@patch.object(ApiToken, 'is_logged_in')
@patch.object(ApiToken, 'is_superuser')
Expand All @@ -51,7 +54,7 @@ def test_handler_debug_mode(self, mock_is_debug_enabled, mock_is_logged_in, mock
mock_client = MockClient()
mock_websocket = Websocket(client=mock_client, endpoint_url="test", connection_id="test", ref="test")
event = {"socket": mock_websocket, "debug": True}
handler(event, {})
handler(event, MockContext())
response = json.loads(mock_client.received_data)
self.assertEqual(response["type"], "debug")

Expand All @@ -65,7 +68,7 @@ def test_handler_debug_mode_for_superusers_only(self, mock_is_debug_enabled, moc
mock_client = MockClient()
mock_websocket = Websocket(client=mock_client, endpoint_url="test", connection_id="test", ref="test")
event = {"socket": mock_websocket, "debug": True}
handler(event, {})
handler(event, MockContext())
response = json.loads(mock_client.received_data)
self.assertEqual(response["type"], "error")

Expand Down
1 change: 1 addition & 0 deletions chat/test/helpers/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def test_token_usage(self):
"prompt": 328,
"question": 15,
"source_documents": 1,
"total": 350
}

self.assertEqual(result, expected_result)
Expand Down

0 comments on commit e07637c

Please sign in to comment.