From b9711a2d0f678da79e992676b1da0b0359a0edd7 Mon Sep 17 00:00:00 2001 From: Brendan Quinn Date: Fri, 23 Jun 2023 19:13:17 +0000 Subject: [PATCH] Turn a POSTed question into a valid JSON response from the LLM --- python/src/handlers/chat.py | 79 ++++++++++++++++++++++++++----------- python/src/setup.py | 17 ++++---- template.yaml | 2 +- 3 files changed, 65 insertions(+), 33 deletions(-) diff --git a/python/src/handlers/chat.py b/python/src/handlers/chat.py index d2ae9b1a..c34831b6 100644 --- a/python/src/handlers/chat.py +++ b/python/src/handlers/chat.py @@ -2,24 +2,13 @@ import json import os import setup +from langchain.chains import RetrievalQAWithSourcesChain -def handler(event, context): - question = event.get("body", "") - if event.get("isBase64Encoded", False): - question = base64.b64decode(question) - headers = event.get("headers") - token = headers.get("authorization", headers.get("Authorization", None)) - if token is None: - for cookie in event.get("cookies", []): - [k, v] = cookie.split("=", 1) - if k == os.getenv("API_TOKEN_NAME"): - token = v - else: - token = token.replace("Bearer ", "") - if not setup.validate_token(token): +def handler(event, context): + if not is_authenticated(event): return { "statusCode": 401, "headers": { @@ -27,23 +16,65 @@ def handler(event, context): }, "body": "Unauthorized" } - - - params = event.get("queryStringParameters", {}) - index_name = params.get("index", "Work") - text_key = params.get("text_key", "title") - attributes = params.get("attributes", "identifier,title").split(",") + question = get_query(event) + index_name = get_param(event, "index", "Work") + text_key = get_param(event, "text_key", "title") + attributes = get_param(event, + "attributes", + "identifier,title,source,alternate_title,contributor,create_date,creator,date_created,description,genre,keywords,language,location,physical_description_material,physical_description_size,scope_and_contents,style_period,subject,table_of_contents,technique,work_type").split(",") weaviate = setup.weaviate_vector_store(index_name=index_name, text_key=text_key, attributes=attributes) - result = weaviate.similarity_search_by_text(query=question, - additional="certainty") + + client = setup.openai_chat_client() + + chain = RetrievalQAWithSourcesChain.from_chain_type( + client, + chain_type="stuff", + retriever=weaviate.as_retriever(), + return_source_documents=True) + + response = chain({"question": question}) + print(response) + response['source_documents'] = [doc.__dict__ for doc in response['source_documents']] return { "statusCode": 200, "headers": { "Content-Type": "application/json" }, - "body": json.dumps([doc.__dict__ for doc in result]) - } \ No newline at end of file + "body": json.dumps(response) + } + + +def get_param(event, parameter, default): + params = event.get("queryStringParameters", {}) + return params.get(parameter, default) + + +def get_query(event): + question = event.get("body", "") + if event.get("isBase64Encoded", False): + question = base64.b64decode(question) + return question + + +def is_authenticated(event): + headers = event.get("headers") + token = headers.get("authorization", headers.get("Authorization", None)) + + if token is None: + for cookie in event.get("cookies", []): + [k, v] = cookie.split("=", 1) + if k == os.getenv("API_TOKEN_NAME"): + token = v + else: + token = token.replace("Bearer ", "") + + return setup.validate_token(token) + + + +# result = weaviate.similarity_search_by_text(query=question, +# additional="certainty") \ No newline at end of file diff --git a/python/src/setup.py b/python/src/setup.py index 72a4f474..d9cb9ad6 100644 --- a/python/src/setup.py +++ b/python/src/setup.py @@ -7,24 +7,25 @@ def openai_chat_client(): deployment = os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT_ID") - os.getenv("AZURE_OPENAI_RESOURCE_NAME") - os.getenv("AZURE_OPENAI_API_KEY") + key = os.getenv("AZURE_OPENAI_API_KEY") + resource = os.getenv("AZURE_OPENAI_RESOURCE_NAME") + + return AzureChatOpenAI(deployment_name=deployment, + openai_api_key=key, + openai_api_base=f"https://{resource}.openai.azure.com/", + openai_api_version="2023-03-15-preview") - return AzureChatOpenAI(deployment_name=deployment) def weaviate_vector_store(index_name: str, text_key: str, attributes: List[str] = []): weaviate_url = os.environ['WEAVIATE_URL'] weaviate_api_key = os.environ['WEAVIATE_API_KEY'] - openai_api_key = os.environ['AZURE_OPENAI_API_KEY'] + # openai_api_key = os.environ['AZURE_OPENAI_API_KEY'] auth_config = weaviate.AuthApiKey(api_key=weaviate_api_key) client = weaviate.Client( url=weaviate_url, - auth_client_secret=auth_config, - additional_headers={ - "X-OpenAI-Api-Key": openai_api_key - } + auth_client_secret=auth_config ) return Weaviate(client=client, index_name=index_name, diff --git a/template.yaml b/template.yaml index 1fab7388..e24f33f2 100644 --- a/template.yaml +++ b/template.yaml @@ -554,7 +554,7 @@ Resources: CodeUri: ./python/src Runtime: python3.9 Handler: handlers/chat.handler - Timeout: 30 + Timeout: 300 Environment: Variables: AZURE_OPENAI_API_KEY: !Ref AzureOpenaiApiKey