From 0b40e2a413eb383e84444a452cd58b167d7eb500 Mon Sep 17 00:00:00 2001 From: Carlos Samey Date: Thu, 15 Aug 2024 16:11:01 +0300 Subject: [PATCH 01/10] First commit --- Makefile | 2 +- core_backend/app/config.py | 15 ++++++++++----- core_backend/app/llm_call/process_output.py | 12 +++++++++++- core_backend/app/question_answer/routers.py | 2 +- core_backend/requirements.txt | 1 + 5 files changed, 24 insertions(+), 8 deletions(-) diff --git a/Makefile b/Makefile index f5b0fafc1..7381cce73 100644 --- a/Makefile +++ b/Makefile @@ -49,7 +49,7 @@ setup-db: @sleep 2 @docker run --name postgres-local \ --env-file "$(CURDIR)/deployment/docker-compose/.core_backend.env" \ - -p 5432:5432 \ + -p 5436:5432 \ -d pgvector/pgvector:pg16 set -a && \ source $(CURDIR)/deployment/docker-compose/.base.env && \ diff --git a/core_backend/app/config.py b/core_backend/app/config.py index f12213189..42587947b 100644 --- a/core_backend/app/config.py +++ b/core_backend/app/config.py @@ -9,7 +9,7 @@ POSTGRES_USER = os.environ.get("POSTGRES_USER", "postgres") POSTGRES_PASSWORD = os.environ.get("POSTGRES_PASSWORD", "postgres") POSTGRES_HOST = os.environ.get("POSTGRES_HOST", "localhost") -POSTGRES_PORT = os.environ.get("POSTGRES_PORT", "5432") +POSTGRES_PORT = os.environ.get("POSTGRES_PORT", "5436") POSTGRES_DB = os.environ.get("POSTGRES_DB", "postgres") # LiteLLM proxy variables @@ -25,9 +25,10 @@ LITELLM_MODEL_EMBEDDING = os.environ.get("LITELLM_MODEL_EMBEDDING", "openai/embeddings") LITELLM_MODEL_DEFAULT = os.environ.get("LITELLM_MODEL_DEFAULT", "openai/default") LITELLM_MODEL_GENERATION = os.environ.get( + # "LITELLM_MODEL_GENERATION", + # "openai/generate-gemini-response", "LITELLM_MODEL_GENERATION", - "openai/generate-gemini-response", - # "LITELLM_MODEL_GENERATION", "openai/generate-response" + "openai/generate-response", ) LITELLM_MODEL_LANGUAGE_DETECT = os.environ.get( "LITELLM_MODEL_LANGUAGE_DETECT", "openai/detect-language" @@ -55,9 +56,13 @@ # Alignment Score variables ALIGN_SCORE_THRESHOLD = os.environ.get("ALIGN_SCORE_THRESHOLD", 0.7) # Method: LLM, AlignScore, or None -ALIGN_SCORE_METHOD = os.environ.get("ALIGN_SCORE_METHOD", "LLM") +# ALIGN_SCORE_METHOD = os.environ.get("ALIGN_SCORE_METHOD", "AlignScore") +ALIGN_SCORE_METHOD = "adadad" # if AlignScore, set ALIGN_SCORE_API. If LLM, set LITELLM_MODEL_ALIGNSCORE above. -ALIGN_SCORE_API = os.environ.get("ALIGN_SCORE_API", "") +ALIGN_SCORE_API = os.environ.get( + "ALIGN_SCORE_API", "http://alignscore:5001/alignscore_base" +) +ALIGN_SCORE_N_RETRIES = os.environ.get("ALIGN_SCORE_N_RETRIES", 3) # Backend paths BACKEND_ROOT_PATH = os.environ.get("BACKEND_ROOT_PATH", "") diff --git a/core_backend/app/llm_call/process_output.py b/core_backend/app/llm_call/process_output.py index 4a2eaf650..bc876c51b 100644 --- a/core_backend/app/llm_call/process_output.py +++ b/core_backend/app/llm_call/process_output.py @@ -6,11 +6,13 @@ from typing import Any, Callable, Optional, TypedDict import aiohttp +import backoff from pydantic import ValidationError from ..config import ( ALIGN_SCORE_API, ALIGN_SCORE_METHOD, + ALIGN_SCORE_N_RETRIES, ALIGN_SCORE_THRESHOLD, LITELLM_MODEL_ALIGNSCORE, ) @@ -177,6 +179,7 @@ async def wrapper( metadata = create_langfuse_metadata( query_id=response.query_id, user_id=query_refined.user_id ) + print("We are here in check_align_score__after") response = await _check_align_score(response, metadata) return response @@ -209,11 +212,12 @@ async def _check_align_score( return response align_score_data = AlignScoreData(evidence=evidence, claim=claim) - + print(ALIGN_SCORE_METHOD) if ALIGN_SCORE_METHOD is None: logger.warning("No alignment score method specified.") return response elif ALIGN_SCORE_METHOD == "AlignScore": + if ALIGN_SCORE_API is not None: align_score = await _get_alignScore_score(ALIGN_SCORE_API, align_score_data) else: @@ -256,6 +260,9 @@ async def _check_align_score( return response +@backoff.on_exception( + backoff.expo, RuntimeError, max_tries=ALIGN_SCORE_N_RETRIES, jitter=None +) async def _get_alignScore_score( api_url: str, align_score_date: AlignScoreData ) -> AlignmentScore: @@ -265,6 +272,8 @@ async def _get_alignScore_score( http_client = get_http_client() assert isinstance(http_client, aiohttp.ClientSession) async with http_client.post(api_url, json=align_score_date) as resp: + print("Alignscore tried") + logger.info("AlignScore retried") if resp.status != 200: logger.error(f"AlignScore API request failed with status {resp.status}") raise RuntimeError( @@ -273,6 +282,7 @@ async def _get_alignScore_score( result = await resp.json() logger.info(f"AlignScore result: {result}") + alignment_score = AlignmentScore(score=result["alignscore"], reason="N/A") return alignment_score diff --git a/core_backend/app/question_answer/routers.py b/core_backend/app/question_answer/routers.py index 811cdd6ed..4a49150e5 100644 --- a/core_backend/app/question_answer/routers.py +++ b/core_backend/app/question_answer/routers.py @@ -288,7 +288,7 @@ async def search_base( exclude_archived=exclude_archived, ) response.search_results = search_results - + print("We are here") return response diff --git a/core_backend/requirements.txt b/core_backend/requirements.txt index 6e6b15134..667d5ab9d 100644 --- a/core_backend/requirements.txt +++ b/core_backend/requirements.txt @@ -19,3 +19,4 @@ types-openpyxl==3.1.4.20240621 redis==5.0.8 python-dateutil==2.8.2 gTTS==2.5.1 +backoff==2.2.1 From 98536dab52246252f51c30fe6b8ce7d9fcf92d37 Mon Sep 17 00:00:00 2001 From: Carlos Samey Date: Mon, 19 Aug 2024 15:39:31 +0300 Subject: [PATCH 02/10] Add retry logic --- core_backend/app/config.py | 4 +- core_backend/app/llm_call/process_output.py | 13 ++--- core_backend/app/question_answer/routers.py | 53 +++++++++++++++++++-- 3 files changed, 56 insertions(+), 14 deletions(-) diff --git a/core_backend/app/config.py b/core_backend/app/config.py index 66ac8e37e..2c952c199 100644 --- a/core_backend/app/config.py +++ b/core_backend/app/config.py @@ -63,12 +63,12 @@ ALIGN_SCORE_THRESHOLD = os.environ.get("ALIGN_SCORE_THRESHOLD", 0.7) # Method: LLM, AlignScore, or None # ALIGN_SCORE_METHOD = os.environ.get("ALIGN_SCORE_METHOD", "AlignScore") -ALIGN_SCORE_METHOD = "adadad" +ALIGN_SCORE_METHOD = "LLM" # if AlignScore, set ALIGN_SCORE_API. If LLM, set LITELLM_MODEL_ALIGNSCORE above. ALIGN_SCORE_API = os.environ.get( "ALIGN_SCORE_API", "http://alignscore:5001/alignscore_base" ) -ALIGN_SCORE_N_RETRIES = os.environ.get("ALIGN_SCORE_N_RETRIES", 3) +ALIGN_SCORE_N_RETRIES = os.environ.get("ALIGN_SCORE_N_RETRIES", 1) # Backend paths BACKEND_ROOT_PATH = os.environ.get("BACKEND_ROOT_PATH", "") diff --git a/core_backend/app/llm_call/process_output.py b/core_backend/app/llm_call/process_output.py index bc876c51b..ab61204ba 100644 --- a/core_backend/app/llm_call/process_output.py +++ b/core_backend/app/llm_call/process_output.py @@ -6,13 +6,11 @@ from typing import Any, Callable, Optional, TypedDict import aiohttp -import backoff from pydantic import ValidationError from ..config import ( ALIGN_SCORE_API, ALIGN_SCORE_METHOD, - ALIGN_SCORE_N_RETRIES, ALIGN_SCORE_THRESHOLD, LITELLM_MODEL_ALIGNSCORE, ) @@ -173,13 +171,13 @@ async def wrapper( response = await func(query_refined, response, *args, **kwargs) - if not kwargs.get("generate_llm_response", False): + if not query_refined.generate_llm_response: return response metadata = create_langfuse_metadata( query_id=response.query_id, user_id=query_refined.user_id ) - print("We are here in check_align_score__after") + response = await _check_align_score(response, metadata) return response @@ -196,6 +194,7 @@ async def _check_align_score( Only runs if the generate_llm_response flag is set to True. Requires "llm_response" and "search_results" in the response. """ + if isinstance(response, QueryResponseError) or response.llm_response is None: return response @@ -212,7 +211,7 @@ async def _check_align_score( return response align_score_data = AlignScoreData(evidence=evidence, claim=claim) - print(ALIGN_SCORE_METHOD) + if ALIGN_SCORE_METHOD is None: logger.warning("No alignment score method specified.") return response @@ -260,9 +259,6 @@ async def _check_align_score( return response -@backoff.on_exception( - backoff.expo, RuntimeError, max_tries=ALIGN_SCORE_N_RETRIES, jitter=None -) async def _get_alignScore_score( api_url: str, align_score_date: AlignScoreData ) -> AlignmentScore: @@ -272,7 +268,6 @@ async def _get_alignScore_score( http_client = get_http_client() assert isinstance(http_client, aiohttp.ClientSession) async with http_client.post(api_url, json=align_score_date) as resp: - print("Alignscore tried") logger.info("AlignScore retried") if resp.status != 200: logger.error(f"AlignScore API request failed with status {resp.status}") diff --git a/core_backend/app/question_answer/routers.py b/core_backend/app/question_answer/routers.py index a38f3ebfc..752bb1436 100644 --- a/core_backend/app/question_answer/routers.py +++ b/core_backend/app/question_answer/routers.py @@ -5,13 +5,14 @@ import os from typing import Tuple +import backoff from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile, status from fastapi.responses import JSONResponse from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession from ..auth.dependencies import authenticate_key, rate_limiter -from ..config import SPEECH_ENDPOINT +from ..config import ALIGN_SCORE_N_RETRIES, SPEECH_ENDPOINT from ..contents.models import ( get_similar_content_async, increment_query_count, @@ -42,6 +43,7 @@ ) from .schemas import ( ContentFeedback, + ErrorType, QueryBase, QueryRefined, QueryResponse, @@ -216,9 +218,21 @@ async def search( contents=response.search_results, asession=asession, ) + if is_unable_to_generate_response(response): + failure_reason = response.debug_info["factual_consistency"] + response = await retry_search( + query_refined=user_query_refined_template, + response=response_template, + user_id=user_db.user_id, + n_similar=int(N_TOP_CONTENT), + asession=asession, + exclude_archived=True, + ) + response.debug_info["past_failure"] = failure_reason if type(response) is QueryResponse: return response + elif type(response) is QueryResponseError: return JSONResponse( status_code=status.HTTP_400_BAD_REQUEST, content=response.model_dump() @@ -234,8 +248,8 @@ async def search( @classify_safety__before @translate_question__before @paraphrase_question__before -@generate_llm_response__after @check_align_score__after +@generate_llm_response__after async def search_base( query_refined: QueryRefined, response: QueryResponse, @@ -284,10 +298,43 @@ async def search_base( exclude_archived=exclude_archived, ) response.search_results = search_results - print("We are here") + return response +def is_unable_to_generate_response(response: QueryResponse) -> bool: + """ + Check if the response is of type QueryResponseError and caused + by low alignment score. + """ + return ( + isinstance(response, QueryResponseError) + and response.error_type == ErrorType.ALIGNMENT_TOO_LOW + ) + + +@backoff.on_predicate( + backoff.expo, + max_tries=int(ALIGN_SCORE_N_RETRIES), + predicate=is_unable_to_generate_response, +) +async def retry_search( + query_refined: QueryRefined, + response: QueryResponse | QueryResponseError, + user_id: int, + n_similar: int, + asession: AsyncSession, + exclude_archived: bool = True, +) -> QueryResponse | QueryResponseError: + """ + Retry wrapper for search_base. + """ + + return await search_base( + query_refined, response, user_id, n_similar, asession, exclude_archived + ) + + async def get_user_query_and_response( user_id: int, user_query: QueryBase, asession: AsyncSession ) -> Tuple[QueryDB, QueryRefined, QueryResponse]: From 421637677230b5e01dc5386550af7e72afdba453 Mon Sep 17 00:00:00 2001 From: Carlos Samey Date: Mon, 19 Aug 2024 15:58:16 +0300 Subject: [PATCH 03/10] Add retry logic --- Makefile | 2 +- core_backend/app/config.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index f17f5cbe5..bf5d57973 100644 --- a/Makefile +++ b/Makefile @@ -49,7 +49,7 @@ setup-db: @sleep 2 @docker run --name postgres-local \ --env-file "$(CURDIR)/deployment/docker-compose/.core_backend.env" \ - -p 5436:5432 \ + -p 5432:5432 \ -d pgvector/pgvector:pg16 set -a && \ source "$(CURDIR)/deployment/docker-compose/.base.env" && \ diff --git a/core_backend/app/config.py b/core_backend/app/config.py index 2c952c199..3a9fc0ce6 100644 --- a/core_backend/app/config.py +++ b/core_backend/app/config.py @@ -9,7 +9,7 @@ POSTGRES_USER = os.environ.get("POSTGRES_USER", "postgres") POSTGRES_PASSWORD = os.environ.get("POSTGRES_PASSWORD", "postgres") POSTGRES_HOST = os.environ.get("POSTGRES_HOST", "localhost") -POSTGRES_PORT = os.environ.get("POSTGRES_PORT", "5436") +POSTGRES_PORT = os.environ.get("POSTGRES_PORT", "5432") POSTGRES_DB = os.environ.get("POSTGRES_DB", "postgres") # PGVector variables From 7561d8c90fdc1665bb1de0d710089f7c65f46577 Mon Sep 17 00:00:00 2001 From: Carlos Samey Date: Mon, 19 Aug 2024 16:02:56 +0300 Subject: [PATCH 04/10] Cleanup --- core_backend/app/config.py | 13 +++++-------- core_backend/app/llm_call/process_output.py | 1 - 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/core_backend/app/config.py b/core_backend/app/config.py index 3a9fc0ce6..501417626 100644 --- a/core_backend/app/config.py +++ b/core_backend/app/config.py @@ -31,10 +31,10 @@ LITELLM_MODEL_EMBEDDING = os.environ.get("LITELLM_MODEL_EMBEDDING", "openai/embeddings") LITELLM_MODEL_DEFAULT = os.environ.get("LITELLM_MODEL_DEFAULT", "openai/default") LITELLM_MODEL_GENERATION = os.environ.get( - # "LITELLM_MODEL_GENERATION", - # "openai/generate-gemini-response", "LITELLM_MODEL_GENERATION", - "openai/generate-response", + "openai/generate-gemini-response", + # "LITELLM_MODEL_GENERATION", + # "openai/generate-response", ) LITELLM_MODEL_LANGUAGE_DETECT = os.environ.get( "LITELLM_MODEL_LANGUAGE_DETECT", "openai/detect-language" @@ -62,12 +62,9 @@ # Alignment Score variables ALIGN_SCORE_THRESHOLD = os.environ.get("ALIGN_SCORE_THRESHOLD", 0.7) # Method: LLM, AlignScore, or None -# ALIGN_SCORE_METHOD = os.environ.get("ALIGN_SCORE_METHOD", "AlignScore") -ALIGN_SCORE_METHOD = "LLM" +ALIGN_SCORE_METHOD = os.environ.get("ALIGN_SCORE_METHOD", "LLM") # if AlignScore, set ALIGN_SCORE_API. If LLM, set LITELLM_MODEL_ALIGNSCORE above. -ALIGN_SCORE_API = os.environ.get( - "ALIGN_SCORE_API", "http://alignscore:5001/alignscore_base" -) +ALIGN_SCORE_API = os.environ.get("ALIGN_SCORE_API", "") ALIGN_SCORE_N_RETRIES = os.environ.get("ALIGN_SCORE_N_RETRIES", 1) # Backend paths diff --git a/core_backend/app/llm_call/process_output.py b/core_backend/app/llm_call/process_output.py index ab61204ba..70a80e811 100644 --- a/core_backend/app/llm_call/process_output.py +++ b/core_backend/app/llm_call/process_output.py @@ -268,7 +268,6 @@ async def _get_alignScore_score( http_client = get_http_client() assert isinstance(http_client, aiohttp.ClientSession) async with http_client.post(api_url, json=align_score_date) as resp: - logger.info("AlignScore retried") if resp.status != 200: logger.error(f"AlignScore API request failed with status {resp.status}") raise RuntimeError( From a5a0ac9ca89e3dced01cc6915569ef20e0b68439 Mon Sep 17 00:00:00 2001 From: Carlos Samey Date: Mon, 19 Aug 2024 16:07:55 +0300 Subject: [PATCH 05/10] Cleanup --- core_backend/app/llm_call/process_output.py | 4 ---- core_backend/app/question_answer/routers.py | 1 - 2 files changed, 5 deletions(-) diff --git a/core_backend/app/llm_call/process_output.py b/core_backend/app/llm_call/process_output.py index 70a80e811..3f550e2f9 100644 --- a/core_backend/app/llm_call/process_output.py +++ b/core_backend/app/llm_call/process_output.py @@ -177,7 +177,6 @@ async def wrapper( metadata = create_langfuse_metadata( query_id=response.query_id, user_id=query_refined.user_id ) - response = await _check_align_score(response, metadata) return response @@ -194,7 +193,6 @@ async def _check_align_score( Only runs if the generate_llm_response flag is set to True. Requires "llm_response" and "search_results" in the response. """ - if isinstance(response, QueryResponseError) or response.llm_response is None: return response @@ -216,7 +214,6 @@ async def _check_align_score( logger.warning("No alignment score method specified.") return response elif ALIGN_SCORE_METHOD == "AlignScore": - if ALIGN_SCORE_API is not None: align_score = await _get_alignScore_score(ALIGN_SCORE_API, align_score_data) else: @@ -276,7 +273,6 @@ async def _get_alignScore_score( result = await resp.json() logger.info(f"AlignScore result: {result}") - alignment_score = AlignmentScore(score=result["alignscore"], reason="N/A") return alignment_score diff --git a/core_backend/app/question_answer/routers.py b/core_backend/app/question_answer/routers.py index 752bb1436..7a72abbd0 100644 --- a/core_backend/app/question_answer/routers.py +++ b/core_backend/app/question_answer/routers.py @@ -232,7 +232,6 @@ async def search( if type(response) is QueryResponse: return response - elif type(response) is QueryResponseError: return JSONResponse( status_code=status.HTTP_400_BAD_REQUEST, content=response.model_dump() From ef7cef013040c4c74db44692af1a8d0c27effa68 Mon Sep 17 00:00:00 2001 From: Carlos Samey Date: Tue, 20 Aug 2024 10:48:40 +0300 Subject: [PATCH 06/10] Fix linting --- core_backend/app/question_answer/routers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core_backend/app/question_answer/routers.py b/core_backend/app/question_answer/routers.py index f3e40da5f..755ab348b 100644 --- a/core_backend/app/question_answer/routers.py +++ b/core_backend/app/question_answer/routers.py @@ -13,7 +13,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from ..auth.dependencies import authenticate_key, rate_limiter -from ..config import ALIGN_SCORE_N_RETRIES,GCS_SPEECH_BUCKET, SPEECH_ENDPOINT +from ..config import ALIGN_SCORE_N_RETRIES, GCS_SPEECH_BUCKET, SPEECH_ENDPOINT from ..contents.models import ( get_similar_content_async, increment_query_count, From d08e4437edd53cf5c2e62b1f858c613b74439557 Mon Sep 17 00:00:00 2001 From: Carlos Samey Date: Tue, 20 Aug 2024 17:06:02 +0300 Subject: [PATCH 07/10] Fix linting --- core_backend/app/question_answer/routers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core_backend/app/question_answer/routers.py b/core_backend/app/question_answer/routers.py index 6cf507c91..70d2267a7 100644 --- a/core_backend/app/question_answer/routers.py +++ b/core_backend/app/question_answer/routers.py @@ -7,7 +7,7 @@ from typing import Tuple import backoff -from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile, status +from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status from fastapi.responses import JSONResponse from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession From 8824d09fa547180a42d5d61bbb592516915c0505 Mon Sep 17 00:00:00 2001 From: lickem22 <44327443+lickem22@users.noreply.github.com> Date: Mon, 26 Aug 2024 10:36:26 +0300 Subject: [PATCH 08/10] Update core_backend/requirements.txt Co-authored-by: Suzin You <7042047+suzinyou@users.noreply.github.com> --- core_backend/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core_backend/requirements.txt b/core_backend/requirements.txt index 02e112984..67807ad34 100644 --- a/core_backend/requirements.txt +++ b/core_backend/requirements.txt @@ -20,4 +20,4 @@ redis==5.0.8 python-dateutil==2.8.2 gTTS==2.5.1 backoff==2.2.1 -google-cloud-storage==2.18.2 \ No newline at end of file +google-cloud-storage==2.18.2 From 3bde1734f9db48ba27b8fd6aa7eaebe01d6621ef Mon Sep 17 00:00:00 2001 From: Carlos Samey Date: Wed, 28 Aug 2024 12:59:30 +0300 Subject: [PATCH 09/10] Add retry prompt for llm alignscore failure --- core_backend/app/llm_call/llm_prompts.py | 6 +++++ core_backend/app/llm_call/llm_rag.py | 9 +++++++- core_backend/app/llm_call/process_output.py | 2 +- core_backend/app/question_answer/routers.py | 25 ++++++++++++--------- 4 files changed, 30 insertions(+), 12 deletions(-) diff --git a/core_backend/app/llm_call/llm_prompts.py b/core_backend/app/llm_call/llm_prompts.py index 1c8b31f82..1c1281265 100644 --- a/core_backend/app/llm_call/llm_prompts.py +++ b/core_backend/app/llm_call/llm_prompts.py @@ -177,6 +177,11 @@ def get_prompt(cls) -> str: You are a helpful question-answering AI. You understand user question and answer their \ question using the REFERENCE TEXT below. """ +RETRY_PROMPT_SUFFIX = """ +If the response above is not aligned with the question, please rectify this by considering \ +the following reason(s) for misalignment: "{failure_reason}". Make necessary adjustments \ +to ensure the answer is aligned with the question. +""" RAG_RESPONSE_PROMPT = ( _RAG_PROFILE_PROMPT + """ @@ -224,6 +229,7 @@ class RAG(BaseModel): answer: str prompt: ClassVar[str] = RAG_RESPONSE_PROMPT + retry_prompt: ClassVar[str] = RAG_RESPONSE_PROMPT + RETRY_PROMPT_SUFFIX class AlignmentScore(BaseModel): diff --git a/core_backend/app/llm_call/llm_rag.py b/core_backend/app/llm_call/llm_rag.py index 0f86fa9cb..6c0da0e16 100644 --- a/core_backend/app/llm_call/llm_rag.py +++ b/core_backend/app/llm_call/llm_rag.py @@ -37,7 +37,14 @@ async def get_llm_rag_answer( """ metadata = metadata or {} - prompt = RAG.prompt.format(context=context, original_language=original_language) + if "failure_reason" in metadata and metadata["failure_reason"]: + prompt = RAG.retry_prompt.format( + context=context, + original_language=original_language, + failure_reason=metadata["failure_reason"], + ) + else: + prompt = RAG.prompt.format(context=context, original_language=original_language) result = await _ask_llm_async( user_message=question, diff --git a/core_backend/app/llm_call/process_output.py b/core_backend/app/llm_call/process_output.py index 85695f1f7..f7370d3e0 100644 --- a/core_backend/app/llm_call/process_output.py +++ b/core_backend/app/llm_call/process_output.py @@ -56,7 +56,7 @@ async def generate_llm_query_response( Only runs if the generate_llm_response flag is set to True. Requires "search_results" and "original_language" in the response. """ - if isinstance(response, QueryResponseError): + if isinstance(response, QueryResponseError) and not metadata["failure_reason"]: return response if response.search_results is None: diff --git a/core_backend/app/question_answer/routers.py b/core_backend/app/question_answer/routers.py index d1931a89b..40fb39ec2 100644 --- a/core_backend/app/question_answer/routers.py +++ b/core_backend/app/question_answer/routers.py @@ -125,6 +125,12 @@ async def search( query_refined=user_query_refined_template, response=response, ) + if is_unable_to_generate_response(response): + failure_reason = response.debug_info["factual_consistency"] + response = await retry_search( + query_refined=user_query_refined_template, response=response + ) + response.debug_info["past_failure"] = failure_reason await save_query_response_to_db(user_query_db, response, asession) await increment_query_count( @@ -230,7 +236,6 @@ async def voice_search( asession=asession, exclude_archived=True, ) - if user_query.generate_llm_response: response = await get_generation_response( query_refined=user_query_refined_template, @@ -343,18 +348,15 @@ def is_unable_to_generate_response(response: QueryResponse) -> bool: async def retry_search( query_refined: QueryRefined, response: QueryResponse | QueryResponseError, - user_id: int, - n_similar: int, - asession: AsyncSession, - exclude_archived: bool = True, ) -> QueryResponse | QueryResponseError: """ - Retry wrapper for search_base. + Retry wrapper for get_generation_response. """ - return await search_base( - query_refined, response, user_id, n_similar, asession, exclude_archived - ) + metadata = query_refined.query_metadata + metadata["failure_reason"] = response.debug_info["factual_consistency"]["reason"] + query_refined.query_metadata = metadata + return await get_generation_response(query_refined, response) @generate_tts__after @@ -376,10 +378,13 @@ async def get_generation_response( query_id=response.query_id, user_id=query_refined.user_id ) + metadata["failure_reason"] = query_refined.query_metadata.get( + "failure_reason", None + ) + response = await generate_llm_query_response( query_refined=query_refined, response=response, metadata=metadata ) - return response From 3858453882c358fea4b15e6ceb8f63a0353460ab Mon Sep 17 00:00:00 2001 From: Carlos Samey Date: Wed, 28 Aug 2024 13:12:56 +0300 Subject: [PATCH 10/10] Cleanups --- core_backend/app/llm_call/llm_prompts.py | 6 +++--- core_backend/app/llm_call/process_output.py | 6 +++++- core_backend/app/question_answer/routers.py | 2 +- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/core_backend/app/llm_call/llm_prompts.py b/core_backend/app/llm_call/llm_prompts.py index 1c1281265..1ff5ade80 100644 --- a/core_backend/app/llm_call/llm_prompts.py +++ b/core_backend/app/llm_call/llm_prompts.py @@ -178,9 +178,9 @@ def get_prompt(cls) -> str: question using the REFERENCE TEXT below. """ RETRY_PROMPT_SUFFIX = """ -If the response above is not aligned with the question, please rectify this by considering \ -the following reason(s) for misalignment: "{failure_reason}". Make necessary adjustments \ -to ensure the answer is aligned with the question. +If the response above is not aligned with the question, please rectify this by \ +considering the following reason(s) for misalignment: "{failure_reason}". +Make necessary adjustments to ensure the answer is aligned with the question. """ RAG_RESPONSE_PROMPT = ( _RAG_PROFILE_PROMPT diff --git a/core_backend/app/llm_call/process_output.py b/core_backend/app/llm_call/process_output.py index f7370d3e0..694db60ec 100644 --- a/core_backend/app/llm_call/process_output.py +++ b/core_backend/app/llm_call/process_output.py @@ -56,7 +56,11 @@ async def generate_llm_query_response( Only runs if the generate_llm_response flag is set to True. Requires "search_results" and "original_language" in the response. """ - if isinstance(response, QueryResponseError) and not metadata["failure_reason"]: + if ( + isinstance(response, QueryResponseError) + and metadata + and not metadata["failure_reason"] + ): return response if response.search_results is None: diff --git a/core_backend/app/question_answer/routers.py b/core_backend/app/question_answer/routers.py index 40fb39ec2..7980070d0 100644 --- a/core_backend/app/question_answer/routers.py +++ b/core_backend/app/question_answer/routers.py @@ -302,7 +302,7 @@ async def get_search_response( n_similar The number of similar contents to retrieve. asession - `AsyncSession` object for database POtransactions. + `AsyncSession` object for database transactions. exclude_archived Specifies whether to exclude archived content.