diff --git a/backend/danswer/direct_qa/llm_utils.py b/backend/danswer/direct_qa/llm_utils.py index 453ca182052..146c2f5805d 100644 --- a/backend/danswer/direct_qa/llm_utils.py +++ b/backend/danswer/direct_qa/llm_utils.py @@ -9,8 +9,6 @@ from danswer.configs.model_configs import GEN_AI_API_KEY from danswer.configs.model_configs import GEN_AI_ENDPOINT from danswer.configs.model_configs import GEN_AI_HOST_TYPE -from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS -from danswer.configs.model_configs import GEN_AI_MODEL_VERSION from danswer.configs.model_configs import INTERNAL_MODEL_VERSION from danswer.direct_qa.exceptions import UnknownModelError from danswer.direct_qa.gpt_4_all import GPT4AllChatCompletionQA @@ -62,12 +60,10 @@ def get_default_qa_handler(model: str) -> QAHandler: def get_default_qa_model( internal_model: str = INTERNAL_MODEL_VERSION, - model_version: str = GEN_AI_MODEL_VERSION, endpoint: str | None = GEN_AI_ENDPOINT, model_host_type: str | None = GEN_AI_HOST_TYPE, api_key: str | None = GEN_AI_API_KEY, timeout: int = QA_TIMEOUT, - max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS, **kwargs: Any, ) -> QAModel: if not api_key: @@ -79,16 +75,7 @@ def get_default_qa_model( try: # un-used arguments will be ignored by the underlying `LLM` class # if any args are missing, a `TypeError` will be thrown - llm = get_default_llm( - model=internal_model, - api_key=api_key, - model_version=model_version, - endpoint=endpoint, - model_host_type=model_host_type, - timeout=timeout, - max_output_tokens=max_output_tokens, - **kwargs, - ) + llm = get_default_llm() qa_handler = get_default_qa_handler(model=internal_model) return QABlock( diff --git a/backend/danswer/direct_qa/qa_block.py b/backend/danswer/direct_qa/qa_block.py index b6cabafc722..c0f7e2d335d 100644 --- a/backend/danswer/direct_qa/qa_block.py +++ b/backend/danswer/direct_qa/qa_block.py @@ -3,10 +3,7 @@ from copy import copy import tiktoken -from langchain.schema.messages import AIMessage from langchain.schema.messages import BaseMessage -from langchain.schema.messages import HumanMessage -from langchain.schema.messages import SystemMessage from danswer.chunking.models import InferenceChunk from danswer.direct_qa.interfaces import AnswerQuestionReturn @@ -19,32 +16,8 @@ from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor from danswer.direct_qa.qa_utils import process_model_tokens from danswer.llm.llm import LLM - - -def _dict_based_prompt_to_langchain_prompt( - messages: list[dict[str, str]] -) -> list[BaseMessage]: - prompt: list[BaseMessage] = [] - for message in messages: - role = message.get("role") - content = message.get("content") - if not role: - raise ValueError(f"Message missing `role`: {message}") - if not content: - raise ValueError(f"Message missing `content`: {message}") - elif role == "user": - prompt.append(HumanMessage(content=content)) - elif role == "system": - prompt.append(SystemMessage(content=content)) - elif role == "assistant": - prompt.append(AIMessage(content=content)) - else: - raise ValueError(f"Unknown role: {role}") - return prompt - - -def _str_prompt_to_langchain_prompt(message: str) -> list[BaseMessage]: - return [HumanMessage(content=message)] +from danswer.llm.utils import dict_based_prompt_to_langchain_prompt +from danswer.llm.utils import str_prompt_to_langchain_prompt class QAHandler(abc.ABC): @@ -69,7 +42,7 @@ class JsonChatQAHandler(QAHandler): def build_prompt( self, query: str, context_chunks: list[InferenceChunk] ) -> list[BaseMessage]: - return _dict_based_prompt_to_langchain_prompt( + return dict_based_prompt_to_langchain_prompt( JsonChatProcessor.fill_prompt( question=query, chunks=context_chunks, include_metadata=False ) @@ -91,7 +64,7 @@ class SimpleChatQAHandler(QAHandler): def build_prompt( self, query: str, context_chunks: list[InferenceChunk] ) -> list[BaseMessage]: - return _str_prompt_to_langchain_prompt( + return str_prompt_to_langchain_prompt( WeakModelFreeformProcessor.fill_prompt( question=query, chunks=context_chunks, diff --git a/backend/danswer/llm/build.py b/backend/danswer/llm/build.py index 6208204c8fc..31939065c0c 100644 --- a/backend/danswer/llm/build.py +++ b/backend/danswer/llm/build.py @@ -1,16 +1,36 @@ from typing import Any +from danswer.configs.app_configs import QA_TIMEOUT from danswer.configs.constants import DanswerGenAIModel from danswer.configs.model_configs import API_TYPE_OPENAI +from danswer.configs.model_configs import GEN_AI_API_KEY +from danswer.configs.model_configs import GEN_AI_ENDPOINT +from danswer.configs.model_configs import GEN_AI_HOST_TYPE +from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS +from danswer.configs.model_configs import GEN_AI_MODEL_VERSION +from danswer.configs.model_configs import INTERNAL_MODEL_VERSION from danswer.llm.azure import AzureGPT from danswer.llm.llm import LLM from danswer.llm.openai import OpenAIGPT -def get_default_llm(model: str, **kwargs: Any) -> LLM: +def get_llm_from_model(model: str, **kwargs: Any) -> LLM: if model == DanswerGenAIModel.OPENAI_CHAT.value: if API_TYPE_OPENAI == "azure": return AzureGPT(**kwargs) return OpenAIGPT(**kwargs) raise ValueError(f"Unknown LLM model: {model}") + + +def get_default_llm(**kwargs: Any) -> LLM: + return get_llm_from_model( + model=INTERNAL_MODEL_VERSION, + api_key=GEN_AI_API_KEY, + model_version=GEN_AI_MODEL_VERSION, + endpoint=GEN_AI_ENDPOINT, + model_host_type=GEN_AI_HOST_TYPE, + timeout=QA_TIMEOUT, + max_output_tokens=GEN_AI_MAX_OUTPUT_TOKENS, + **kwargs, + ) diff --git a/backend/danswer/llm/utils.py b/backend/danswer/llm/utils.py index 3e10577d586..2f5a80a773c 100644 --- a/backend/danswer/llm/utils.py +++ b/backend/danswer/llm/utils.py @@ -2,15 +2,43 @@ from langchain.prompts.base import StringPromptValue from langchain.prompts.chat import ChatPromptValue -from langchain.schema import ( - PromptValue, -) +from langchain.schema import PromptValue from langchain.schema.language_model import LanguageModelInput +from langchain.schema.messages import AIMessage +from langchain.schema.messages import BaseMessage from langchain.schema.messages import BaseMessageChunk +from langchain.schema.messages import HumanMessage +from langchain.schema.messages import SystemMessage from danswer.configs.app_configs import LOG_LEVEL +def dict_based_prompt_to_langchain_prompt( + messages: list[dict[str, str]] +) -> list[BaseMessage]: + prompt: list[BaseMessage] = [] + for message in messages: + role = message.get("role") + content = message.get("content") + if not role: + raise ValueError(f"Message missing `role`: {message}") + if not content: + raise ValueError(f"Message missing `content`: {message}") + elif role == "user": + prompt.append(HumanMessage(content=content)) + elif role == "system": + prompt.append(SystemMessage(content=content)) + elif role == "assistant": + prompt.append(AIMessage(content=content)) + else: + raise ValueError(f"Unknown role: {role}") + return prompt + + +def str_prompt_to_langchain_prompt(message: str) -> list[BaseMessage]: + return [HumanMessage(content=message)] + + def message_generator_to_string_generator( messages: Iterator[BaseMessageChunk], ) -> Iterator[str]: @@ -18,21 +46,21 @@ def message_generator_to_string_generator( yield message.content -def convert_input(input: LanguageModelInput) -> str: +def convert_input(lm_input: LanguageModelInput) -> str: """Heavily inspired by: https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/chat_models/base.py#L86 """ prompt_value = None - if isinstance(input, PromptValue): - prompt_value = input - elif isinstance(input, str): - prompt_value = StringPromptValue(text=input) - elif isinstance(input, list): - prompt_value = ChatPromptValue(messages=input) + if isinstance(lm_input, PromptValue): + prompt_value = lm_input + elif isinstance(lm_input, str): + prompt_value = StringPromptValue(text=lm_input) + elif isinstance(lm_input, list): + prompt_value = ChatPromptValue(messages=lm_input) if prompt_value is None: raise ValueError( - f"Invalid input type {type(input)}. " + f"Invalid input type {type(lm_input)}. " "Must be a PromptValue, str, or list of BaseMessages." ) diff --git a/backend/danswer/secondary_llm_flows/__init__.py b/backend/danswer/secondary_llm_flows/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/backend/danswer/secondary_llm_flows/query_validation.py b/backend/danswer/secondary_llm_flows/query_validation.py new file mode 100644 index 00000000000..4b536b2f2b6 --- /dev/null +++ b/backend/danswer/secondary_llm_flows/query_validation.py @@ -0,0 +1,107 @@ +import re +from collections.abc import Iterator +from dataclasses import asdict + +from danswer.direct_qa.interfaces import DanswerAnswerPiece +from danswer.direct_qa.qa_block import dict_based_prompt_to_langchain_prompt +from danswer.llm.build import get_default_llm +from danswer.server.models import QueryValidationResponse +from danswer.server.utils import get_json_line + +REASONING_PAT = "REASONING: " +ANSWERABLE_PAT = "ANSWERABLE: " +COT_PAT = "\nLet's think step by step" + + +def get_query_validation_messages(user_query: str) -> list[dict[str, str]]: + messages = [ + { + "role": "system", + "content": f"You are a helper tool to determine if a query is answerable using retrieval augmented " + f"generation. A system will try to answer the user query based on ONLY the top 5 most relevant " + f"documents found from search. Sources contain both up to date and proprietary information for " + f"the specific team. For named or unknown entities, assume the search will always find " + f"consistent knowledge about the entity. Determine if that system should attempt to answer. " + f'"{ANSWERABLE_PAT}" must be exactly "True" or "False"', + }, + {"role": "user", "content": "What is this Slack channel about?"}, + { + "role": "assistant", + "content": f"{REASONING_PAT}First the system must determine which Slack channel is being referred to." + f"By fetching 5 documents related to Slack channel contents, it is not possible to determine" + f"which Slack channel the user is referring to.\n{ANSWERABLE_PAT}False", + }, + { + "role": "user", + "content": f"Danswer is unreachable.{COT_PAT}", + }, + { + "role": "assistant", + "content": f"{REASONING_PAT}The system searches documents related to Danswer being " + f"unreachable. Assuming the documents from search contains situations where Danswer is not " + f"reachable and contains a fix, the query is answerable.\n{ANSWERABLE_PAT}True", + }, + {"role": "user", "content": f"How many customers do we have?{COT_PAT}"}, + { + "role": "assistant", + "content": f"{REASONING_PAT}Assuming the searched documents contains customer acquisition information" + f"including a list of customers, the query can be answered.\n{ANSWERABLE_PAT}True", + }, + {"role": "user", "content": user_query + COT_PAT}, + ] + + return messages + + +def extract_answerability_reasoning(model_raw: str) -> str: + reasoning_match = re.search( + f"{REASONING_PAT}(.*?){ANSWERABLE_PAT}", model_raw, re.DOTALL + ) + reasoning_text = reasoning_match.group(1).strip() if reasoning_match else "" + return reasoning_text + + +def extract_answerability_bool(model_raw: str) -> bool: + answerable_match = re.search(f"{ANSWERABLE_PAT}(.+)", model_raw) + answerable_text = answerable_match.group(1).strip() if answerable_match else "" + answerable = True if answerable_text.strip().lower() in ["true", "yes"] else False + return answerable + + +def get_query_answerability(user_query: str) -> tuple[str, bool]: + messages = get_query_validation_messages(user_query) + filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages) + model_output = get_default_llm().invoke(filled_llm_prompt) + + reasoning = extract_answerability_reasoning(model_output) + answerable = extract_answerability_bool(model_output) + + return reasoning, answerable + + +def stream_query_answerability(user_query: str) -> Iterator[str]: + messages = get_query_validation_messages(user_query) + filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages) + tokens = get_default_llm().stream(filled_llm_prompt) + reasoning_pat_found = False + model_output = "" + for token in tokens: + model_output = model_output + token + + if not reasoning_pat_found and REASONING_PAT in model_output: + reasoning_pat_found = True + remaining = model_output[len(REASONING_PAT) :] + if remaining: + yield get_json_line(asdict(DanswerAnswerPiece(answer_piece=remaining))) + continue + + if reasoning_pat_found: + yield get_json_line(asdict(DanswerAnswerPiece(answer_piece=token))) + + reasoning = extract_answerability_reasoning(model_output) + answerable = extract_answerability_bool(model_output) + + yield get_json_line( + QueryValidationResponse(reasoning=reasoning, answerable=answerable).dict() + ) + return diff --git a/backend/danswer/server/models.py b/backend/danswer/server/models.py index db9efd6bd3a..292fba396b5 100644 --- a/backend/danswer/server/models.py +++ b/backend/danswer/server/models.py @@ -151,6 +151,11 @@ class SearchFeedbackRequest(BaseModel): search_feedback: SearchFeedbackType +class QueryValidationResponse(BaseModel): + reasoning: str + answerable: bool + + class SearchResponse(BaseModel): # For semantic search, top docs are reranked, the remaining are as ordered from retrieval top_ranked_docs: list[SearchDoc] | None diff --git a/backend/danswer/server/search_backend.py b/backend/danswer/server/search_backend.py index 49709598028..37950735517 100644 --- a/backend/danswer/server/search_backend.py +++ b/backend/danswer/server/search_backend.py @@ -1,4 +1,3 @@ -import json from collections.abc import Generator from dataclasses import asdict @@ -29,12 +28,16 @@ from danswer.search.models import SearchType from danswer.search.semantic_search import chunks_to_search_docs from danswer.search.semantic_search import retrieve_ranked_documents +from danswer.secondary_llm_flows.query_validation import get_query_answerability +from danswer.secondary_llm_flows.query_validation import stream_query_answerability from danswer.server.models import HelperResponse from danswer.server.models import QAFeedbackRequest from danswer.server.models import QAResponse +from danswer.server.models import QueryValidationResponse from danswer.server.models import QuestionRequest from danswer.server.models import SearchFeedbackRequest from danswer.server.models import SearchResponse +from danswer.server.utils import get_json_line from danswer.utils.logger import setup_logger from danswer.utils.timing import log_generator_function_time @@ -43,10 +46,6 @@ router = APIRouter() -def get_json_line(json_dict: dict) -> str: - return json.dumps(json_dict) + "\n" - - @router.get("/search-intent") def get_search_type( question: QuestionRequest = Depends(), _: User = Depends(current_user) @@ -56,6 +55,25 @@ def get_search_type( return recommend_search_flow(query, use_keyword) +@router.get("/query-validation") +def query_validation( + question: QuestionRequest = Depends(), _: User = Depends(current_user) +) -> QueryValidationResponse: + query = question.query + reasoning, answerable = get_query_answerability(query) + return QueryValidationResponse(reasoning=reasoning, answerable=answerable) + + +@router.get("/stream-query-validation") +def stream_query_validation( + question: QuestionRequest = Depends(), _: User = Depends(current_user) +) -> StreamingResponse: + query = question.query + return StreamingResponse( + stream_query_answerability(query), media_type="application/json" + ) + + @router.post("/semantic-search") def semantic_search( question: QuestionRequest, diff --git a/backend/danswer/server/utils.py b/backend/danswer/server/utils.py index f18db93a74f..bf535661878 100644 --- a/backend/danswer/server/utils.py +++ b/backend/danswer/server/utils.py @@ -1,6 +1,11 @@ +import json from typing import Any +def get_json_line(json_dict: dict) -> str: + return json.dumps(json_dict) + "\n" + + def mask_string(sensitive_str: str) -> str: return "****...**" + sensitive_str[-4:]