Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DEPRIORITIZED][AAQ-765] Retry LLM generation when AlignScore fails #399

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions core_backend/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
ALIGN_SCORE_METHOD = os.environ.get("ALIGN_SCORE_METHOD", "LLM")
# if AlignScore, set ALIGN_SCORE_API. If LLM, set LITELLM_MODEL_ALIGNSCORE above.
ALIGN_SCORE_API = os.environ.get("ALIGN_SCORE_API", "")
ALIGN_SCORE_N_RETRIES = os.environ.get("ALIGN_SCORE_N_RETRIES", 1)

# Backend paths
BACKEND_ROOT_PATH = os.environ.get("BACKEND_ROOT_PATH", "")
Expand Down
6 changes: 6 additions & 0 deletions core_backend/app/llm_call/llm_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,11 @@ def get_prompt(cls) -> str:
You are a helpful question-answering AI. You understand user question and answer their \
question using the REFERENCE TEXT below.
"""
RETRY_PROMPT_SUFFIX = """
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a suffix to the prompt to incorporate failure reason

If the response above is not aligned with the question, please rectify this by \
considering the following reason(s) for misalignment: "{failure_reason}".
Make necessary adjustments to ensure the answer is aligned with the question.
"""
Comment on lines +180 to +184
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now, we are only passing failure_reason which is response.debug_info["factual_consistency"]["reason"],
but we should also include the LLM response in this prompt..

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, shouldn't the prompt define what we mean by alignment?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense. To be honest, I was just having a go at updating the prompt to take the output into consideration. I am not exactly an expert in prompt engineering. Should we discuss that in a tech session?

RAG_RESPONSE_PROMPT = (
_RAG_PROFILE_PROMPT
+ """
Expand Down Expand Up @@ -224,6 +229,7 @@ class RAG(BaseModel):
answer: str

prompt: ClassVar[str] = RAG_RESPONSE_PROMPT
retry_prompt: ClassVar[str] = RAG_RESPONSE_PROMPT + RETRY_PROMPT_SUFFIX


class AlignmentScore(BaseModel):
Expand Down
9 changes: 8 additions & 1 deletion core_backend/app/llm_call/llm_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,14 @@ async def get_llm_rag_answer(
"""

metadata = metadata or {}
prompt = RAG.prompt.format(context=context, original_language=original_language)
if "failure_reason" in metadata and metadata["failure_reason"]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about we create a new arg, "retry=False"?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The downside is

  1. We would have to create it for all the parent functions and
  2. We need both is_retry and metadata["failure_reaon"] to actually do retry.

But I think it would be easier to understand the code, and we won't be hiding any unexpected actions! What do you think?

Something like

    if is_retry:
        if "failure_reason" not in metadata:
            raise ValueError("failure_reason is required for retry requests")
        
        prompt = RAG.retry_prompt.format(
            context=context,
            original_language=original_language,
            failure_reason=metadata["failure_reason"],
        )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My initial understanding was that we are using this to try the functionality. What if we keep it like this while testing, and if it turns out to be something we want to keep, the we will explicitly set it as a functionality by addind the is_retry parameter. What do you think?

prompt = RAG.retry_prompt.format(
context=context,
original_language=original_language,
failure_reason=metadata["failure_reason"],
)
else:
prompt = RAG.prompt.format(context=context, original_language=original_language)

result = await _ask_llm_async(
user_message=question,
Expand Down
6 changes: 5 additions & 1 deletion core_backend/app/llm_call/process_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,11 @@ async def generate_llm_query_response(
Only runs if the generate_llm_response flag is set to True.
Requires "search_results" and "original_language" in the response.
"""
if isinstance(response, QueryResponseError):
if (
isinstance(response, QueryResponseError)
and metadata
and not metadata["failure_reason"]
):
return response

if response.search_results is None:
Expand Down
46 changes: 43 additions & 3 deletions core_backend/app/question_answer/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
from io import BytesIO
from typing import Tuple

import backoff
from fastapi import APIRouter, Depends, File, UploadFile, status
from fastapi.responses import JSONResponse
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession

from ..auth.dependencies import authenticate_key, rate_limiter
from ..config import CUSTOM_SPEECH_ENDPOINT, GCS_SPEECH_BUCKET
from ..config import ALIGN_SCORE_N_RETRIES, CUSTOM_SPEECH_ENDPOINT, GCS_SPEECH_BUCKET
from ..contents.models import (
get_similar_content_async,
increment_query_count,
Expand Down Expand Up @@ -50,6 +51,7 @@
)
from .schemas import (
ContentFeedback,
ErrorType,
QueryAudioResponse,
QueryBase,
QueryRefined,
Expand Down Expand Up @@ -123,6 +125,12 @@ async def search(
query_refined=user_query_refined_template,
response=response,
)
if is_unable_to_generate_response(response):
failure_reason = response.debug_info["factual_consistency"]
response = await retry_search(
query_refined=user_query_refined_template, response=response
)
response.debug_info["past_failure"] = failure_reason

await save_query_response_to_db(user_query_db, response, asession)
await increment_query_count(
Expand Down Expand Up @@ -228,7 +236,6 @@ async def voice_search(
asession=asession,
exclude_archived=True,
)

if user_query.generate_llm_response:
response = await get_generation_response(
query_refined=user_query_refined_template,
Expand Down Expand Up @@ -322,6 +329,36 @@ async def get_search_response(
return response


def is_unable_to_generate_response(response: QueryResponse) -> bool:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this function retry only if that condition is met.

"""
Check if the response is of type QueryResponseError and caused
by low alignment score.
"""
return (
isinstance(response, QueryResponseError)
and response.error_type == ErrorType.ALIGNMENT_TOO_LOW
)


@backoff.on_predicate(
backoff.expo,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What backoff.expo does is basically waiting a little more everytime the function is reran in an exponential way just to handle the load better.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we just have a logic that retries once, instead of adding a config (num retries) we don't know if we'll use 🤔 ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it depends on how useful the approach is, since we haven't done any analysis to see how well it works.
But personnally I think since it doesn't add a dependency (backoff being used by litellm), and since the only code we would change if we retry just once is the decorator and the config variable, the cost is pretty low, so we can just keep it.

max_tries=int(ALIGN_SCORE_N_RETRIES),
predicate=is_unable_to_generate_response,
)
async def retry_search(
suzinyou marked this conversation as resolved.
Show resolved Hide resolved
query_refined: QueryRefined,
response: QueryResponse | QueryResponseError,
) -> QueryResponse | QueryResponseError:
"""
Retry wrapper for get_generation_response.
"""

metadata = query_refined.query_metadata
metadata["failure_reason"] = response.debug_info["factual_consistency"]["reason"]
query_refined.query_metadata = metadata
return await get_generation_response(query_refined, response)


@generate_tts__after
@check_align_score__after
async def get_generation_response(
Expand All @@ -341,10 +378,13 @@ async def get_generation_response(
query_id=response.query_id, user_id=query_refined.user_id
)

metadata["failure_reason"] = query_refined.query_metadata.get(
"failure_reason", None
)

response = await generate_llm_query_response(
query_refined=query_refined, response=response, metadata=metadata
)

return response


Expand Down
1 change: 1 addition & 0 deletions core_backend/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pandas-stubs==2.2.2.240603
types-openpyxl==3.1.4.20240621
redis==5.0.8
python-dateutil==2.8.2
backoff==2.2.1
google-cloud-storage==2.18.2
google-cloud-texttospeech==2.16.5
google-cloud-speech==2.27.0
Expand Down
Loading