From 5017697ee4bcba42a0670b6d78d938d2b6cbb600 Mon Sep 17 00:00:00 2001 From: Karen Shaw Date: Fri, 23 Aug 2024 20:08:52 +0000 Subject: [PATCH 1/2] Create http endpoint for chat for testing --- chat/src/handlers/chat_sync.py | 43 +++++ chat/src/helpers/http_response.py | 60 +++++++ chat/src/http_event_config.py | 188 ++++++++++++++++++++ chat/template.yaml | 42 +++++ chat/test/handlers/test_chat_sync.py | 35 ++++ chat/test/helpers/test_http_event_config.py | 95 ++++++++++ 6 files changed, 463 insertions(+) create mode 100644 chat/src/handlers/chat_sync.py create mode 100644 chat/src/helpers/http_response.py create mode 100644 chat/src/http_event_config.py create mode 100644 chat/test/handlers/test_chat_sync.py create mode 100644 chat/test/helpers/test_http_event_config.py diff --git a/chat/src/handlers/chat_sync.py b/chat/src/handlers/chat_sync.py new file mode 100644 index 00000000..8166870e --- /dev/null +++ b/chat/src/handlers/chat_sync.py @@ -0,0 +1,43 @@ + +import json +import logging +import os +from http_event_config import HTTPEventConfig +from helpers.http_response import HTTPResponse +from honeybadger import honeybadger + +honeybadger.configure() +logging.getLogger('honeybadger').addHandler(logging.StreamHandler()) + +RESPONSE_TYPES = { + "base": ["answer", "ref"], + "debug": ["answer", "attributes", "azure_endpoint", "deployment_name", "is_superuser", "k", "openai_api_version", "prompt", "question", "ref", "temperature", "text_key", "token_counts"], + "log": ["answer", "deployment_name", "is_superuser", "k", "openai_api_version", "prompt", "question", "ref", "size", "source_documents", "temperature", "token_counts"], + "error": ["question", "error", "source_documents"] +} + +def handler(event, context): + print(f'Event: {event}') + + config = HTTPEventConfig(event) + + if not config.is_logged_in: + return {"statusCode": 401, "body": "Unauthorized"} + + if config.question is None or config.question == "": + return {"statusCode": 400, "body": "Question cannot be blank"} + + if not os.getenv("SKIP_WEAVIATE_SETUP"): + config.setup_llm_request() + response = HTTPResponse(config) + final_response = response.prepare_response() + if "error" in final_response: + logging.error(f'Error: {final_response["error"]}') + return {"statusCode": 500, "body": "Internal Server Error"} + else: + return {"statusCode": 200, "body": json.dumps(reshape_response(final_response, 'debug' if config.debug_mode else 'base'))} + + return {"statusCode": 200} + +def reshape_response(response, type): + return {k: response[k] for k in RESPONSE_TYPES[type]} \ No newline at end of file diff --git a/chat/src/helpers/http_response.py b/chat/src/helpers/http_response.py new file mode 100644 index 00000000..fc6abc6f --- /dev/null +++ b/chat/src/helpers/http_response.py @@ -0,0 +1,60 @@ +from helpers.metrics import debug_response +from langchain_core.output_parsers import StrOutputParser +from langchain_core.runnables import RunnableLambda, RunnablePassthrough + +def extract_prompt_value(v): + if isinstance(v, list): + return [extract_prompt_value(item) for item in v] + elif isinstance(v, dict) and 'label' in v: + return [v.get('label')] + else: + return v + +class HTTPResponse: + def __init__(self, config): + self.config = config + self.store = {} + + def debug_response_passthrough(self): + return RunnableLambda(lambda x: debug_response(self.config, x, self.original_question)) + + def original_question_passthrough(self): + def get_and_send_original_question(docs): + source_documents = [] + for doc in docs["context"]: + doc.metadata = {key: extract_prompt_value(doc.metadata.get(key)) for key in self.config.attributes if key in doc.metadata} + source_document = doc.metadata.copy() + source_document["content"] = doc.page_content + source_documents.append(source_document) + + original_question = { + "question": self.config.question, + "source_documents": source_documents, + } + + self.original_question = original_question + return docs + + return RunnablePassthrough(get_and_send_original_question) + + def prepare_response(self): + try: + retriever = self.config.opensearch.as_retriever(search_type="similarity", search_kwargs={"k": self.config.k, "size": self.config.size, "_source": {"excludes": ["embedding"]}}) + chain = ( + {"context": retriever, "question": RunnablePassthrough()} + | self.original_question_passthrough() + | self.config.prompt + | self.config.client + | StrOutputParser() + | self.debug_response_passthrough() + ) + response = chain.invoke(self.config.question) + except Exception as err: + response = { + "question": self.config.question, + "error": str(err), + "source_documents": [], + } + return response + + \ No newline at end of file diff --git a/chat/src/http_event_config.py b/chat/src/http_event_config.py new file mode 100644 index 00000000..47f479aa --- /dev/null +++ b/chat/src/http_event_config.py @@ -0,0 +1,188 @@ +import os +import json + +from dataclasses import dataclass, field + +from langchain_core.prompts import ChatPromptTemplate +from setup import ( + opensearch_client, + opensearch_vector_store, + openai_chat_client, +) +from typing import List +from helpers.apitoken import ApiToken +from helpers.prompts import document_template, prompt_template + +CHAIN_TYPE = "stuff" +DOCUMENT_VARIABLE_NAME = "context" +K_VALUE = 40 +MAX_K = 100 +MAX_TOKENS = 1000 +SIZE = 5 +TEMPERATURE = 0.2 +TEXT_KEY = "id" +VERSION = "2024-02-01" + +@dataclass +class HTTPEventConfig: + """ + The EventConfig class represents the configuration for an event. + Default values are set for the following properties which can be overridden in the payload message. + """ + + DEFAULT_ATTRIBUTES = ["accession_number", "alternate_title", "api_link", "canonical_link", "caption", "collection", + "contributor", "date_created", "date_created_edtf", "description", "genre", "id", "identifier", + "keywords", "language", "notes", "physical_description_material", "physical_description_size", + "provenance", "publisher", "rights_statement", "subject", "table_of_contents", "thumbnail", + "title", "visibility", "work_type"] + + api_token: ApiToken = field(init=False) + attributes: List[str] = field(init=False) + azure_endpoint: str = field(init=False) + azure_resource_name: str = field(init=False) + debug_mode: bool = field(init=False) + deployment_name: str = field(init=False) + document_prompt: ChatPromptTemplate = field(init=False) + event: dict = field(default_factory=dict) + is_logged_in: bool = field(init=False) + k: int = field(init=False) + max_tokens: int = field(init=False) + openai_api_version: str = field(init=False) + payload: dict = field(default_factory=dict) + prompt_text: str = field(init=False) + prompt: ChatPromptTemplate = field(init=False) + question: str = field(init=False) + ref: str = field(init=False) + request_context: dict = field(init=False) + temperature: float = field(init=False) + size: int = field(init=False) + stream_response: bool = field(init=False) + text_key: str = field(init=False) + + def __post_init__(self): + self.payload = json.loads(self.event.get("body", "{}")) + self.api_token = ApiToken(signed_token=self.payload.get("auth")) + self.attributes = self._get_attributes() + self.azure_endpoint = self._get_azure_endpoint() + self.azure_resource_name = self._get_azure_resource_name() + self.debug_mode = self._is_debug_mode_enabled() + self.deployment_name = self._get_deployment_name() + self.is_logged_in = self.api_token.is_logged_in() + self.k = self._get_k() + self.max_tokens = min(self.payload.get("max_tokens", MAX_TOKENS), MAX_TOKENS) + self.openai_api_version = self._get_openai_api_version() + self.prompt_text = self._get_prompt_text() + self.request_context = self.event.get("requestContext", {}) + self.question = self.payload.get("question") + self.ref = self.payload.get("ref") + self.size = self._get_size() + self.stream_response = self.payload.get("stream_response", not self.debug_mode) + self.temperature = self._get_temperature() + self.text_key = self._get_text_key() + self.document_prompt = self._get_document_prompt() + self.prompt = ChatPromptTemplate.from_template(self.prompt_text) + + def _get_payload_value_with_superuser_check(self, key, default): + if self.api_token.is_superuser(): + return self.payload.get(key, default) + else: + return default + + def _get_attributes_function(self): + try: + opensearch = opensearch_client() + mapping = opensearch.indices.get_mapping(index="dc-v2-work") + return list(next(iter(mapping.values()))['mappings']['properties'].keys()) + except StopIteration: + return [] + + def _get_attributes(self): + return self._get_payload_value_with_superuser_check("attributes", self.DEFAULT_ATTRIBUTES) + + def _get_azure_endpoint(self): + default = f"https://{self._get_azure_resource_name()}.openai.azure.com/" + return self._get_payload_value_with_superuser_check("azure_endpoint", default) + + def _get_azure_resource_name(self): + azure_resource_name = self._get_payload_value_with_superuser_check( + "azure_resource_name", os.environ.get("AZURE_OPENAI_RESOURCE_NAME") + ) + if not azure_resource_name: + raise EnvironmentError( + "Either payload must contain 'azure_resource_name' or environment variable 'AZURE_OPENAI_RESOURCE_NAME' must be set" + ) + return azure_resource_name + + def _get_deployment_name(self): + return self._get_payload_value_with_superuser_check( + "deployment_name", os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT_ID") + ) + + def _get_k(self): + value = self._get_payload_value_with_superuser_check("k", K_VALUE) + return min(value, MAX_K) + + def _get_openai_api_version(self): + return self._get_payload_value_with_superuser_check( + "openai_api_version", VERSION + ) + + def _get_prompt_text(self): + return self._get_payload_value_with_superuser_check("prompt", prompt_template()) + + def _get_size(self): + return self._get_payload_value_with_superuser_check("size", SIZE) + + def _get_temperature(self): + return self._get_payload_value_with_superuser_check("temperature", TEMPERATURE) + + def _get_text_key(self): + return self._get_payload_value_with_superuser_check("text_key", TEXT_KEY) + + def _get_document_prompt(self): + return ChatPromptTemplate.from_template(document_template(self.attributes)) + + def debug_message(self): + return { + "type": "debug", + "message": { + "attributes": self.attributes, + "azure_endpoint": self.azure_endpoint, + "deployment_name": self.deployment_name, + "k": self.k, + "openai_api_version": self.openai_api_version, + "prompt": self.prompt_text, + "question": self.question, + "ref": self.ref, + "size": self.ref, + "temperature": self.temperature, + "text_key": self.text_key, + }, + } + + def setup_llm_request(self): + self._setup_vector_store() + self._setup_chat_client() + + def _setup_vector_store(self): + self.opensearch = opensearch_vector_store() + + def _setup_chat_client(self): + self.client = openai_chat_client( + azure_deployment=self.deployment_name, + azure_endpoint=self.azure_endpoint, + openai_api_version=self.openai_api_version, + max_tokens=self.max_tokens + ) + + def _is_debug_mode_enabled(self): + debug = self.payload.get("debug", False) + return debug and self.api_token.is_superuser() + + def _to_bool(self, val): + """Converts a value to boolean. If the value is a string, it considers + "", "no", "false", "0" as False. Otherwise, it returns the boolean of the value. + """ + if isinstance(val, str): + return val.lower() not in ["", "no", "false", "0"] + return bool(val) diff --git a/chat/template.yaml b/chat/template.yaml index 04d59379..98ac2ed7 100644 --- a/chat/template.yaml +++ b/chat/template.yaml @@ -242,6 +242,48 @@ Resources: Resource: !Sub "${ChatMetricsLog.Arn}:*" #* Metadata: #* BuildMethod: nodejs20.x + ChatSyncFunction: + Type: AWS::Serverless::Function + Properties: + CodeUri: ./src + Runtime: python3.10 + Architectures: + - x86_64 + #* Layers: + #* - !Ref ChatDependencies + MemorySize: 1024 + Handler: handlers/chat_sync.handler + Timeout: 300 + Environment: + Variables: + API_TOKEN_SECRET: !Ref ApiTokenSecret + AZURE_OPENAI_API_KEY: !Ref AzureOpenaiApiKey + AZURE_OPENAI_LLM_DEPLOYMENT_ID: !Ref AzureOpenaiLlmDeploymentId + AZURE_OPENAI_RESOURCE_NAME: !Ref AzureOpenaiResourceName + ENV_PREFIX: !Ref EnvironmentPrefix + HONEYBADGER_API_KEY: !Ref HoneybadgerApiKey + HONEYBADGER_ENVIRONMENT: !Ref HoneybadgerEnv + HONEYBADGER_REVISION: !Ref HoneybadgerRevision + METRICS_LOG_GROUP: !Ref ChatMetricsLog + OPENSEARCH_ENDPOINT: !Ref OpenSearchEndpoint + OPENSEARCH_MODEL_ID: !Ref OpenSearchModelId + FunctionUrlConfig: + AuthType: NONE + Policies: + - Statement: + - Effect: Allow + Action: + - 'es:ESHttpGet' + - 'es:ESHttpPost' + Resource: '*' + # - Statement: + # - Effect: Allow + # Action: + # - logs:CreateLogStream + # - logs:PutLogEvents + # Resource: !Sub "${ChatMetricsLog.Arn}:*" + #* Metadata: + #* BuildMethod: nodejs20.x ChatMetricsLog: Type: AWS::Logs::LogGroup Properties: diff --git a/chat/test/handlers/test_chat_sync.py b/chat/test/handlers/test_chat_sync.py new file mode 100644 index 00000000..773ebfe0 --- /dev/null +++ b/chat/test/handlers/test_chat_sync.py @@ -0,0 +1,35 @@ +# ruff: noqa: E402 + +import os +import sys + +sys.path.append('./src') + +from unittest import mock, TestCase +from unittest.mock import patch +from handlers.chat_sync import handler +from helpers.apitoken import ApiToken + +class MockContext: + def __init__(self): + self.log_stream_name = 'test' + +@mock.patch.dict( + os.environ, + { + "AZURE_OPENAI_RESOURCE_NAME": "test", + }, +) +class TestHandler(TestCase): + def test_handler_unauthorized(self): + self.assertEqual(handler({"body": '{ "question": "Question?"}'}, MockContext()), {'body': 'Unauthorized', 'statusCode': 401}) + + @patch.object(ApiToken, 'is_logged_in') + def test_no_question(self, mock_is_logged_in): + mock_is_logged_in.return_value = True + self.assertEqual(handler({"body": '{ "question": ""}'}, MockContext()), {'statusCode': 400, 'body': 'Question cannot be blank'}) + + @patch.object(ApiToken, 'is_logged_in') + def test_handler_success(self, mock_is_logged_in): + mock_is_logged_in.return_value = True + self.assertEqual(handler({"body": '{"question": "Question?"}'}, MockContext()), {'statusCode': 200}) diff --git a/chat/test/helpers/test_http_event_config.py b/chat/test/helpers/test_http_event_config.py new file mode 100644 index 00000000..3bc67075 --- /dev/null +++ b/chat/test/helpers/test_http_event_config.py @@ -0,0 +1,95 @@ +# ruff: noqa: E402 +import json +import os +import sys +sys.path.append('./src') + +from http_event_config import HTTPEventConfig +from unittest import TestCase, mock + + +class TestEventConfigWithoutAzureResource(TestCase): + def test_requires_an_azure_resource(self): + with self.assertRaises(EnvironmentError): + HTTPEventConfig() + + +@mock.patch.dict( + os.environ, + { + "AZURE_OPENAI_RESOURCE_NAME": "test", + }, +) +class TestHTTPEventConfig(TestCase): + def test_fetches_attributes_from_vector_database(self): + os.environ.pop("AZURE_OPENAI_RESOURCE_NAME", None) + with self.assertRaises(EnvironmentError): + HTTPEventConfig() + + def test_defaults(self): + actual = HTTPEventConfig(event={"body": json.dumps({"attributes": ["title"]})}) + expected_defaults = {"azure_endpoint": "https://test.openai.azure.com/"} + self.assertEqual(actual.azure_endpoint, expected_defaults["azure_endpoint"]) + + def test_attempt_override_without_superuser_status(self): + actual = HTTPEventConfig( + event={ + "body": json.dumps( + { + "azure_resource_name": "new_name_for_test", + "attributes": ["title", "subject", "date_created"], + "index": "testIndex", + "k": 100, + "openai_api_version": "2024-01-01", + "question": "test question", + "ref": "test ref", + "size": 90, + "temperature": 0.9, + "text_key": "accession_number", + } + ) + } + ) + expected_output = { + "attributes": HTTPEventConfig.DEFAULT_ATTRIBUTES, + "azure_endpoint": "https://test.openai.azure.com/", + "k": 40, + "openai_api_version": "2024-02-01", + "question": "test question", + "size": 5, + "ref": "test ref", + "temperature": 0.2, + "text_key": "id", + } + self.assertEqual(actual.azure_endpoint, expected_output["azure_endpoint"]) + self.assertEqual(actual.attributes, expected_output["attributes"]) + self.assertEqual(actual.k, expected_output["k"]) + self.assertEqual( + actual.openai_api_version, expected_output["openai_api_version"] + ) + self.assertEqual(actual.question, expected_output["question"]) + self.assertEqual(actual.ref, expected_output["ref"]) + self.assertEqual(actual.temperature, expected_output["temperature"]) + self.assertEqual(actual.text_key, expected_output["text_key"]) + + def test_debug_message(self): + self.assertEqual( + HTTPEventConfig( + event={"body": json.dumps({"attributes": ["source"]})} + ).debug_message()["type"], + "debug", + ) + + def test_to_bool(self): + self.assertEqual(HTTPEventConfig(event={"body": json.dumps({"attributes": ["source"]})})._to_bool(""), False) + self.assertEqual(HTTPEventConfig(event={"body": json.dumps({"attributes": ["source"]})})._to_bool("0"), False) + self.assertEqual(HTTPEventConfig(event={"body": json.dumps({"attributes": ["source"]})})._to_bool("no"), False) + self.assertEqual(HTTPEventConfig(event={"body": json.dumps({"attributes": ["source"]})})._to_bool("false"), False) + self.assertEqual(HTTPEventConfig(event={"body": json.dumps({"attributes": ["source"]})})._to_bool("False"), False) + self.assertEqual(HTTPEventConfig(event={"body": json.dumps({"attributes": ["source"]})})._to_bool("FALSE"), False) + self.assertEqual(HTTPEventConfig(event={"body": json.dumps({"attributes": ["source"]})})._to_bool("no"), False) + self.assertEqual(HTTPEventConfig(event={"body": json.dumps({"attributes": ["source"]})})._to_bool("No"), False) + self.assertEqual(HTTPEventConfig(event={"body": json.dumps({"attributes": ["source"]})})._to_bool("NO"), False) + self.assertEqual(HTTPEventConfig(event={"body": json.dumps({"attributes": ["source"]})})._to_bool("true"), True) + self.assertEqual(HTTPEventConfig(event={"body": json.dumps({"attributes": ["source"]})})._to_bool(True), True) + self.assertEqual(HTTPEventConfig(event={"body": json.dumps({"attributes": ["source"]})})._to_bool(False), False) From 730c6caf1219d8a7592a5669d562cff28d652057 Mon Sep 17 00:00:00 2001 From: Brendan Quinn Date: Mon, 26 Aug 2024 16:35:40 -0500 Subject: [PATCH 2/2] Bump size default from 5 to 20 (#252) * Bumps size default from 5 to 20 * Limits source documents sent via websockets to 5 * Sends all source documents to the chat metrics endpoint --- chat/src/event_config.py | 2 +- chat/src/helpers/hybrid_query.py | 4 +-- chat/src/helpers/response.py | 13 +++++-- chat/test/helpers/test_metrics.py | 59 +++++++++---------------------- chat/test/test_event_config.py | 2 +- 5 files changed, 31 insertions(+), 49 deletions(-) diff --git a/chat/src/event_config.py b/chat/src/event_config.py index 07e42ee7..3b1e8ae7 100644 --- a/chat/src/event_config.py +++ b/chat/src/event_config.py @@ -20,7 +20,7 @@ K_VALUE = 40 MAX_K = 100 MAX_TOKENS = 1000 -SIZE = 5 +SIZE = 20 TEMPERATURE = 0.2 TEXT_KEY = "id" VERSION = "2024-02-01" diff --git a/chat/src/helpers/hybrid_query.py b/chat/src/helpers/hybrid_query.py index 8e202d5f..2cee1b6e 100644 --- a/chat/src/helpers/hybrid_query.py +++ b/chat/src/helpers/hybrid_query.py @@ -11,9 +11,9 @@ def filter(query: dict): } } -def hybrid_query(query: str, model_id: str, vector_field: str = "embedding", k: int = 10, **kwargs: Any): +def hybrid_query(query: str, model_id: str, vector_field: str = "embedding", k: int = 40, **kwargs: Any): result = { - "size": kwargs.get("size", 5), + "size": kwargs.get("size", 20), "query": { "hybrid": { "queries": [ diff --git a/chat/src/helpers/response.py b/chat/src/helpers/response.py index 94c9678c..b7c24836 100644 --- a/chat/src/helpers/response.py +++ b/chat/src/helpers/response.py @@ -26,13 +26,20 @@ def get_and_send_original_question(docs): source_document = doc.metadata.copy() source_document["content"] = doc.page_content source_documents.append(source_document) - + + socket_message = { + "question": self.config.question, + "source_documents": source_documents[:5] + } + self.config.socket.send(socket_message) + original_question = { "question": self.config.question, - "source_documents": source_documents, + "source_documents": source_documents } - self.config.socket.send(original_question) self.original_question = original_question + + docs["source_documents"] = source_documents return docs return RunnablePassthrough(get_and_send_original_question) diff --git a/chat/test/helpers/test_metrics.py b/chat/test/helpers/test_metrics.py index 3cd8314f..84a9f9d0 100644 --- a/chat/test/helpers/test_metrics.py +++ b/chat/test/helpers/test_metrics.py @@ -23,38 +23,7 @@ def setUp(self): self.question = "What is your name?" self.original_question = { "question": self.question, - "source_documents": [ - { - "accession_number": "SourceDoc:1", - "api_link": "https://api.dc.library.northwestern.edu/api/v2/works/881e8cae-67be-4e04-9970-7eafb52b2c5c", - "canonical_link": "https://dc.library.northwestern.edu/items/881e8cae-67be-4e04-9970-7eafb52b2c5c", - "title": "Source Document One!" - }, - { - "accession_number": "SourceDoc:2", - "api_link": "https://api.dc.library.northwestern.edu/api/v2/works/ac0b2a0d-8f80-420a-b1a1-63b6ac2299f1", - "canonical_link": "https://dc.library.northwestern.edu/items/ac0b2a0d-8f80-420a-b1a1-63b6ac2299f1", - "title": "Source Document Two!" - }, - { - "accession_number": "SourceDoc:3", - "api_link": "https://api.dc.library.northwestern.edu/api/v2/works/11569bb5-1b89-4fa9-bdfb-2caf2ded5aa5", - "canonical_link": "https://dc.library.northwestern.edu/items/11569bb5-1b89-4fa9-bdfb-2caf2ded5aa5", - "title": "Source Document Three!" - }, - { - "accession_number": "SourceDoc:4", - "api_link": "https://api.dc.library.northwestern.edu/api/v2/works/211eeeca-d56e-4c6e-9123-1612d72258f9", - "canonical_link": "https://dc.library.northwestern.edu/items/211eeeca-d56e-4c6e-9123-1612d72258f9", - "title": "Source Document Four!" - }, - { - "accession_number": "SourceDoc:5", - "api_link": "https://api.dc.library.northwestern.edu/api/v2/works/10e45e7a-8011-4ac5-97df-efa6a5439d0e", - "canonical_link": "https://dc.library.northwestern.edu/items/10e45e7a-8011-4ac5-97df-efa6a5439d0e", - "title": "Source Document Five!" - } - ], + "source_documents": self.generate_source_documents(20), } self.event = { "body": json.dumps({ @@ -75,6 +44,17 @@ def setUp(self): self.response = { "output_text": "This is a test response.", } + + def generate_source_documents(self, count): + return [ + { + "accession_number": f"SourceDoc:{i+1}", + "api_link": f"https://api.dc.library.northwestern.edu/api/v2/works/{i+1:0>32}", + "canonical_link": f"https://dc.library.northwestern.edu/items/{i+1:0>32}", + "title": f"Source Document {i+1}!" + } + for i in range(count) + ] def test_debug_response(self): result = debug_response(self.config, self.response, self.original_question) @@ -82,16 +62,11 @@ def test_debug_response(self): self.assertEqual(result["k"], 40) self.assertEqual(result["question"], self.question) self.assertEqual(result["ref"], "test") - self.assertEqual(result["size"], 5) + self.assertEqual(result["size"], 20) + self.assertEqual(len(result["source_documents"]), 20) self.assertEqual( result["source_documents"], - [ - "https://api.dc.library.northwestern.edu/api/v2/works/881e8cae-67be-4e04-9970-7eafb52b2c5c", - "https://api.dc.library.northwestern.edu/api/v2/works/ac0b2a0d-8f80-420a-b1a1-63b6ac2299f1", - "https://api.dc.library.northwestern.edu/api/v2/works/11569bb5-1b89-4fa9-bdfb-2caf2ded5aa5", - "https://api.dc.library.northwestern.edu/api/v2/works/211eeeca-d56e-4c6e-9123-1612d72258f9", - "https://api.dc.library.northwestern.edu/api/v2/works/10e45e7a-8011-4ac5-97df-efa6a5439d0e" - ] + [doc["api_link"] for doc in self.original_question["source_documents"]] ) def test_token_usage(self): @@ -101,8 +76,8 @@ def test_token_usage(self): "answer": 12, "prompt": 329, "question": 5, - "source_documents": 527, - "total": 873 + "source_documents": 1602, + "total": 1948 } self.assertEqual(result, expected_result) diff --git a/chat/test/test_event_config.py b/chat/test/test_event_config.py index 9f9f39f5..0d1a4654 100644 --- a/chat/test/test_event_config.py +++ b/chat/test/test_event_config.py @@ -56,7 +56,7 @@ def test_attempt_override_without_superuser_status(self): "k": 40, "openai_api_version": "2024-02-01", "question": "test question", - "size": 5, + "size": 20, "ref": "test ref", "temperature": 0.2, "text_key": "id",