Skip to content

Commit

Permalink
feat: add strict mode option
Browse files Browse the repository at this point in the history
  • Loading branch information
lsorber committed Dec 5, 2024
1 parent 31ce3fe commit 1fcf177
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 32 deletions.
27 changes: 13 additions & 14 deletions src/raglite/_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import numpy as np
import pandas as pd
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, ConfigDict, Field, field_validator
from sqlmodel import Session, func, select
from tqdm.auto import tqdm, trange

Expand All @@ -25,6 +25,9 @@ def insert_evals( # noqa: C901
class QuestionResponse(BaseModel):
"""A specific question about the content of a set of document contexts."""

model_config = ConfigDict(
extra="forbid" # Forbid extra attributes as required by OpenAI's strict mode.
)
question: str = Field(
..., description="A specific question about the content of a set of document contexts."
)
Expand All @@ -41,11 +44,6 @@ class QuestionResponse(BaseModel):
- The question MUST treat the context as if its contents are entirely part of your working memory.
""".strip()

class Config:
extra = (
"forbid" # Ensure no extra fields are allowed as required by OpenAI's strict mode.
)

@field_validator("question")
@classmethod
def validate_question(cls, value: str) -> str:
Expand Down Expand Up @@ -88,7 +86,7 @@ def validate_question(cls, value: str) -> str:
# Extract a question from the seed chunk's related chunks.
try:
question_response = extract_with_llm(
QuestionResponse, related_chunks, config=config
QuestionResponse, related_chunks, strict=True, config=config
)
except ValueError:
continue
Expand All @@ -104,6 +102,9 @@ def validate_question(cls, value: str) -> str:
class ContextEvalResponse(BaseModel):
"""Indicate whether the provided context can be used to answer a given question."""

model_config = ConfigDict(
extra="forbid" # Forbid extra attributes as required by OpenAI's strict mode.
)
hit: bool = Field(
...,
description="True if the provided context contains (a part of) the answer to the given question, false otherwise.",
Expand All @@ -115,16 +116,13 @@ class ContextEvalResponse(BaseModel):
An example of a context that does NOT contain (a part of) the answer is a table of contents.
""".strip()

class Config:
extra = "forbid" # Ensure no extra fields are allowed as required by OpenAI API's strict mode.

relevant_chunks = []
for candidate_chunk in tqdm(
candidate_chunks, desc="Evaluating chunks", unit="chunk", dynamic_ncols=True
):
try:
context_eval_response = extract_with_llm(
ContextEvalResponse, str(candidate_chunk), config=config
ContextEvalResponse, str(candidate_chunk), strict=True, config=config
)
except ValueError: # noqa: PERF203
pass
Expand All @@ -138,6 +136,9 @@ class Config:
class AnswerResponse(BaseModel):
"""Answer a question using the provided context."""

model_config = ConfigDict(
extra="forbid" # Forbid extra attributes as required by OpenAI's strict mode.
)
answer: str = Field(
...,
description="A complete answer to the given question using the provided context.",
Expand All @@ -153,13 +154,11 @@ class AnswerResponse(BaseModel):
- The answer MUST treat the context as if its contents are entirely part of your working memory.
""".strip()

class Config:
extra = "forbid" # Ensure no extra fields are allowed as required by OpenAI API's strict mode.

try:
answer_response = extract_with_llm(
AnswerResponse,
[str(relevant_chunk) for relevant_chunk in relevant_chunks],
strict=True,
config=config,
)
except ValueError:
Expand Down
12 changes: 5 additions & 7 deletions src/raglite/_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from typing import Any, TypeVar

import litellm
from litellm import completion, get_supported_openai_params # type: ignore[attr-defined]
from pydantic import BaseModel, ValidationError

Expand All @@ -14,6 +13,7 @@
def extract_with_llm(
return_type: type[T],
user_prompt: str | list[str],
strict: bool = False, # noqa: FBT001,FBT002
config: RAGLiteConfig | None = None,
**kwargs: Any,
) -> T:
Expand Down Expand Up @@ -41,8 +41,10 @@ class MyNameResponse(BaseModel):
str(return_type.model_json_schema()),
)
)
# Constrain the reponse format to the JSON schema if it's supported by the LLM [1].
# Constrain the reponse format to the JSON schema if it's supported by the LLM [1]. Strict mode
# is disabled by default because it only supports a subset of JSON schema features [2].
# [1] https://docs.litellm.ai/docs/completion/json_mode
# [2] https://platform.openai.com/docs/guides/structured-outputs#some-type-specific-keywords-are-not-yet-supported
# TODO: Fall back to {"type": "json_object"} if JSON schema is not supported by the LLM.
llm_provider = "llama-cpp-python" if config.embedder.startswith("llama-cpp") else None
response_format: dict[str, Any] | None = (
Expand All @@ -52,6 +54,7 @@ class MyNameResponse(BaseModel):
"name": return_type.__name__,
"description": return_type.__doc__ or "",
"schema": return_type.model_json_schema(),
"strict": strict,
},
}
if "response_format"
Expand All @@ -64,9 +67,6 @@ class MyNameResponse(BaseModel):
f'<context index="{i + 1}">\n{chunk.strip()}\n</context>'
for i, chunk in enumerate(user_prompt)
)
# Enable JSON schema validation.
enable_json_schema_validation = litellm.enable_json_schema_validation
litellm.enable_json_schema_validation = True
# Extract structured data from the unstructured input.
for _ in range(config.llm_max_tries):
response = completion(
Expand All @@ -89,6 +89,4 @@ class MyNameResponse(BaseModel):
else:
error_message = f"Failed to extract {return_type} from input {user_prompt}."
raise ValueError(error_message) from last_exception
# Restore the previous JSON schema validation setting.
litellm.enable_json_schema_validation = enable_json_schema_validation
return instance
26 changes: 15 additions & 11 deletions tests/test_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import ClassVar

import pytest
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field

from raglite import RAGLiteConfig
from raglite._extract import extract_with_llm
Expand All @@ -26,18 +26,22 @@ def test_extract(llm: str) -> None:
# Set the LLM.
config = RAGLiteConfig(llm=llm)

# Extract structured data.
# Define the JSON schema of the response.
class LoginResponse(BaseModel):
model_config = ConfigDict(
extra="forbid" # Forbid extra attributes as required by OpenAI's strict mode.
)
username: str = Field(..., description="The username.")
password: str = Field(..., description="The password.")
system_prompt: ClassVar[str] = "Extract the username and password from the input."

class Config:
extra = "forbid"

username, password = "cypher", "steak"
login_response = extract_with_llm(LoginResponse, f"{username} // {password}", config=config)
# Validate the response.
assert isinstance(login_response, LoginResponse)
assert login_response.username == username
assert login_response.password == password
for strict in (False, True):
# Extract structured data.
username, password = "cypher", "steak"
login_response = extract_with_llm(
LoginResponse, f"{username} // {password}", strict=strict, config=config
)
# Validate the response.
assert isinstance(login_response, LoginResponse)
assert login_response.username == username
assert login_response.password == password

0 comments on commit 1fcf177

Please sign in to comment.