-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DEPRIORITIZED][AAQ-765] Retry LLM generation when AlignScore fails #399
base: main
Are you sure you want to change the base?
Changes from all commits
0b40e2a
ffef8d5
98536da
4216376
7561d8c
a5a0ac9
c738e64
ef7cef0
2f10b2f
d08e443
8824d09
eef90a1
6306091
3bde173
3858453
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -177,6 +177,11 @@ def get_prompt(cls) -> str: | |
You are a helpful question-answering AI. You understand user question and answer their \ | ||
question using the REFERENCE TEXT below. | ||
""" | ||
RETRY_PROMPT_SUFFIX = """ | ||
If the response above is not aligned with the question, please rectify this by \ | ||
considering the following reason(s) for misalignment: "{failure_reason}". | ||
Make necessary adjustments to ensure the answer is aligned with the question. | ||
""" | ||
Comment on lines
+180
to
+184
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right now, we are only passing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, shouldn't the prompt define what we mean by alignment? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That makes sense. To be honest, I was just having a go at updating the prompt to take the output into consideration. I am not exactly an expert in prompt engineering. Should we discuss that in a tech session? |
||
RAG_RESPONSE_PROMPT = ( | ||
_RAG_PROFILE_PROMPT | ||
+ """ | ||
|
@@ -224,6 +229,7 @@ class RAG(BaseModel): | |
answer: str | ||
|
||
prompt: ClassVar[str] = RAG_RESPONSE_PROMPT | ||
retry_prompt: ClassVar[str] = RAG_RESPONSE_PROMPT + RETRY_PROMPT_SUFFIX | ||
|
||
|
||
class AlignmentScore(BaseModel): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,7 +37,14 @@ async def get_llm_rag_answer( | |
""" | ||
|
||
metadata = metadata or {} | ||
prompt = RAG.prompt.format(context=context, original_language=original_language) | ||
if "failure_reason" in metadata and metadata["failure_reason"]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about we create a new arg, "retry=False"? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The downside is
But I think it would be easier to understand the code, and we won't be hiding any unexpected actions! What do you think? Something like if is_retry:
if "failure_reason" not in metadata:
raise ValueError("failure_reason is required for retry requests")
prompt = RAG.retry_prompt.format(
context=context,
original_language=original_language,
failure_reason=metadata["failure_reason"],
) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My initial understanding was that we are using this to try the functionality. What if we keep it like this while testing, and if it turns out to be something we want to keep, the we will explicitly set it as a functionality by addind the |
||
prompt = RAG.retry_prompt.format( | ||
context=context, | ||
original_language=original_language, | ||
failure_reason=metadata["failure_reason"], | ||
) | ||
else: | ||
prompt = RAG.prompt.format(context=context, original_language=original_language) | ||
|
||
result = await _ask_llm_async( | ||
user_message=question, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,13 +6,14 @@ | |
from io import BytesIO | ||
from typing import Tuple | ||
|
||
import backoff | ||
from fastapi import APIRouter, Depends, File, UploadFile, status | ||
from fastapi.responses import JSONResponse | ||
from sqlalchemy.exc import IntegrityError | ||
from sqlalchemy.ext.asyncio import AsyncSession | ||
|
||
from ..auth.dependencies import authenticate_key, rate_limiter | ||
from ..config import CUSTOM_SPEECH_ENDPOINT, GCS_SPEECH_BUCKET | ||
from ..config import ALIGN_SCORE_N_RETRIES, CUSTOM_SPEECH_ENDPOINT, GCS_SPEECH_BUCKET | ||
from ..contents.models import ( | ||
get_similar_content_async, | ||
increment_query_count, | ||
|
@@ -50,6 +51,7 @@ | |
) | ||
from .schemas import ( | ||
ContentFeedback, | ||
ErrorType, | ||
QueryAudioResponse, | ||
QueryBase, | ||
QueryRefined, | ||
|
@@ -123,6 +125,12 @@ async def search( | |
query_refined=user_query_refined_template, | ||
response=response, | ||
) | ||
if is_unable_to_generate_response(response): | ||
failure_reason = response.debug_info["factual_consistency"] | ||
response = await retry_search( | ||
query_refined=user_query_refined_template, response=response | ||
) | ||
response.debug_info["past_failure"] = failure_reason | ||
|
||
await save_query_response_to_db(user_query_db, response, asession) | ||
await increment_query_count( | ||
|
@@ -228,7 +236,6 @@ async def voice_search( | |
asession=asession, | ||
exclude_archived=True, | ||
) | ||
|
||
if user_query.generate_llm_response: | ||
response = await get_generation_response( | ||
query_refined=user_query_refined_template, | ||
|
@@ -322,6 +329,36 @@ async def get_search_response( | |
return response | ||
|
||
|
||
def is_unable_to_generate_response(response: QueryResponse) -> bool: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added this function retry only if that condition is met. |
||
""" | ||
Check if the response is of type QueryResponseError and caused | ||
by low alignment score. | ||
""" | ||
return ( | ||
isinstance(response, QueryResponseError) | ||
and response.error_type == ErrorType.ALIGNMENT_TOO_LOW | ||
) | ||
|
||
|
||
@backoff.on_predicate( | ||
backoff.expo, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we just have a logic that retries once, instead of adding a config (num retries) we don't know if we'll use 🤔 ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess it depends on how useful the approach is, since we haven't done any analysis to see how well it works. |
||
max_tries=int(ALIGN_SCORE_N_RETRIES), | ||
predicate=is_unable_to_generate_response, | ||
) | ||
async def retry_search( | ||
suzinyou marked this conversation as resolved.
Show resolved
Hide resolved
|
||
query_refined: QueryRefined, | ||
response: QueryResponse | QueryResponseError, | ||
) -> QueryResponse | QueryResponseError: | ||
""" | ||
Retry wrapper for get_generation_response. | ||
""" | ||
|
||
metadata = query_refined.query_metadata | ||
metadata["failure_reason"] = response.debug_info["factual_consistency"]["reason"] | ||
query_refined.query_metadata = metadata | ||
return await get_generation_response(query_refined, response) | ||
|
||
|
||
@generate_tts__after | ||
@check_align_score__after | ||
async def get_generation_response( | ||
|
@@ -341,10 +378,13 @@ async def get_generation_response( | |
query_id=response.query_id, user_id=query_refined.user_id | ||
) | ||
|
||
metadata["failure_reason"] = query_refined.query_metadata.get( | ||
"failure_reason", None | ||
) | ||
|
||
response = await generate_llm_query_response( | ||
query_refined=query_refined, response=response, metadata=metadata | ||
) | ||
|
||
return response | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a suffix to the prompt to incorporate failure reason