From 1fcf1773fd1a46f19d1bb24230661501ea74d981 Mon Sep 17 00:00:00 2001 From: Laurent Sorber Date: Thu, 5 Dec 2024 16:06:58 +0100 Subject: [PATCH] feat: add strict mode option --- src/raglite/_eval.py | 27 +++++++++++++-------------- src/raglite/_extract.py | 12 +++++------- tests/test_extract.py | 26 +++++++++++++++----------- 3 files changed, 33 insertions(+), 32 deletions(-) diff --git a/src/raglite/_eval.py b/src/raglite/_eval.py index a163ba5..b5dd058 100644 --- a/src/raglite/_eval.py +++ b/src/raglite/_eval.py @@ -5,7 +5,7 @@ import numpy as np import pandas as pd -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, ConfigDict, Field, field_validator from sqlmodel import Session, func, select from tqdm.auto import tqdm, trange @@ -25,6 +25,9 @@ def insert_evals( # noqa: C901 class QuestionResponse(BaseModel): """A specific question about the content of a set of document contexts.""" + model_config = ConfigDict( + extra="forbid" # Forbid extra attributes as required by OpenAI's strict mode. + ) question: str = Field( ..., description="A specific question about the content of a set of document contexts." ) @@ -41,11 +44,6 @@ class QuestionResponse(BaseModel): - The question MUST treat the context as if its contents are entirely part of your working memory. """.strip() - class Config: - extra = ( - "forbid" # Ensure no extra fields are allowed as required by OpenAI's strict mode. - ) - @field_validator("question") @classmethod def validate_question(cls, value: str) -> str: @@ -88,7 +86,7 @@ def validate_question(cls, value: str) -> str: # Extract a question from the seed chunk's related chunks. try: question_response = extract_with_llm( - QuestionResponse, related_chunks, config=config + QuestionResponse, related_chunks, strict=True, config=config ) except ValueError: continue @@ -104,6 +102,9 @@ def validate_question(cls, value: str) -> str: class ContextEvalResponse(BaseModel): """Indicate whether the provided context can be used to answer a given question.""" + model_config = ConfigDict( + extra="forbid" # Forbid extra attributes as required by OpenAI's strict mode. + ) hit: bool = Field( ..., description="True if the provided context contains (a part of) the answer to the given question, false otherwise.", @@ -115,16 +116,13 @@ class ContextEvalResponse(BaseModel): An example of a context that does NOT contain (a part of) the answer is a table of contents. """.strip() - class Config: - extra = "forbid" # Ensure no extra fields are allowed as required by OpenAI API's strict mode. - relevant_chunks = [] for candidate_chunk in tqdm( candidate_chunks, desc="Evaluating chunks", unit="chunk", dynamic_ncols=True ): try: context_eval_response = extract_with_llm( - ContextEvalResponse, str(candidate_chunk), config=config + ContextEvalResponse, str(candidate_chunk), strict=True, config=config ) except ValueError: # noqa: PERF203 pass @@ -138,6 +136,9 @@ class Config: class AnswerResponse(BaseModel): """Answer a question using the provided context.""" + model_config = ConfigDict( + extra="forbid" # Forbid extra attributes as required by OpenAI's strict mode. + ) answer: str = Field( ..., description="A complete answer to the given question using the provided context.", @@ -153,13 +154,11 @@ class AnswerResponse(BaseModel): - The answer MUST treat the context as if its contents are entirely part of your working memory. """.strip() - class Config: - extra = "forbid" # Ensure no extra fields are allowed as required by OpenAI API's strict mode. - try: answer_response = extract_with_llm( AnswerResponse, [str(relevant_chunk) for relevant_chunk in relevant_chunks], + strict=True, config=config, ) except ValueError: diff --git a/src/raglite/_extract.py b/src/raglite/_extract.py index f3d73ff..634f6dc 100644 --- a/src/raglite/_extract.py +++ b/src/raglite/_extract.py @@ -2,7 +2,6 @@ from typing import Any, TypeVar -import litellm from litellm import completion, get_supported_openai_params # type: ignore[attr-defined] from pydantic import BaseModel, ValidationError @@ -14,6 +13,7 @@ def extract_with_llm( return_type: type[T], user_prompt: str | list[str], + strict: bool = False, # noqa: FBT001,FBT002 config: RAGLiteConfig | None = None, **kwargs: Any, ) -> T: @@ -41,8 +41,10 @@ class MyNameResponse(BaseModel): str(return_type.model_json_schema()), ) ) - # Constrain the reponse format to the JSON schema if it's supported by the LLM [1]. + # Constrain the reponse format to the JSON schema if it's supported by the LLM [1]. Strict mode + # is disabled by default because it only supports a subset of JSON schema features [2]. # [1] https://docs.litellm.ai/docs/completion/json_mode + # [2] https://platform.openai.com/docs/guides/structured-outputs#some-type-specific-keywords-are-not-yet-supported # TODO: Fall back to {"type": "json_object"} if JSON schema is not supported by the LLM. llm_provider = "llama-cpp-python" if config.embedder.startswith("llama-cpp") else None response_format: dict[str, Any] | None = ( @@ -52,6 +54,7 @@ class MyNameResponse(BaseModel): "name": return_type.__name__, "description": return_type.__doc__ or "", "schema": return_type.model_json_schema(), + "strict": strict, }, } if "response_format" @@ -64,9 +67,6 @@ class MyNameResponse(BaseModel): f'\n{chunk.strip()}\n' for i, chunk in enumerate(user_prompt) ) - # Enable JSON schema validation. - enable_json_schema_validation = litellm.enable_json_schema_validation - litellm.enable_json_schema_validation = True # Extract structured data from the unstructured input. for _ in range(config.llm_max_tries): response = completion( @@ -89,6 +89,4 @@ class MyNameResponse(BaseModel): else: error_message = f"Failed to extract {return_type} from input {user_prompt}." raise ValueError(error_message) from last_exception - # Restore the previous JSON schema validation setting. - litellm.enable_json_schema_validation = enable_json_schema_validation return instance diff --git a/tests/test_extract.py b/tests/test_extract.py index 23eaf2b..ab9d881 100644 --- a/tests/test_extract.py +++ b/tests/test_extract.py @@ -3,7 +3,7 @@ from typing import ClassVar import pytest -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field from raglite import RAGLiteConfig from raglite._extract import extract_with_llm @@ -26,18 +26,22 @@ def test_extract(llm: str) -> None: # Set the LLM. config = RAGLiteConfig(llm=llm) - # Extract structured data. + # Define the JSON schema of the response. class LoginResponse(BaseModel): + model_config = ConfigDict( + extra="forbid" # Forbid extra attributes as required by OpenAI's strict mode. + ) username: str = Field(..., description="The username.") password: str = Field(..., description="The password.") system_prompt: ClassVar[str] = "Extract the username and password from the input." - class Config: - extra = "forbid" - - username, password = "cypher", "steak" - login_response = extract_with_llm(LoginResponse, f"{username} // {password}", config=config) - # Validate the response. - assert isinstance(login_response, LoginResponse) - assert login_response.username == username - assert login_response.password == password + for strict in (False, True): + # Extract structured data. + username, password = "cypher", "steak" + login_response = extract_with_llm( + LoginResponse, f"{username} // {password}", strict=strict, config=config + ) + # Validate the response. + assert isinstance(login_response, LoginResponse) + assert login_response.username == username + assert login_response.password == password