Skip to content

Commit

Permalink
Allow superuser to override prompt and attributes
Browse files Browse the repository at this point in the history
Allow any request to override index name and k value
  • Loading branch information
mbklein committed Oct 17, 2023
1 parent 2320c3d commit dcecfe0
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 18 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ on:
paths:
- ".github/workflows/deploy.yml"
- "node/**"
- "python/**"
- "chat/**"
- "template.yaml"
workflow_dispatch:
concurrency:
Expand Down
32 changes: 21 additions & 11 deletions chat/src/handlers/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,8 @@

DEFAULT_INDEX = "Work"
DEFAULT_KEY = "title"
DEFAULT_ATTRIBUTES = ("title,alternate_title,collection,contributor,creator,"
"date_created,description,genre,language,library_unit,"
"location,physical_description_material,physical_description_size,"
"published,rights_statement,scope_and_contents,series,source,"
"style_period,subject,table_of_contents,technique,visibility,"
"work_type")
DEFAULT_K = 10
MAX_K = 100

class Websocket:
def __init__(self, endpoint_url, connection_id, ref):
Expand Down Expand Up @@ -56,11 +52,12 @@ def handler(event, context):
}

question = payload.get("question")
index_name = payload.get("index", DEFAULT_INDEX)
text_key = payload.get("text_key", DEFAULT_KEY)
index_name = payload.get("index", payload.get('index', DEFAULT_INDEX))
print(f'Searching index {index_name}')
text_key = payload.get("text_key", DEFAULT_KEY)
attributes = [
item for item
in set(payload.get("attributes", DEFAULT_ATTRIBUTES).split(","))
in get_attributes(index_name, payload if api_token.is_superuser() else {})
if item not in [text_key, "source"]
]

Expand All @@ -70,8 +67,9 @@ def handler(event, context):

client = setup.openai_chat_client(callbacks=[StreamingSocketCallbackHandler(socket)], streaming=True)

prompt_text = payload.get("prompt", prompt_template()) if api_token.is_superuser() else prompt_template()
prompt = PromptTemplate(
template=prompt_template(),
template=prompt_text,
input_variables=["question", "context"]
)

Expand All @@ -80,7 +78,8 @@ def handler(event, context):
input_variables=["page_content", "source"] + attributes,
)

docs = weaviate.similarity_search(question, k=10, additional="certainty")
k = min(payload.get("k", DEFAULT_K), MAX_K)
docs = weaviate.similarity_search(question, k=k, additional="certainty")
chain = load_qa_with_sources_chain(
client,
chain_type="stuff",
Expand Down Expand Up @@ -111,6 +110,17 @@ def handler(event, context):
print(event)
raise err

def get_attributes(index, payload):
request_attributes = payload.get('attributes', None)
if request_attributes is not None:
return ','.split(request_attributes)

client = setup.weaviate_client()
schema = client.schema.get(index)
names = [prop['name'] for prop in schema.get('properties')]
print(f'Retrieved attributes: {names}')
return names

def to_bool(val):
if isinstance(val, str):
return val.lower() not in ["", "no", "false", "0"]
Expand Down
3 changes: 3 additions & 0 deletions chat/src/helpers/apitoken.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,6 @@ def __init__(self, signed_token=None):

def is_logged_in(self):
return self.token.get("isLoggedIn", False)

def is_superuser(self):
return self.token.get("isSuperUser", False)
11 changes: 6 additions & 5 deletions chat/src/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,19 @@ def openai_chat_client(**kwargs):
**kwargs)



def weaviate_vector_store(index_name: str, text_key: str, attributes: List[str] = []):
def weaviate_client():
weaviate_url = os.environ['WEAVIATE_URL']
weaviate_api_key = os.environ['WEAVIATE_API_KEY']
# openai_api_key = os.environ['AZURE_OPENAI_API_KEY']

auth_config = weaviate.AuthApiKey(api_key=weaviate_api_key)

client = weaviate.Client(
return weaviate.Client(
url=weaviate_url,
auth_client_secret=auth_config
)

def weaviate_vector_store(index_name: str, text_key: str, attributes: List[str] = []):
client = weaviate_client()

return Weaviate(client=client,
index_name=index_name,
text_key=text_key,
Expand Down
3 changes: 3 additions & 0 deletions chat/test/fixtures/apitoken.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
TEST_SECRET = "TEST_SECRET"
TEST_TOKEN_NAME = "dcTestToken"
SUPER_TOKEN = ('eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjQ4NDM1NzU2ODg5MTIs'
'ImlhdCI6MTY4Nzg4MDI0NywiaXNMb2dnZWRJbiI6dHJ1ZSwic3ViIjoiYXBpVGVzdF'
'N1cGVyVXNlciIsImlzU3VwZXJVc2VyIjp0cnVlfQ.uGEdWlhwUr8RHrC6CLCV5_pOrQDTw41kM6_X99AEg1Q')
TEST_TOKEN = ('eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjQ4NDM1ODY2MDYxNjUs'
'ImlhdCI6MTY4Nzg5MTM2OSwiZW50aXRsZW1lbnRzIjpbXSwiaXNMb2dnZWRJbiI6d'
'HJ1ZSwic3ViIjoidGVzdFVzZXIifQ.vIZag1pHE1YyrxsKKlakXX_44ckAvkg7xWOoA_w4x58')
8 changes: 7 additions & 1 deletion chat/test/helpers/test_apitoken.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from src.helpers.apitoken import ApiToken
from test.fixtures.apitoken import TEST_SECRET, TEST_TOKEN
from test.fixtures.apitoken import SUPER_TOKEN, TEST_SECRET, TEST_TOKEN
from unittest import mock, TestCase

@mock.patch.dict(
Expand All @@ -17,6 +17,12 @@ def test_empty_token(self):
def test_valid_token(self):
subject = ApiToken(TEST_TOKEN)
self.assertTrue(subject.is_logged_in())
self.assertFalse(subject.is_superuser())

def test_superuser_token(self):
subject = ApiToken(SUPER_TOKEN)
self.assertTrue(subject.is_logged_in())
self.assertTrue(subject.is_superuser())

def test_invalid_token(self):
subject = ApiToken("INVALID_TOKEN")
Expand Down

0 comments on commit dcecfe0

Please sign in to comment.