Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better QA Prompts #409

Merged
merged 4 commits into from
Sep 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backend/danswer/configs/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
GEN_AI_HOST_TYPE = os.environ.get("GEN_AI_HOST_TYPE", ModelHostType.HUGGINGFACE.value)

# Set this to be enough for an answer + quotes
GEN_AI_MAX_OUTPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_OUTPUT_TOKENS", "512"))
GEN_AI_MAX_OUTPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_OUTPUT_TOKENS", "1024"))

# Danswer custom Deep Learning Models
INTENT_MODEL_VERSION = "danswer/intent-model"
Expand Down
4 changes: 2 additions & 2 deletions backend/danswer/direct_qa/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.local_transformers import TransformerQA
from danswer.direct_qa.open_ai import OpenAICompletionQA
from danswer.direct_qa.qa_block import JsonChatQAHandler
from danswer.direct_qa.qa_block import QABlock
from danswer.direct_qa.qa_block import QAHandler
from danswer.direct_qa.qa_block import SimpleChatQAHandler
from danswer.direct_qa.qa_block import SingleMessageQAHandler
from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor
from danswer.direct_qa.qa_utils import get_gen_ai_api_key
from danswer.direct_qa.request_model import RequestCompletionQA
Expand Down Expand Up @@ -53,7 +53,7 @@ def check_model_api_key_is_valid(model_api_key: str) -> bool:

def get_default_qa_handler(model: str) -> QAHandler:
if model == DanswerGenAIModel.OPENAI_CHAT.value:
return JsonChatQAHandler()
return SingleMessageQAHandler()

return SimpleChatQAHandler()

Expand Down
46 changes: 46 additions & 0 deletions backend/danswer/direct_qa/qa_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
from danswer.direct_qa.interfaces import DanswerAnswerPiece
from danswer.direct_qa.interfaces import DanswerQuotes
from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.qa_prompts import CODE_BLOCK_PAT
from danswer.direct_qa.qa_prompts import GENERAL_SEP_PAT
from danswer.direct_qa.qa_prompts import JsonChatProcessor
from danswer.direct_qa.qa_prompts import QUESTION_PAT
from danswer.direct_qa.qa_prompts import SAMPLE_JSON_RESPONSE
from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT
from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor
Expand Down Expand Up @@ -93,6 +96,49 @@ def process_response(
)


class SingleMessageQAHandler(QAHandler):
def build_prompt(
self, query: str, context_chunks: list[InferenceChunk]
) -> list[BaseMessage]:
complete_answer_not_found_response = (
'{"answer": "' + UNCERTAINTY_PAT + '", "quotes": []}'
)

context_docs_str = "\n".join(
f"{CODE_BLOCK_PAT.format(c.content)}" for c in context_chunks
)

prompt: list[BaseMessage] = [
HumanMessage(
content="You are a question answering system that is constantly learning and improving. "
"You can process and comprehend vast amounts of text and utilize this knowledge "
"to provide accurate and detailed answers to diverse queries.\n"
"You ALWAYS responds in a json containing an answer and quotes that support the answer.\n"
"Your responses are as informative and detailed as possible.\n"
"If you don't know the answer, respond with "
f"{CODE_BLOCK_PAT.format(complete_answer_not_found_response)}"
"\nSample response:"
f"{CODE_BLOCK_PAT.format(json.dumps(SAMPLE_JSON_RESPONSE))}"
f"{GENERAL_SEP_PAT}CONTEXT:\n\n{context_docs_str}"
f"{GENERAL_SEP_PAT}{QUESTION_PAT} {query}"
"\nHint: Make the answer as informative as possible and use a JSON! "
"Quotes MUST be EXACT substrings from provided documents!"
)
]
return prompt

def process_response(
self,
tokens: Iterator[str],
context_chunks: list[InferenceChunk],
) -> AnswerQuestionStreamReturn:
yield from process_model_tokens(
tokens=tokens,
context_docs=context_chunks,
is_json_prompt=True,
)


