Skip to content

Commit

Permalink
Merge pull request #152 from nulib/wip
Browse files Browse the repository at this point in the history
Turn a POSTed question into a valid JSON response from the LLM
  • Loading branch information
mbklein authored Jun 26, 2023
2 parents 8e8da1e + b9711a2 commit f0ceee6
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 33 deletions.
79 changes: 55 additions & 24 deletions python/src/handlers/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,48 +2,79 @@
import json
import os
import setup
from langchain.chains import RetrievalQAWithSourcesChain

def handler(event, context):
question = event.get("body", "")
if event.get("isBase64Encoded", False):
question = base64.b64decode(question)

headers = event.get("headers")
token = headers.get("authorization", headers.get("Authorization", None))

if token is None:
for cookie in event.get("cookies", []):
[k, v] = cookie.split("=", 1)
if k == os.getenv("API_TOKEN_NAME"):
token = v
else:
token = token.replace("Bearer ", "")

if not setup.validate_token(token):
def handler(event, context):
if not is_authenticated(event):
return {
"statusCode": 401,
"headers": {
"Content-Type": "text/plain"
},
"body": "Unauthorized"
}


params = event.get("queryStringParameters", {})
index_name = params.get("index", "Work")
text_key = params.get("text_key", "title")
attributes = params.get("attributes", "identifier,title").split(",")
question = get_query(event)
index_name = get_param(event, "index", "Work")
text_key = get_param(event, "text_key", "title")
attributes = get_param(event,
"attributes",
"identifier,title,source,alternate_title,contributor,create_date,creator,date_created,description,genre,keywords,language,location,physical_description_material,physical_description_size,scope_and_contents,style_period,subject,table_of_contents,technique,work_type").split(",")

weaviate = setup.weaviate_vector_store(index_name=index_name,
text_key=text_key,
attributes=attributes)
result = weaviate.similarity_search_by_text(query=question,
additional="certainty")

client = setup.openai_chat_client()


chain = RetrievalQAWithSourcesChain.from_chain_type(
client,
chain_type="stuff",
retriever=weaviate.as_retriever(),
return_source_documents=True)

response = chain({"question": question})
print(response)
response['source_documents'] = [doc.__dict__ for doc in response['source_documents']]
return {
"statusCode": 200,
"headers": {
"Content-Type": "application/json"
},
"body": json.dumps([doc.__dict__ for doc in result])
}
"body": json.dumps(response)
}


def get_param(event, parameter, default):
params = event.get("queryStringParameters", {})
return params.get(parameter, default)


def get_query(event):
question = event.get("body", "")
if event.get("isBase64Encoded", False):
question = base64.b64decode(question)
return question


def is_authenticated(event):
headers = event.get("headers")
token = headers.get("authorization", headers.get("Authorization", None))

if token is None:
for cookie in event.get("cookies", []):
[k, v] = cookie.split("=", 1)
if k == os.getenv("API_TOKEN_NAME"):
token = v
else:
token = token.replace("Bearer ", "")

return setup.validate_token(token)



# result = weaviate.similarity_search_by_text(query=question,
# additional="certainty")
17 changes: 9 additions & 8 deletions python/src/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,25 @@

def openai_chat_client():
deployment = os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT_ID")
os.getenv("AZURE_OPENAI_RESOURCE_NAME")
os.getenv("AZURE_OPENAI_API_KEY")
key = os.getenv("AZURE_OPENAI_API_KEY")
resource = os.getenv("AZURE_OPENAI_RESOURCE_NAME")

return AzureChatOpenAI(deployment_name=deployment,
openai_api_key=key,
openai_api_base=f"https://{resource}.openai.azure.com/",
openai_api_version="2023-03-15-preview")

return AzureChatOpenAI(deployment_name=deployment)

def weaviate_vector_store(index_name: str, text_key: str, attributes: List[str] = []):
weaviate_url = os.environ['WEAVIATE_URL']
weaviate_api_key = os.environ['WEAVIATE_API_KEY']
openai_api_key = os.environ['AZURE_OPENAI_API_KEY']
# openai_api_key = os.environ['AZURE_OPENAI_API_KEY']

auth_config = weaviate.AuthApiKey(api_key=weaviate_api_key)

client = weaviate.Client(
url=weaviate_url,
auth_client_secret=auth_config,
additional_headers={
"X-OpenAI-Api-Key": openai_api_key
}
auth_client_secret=auth_config
)
return Weaviate(client=client,
index_name=index_name,
Expand Down
2 changes: 1 addition & 1 deletion template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ Resources:
CodeUri: ./python/src
Runtime: python3.9
Handler: handlers/chat.handler
Timeout: 30
Timeout: 300
Environment:
Variables:
AZURE_OPENAI_API_KEY: !Ref AzureOpenaiApiKey
Expand Down

0 comments on commit f0ceee6

Please sign in to comment.