Skip to content

Commit

Permalink
Add LangChain-based LLM
Browse files Browse the repository at this point in the history
  • Loading branch information
Weves committed Aug 27, 2023
1 parent 642862b commit 627af67
Show file tree
Hide file tree
Showing 16 changed files with 469 additions and 122 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
.env
.env
.DS_store
4 changes: 2 additions & 2 deletions backend/danswer/direct_qa/answer_question.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from danswer.db.models import User
from danswer.direct_qa.exceptions import OpenAIKeyMissing
from danswer.direct_qa.exceptions import UnknownModelError
from danswer.direct_qa.llm_utils import get_default_llm
from danswer.direct_qa.llm_utils import get_default_qa_model
from danswer.search.danswer_helper import query_intent
from danswer.search.keyword_search import retrieve_keyword_documents
from danswer.search.models import QueryFlow
Expand Down Expand Up @@ -73,7 +73,7 @@ def answer_question(
)

try:
qa_model = get_default_llm(timeout=answer_generation_timeout)
qa_model = get_default_qa_model(timeout=answer_generation_timeout)
except (UnknownModelError, OpenAIKeyMissing) as e:
return QAResponse(
answer=None,
Expand Down
48 changes: 43 additions & 5 deletions backend/danswer/direct_qa/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from danswer.configs.model_configs import GEN_AI_API_KEY
from danswer.configs.model_configs import GEN_AI_ENDPOINT
from danswer.configs.model_configs import GEN_AI_HOST_TYPE
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
from danswer.direct_qa.exceptions import UnknownModelError
from danswer.direct_qa.gpt_4_all import GPT4AllChatCompletionQA
Expand All @@ -17,12 +19,16 @@
from danswer.direct_qa.huggingface import HuggingFaceCompletionQA
from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.local_transformers import TransformerQA
from danswer.direct_qa.open_ai import OpenAIChatCompletionQA
from danswer.direct_qa.open_ai import OpenAICompletionQA
from danswer.direct_qa.qa_block import JsonChatQAHandler
from danswer.direct_qa.qa_block import QABlock
from danswer.direct_qa.qa_block import QAHandler
from danswer.direct_qa.qa_block import SimpleChatQAHandler
from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor
from danswer.direct_qa.qa_utils import get_gen_ai_api_key
from danswer.direct_qa.request_model import RequestCompletionQA
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.llm.build import get_default_llm
from danswer.utils.logger import setup_logger

logger = setup_logger()
Expand All @@ -32,7 +38,7 @@ def check_model_api_key_is_valid(model_api_key: str) -> bool:
if not model_api_key:
return False

qa_model = get_default_llm(api_key=model_api_key, timeout=5)
qa_model = get_default_qa_model(api_key=model_api_key, timeout=5)

# try for up to 2 timeouts (e.g. 10 seconds in total)
for _ in range(2):
Expand All @@ -47,12 +53,21 @@ def check_model_api_key_is_valid(model_api_key: str) -> bool:
return False


def get_default_llm(
def get_default_qa_handler(model: str) -> QAHandler:
if model == DanswerGenAIModel.OPENAI_CHAT.value:
return JsonChatQAHandler()

return SimpleChatQAHandler()


def get_default_qa_model(
internal_model: str = INTERNAL_MODEL_VERSION,
model_version: str = GEN_AI_MODEL_VERSION,
endpoint: str | None = GEN_AI_ENDPOINT,
model_host_type: str | None = GEN_AI_HOST_TYPE,
api_key: str | None = GEN_AI_API_KEY,
timeout: int = QA_TIMEOUT,
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
**kwargs: Any,
) -> QAModel:
if not api_key:
Expand All @@ -61,6 +76,31 @@ def get_default_llm(
except ConfigNotFoundError:
pass

try:
# un-used arguments will be ignored by the underlying `LLM` class
# if any args are missing, a `TypeError` will be thrown
llm = get_default_llm(
model=internal_model,
api_key=api_key,
model_version=model_version,
endpoint=endpoint,
model_host_type=model_host_type,
timeout=timeout,
max_output_tokens=max_output_tokens,
**kwargs,
)
qa_handler = get_default_qa_handler(model=internal_model)

return QABlock(
llm=llm,
qa_handler=qa_handler,
)
except:
logger.exception(
"Unable to build a QABlock with the new approach, going back to the "
"legacy approach"
)

if internal_model in [
DanswerGenAIModel.GPT4ALL.value,
DanswerGenAIModel.GPT4ALL_CHAT.value,
Expand All @@ -70,8 +110,6 @@ def get_default_llm(

if internal_model == DanswerGenAIModel.OPENAI.value:
return OpenAICompletionQA(timeout=timeout, api_key=api_key, **kwargs)
elif internal_model == DanswerGenAIModel.OPENAI_CHAT.value:
return OpenAIChatCompletionQA(timeout=timeout, api_key=api_key, **kwargs)
elif internal_model == DanswerGenAIModel.GPT4ALL.value:
return GPT4AllCompletionQA(**kwargs)
elif internal_model == DanswerGenAIModel.GPT4ALL_CHAT.value:
Expand Down
107 changes: 0 additions & 107 deletions backend/danswer/direct_qa/open_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,6 @@
from danswer.direct_qa.interfaces import AnswerQuestionReturn
from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn
from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.qa_prompts import ChatPromptProcessor
from danswer.direct_qa.qa_prompts import get_json_chat_reflexion_msg
from danswer.direct_qa.qa_prompts import JsonChatProcessor
from danswer.direct_qa.qa_prompts import JsonProcessor
from danswer.direct_qa.qa_prompts import NonChatPromptProcessor
from danswer.direct_qa.qa_utils import get_gen_ai_api_key
Expand Down Expand Up @@ -207,107 +204,3 @@ def answer_question_stream(
context_docs=context_docs,
is_json_prompt=self.prompt_processor.specifies_json_output,
)


class OpenAIChatCompletionQA(OpenAIQAModel):
def __init__(
self,
prompt_processor: ChatPromptProcessor = JsonChatProcessor(),
model_version: str = GEN_AI_MODEL_VERSION,
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
timeout: int | None = None,
reflexion_try_count: int = 0,
api_key: str | None = None,
include_metadata: bool = INCLUDE_METADATA,
) -> None:
self.prompt_processor = prompt_processor
self.model_version = model_version
self.max_output_tokens = max_output_tokens
self.reflexion_try_count = reflexion_try_count
self.timeout = timeout
self.include_metadata = include_metadata
self.api_key = api_key

@staticmethod
def _generate_tokens_from_response(response: Any) -> Generator[str, None, None]:
for event in response:
event_dict = cast(dict[str, Any], event["choices"][0]["delta"])
if (
"content" not in event_dict
): # could be a role message or empty termination
continue
yield event_dict["content"]

@log_function_time()
def answer_question(
self,
query: str,
context_docs: list[InferenceChunk],
) -> AnswerQuestionReturn:
context_docs = _tiktoken_trim_chunks(context_docs, self.model_version)

messages = self.prompt_processor.fill_prompt(
query, context_docs, self.include_metadata
)
logger.debug(json.dumps(messages, indent=4))
model_output = ""
for _ in range(self.reflexion_try_count + 1):
openai_call = _handle_openai_exceptions_wrapper(
openai_call=openai.ChatCompletion.create,
query=query,
)
response = openai_call(
**_build_openai_settings(
api_key=_ensure_openai_api_key(self.api_key),
messages=messages,
model=self.model_version,
max_tokens=self.max_output_tokens,
request_timeout=self.timeout,
),
)
model_output = cast(
str, response["choices"][0]["message"]["content"]
).strip()
assistant_msg = {"content": model_output, "role": "assistant"}
messages.extend([assistant_msg, get_json_chat_reflexion_msg()])
logger.info(
"OpenAI Token Usage: " + str(response["usage"]).replace("\n", "")
)

logger.debug(model_output)

answer, quotes = process_answer(model_output, context_docs)
return answer, quotes

def answer_question_stream(
self, query: str, context_docs: list[InferenceChunk]
) -> AnswerQuestionStreamReturn:
context_docs = _tiktoken_trim_chunks(context_docs, self.model_version)

messages = self.prompt_processor.fill_prompt(
query, context_docs, self.include_metadata
)
logger.debug(json.dumps(messages, indent=4))

openai_call = _handle_openai_exceptions_wrapper(
openai_call=openai.ChatCompletion.create,
query=query,
)
response = openai_call(
**_build_openai_settings(
api_key=_ensure_openai_api_key(self.api_key),
messages=messages,
model=self.model_version,
max_tokens=self.max_output_tokens,
request_timeout=self.timeout,
stream=True,
),
)

tokens = self._generate_tokens_from_response(response)

yield from process_model_tokens(
tokens=tokens,
context_docs=context_docs,
is_json_prompt=self.prompt_processor.specifies_json_output,
)
Loading

0 comments on commit 627af67

Please sign in to comment.