diff --git a/chat/src/handlers/chat.py b/chat/src/handlers/chat.py index 8757b286..dc76b959 100644 --- a/chat/src/handlers/chat.py +++ b/chat/src/handlers/chat.py @@ -1,10 +1,19 @@ +import boto3 +import json import os import sys import traceback +from datetime import datetime from event_config import EventConfig from helpers.response import prepare_response -def handler(event, _context): +RESPONSE_TYPES = { + "base": ["answer", "ref"], + "debug": ["answer", "attributes", "azure_endpoint", "deployment_name", "is_superuser", "k", "openai_api_version", "prompt", "question", "ref", "temperature", "text_key", "token_counts"], + "log": ["answer", "is_superuser", "k", "openai_api_version", "prompt", "question", "ref", "temperature", "token_counts"] +} + +def handler(event, context): try: config = EventConfig(event) socket = event.get('socket', None) @@ -14,13 +23,27 @@ def handler(event, _context): config.socket.send({"type": "error", "message": "Unauthorized"}) return {"statusCode": 401, "body": "Unauthorized"} + debug_message = config.debug_message() if config.debug_mode: - config.socket.send(config.debug_message()) + config.socket.send(debug_message) if not os.getenv("SKIP_WEAVIATE_SETUP"): config.setup_llm_request() final_response = prepare_response(config) - config.socket.send(final_response) + config.socket.send(reshape_response(final_response, 'debug' if config.debug_mode else 'base')) + + log_group = os.getenv('METRICS_LOG_GROUP') + log_stream = context.log_stream_name + if log_group and ensure_log_stream_exists(log_group, log_stream): + log_client = boto3.client('logs') + log_message = reshape_response(final_response, 'log') + log_events = [ + { + 'timestamp': timestamp(), + 'message': json.dumps(log_message) + } + ] + log_client.put_log_events(logGroupName=log_group, logStreamName=log_stream, logEvents=log_events) return {"statusCode": 200} except Exception: @@ -28,3 +51,20 @@ def handler(event, _context): err_text = ''.join(traceback.format_exception(*exc_info)) print(err_text) return {"statusCode": 500, "body": f'Unhandled error:\n{err_text}'} + +def reshape_response(response, type): + return {k: response[k] for k in RESPONSE_TYPES[type]} + +def ensure_log_stream_exists(log_group, log_stream): + log_client = boto3.client('logs') + try: + log_client.create_log_stream(logGroupName=log_group, logStreamName=log_stream) + return True + except log_client.exceptions.ResourceAlreadyExistsException: + return True + except Exception: + print(f'Could not create log stream: {log_group}:{log_stream}') + return False + +def timestamp(): + return round(datetime.timestamp(datetime.now()) * 1000) \ No newline at end of file diff --git a/chat/src/helpers/metrics.py b/chat/src/helpers/metrics.py index 168cd02f..6cb13efb 100644 --- a/chat/src/helpers/metrics.py +++ b/chat/src/helpers/metrics.py @@ -2,12 +2,14 @@ def token_usage(config, response, original_question): - return { + data = { "question": count_tokens(config.question), "answer": count_tokens(response["output_text"]), "prompt": count_tokens(config.prompt_text), "source_documents": count_tokens(original_question["source_documents"]), } + data["total"] = sum(data.values()) + return data def count_tokens(val): diff --git a/chat/src/helpers/response.py b/chat/src/helpers/response.py index 351be8ff..9bb55617 100644 --- a/chat/src/helpers/response.py +++ b/chat/src/helpers/response.py @@ -1,13 +1,9 @@ from helpers.metrics import token_usage from openai.error import InvalidRequestError -def base_response(config, response): - return {"answer": response["output_text"], "ref": config.ref} - - def debug_response(config, response, original_question): - response_base = base_response(config, response) - debug_info = { + return { + "answer": response["output_text"], "attributes": config.attributes, "azure_endpoint": config.azure_endpoint, "deployment_name": config.deployment_name, @@ -15,13 +11,12 @@ def debug_response(config, response, original_question): "k": config.k, "openai_api_version": config.openai_api_version, "prompt": config.prompt_text, + "question": config.question, "ref": config.ref, "temperature": config.temperature, "text_key": config.text_key, "token_counts": token_usage(config, response, original_question), } - return {**response_base, **debug_info} - def get_and_send_original_question(config, docs): doc_response = [] @@ -55,10 +50,7 @@ def prepare_response(config): original_question = get_and_send_original_question(config, docs) response = config.chain({"question": config.question, "input_documents": docs}) - if config.debug_mode: - prepared_response = debug_response(config, response, original_question) - else: - prepared_response = base_response(config, response) + prepared_response = debug_response(config, response, original_question) except InvalidRequestError as err: prepared_response = { "question": config.question, diff --git a/chat/template.yaml b/chat/template.yaml index 24b95aac..6d509b8b 100644 --- a/chat/template.yaml +++ b/chat/template.yaml @@ -203,6 +203,7 @@ Resources: AZURE_OPENAI_LLM_DEPLOYMENT_ID: !Ref AzureOpenaiLlmDeploymentId AZURE_OPENAI_RESOURCE_NAME: !Ref AzureOpenaiResourceName ENV_PREFIX: !Ref EnvironmentPrefix + METRICS_LOG_GROUP: !Ref ChatMetricsLog OPENSEARCH_ENDPOINT: !Ref OpenSearchEndpoint OPENSEARCH_MODEL_ID: !Ref OpenSearchModelId Policies: @@ -218,8 +219,20 @@ Resources: - 'es:ESHttpGet' - 'es:ESHttpPost' Resource: '*' + - Statement: + - Effect: Allow + Action: + - logs:CreateLogStream + - logs:PutLogEvents + Resource: !Sub "${ChatMetricsLog.Arn}:*" Metadata: BuildMethod: nodejs18.x + ChatMetricsLog: + Type: AWS::Logs::LogGroup + Properties: + LogGroupName: !Sub "/nul/${AWS::StackName}/ChatFunction-Metrics" + LogGroupClass: STANDARD + RetentionInDays: 90 Deployment: Type: AWS::ApiGatewayV2::Deployment DependsOn: diff --git a/chat/test/handlers/test_chat.py b/chat/test/handlers/test_chat.py index ebce7e51..c8bad364 100644 --- a/chat/test/handlers/test_chat.py +++ b/chat/test/handlers/test_chat.py @@ -23,6 +23,9 @@ def post_to_connection(self, Data, ConnectionId): self.received_data = Data return Data +class MockContext: + def __init__(self): + self.log_stream_name = 'test' @mock.patch.dict( os.environ, @@ -33,13 +36,13 @@ def post_to_connection(self, Data, ConnectionId): class TestHandler(TestCase): def test_handler_unauthorized(self): event = {"socket": Websocket(client=MockClient(), endpoint_url="test", connection_id="test", ref="test")} - self.assertEqual(handler(event, {}), {'body': 'Unauthorized', 'statusCode': 401}) + self.assertEqual(handler(event, MockContext()), {'body': 'Unauthorized', 'statusCode': 401}) @patch.object(ApiToken, 'is_logged_in') def test_handler_success(self, mock_is_logged_in): mock_is_logged_in.return_value = True event = {"socket": Websocket(client=MockClient(), endpoint_url="test", connection_id="test", ref="test")} - self.assertEqual(handler(event, {}), {'statusCode': 200}) + self.assertEqual(handler(event, MockContext()), {'statusCode': 200}) @patch.object(ApiToken, 'is_logged_in') @patch.object(ApiToken, 'is_superuser') @@ -51,7 +54,7 @@ def test_handler_debug_mode(self, mock_is_debug_enabled, mock_is_logged_in, mock mock_client = MockClient() mock_websocket = Websocket(client=mock_client, endpoint_url="test", connection_id="test", ref="test") event = {"socket": mock_websocket, "debug": True} - handler(event, {}) + handler(event, MockContext()) response = json.loads(mock_client.received_data) self.assertEqual(response["type"], "debug") @@ -65,7 +68,7 @@ def test_handler_debug_mode_for_superusers_only(self, mock_is_debug_enabled, moc mock_client = MockClient() mock_websocket = Websocket(client=mock_client, endpoint_url="test", connection_id="test", ref="test") event = {"socket": mock_websocket, "debug": True} - handler(event, {}) + handler(event, MockContext()) response = json.loads(mock_client.received_data) self.assertEqual(response["type"], "error") diff --git a/chat/test/helpers/test_metrics.py b/chat/test/helpers/test_metrics.py index efab07cd..389d8ea1 100644 --- a/chat/test/helpers/test_metrics.py +++ b/chat/test/helpers/test_metrics.py @@ -51,6 +51,7 @@ def test_token_usage(self): "prompt": 328, "question": 15, "source_documents": 1, + "total": 350 } self.assertEqual(result, expected_result)