Skip to content

Commit

Permalink
LLM to validate user Query (#365)
Browse files Browse the repository at this point in the history
Backend Only
  • Loading branch information
yuhongsun96 authored Aug 31, 2023
1 parent 0a77758 commit 51ec251
Show file tree
Hide file tree
Showing 9 changed files with 205 additions and 62 deletions.
15 changes: 1 addition & 14 deletions backend/danswer/direct_qa/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
from danswer.configs.model_configs import GEN_AI_API_KEY
from danswer.configs.model_configs import GEN_AI_ENDPOINT
from danswer.configs.model_configs import GEN_AI_HOST_TYPE
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
from danswer.direct_qa.exceptions import UnknownModelError
from danswer.direct_qa.gpt_4_all import GPT4AllChatCompletionQA
Expand Down Expand Up @@ -62,12 +60,10 @@ def get_default_qa_handler(model: str) -> QAHandler:

def get_default_qa_model(
internal_model: str = INTERNAL_MODEL_VERSION,
model_version: str = GEN_AI_MODEL_VERSION,
endpoint: str | None = GEN_AI_ENDPOINT,
model_host_type: str | None = GEN_AI_HOST_TYPE,
api_key: str | None = GEN_AI_API_KEY,
timeout: int = QA_TIMEOUT,
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
**kwargs: Any,
) -> QAModel:
if not api_key:
Expand All @@ -79,16 +75,7 @@ def get_default_qa_model(
try:
# un-used arguments will be ignored by the underlying `LLM` class
# if any args are missing, a `TypeError` will be thrown
llm = get_default_llm(
model=internal_model,
api_key=api_key,
model_version=model_version,
endpoint=endpoint,
model_host_type=model_host_type,
timeout=timeout,
max_output_tokens=max_output_tokens,
**kwargs,
)
llm = get_default_llm()
qa_handler = get_default_qa_handler(model=internal_model)

return QABlock(
Expand Down
35 changes: 4 additions & 31 deletions backend/danswer/direct_qa/qa_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@
from copy import copy

import tiktoken
from langchain.schema.messages import AIMessage
from langchain.schema.messages import BaseMessage
from langchain.schema.messages import HumanMessage
from langchain.schema.messages import SystemMessage

from danswer.chunking.models import InferenceChunk
from danswer.direct_qa.interfaces import AnswerQuestionReturn
Expand All @@ -19,32 +16,8 @@
from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor
from danswer.direct_qa.qa_utils import process_model_tokens
from danswer.llm.llm import LLM


def _dict_based_prompt_to_langchain_prompt(
messages: list[dict[str, str]]
) -> list[BaseMessage]:
prompt: list[BaseMessage] = []
for message in messages:
role = message.get("role")
content = message.get("content")
if not role:
raise ValueError(f"Message missing `role`: {message}")
if not content:
raise ValueError(f"Message missing `content`: {message}")
elif role == "user":
prompt.append(HumanMessage(content=content))
elif role == "system":
prompt.append(SystemMessage(content=content))
elif role == "assistant":
prompt.append(AIMessage(content=content))
else:
raise ValueError(f"Unknown role: {role}")
return prompt


def _str_prompt_to_langchain_prompt(message: str) -> list[BaseMessage]:
return [HumanMessage(content=message)]
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
from danswer.llm.utils import str_prompt_to_langchain_prompt


class QAHandler(abc.ABC):
Expand All @@ -69,7 +42,7 @@ class JsonChatQAHandler(QAHandler):
def build_prompt(
self, query: str, context_chunks: list[InferenceChunk]
) -> list[BaseMessage]:
return _dict_based_prompt_to_langchain_prompt(
return dict_based_prompt_to_langchain_prompt(
JsonChatProcessor.fill_prompt(
question=query, chunks=context_chunks, include_metadata=False
)
Expand All @@ -91,7 +64,7 @@ class SimpleChatQAHandler(QAHandler):
def build_prompt(
self, query: str, context_chunks: list[InferenceChunk]
) -> list[BaseMessage]:
return _str_prompt_to_langchain_prompt(
return str_prompt_to_langchain_prompt(
WeakModelFreeformProcessor.fill_prompt(
question=query,
chunks=context_chunks,
Expand Down
22 changes: 21 additions & 1 deletion backend/danswer/llm/build.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,36 @@
from typing import Any

from danswer.configs.app_configs import QA_TIMEOUT
from danswer.configs.constants import DanswerGenAIModel
from danswer.configs.model_configs import API_TYPE_OPENAI
from danswer.configs.model_configs import GEN_AI_API_KEY
from danswer.configs.model_configs import GEN_AI_ENDPOINT
from danswer.configs.model_configs import GEN_AI_HOST_TYPE
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
from danswer.llm.azure import AzureGPT
from danswer.llm.llm import LLM
from danswer.llm.openai import OpenAIGPT


def get_default_llm(model: str, **kwargs: Any) -> LLM:
def get_llm_from_model(model: str, **kwargs: Any) -> LLM:
if model == DanswerGenAIModel.OPENAI_CHAT.value:
if API_TYPE_OPENAI == "azure":
return AzureGPT(**kwargs)
return OpenAIGPT(**kwargs)

raise ValueError(f"Unknown LLM model: {model}")


def get_default_llm(**kwargs: Any) -> LLM:
return get_llm_from_model(
model=INTERNAL_MODEL_VERSION,
api_key=GEN_AI_API_KEY,
model_version=GEN_AI_MODEL_VERSION,
endpoint=GEN_AI_ENDPOINT,
model_host_type=GEN_AI_HOST_TYPE,
timeout=QA_TIMEOUT,
max_output_tokens=GEN_AI_MAX_OUTPUT_TOKENS,
**kwargs,
)
50 changes: 39 additions & 11 deletions backend/danswer/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,65 @@

from langchain.prompts.base import StringPromptValue
from langchain.prompts.chat import ChatPromptValue
from langchain.schema import (
PromptValue,
)
from langchain.schema import PromptValue
from langchain.schema.language_model import LanguageModelInput
from langchain.schema.messages import AIMessage
from langchain.schema.messages import BaseMessage
from langchain.schema.messages import BaseMessageChunk
from langchain.schema.messages import HumanMessage
from langchain.schema.messages import SystemMessage

from danswer.configs.app_configs import LOG_LEVEL


def dict_based_prompt_to_langchain_prompt(
messages: list[dict[str, str]]
) -> list[BaseMessage]:
prompt: list[BaseMessage] = []
for message in messages:
role = message.get("role")
content = message.get("content")
if not role:
raise ValueError(f"Message missing `role`: {message}")
if not content:
raise ValueError(f"Message missing `content`: {message}")
elif role == "user":
prompt.append(HumanMessage(content=content))
elif role == "system":
prompt.append(SystemMessage(content=content))
elif role == "assistant":
prompt.append(AIMessage(content=content))
else:
raise ValueError(f"Unknown role: {role}")
return prompt


def str_prompt_to_langchain_prompt(message: str) -> list[BaseMessage]:
return [HumanMessage(content=message)]


def message_generator_to_string_generator(
messages: Iterator[BaseMessageChunk],
) -> Iterator[str]:
for message in messages:
yield message.content


def convert_input(input: LanguageModelInput) -> str:
def convert_input(lm_input: LanguageModelInput) -> str:
"""Heavily inspired by:
https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/chat_models/base.py#L86
"""
prompt_value = None
if isinstance(input, PromptValue):
prompt_value = input
elif isinstance(input, str):
prompt_value = StringPromptValue(text=input)
elif isinstance(input, list):
prompt_value = ChatPromptValue(messages=input)
if isinstance(lm_input, PromptValue):
prompt_value = lm_input
elif isinstance(lm_input, str):
prompt_value = StringPromptValue(text=lm_input)
elif isinstance(lm_input, list):
prompt_value = ChatPromptValue(messages=lm_input)

if prompt_value is None:
raise ValueError(
f"Invalid input type {type(input)}. "
f"Invalid input type {type(lm_input)}. "
"Must be a PromptValue, str, or list of BaseMessages."
)

Expand Down
Empty file.
107 changes: 107 additions & 0 deletions backend/danswer/secondary_llm_flows/query_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import re
from collections.abc import Iterator
from dataclasses import asdict

from danswer.direct_qa.interfaces import DanswerAnswerPiece
from danswer.direct_qa.qa_block import dict_based_prompt_to_langchain_prompt
from danswer.llm.build import get_default_llm
from danswer.server.models import QueryValidationResponse
from danswer.server.utils import get_json_line

REASONING_PAT = "REASONING: "
ANSWERABLE_PAT = "ANSWERABLE: "
COT_PAT = "\nLet's think step by step"


def get_query_validation_messages(user_query: str) -> list[dict[str, str]]:
messages = [
{
"role": "system",
"content": f"You are a helper tool to determine if a query is answerable using retrieval augmented "
f"generation. A system will try to answer the user query based on ONLY the top 5 most relevant "
f"documents found from search. Sources contain both up to date and proprietary information for "
f"the specific team. For named or unknown entities, assume the search will always find "
f"consistent knowledge about the entity. Determine if that system should attempt to answer. "
f'"{ANSWERABLE_PAT}" must be exactly "True" or "False"',
},
{"role": "user", "content": "What is this Slack channel about?"},
{
"role": "assistant",
"content": f"{REASONING_PAT}First the system must determine which Slack channel is being referred to."
f"By fetching 5 documents related to Slack channel contents, it is not possible to determine"
f"which Slack channel the user is referring to.\n{ANSWERABLE_PAT}False",
},
{
"role": "user",
"content": f"Danswer is unreachable.{COT_PAT}",
},
{
"role": "assistant",
"content": f"{REASONING_PAT}The system searches documents related to Danswer being "
f"unreachable. Assuming the documents from search contains situations where Danswer is not "
f"reachable and contains a fix, the query is answerable.\n{ANSWERABLE_PAT}True",
},
{"role": "user", "content": f"How many customers do we have?{COT_PAT}"},
{
"role": "assistant",
"content": f"{REASONING_PAT}Assuming the searched documents contains customer acquisition information"
f"including a list of customers, the query can be answered.\n{ANSWERABLE_PAT}True",
},
{"role": "user", "content": user_query + COT_PAT},
]

return messages


def extract_answerability_reasoning(model_raw: str) -> str:
reasoning_match = re.search(
f"{REASONING_PAT}(.*?){ANSWERABLE_PAT}", model_raw, re.DOTALL
)
reasoning_text = reasoning_match.group(1).strip() if reasoning_match else ""
return reasoning_text


def extract_answerability_bool(model_raw: str) -> bool:
answerable_match = re.search(f"{ANSWERABLE_PAT}(.+)", model_raw)
answerable_text = answerable_match.group(1).strip() if answerable_match else ""
answerable = True if answerable_text.strip().lower() in ["true", "yes"] else False
return answerable


def get_query_answerability(user_query: str) -> tuple[str, bool]:
messages = get_query_validation_messages(user_query)
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
model_output = get_default_llm().invoke(filled_llm_prompt)

reasoning = extract_answerability_reasoning(model_output)
answerable = extract_answerability_bool(model_output)

return reasoning, answerable


def stream_query_answerability(user_query: str) -> Iterator[str]:
messages = get_query_validation_messages(user_query)
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
tokens = get_default_llm().stream(filled_llm_prompt)
reasoning_pat_found = False
model_output = ""
for token in tokens:
model_output = model_output + token

if not reasoning_pat_found and REASONING_PAT in model_output:
reasoning_pat_found = True
remaining = model_output[len(REASONING_PAT) :]
if remaining:
yield get_json_line(asdict(DanswerAnswerPiece(answer_piece=remaining)))
continue

if reasoning_pat_found:
yield get_json_line(asdict(DanswerAnswerPiece(answer_piece=token)))

reasoning = extract_answerability_reasoning(model_output)
answerable = extract_answerability_bool(model_output)

yield get_json_line(
QueryValidationResponse(reasoning=reasoning, answerable=answerable).dict()
)
return
5 changes: 5 additions & 0 deletions backend/danswer/server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,11 @@ class SearchFeedbackRequest(BaseModel):
search_feedback: SearchFeedbackType


class QueryValidationResponse(BaseModel):
reasoning: str
answerable: bool


class SearchResponse(BaseModel):
# For semantic search, top docs are reranked, the remaining are as ordered from retrieval
top_ranked_docs: list[SearchDoc] | None
Expand Down
28 changes: 23 additions & 5 deletions backend/danswer/server/search_backend.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
from collections.abc import Generator
from dataclasses import asdict

Expand Down Expand Up @@ -29,12 +28,16 @@
from danswer.search.models import SearchType
from danswer.search.semantic_search import chunks_to_search_docs
from danswer.search.semantic_search import retrieve_ranked_documents
from danswer.secondary_llm_flows.query_validation import get_query_answerability
from danswer.secondary_llm_flows.query_validation import stream_query_answerability
from danswer.server.models import HelperResponse
from danswer.server.models import QAFeedbackRequest
from danswer.server.models import QAResponse
from danswer.server.models import QueryValidationResponse
from danswer.server.models import QuestionRequest
from danswer.server.models import SearchFeedbackRequest
from danswer.server.models import SearchResponse
from danswer.server.utils import get_json_line
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_generator_function_time

Expand All @@ -43,10 +46,6 @@
router = APIRouter()


def get_json_line(json_dict: dict) -> str:
return json.dumps(json_dict) + "\n"


@router.get("/search-intent")
def get_search_type(
question: QuestionRequest = Depends(), _: User = Depends(current_user)
Expand All @@ -56,6 +55,25 @@ def get_search_type(
return recommend_search_flow(query, use_keyword)


@router.get("/query-validation")
def query_validation(
question: QuestionRequest = Depends(), _: User = Depends(current_user)
) -> QueryValidationResponse:
query = question.query
reasoning, answerable = get_query_answerability(query)
return QueryValidationResponse(reasoning=reasoning, answerable=answerable)


@router.get("/stream-query-validation")
def stream_query_validation(
question: QuestionRequest = Depends(), _: User = Depends(current_user)
) -> StreamingResponse:
query = question.query
return StreamingResponse(
stream_query_answerability(query), media_type="application/json"
)


@router.post("/semantic-search")
def semantic_search(
question: QuestionRequest,
Expand Down
Loading

1 comment on commit 51ec251

@vercel
Copy link

@vercel vercel bot commented on 51ec251 Aug 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.