diff --git a/.gitignore b/.gitignore index 2eea525d885..6069bf98f11 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ -.env \ No newline at end of file +.env +.DS_store diff --git a/backend/danswer/direct_qa/answer_question.py b/backend/danswer/direct_qa/answer_question.py index 6f9fa6e737d..7af0b668828 100644 --- a/backend/danswer/direct_qa/answer_question.py +++ b/backend/danswer/direct_qa/answer_question.py @@ -6,7 +6,7 @@ from danswer.db.models import User from danswer.direct_qa.exceptions import OpenAIKeyMissing from danswer.direct_qa.exceptions import UnknownModelError -from danswer.direct_qa.llm_utils import get_default_llm +from danswer.direct_qa.llm_utils import get_default_qa_model from danswer.search.danswer_helper import query_intent from danswer.search.keyword_search import retrieve_keyword_documents from danswer.search.models import QueryFlow @@ -73,7 +73,7 @@ def answer_question( ) try: - qa_model = get_default_llm(timeout=answer_generation_timeout) + qa_model = get_default_qa_model(timeout=answer_generation_timeout) except (UnknownModelError, OpenAIKeyMissing) as e: return QAResponse( answer=None, diff --git a/backend/danswer/direct_qa/llm_utils.py b/backend/danswer/direct_qa/llm_utils.py index a49052c9f11..453ca182052 100644 --- a/backend/danswer/direct_qa/llm_utils.py +++ b/backend/danswer/direct_qa/llm_utils.py @@ -9,6 +9,8 @@ from danswer.configs.model_configs import GEN_AI_API_KEY from danswer.configs.model_configs import GEN_AI_ENDPOINT from danswer.configs.model_configs import GEN_AI_HOST_TYPE +from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS +from danswer.configs.model_configs import GEN_AI_MODEL_VERSION from danswer.configs.model_configs import INTERNAL_MODEL_VERSION from danswer.direct_qa.exceptions import UnknownModelError from danswer.direct_qa.gpt_4_all import GPT4AllChatCompletionQA @@ -17,12 +19,16 @@ from danswer.direct_qa.huggingface import HuggingFaceCompletionQA from danswer.direct_qa.interfaces import QAModel from danswer.direct_qa.local_transformers import TransformerQA -from danswer.direct_qa.open_ai import OpenAIChatCompletionQA from danswer.direct_qa.open_ai import OpenAICompletionQA +from danswer.direct_qa.qa_block import JsonChatQAHandler +from danswer.direct_qa.qa_block import QABlock +from danswer.direct_qa.qa_block import QAHandler +from danswer.direct_qa.qa_block import SimpleChatQAHandler from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor from danswer.direct_qa.qa_utils import get_gen_ai_api_key from danswer.direct_qa.request_model import RequestCompletionQA from danswer.dynamic_configs.interface import ConfigNotFoundError +from danswer.llm.build import get_default_llm from danswer.utils.logger import setup_logger logger = setup_logger() @@ -32,7 +38,7 @@ def check_model_api_key_is_valid(model_api_key: str) -> bool: if not model_api_key: return False - qa_model = get_default_llm(api_key=model_api_key, timeout=5) + qa_model = get_default_qa_model(api_key=model_api_key, timeout=5) # try for up to 2 timeouts (e.g. 10 seconds in total) for _ in range(2): @@ -47,12 +53,21 @@ def check_model_api_key_is_valid(model_api_key: str) -> bool: return False -def get_default_llm( +def get_default_qa_handler(model: str) -> QAHandler: + if model == DanswerGenAIModel.OPENAI_CHAT.value: + return JsonChatQAHandler() + + return SimpleChatQAHandler() + + +def get_default_qa_model( internal_model: str = INTERNAL_MODEL_VERSION, + model_version: str = GEN_AI_MODEL_VERSION, endpoint: str | None = GEN_AI_ENDPOINT, model_host_type: str | None = GEN_AI_HOST_TYPE, api_key: str | None = GEN_AI_API_KEY, timeout: int = QA_TIMEOUT, + max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS, **kwargs: Any, ) -> QAModel: if not api_key: @@ -61,6 +76,31 @@ def get_default_llm( except ConfigNotFoundError: pass + try: + # un-used arguments will be ignored by the underlying `LLM` class + # if any args are missing, a `TypeError` will be thrown + llm = get_default_llm( + model=internal_model, + api_key=api_key, + model_version=model_version, + endpoint=endpoint, + model_host_type=model_host_type, + timeout=timeout, + max_output_tokens=max_output_tokens, + **kwargs, + ) + qa_handler = get_default_qa_handler(model=internal_model) + + return QABlock( + llm=llm, + qa_handler=qa_handler, + ) + except: + logger.exception( + "Unable to build a QABlock with the new approach, going back to the " + "legacy approach" + ) + if internal_model in [ DanswerGenAIModel.GPT4ALL.value, DanswerGenAIModel.GPT4ALL_CHAT.value, @@ -70,8 +110,6 @@ def get_default_llm( if internal_model == DanswerGenAIModel.OPENAI.value: return OpenAICompletionQA(timeout=timeout, api_key=api_key, **kwargs) - elif internal_model == DanswerGenAIModel.OPENAI_CHAT.value: - return OpenAIChatCompletionQA(timeout=timeout, api_key=api_key, **kwargs) elif internal_model == DanswerGenAIModel.GPT4ALL.value: return GPT4AllCompletionQA(**kwargs) elif internal_model == DanswerGenAIModel.GPT4ALL_CHAT.value: diff --git a/backend/danswer/direct_qa/open_ai.py b/backend/danswer/direct_qa/open_ai.py index 8e4dbdb1462..ddb3cba7f1e 100644 --- a/backend/danswer/direct_qa/open_ai.py +++ b/backend/danswer/direct_qa/open_ai.py @@ -25,9 +25,6 @@ from danswer.direct_qa.interfaces import AnswerQuestionReturn from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn from danswer.direct_qa.interfaces import QAModel -from danswer.direct_qa.qa_prompts import ChatPromptProcessor -from danswer.direct_qa.qa_prompts import get_json_chat_reflexion_msg -from danswer.direct_qa.qa_prompts import JsonChatProcessor from danswer.direct_qa.qa_prompts import JsonProcessor from danswer.direct_qa.qa_prompts import NonChatPromptProcessor from danswer.direct_qa.qa_utils import get_gen_ai_api_key @@ -207,107 +204,3 @@ def answer_question_stream( context_docs=context_docs, is_json_prompt=self.prompt_processor.specifies_json_output, ) - - -class OpenAIChatCompletionQA(OpenAIQAModel): - def __init__( - self, - prompt_processor: ChatPromptProcessor = JsonChatProcessor(), - model_version: str = GEN_AI_MODEL_VERSION, - max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS, - timeout: int | None = None, - reflexion_try_count: int = 0, - api_key: str | None = None, - include_metadata: bool = INCLUDE_METADATA, - ) -> None: - self.prompt_processor = prompt_processor - self.model_version = model_version - self.max_output_tokens = max_output_tokens - self.reflexion_try_count = reflexion_try_count - self.timeout = timeout - self.include_metadata = include_metadata - self.api_key = api_key - - @staticmethod - def _generate_tokens_from_response(response: Any) -> Generator[str, None, None]: - for event in response: - event_dict = cast(dict[str, Any], event["choices"][0]["delta"]) - if ( - "content" not in event_dict - ): # could be a role message or empty termination - continue - yield event_dict["content"] - - @log_function_time() - def answer_question( - self, - query: str, - context_docs: list[InferenceChunk], - ) -> AnswerQuestionReturn: - context_docs = _tiktoken_trim_chunks(context_docs, self.model_version) - - messages = self.prompt_processor.fill_prompt( - query, context_docs, self.include_metadata - ) - logger.debug(json.dumps(messages, indent=4)) - model_output = "" - for _ in range(self.reflexion_try_count + 1): - openai_call = _handle_openai_exceptions_wrapper( - openai_call=openai.ChatCompletion.create, - query=query, - ) - response = openai_call( - **_build_openai_settings( - api_key=_ensure_openai_api_key(self.api_key), - messages=messages, - model=self.model_version, - max_tokens=self.max_output_tokens, - request_timeout=self.timeout, - ), - ) - model_output = cast( - str, response["choices"][0]["message"]["content"] - ).strip() - assistant_msg = {"content": model_output, "role": "assistant"} - messages.extend([assistant_msg, get_json_chat_reflexion_msg()]) - logger.info( - "OpenAI Token Usage: " + str(response["usage"]).replace("\n", "") - ) - - logger.debug(model_output) - - answer, quotes = process_answer(model_output, context_docs) - return answer, quotes - - def answer_question_stream( - self, query: str, context_docs: list[InferenceChunk] - ) -> AnswerQuestionStreamReturn: - context_docs = _tiktoken_trim_chunks(context_docs, self.model_version) - - messages = self.prompt_processor.fill_prompt( - query, context_docs, self.include_metadata - ) - logger.debug(json.dumps(messages, indent=4)) - - openai_call = _handle_openai_exceptions_wrapper( - openai_call=openai.ChatCompletion.create, - query=query, - ) - response = openai_call( - **_build_openai_settings( - api_key=_ensure_openai_api_key(self.api_key), - messages=messages, - model=self.model_version, - max_tokens=self.max_output_tokens, - request_timeout=self.timeout, - stream=True, - ), - ) - - tokens = self._generate_tokens_from_response(response) - - yield from process_model_tokens( - tokens=tokens, - context_docs=context_docs, - is_json_prompt=self.prompt_processor.specifies_json_output, - ) diff --git a/backend/danswer/direct_qa/qa_block.py b/backend/danswer/direct_qa/qa_block.py new file mode 100644 index 00000000000..2af2c20a981 --- /dev/null +++ b/backend/danswer/direct_qa/qa_block.py @@ -0,0 +1,176 @@ +import abc +import json +from collections.abc import Iterator +from copy import copy + +import tiktoken +from langchain.schema.messages import AIMessage +from langchain.schema.messages import BaseMessage +from langchain.schema.messages import HumanMessage +from langchain.schema.messages import SystemMessage + +from danswer.chunking.models import InferenceChunk +from danswer.direct_qa.interfaces import AnswerQuestionReturn +from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn +from danswer.direct_qa.interfaces import DanswerAnswer +from danswer.direct_qa.interfaces import DanswerAnswerPiece +from danswer.direct_qa.interfaces import DanswerQuotes +from danswer.direct_qa.interfaces import QAModel +from danswer.direct_qa.qa_prompts import JsonChatProcessor +from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor +from danswer.direct_qa.qa_utils import process_model_tokens +from danswer.llm.llm import LLM + + +def _dict_based_prompt_to_langchain_prompt( + messages: list[dict[str, str]] +) -> list[BaseMessage]: + prompt: list[BaseMessage] = [] + for message in messages: + role = message.get("role") + content = message.get("content") + if not role: + raise ValueError(f"Message missing `role`: {message}") + if not content: + raise ValueError(f"Message missing `content`: {message}") + elif role == "user": + prompt.append(HumanMessage(content=content)) + elif role == "system": + prompt.append(SystemMessage(content=content)) + elif role == "assistant": + prompt.append(AIMessage(content=content)) + else: + raise ValueError(f"Unknown role: {role}") + return prompt + + +def _str_prompt_to_langchain_prompt(message: str) -> list[BaseMessage]: + return [HumanMessage(content=message)] + + +class QAHandler(abc.ABC): + """Evolution of the `PromptProcessor` - handles both building the prompt and + processing the response. These are neccessarily coupled, since the prompt determines + the response format (and thus how it should be parsed into an answer + quotes).""" + + @abc.abstractmethod + def build_prompt( + self, query: str, context_chunks: list[InferenceChunk] + ) -> list[BaseMessage]: + raise NotImplementedError + + @abc.abstractmethod + def process_response( + self, tokens: Iterator[str], context_chunks: list[InferenceChunk] + ) -> AnswerQuestionStreamReturn: + raise NotImplementedError + + +class JsonChatQAHandler(QAHandler): + def build_prompt( + self, query: str, context_chunks: list[InferenceChunk] + ) -> list[BaseMessage]: + return _dict_based_prompt_to_langchain_prompt( + JsonChatProcessor.fill_prompt( + question=query, chunks=context_chunks, include_metadata=False + ) + ) + + def process_response( + self, + tokens: Iterator[str], + context_chunks: list[InferenceChunk], + ) -> AnswerQuestionStreamReturn: + yield from process_model_tokens( + tokens=tokens, + context_docs=context_chunks, + is_json_prompt=True, + ) + + +class SimpleChatQAHandler(QAHandler): + def build_prompt( + self, query: str, context_chunks: list[InferenceChunk] + ) -> list[BaseMessage]: + return _str_prompt_to_langchain_prompt( + WeakModelFreeformProcessor.fill_prompt( + question=query, + chunks=context_chunks, + include_metadata=False, + ) + ) + + def process_response( + self, + tokens: Iterator[str], + context_chunks: list[InferenceChunk], + ) -> AnswerQuestionStreamReturn: + yield from process_model_tokens( + tokens=tokens, + context_docs=context_chunks, + is_json_prompt=False, + ) + + +def _tiktoken_trim_chunks( + chunks: list[InferenceChunk], max_chunk_toks: int = 512 +) -> list[InferenceChunk]: + """Edit chunks that have too high token count. Generally due to parsing issues or + characters from another language that are 1 char = 1 token + Trimming by tokens leads to information loss but currently no better way of handling + NOTE: currently gpt-3.5 / gpt-4 tokenizer across all LLMs currently + TODO: make "chunk modification" its own step in the pipeline + """ + encoder = tiktoken.get_encoding("cl100k_base") + new_chunks = copy(chunks) + for ind, chunk in enumerate(new_chunks): + tokens = encoder.encode(chunk.content) + if len(tokens) > max_chunk_toks: + new_chunk = copy(chunk) + new_chunk.content = encoder.decode(tokens[:max_chunk_toks]) + new_chunks[ind] = new_chunk + return new_chunks + + +class QABlock(QAModel): + def __init__(self, llm: LLM, qa_handler: QAHandler) -> None: + self._llm = llm + self._qa_handler = qa_handler + + def warm_up_model(self) -> None: + """This is called during server start up to load the models into memory + in case the chosen LLM is not accessed via API""" + self._llm.stream("Ignore this!") + + def answer_question( + self, + query: str, + context_docs: list[InferenceChunk], + ) -> AnswerQuestionReturn: + trimmed_context_docs = _tiktoken_trim_chunks(context_docs) + prompt = self._qa_handler.build_prompt(query, trimmed_context_docs) + tokens = self._llm.stream(prompt) + + final_answer = "" + quotes = DanswerQuotes([]) + for output in self._qa_handler.process_response(tokens, trimmed_context_docs): + if output is None: + continue + + if isinstance(output, DanswerAnswerPiece): + if output.answer_piece: + final_answer += output.answer_piece + elif isinstance(output, DanswerQuotes): + quotes = output + + return DanswerAnswer(final_answer), quotes + + def answer_question_stream( + self, + query: str, + context_docs: list[InferenceChunk], + ) -> AnswerQuestionStreamReturn: + trimmed_context_docs = _tiktoken_trim_chunks(context_docs) + prompt = self._qa_handler.build_prompt(query, trimmed_context_docs) + tokens = self._llm.stream(prompt) + yield from self._qa_handler.process_response(tokens, trimmed_context_docs) diff --git a/backend/danswer/direct_qa/qa_utils.py b/backend/danswer/direct_qa/qa_utils.py index c667c81d66a..70078c48280 100644 --- a/backend/danswer/direct_qa/qa_utils.py +++ b/backend/danswer/direct_qa/qa_utils.py @@ -2,6 +2,7 @@ import math import re from collections.abc import Generator +from collections.abc import Iterator from typing import cast from typing import Optional from typing import Tuple @@ -191,7 +192,7 @@ def extract_quotes_from_completed_token_stream( def process_model_tokens( - tokens: Generator[str, None, None], + tokens: Iterator[str], context_docs: list[InferenceChunk], is_json_prompt: bool = True, ) -> Generator[DanswerAnswerPiece | DanswerQuotes, None, None]: diff --git a/backend/danswer/llm/azure.py b/backend/danswer/llm/azure.py new file mode 100644 index 00000000000..49a91afacf6 --- /dev/null +++ b/backend/danswer/llm/azure.py @@ -0,0 +1,45 @@ +from typing import Any + +from langchain.chat_models.azure_openai import AzureChatOpenAI + +from danswer.configs.model_configs import API_BASE_OPENAI +from danswer.configs.model_configs import API_VERSION_OPENAI +from danswer.configs.model_configs import AZURE_DEPLOYMENT_ID +from danswer.llm.llm import LangChainChatLLM +from danswer.llm.utils import should_be_verbose + + +class AzureGPT(LangChainChatLLM): + def __init__( + self, + api_key: str, + max_output_tokens: int, + timeout: int, + model_version: str, + api_base: str = API_BASE_OPENAI, + api_version: str = API_VERSION_OPENAI, + deployment_name: str = AZURE_DEPLOYMENT_ID, + *args: list[Any], + **kwargs: dict[str, Any] + ): + self._llm = AzureChatOpenAI( + model=model_version, + openai_api_type="azure", + openai_api_base=api_base, + openai_api_version=api_version, + deployment_name=deployment_name, + openai_api_key=api_key, + max_tokens=max_output_tokens, + temperature=0, + request_timeout=timeout, + model_kwargs={ + "top_p": 1, + "frequency_penalty": 0, + "presence_penalty": 0, + }, + verbose=should_be_verbose(), + ) + + @property + def llm(self) -> AzureChatOpenAI: + return self._llm diff --git a/backend/danswer/llm/build.py b/backend/danswer/llm/build.py new file mode 100644 index 00000000000..6208204c8fc --- /dev/null +++ b/backend/danswer/llm/build.py @@ -0,0 +1,16 @@ +from typing import Any + +from danswer.configs.constants import DanswerGenAIModel +from danswer.configs.model_configs import API_TYPE_OPENAI +from danswer.llm.azure import AzureGPT +from danswer.llm.llm import LLM +from danswer.llm.openai import OpenAIGPT + + +def get_default_llm(model: str, **kwargs: Any) -> LLM: + if model == DanswerGenAIModel.OPENAI_CHAT.value: + if API_TYPE_OPENAI == "azure": + return AzureGPT(**kwargs) + return OpenAIGPT(**kwargs) + + raise ValueError(f"Unknown LLM model: {model}") diff --git a/backend/danswer/llm/google_colab_demo.py b/backend/danswer/llm/google_colab_demo.py new file mode 100644 index 00000000000..2ad85583abe --- /dev/null +++ b/backend/danswer/llm/google_colab_demo.py @@ -0,0 +1,53 @@ +import json +from collections.abc import Iterator +from typing import Any + +import requests +from langchain.schema.language_model import LanguageModelInput +from langchain.schema.messages import BaseMessageChunk +from requests import Timeout + +from danswer.llm.llm import LLM +from danswer.llm.utils import convert_input + + +class GoogleColabDemo(LLM): + def __init__( + self, + endpoint: str, + max_output_tokens: int, + timeout: int, + *args: list[Any], + **kwargs: dict[str, Any], + ): + self._endpoint = endpoint + self._max_output_tokens = max_output_tokens + self._timeout = timeout + + def _execute(self, input: LanguageModelInput) -> str: + headers = { + "Content-Type": "application/json", + } + + data = { + "inputs": convert_input(input), + "parameters": { + "temperature": 0.0, + "max_tokens": self._max_output_tokens, + }, + } + try: + response = requests.post( + self._endpoint, headers=headers, json=data, timeout=self._timeout + ) + except Timeout as error: + raise Timeout(f"Model inference to {self._endpoint} timed out") from error + + response.raise_for_status() + return json.loads(response.content).get("generated_text", "") + + def invoke(self, input: LanguageModelInput) -> str: + return self._execute(input) + + def stream(self, input: LanguageModelInput) -> Iterator[str]: + yield self._execute(input) diff --git a/backend/danswer/llm/llm.py b/backend/danswer/llm/llm.py new file mode 100644 index 00000000000..985994769bd --- /dev/null +++ b/backend/danswer/llm/llm.py @@ -0,0 +1,44 @@ +import abc +from collections.abc import Iterator + +from langchain.chat_models.base import BaseChatModel +from langchain.schema.language_model import LanguageModelInput + +from danswer.llm.utils import message_generator_to_string_generator +from danswer.utils.logger import setup_logger + + +logger = setup_logger() + + +class LLM(abc.ABC): + """Mimics the LangChain LLM / BaseChatModel interfaces to make it easy + to use these implementations to connect to a variety of LLM providers.""" + + @abc.abstractmethod + def invoke(self, input: LanguageModelInput) -> str: + raise NotImplementedError + + @abc.abstractmethod + def stream(self, input: LanguageModelInput) -> Iterator[str]: + raise NotImplementedError + + +class LangChainChatLLM(LLM, abc.ABC): + @property + @abc.abstractmethod + def llm(self) -> BaseChatModel: + raise NotImplementedError + + def _log_model_config(self) -> None: + logger.debug( + f"Model Class: {self.llm.__class__.__name__}, Model Config: {self.llm.__dict__}" + ) + + def invoke(self, input: LanguageModelInput) -> str: + self._log_model_config() + return self.llm.invoke(input).content + + def stream(self, input: LanguageModelInput) -> Iterator[str]: + self._log_model_config() + yield from message_generator_to_string_generator(self.llm.stream(input)) diff --git a/backend/danswer/llm/openai.py b/backend/danswer/llm/openai.py new file mode 100644 index 00000000000..4aa9274a0bc --- /dev/null +++ b/backend/danswer/llm/openai.py @@ -0,0 +1,35 @@ +from typing import Any + +from langchain.chat_models.openai import ChatOpenAI + +from danswer.llm.llm import LangChainChatLLM +from danswer.llm.utils import should_be_verbose + + +class OpenAIGPT(LangChainChatLLM): + def __init__( + self, + api_key: str, + max_output_tokens: int, + timeout: int, + model_version: str, + *args: list[Any], + **kwargs: dict[str, Any] + ): + self._llm = ChatOpenAI( + model=model_version, + openai_api_key=api_key, + max_tokens=max_output_tokens, + temperature=0, + request_timeout=timeout, + model_kwargs={ + "top_p": 1, + "frequency_penalty": 0, + "presence_penalty": 0, + }, + verbose=should_be_verbose(), + ) + + @property + def llm(self) -> ChatOpenAI: + return self._llm diff --git a/backend/danswer/llm/utils.py b/backend/danswer/llm/utils.py new file mode 100644 index 00000000000..3e10577d586 --- /dev/null +++ b/backend/danswer/llm/utils.py @@ -0,0 +1,43 @@ +from collections.abc import Iterator + +from langchain.prompts.base import StringPromptValue +from langchain.prompts.chat import ChatPromptValue +from langchain.schema import ( + PromptValue, +) +from langchain.schema.language_model import LanguageModelInput +from langchain.schema.messages import BaseMessageChunk + +from danswer.configs.app_configs import LOG_LEVEL + + +def message_generator_to_string_generator( + messages: Iterator[BaseMessageChunk], +) -> Iterator[str]: + for message in messages: + yield message.content + + +def convert_input(input: LanguageModelInput) -> str: + """Heavily inspired by: + https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/chat_models/base.py#L86 + """ + prompt_value = None + if isinstance(input, PromptValue): + prompt_value = input + elif isinstance(input, str): + prompt_value = StringPromptValue(text=input) + elif isinstance(input, list): + prompt_value = ChatPromptValue(messages=input) + + if prompt_value is None: + raise ValueError( + f"Invalid input type {type(input)}. " + "Must be a PromptValue, str, or list of BaseMessages." + ) + + return prompt_value.to_string() + + +def should_be_verbose() -> bool: + return LOG_LEVEL == "debug" diff --git a/backend/danswer/main.py b/backend/danswer/main.py index 197bec11473..a6eec74b89c 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -29,7 +29,7 @@ from danswer.configs.model_configs import INTERNAL_MODEL_VERSION from danswer.datastores.document_index import get_default_document_index from danswer.db.credentials import create_initial_public_credential -from danswer.direct_qa.llm_utils import get_default_llm +from danswer.direct_qa.llm_utils import get_default_qa_model from danswer.server.credential import router as credential_router from danswer.server.event_loading import router as event_processing_router from danswer.server.health import router as health_router @@ -178,7 +178,7 @@ def startup_event() -> None: logger.info("Warming up local NLP models.") warm_up_models() - qa_model = get_default_llm() + qa_model = get_default_qa_model() qa_model.warm_up_model() logger.info("Verifying query preprocessing (NLTK) data is downloaded") diff --git a/backend/danswer/server/manage.py b/backend/danswer/server/manage.py index 6dbcea67a06..67813cf7873 100644 --- a/backend/danswer/server/manage.py +++ b/backend/danswer/server/manage.py @@ -57,7 +57,7 @@ from danswer.db.models import DeletionAttempt from danswer.db.models import User from danswer.direct_qa.llm_utils import check_model_api_key_is_valid -from danswer.direct_qa.llm_utils import get_default_llm +from danswer.direct_qa.llm_utils import get_default_qa_model from danswer.direct_qa.open_ai import get_gen_ai_api_key from danswer.dynamic_configs import get_dynamic_config_store from danswer.dynamic_configs.interface import ConfigNotFoundError @@ -423,7 +423,7 @@ def validate_existing_genai_api_key( ) -> None: # OpenAI key is only used for generative QA, so no need to validate this # if it's turned off or if a non-OpenAI model is being used - if DISABLE_GENERATIVE_AI or not get_default_llm().requires_api_key: + if DISABLE_GENERATIVE_AI or not get_default_qa_model().requires_api_key: return # Only validate every so often diff --git a/backend/danswer/server/search_backend.py b/backend/danswer/server/search_backend.py index 084fc255e51..1e4cc47f3eb 100644 --- a/backend/danswer/server/search_backend.py +++ b/backend/danswer/server/search_backend.py @@ -15,7 +15,7 @@ from danswer.direct_qa.answer_question import answer_question from danswer.direct_qa.exceptions import OpenAIKeyMissing from danswer.direct_qa.exceptions import UnknownModelError -from danswer.direct_qa.llm_utils import get_default_llm +from danswer.direct_qa.llm_utils import get_default_qa_model from danswer.search.danswer_helper import query_intent from danswer.search.danswer_helper import recommend_search_flow from danswer.search.keyword_search import retrieve_keyword_documents @@ -174,7 +174,7 @@ def stream_qa_portions( return try: - qa_model = get_default_llm() + qa_model = get_default_qa_model() except (UnknownModelError, OpenAIKeyMissing) as e: logger.exception("Unable to get QA model") yield get_json_line({"error": str(e)}) @@ -199,6 +199,7 @@ def stream_qa_portions( except Exception as e: # exception is logged in the answer_question method, no need to re-log yield get_json_line({"error": str(e)}) + logger.exception("Failed to run QA") return diff --git a/backend/requirements/default.txt b/backend/requirements/default.txt index 7a3970ce829..f4341d3f0b7 100644 --- a/backend/requirements/default.txt +++ b/backend/requirements/default.txt @@ -20,6 +20,7 @@ httpx==0.23.3 httpx-oauth==0.11.2 huggingface-hub==0.16.4 jira==3.5.1 +langchain==0.0.273 Mako==1.2.4 nltk==3.8.1 docx2txt==0.8