Skip to content

Commit

Permalink
fix: gracefully catch error if streamed json does not meet schema val…
Browse files Browse the repository at this point in the history
…idation (fixes #14)

PiperOrigin-RevId: 711449980
  • Loading branch information
sasha-gitg authored and copybara-github committed Jan 2, 2025
1 parent 4340939 commit f494432
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 3 deletions.
54 changes: 54 additions & 0 deletions google/genai/tests/models/test_generate_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,60 @@ def test_json_schema(client):
assert isinstance(response.parsed, dict)


def test_json_schema_with_streaming(client):

response = client.models.generate_content_stream(
model='gemini-2.0-flash-exp',
contents='Give me information of the United States.',
config={
'response_mime_type': 'application/json',
'response_schema': {
'properties': {
'name': {'type': 'STRING'},
'population': {'type': 'INTEGER'},
'capital': {'type': 'STRING'},
'continent': {'type': 'STRING'},
'gdp': {'type': 'INTEGER'},
'official_language': {'type': 'STRING'},
'total_area_sq_mi': {'type': 'INTEGER'},
},
'type': 'OBJECT',
},
},
)

for r in response:
parts = r.candidates[0].content.parts
for p in parts:
print(p.text)


def test_pydantic_schema_with_streaming(client):

class CountryInfo(BaseModel):
name: str
population: int
capital: str
continent: str
gdp: int
official_language: str
total_area_sq_mi: int

response = client.models.generate_content_stream(
model='gemini-2.0-flash-exp',
contents='Give me information of the United States.',
config={
'response_mime_type': 'application/json',
'response_schema': CountryInfo
},
)

for r in response:
parts = r.candidates[0].content.parts
for p in parts:
print(p.text)


def test_function(client):
def get_weather(city: str) -> str:
return f'The weather in {city} is sunny and 100 degrees.'
Expand Down
18 changes: 15 additions & 3 deletions google/genai/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2453,7 +2453,10 @@ class GenerateContentResponse(_common.BaseModel):
default=None, description="""Usage metadata about the response(s)."""
)
automatic_function_calling_history: Optional[list[Content]] = None
parsed: Union[pydantic.BaseModel, dict] = None
parsed: Union[pydantic.BaseModel, dict] = Field(
default=None,
description="""Parsed response if response_schema is provided. Not available for streaming.""",
)

@property
def text(self) -> Optional[str]:
Expand Down Expand Up @@ -2503,12 +2506,21 @@ def _from_response(
response_schema, pydantic.BaseModel
):
# Pydantic schema.
result.parsed = response_schema.model_validate_json(result.text)
try:
result.parsed = response_schema.model_validate_json(result.text)
# may not be a valid json per stream response
except pydantic.ValidationError:
pass

elif isinstance(response_schema, dict) or isinstance(
response_schema, pydantic.BaseModel
):
# JSON schema.
result.parsed = json.loads(result.text)
try:
result.parsed = json.loads(result.text)
# may not be a valid json per stream response
except json.decoder.JSONDecodeError:
pass

return result

Expand Down

0 comments on commit f494432

Please sign in to comment.