class JsonChatQAUnshackledHandler(QAHandler):
def build_prompt(
self, query: str, context_chunks: list[InferenceChunk]
Expand Down
3 changes: 2 additions & 1 deletion backend/danswer/direct_qa/qa_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from danswer.connectors.factory import identify_connector_class


GENERAL_SEP_PAT = "---\n"
GENERAL_SEP_PAT = "\n-----\n"
CODE_BLOCK_PAT = "\n```\n{}\n```\n"
DOC_SEP_PAT = "---NEW DOCUMENT---"
DOC_CONTENT_START_PAT = "DOCUMENT CONTENTS:\n"
QUESTION_PAT = "Query:"
Expand Down
3 changes: 3 additions & 0 deletions backend/danswer/direct_qa/qa_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,9 @@ def process_model_tokens(
json_answer_ind = model_output.index('{"answer":')
if json_answer_ind != 0:
model_output = model_output[json_answer_ind:]
end = model_output.rfind("}")
if end != -1:
model_output = model_output[: end + 1]
except ValueError:
logger.exception("Did not find answer pattern in response for JSON prompt")

Expand Down
67 changes: 38 additions & 29 deletions backend/danswer/secondary_llm_flows/query_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,50 +4,59 @@

from danswer.direct_qa.interfaces import DanswerAnswerPiece
from danswer.direct_qa.qa_block import dict_based_prompt_to_langchain_prompt
from danswer.direct_qa.qa_prompts import CODE_BLOCK_PAT
from danswer.llm.build import get_default_llm
from danswer.server.models import QueryValidationResponse
from danswer.server.utils import get_json_line

QUERY_PAT = "QUERY: "
REASONING_PAT = "REASONING: "
ANSWERABLE_PAT = "ANSWERABLE: "
COT_PAT = "\nLet's think step by step"


def get_query_validation_messages(user_query: str) -> list[dict[str, str]]:
ambiguous_example = (
f"{QUERY_PAT}What is this Slack channel about?\n"
f"{REASONING_PAT}First the system must determine which Slack channel is "
f"being referred to. By fetching 5 documents related to Slack channel contents, "
f"it is not possible to determine which Slack channel the user is referring to.\n"
f"{ANSWERABLE_PAT}False"
)

debug_example = (
f"{QUERY_PAT}Danswer is unreachable.\n"
f"{REASONING_PAT}The system searches documents related to Danswer being "
f"unreachable. Assuming the documents from search contains situations where "
f"Danswer is not reachable and contains a fix, the query is answerable.\n"
f"{ANSWERABLE_PAT}True"
)

up_to_date_example = (
f"{QUERY_PAT}How many customers do we have\n"
f"{REASONING_PAT}Assuming the retrieved documents contain up to date customer "
f"acquisition information including a list of customers, the query can be answered. "
f"It is important to note that if the information only exists in a database, "
f"the system is unable to execute SQL and won't find an answer."
f"\n{ANSWERABLE_PAT}True"
)

messages = [
{
"role": "system",
"content": f"You are a helper tool to determine if a query is answerable using retrieval augmented "
"role": "user",
"content": "You are a helper tool to determine if a query is answerable using retrieval augmented "
f"generation. A system will try to answer the user query based on ONLY the top 5 most relevant "
f"documents found from search. Sources contain both up to date and proprietary information for "
f"the specific team. For named or unknown entities, assume the search will always find "
f"consistent knowledge about the entity. Determine if that system should attempt to answer. "
f'"{ANSWERABLE_PAT}" must be exactly "True" or "False"',
},
{"role": "user", "content": "What is this Slack channel about?"},
{
"role": "assistant",
"content": f"{REASONING_PAT}First the system must determine which Slack channel is being referred to."
f"By fetching 5 documents related to Slack channel contents, it is not possible to determine"
f"which Slack channel the user is referring to.\n{ANSWERABLE_PAT}False",
},
{
"role": "user",
"content": f"Danswer is unreachable.{COT_PAT}",
},
{
"role": "assistant",
"content": f"{REASONING_PAT}The system searches documents related to Danswer being "
f"unreachable. Assuming the documents from search contains situations where Danswer is not "
f"reachable and contains a fix, the query is answerable.\n{ANSWERABLE_PAT}True",
},
{"role": "user", "content": f"How many customers do we have?{COT_PAT}"},
{
"role": "assistant",
"content": f"{REASONING_PAT}Assuming the searched documents contains customer acquisition information"
f"including a list of customers, the query can be answered.\n{ANSWERABLE_PAT}True",
f"consistent knowledge about the entity.\n"
f"The system is not tuned for writing code nor for interfacing with structured data "
f"via query languages like SQL.\n"
f"Determine if that system should attempt to answer. "
f'"{ANSWERABLE_PAT}" must be exactly "True" or "False"\n'
f"{CODE_BLOCK_PAT.format(ambiguous_example)}\n"
f"{CODE_BLOCK_PAT.format(debug_example)}\n"
f"{CODE_BLOCK_PAT.format(up_to_date_example)}\n"
f"{CODE_BLOCK_PAT.format(QUERY_PAT + user_query)}\n",
},
{"role": "user", "content": user_query + COT_PAT},
]

return messages
Expand Down
Loading