Skip to content

Commit

Permalink
[ENG-1067] Improve ergonomics of accessing response values in Python (#…
Browse files Browse the repository at this point in the history
…328)

* fix extra space

* write test

* save

* works with union

* refactor construction of pydantic model

* update snapshot
  • Loading branch information
dphuang2 authored Nov 3, 2023
1 parent 7e179ba commit 46da09e
Show file tree
Hide file tree
Showing 73 changed files with 9,541 additions and 135 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import typing_extensions
import aiohttp
import urllib3
{{#if prstv2}}
from pydantic import BaseModel
from pydantic import BaseModel, RootModel, ValidationError
{{/if}}
from urllib3._collections import HTTPHeaderDict
from urllib.parse import urlparse, quote
Expand Down Expand Up @@ -66,32 +66,47 @@ class RequestField(RequestFieldBase):


{{#if prstv2}}
T = typing.TypeVar('T', bound=BaseModel)
T = typing.TypeVar('T')


def construct_model_instance(model: typing.Type[T], data: dict) -> T:
def construct_model_instance(model: typing.Type[T], data: typing.Any) -> T:
"""
Recursively construct an instance of a Pydantic model along with its nested models.
"""
for field_name, field_type in model.__annotations__.items():
if field_name in data:
if typing_extensions.get_origin(field_type) is list:
list_item_type = typing_extensions.get_args(field_type)[0]
if issubclass(list_item_type, BaseModel):
data[field_name] = [construct_model_instance(list_item_type, item) for item in data[field_name]]
elif issubclass(field_type, BaseModel):
data[field_name] = construct_model_instance(field_type, data[field_name])

return model.model_construct(**data)


def construct_model_list(model: typing.Type[typing.List[T]], data_list: typing.List[dict]) -> typing.List[T]:
"""
Construct a list of Pydantic model instances from a list of dictionaries.
"""
# Extract the inner model type from Type[List[T]]
inner_model = typing_extensions.get_args(model)[0]
return [construct_model_instance(inner_model, data) for data in data_list]
# if model is Union,
if typing_extensions.get_origin(model) is typing.Union:
closest = []
# iterate over all union types and determine which one is closest to the data
for union_type in model.__args__:
matches = isinstance(data, union_type)
if matches:
# found match, return using construct_model_instance
return construct_model_instance(union_type, data)
# if no match, just use the first union_type
return construct_model_instance(model.__args__[0], data)
# if model is scalar value like str, number, etc., use RootModel to construct
elif isinstance(model, type):
model = RootModel[model]
# try to coerce value to model type
try:
return model(data).root
except ValidationError as e:
pass
# if not possible, give up
return model.model_construct(data).root
# if model is list, iterate over list and recursively call
elif typing_extensions.get_origin(model) is list:
item_model = typing_extensions.get_args(model)[0]
return [construct_model_instance(item_model, item) for item in data]
# if model is BaseModel, iterate over fields and recursively call
elif issubclass(model, BaseModel):
new_data = {}
for field_name, field_type in model.__annotations__.items():
if field_name in data:
new_data[field_name] = construct_model_instance(field_type, data[field_name])
return model.model_construct(**data)
raise ApiTypeError(f"Unable to construct model instance of type {model}")


class Dictionary(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{{#if returnModel.isArray}}
if validate:
return RootModel[{{> endpoint_pydantic_impl_return_model}}](raw_response.body).root
return api_client.construct_model_list({{> endpoint_pydantic_impl_return_model}}, raw_response.body)
return api_client.construct_model_instance({{> endpoint_pydantic_impl_return_model}}, raw_response.body)
{{else}}
if validate:
return {{> endpoint_pydantic_impl_return_model}}(**raw_response.body)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import typing_extensions
import aiohttp
import urllib3
from pydantic import BaseModel
from pydantic import BaseModel, RootModel, ValidationError
from urllib3._collections import HTTPHeaderDict
from urllib.parse import urlparse, quote
from urllib3.fields import RequestField as RequestFieldBase
Expand Down Expand Up @@ -68,32 +68,47 @@ def __eq__(self, other):
return self.__dict__ == other.__dict__


T = typing.TypeVar('T', bound=BaseModel)
T = typing.TypeVar('T')


def construct_model_instance(model: typing.Type[T], data: dict) -> T:
def construct_model_instance(model: typing.Type[T], data: typing.Any) -> T:
"""
Recursively construct an instance of a Pydantic model along with its nested models.
"""
for field_name, field_type in model.__annotations__.items():
if field_name in data:
if typing_extensions.get_origin(field_type) is list:
list_item_type = typing_extensions.get_args(field_type)[0]
if issubclass(list_item_type, BaseModel):
data[field_name] = [construct_model_instance(list_item_type, item) for item in data[field_name]]
elif issubclass(field_type, BaseModel):
data[field_name] = construct_model_instance(field_type, data[field_name])

return model.model_construct(**data)


def construct_model_list(model: typing.Type[typing.List[T]], data_list: typing.List[dict]) -> typing.List[T]:
"""
Construct a list of Pydantic model instances from a list of dictionaries.
"""
# Extract the inner model type from Type[List[T]]
inner_model = typing_extensions.get_args(model)[0]
return [construct_model_instance(inner_model, data) for data in data_list]
# if model is Union,
if typing_extensions.get_origin(model) is typing.Union:
closest = []
# iterate over all union types and determine which one is closest to the data
for union_type in model.__args__:
matches = isinstance(data, union_type)
if matches:
# found match, return using construct_model_instance
return construct_model_instance(union_type, data)
# if no match, just use the first union_type
return construct_model_instance(model.__args__[0], data)
# if model is scalar value like str, number, etc., use RootModel to construct
elif isinstance(model, type):
model = RootModel[model]
# try to coerce value to model type
try:
return model(data).root
except ValidationError as e:
pass
# if not possible, give up
return model.model_construct(data).root
# if model is list, iterate over list and recursively call
elif typing_extensions.get_origin(model) is list:
item_model = typing_extensions.get_args(model)[0]
return [construct_model_instance(item_model, item) for item in data]
# if model is BaseModel, iterate over fields and recursively call
elif issubclass(model, BaseModel):
new_data = {}
for field_name, field_type in model.__annotations__.items():
if field_name in data:
new_data[field_name] = construct_model_instance(field_type, data[field_name])
return model.model_construct(**data)
raise ApiTypeError(f"Unable to construct model instance of type {model}")


class Dictionary(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import typing_extensions
import aiohttp
import urllib3
from pydantic import BaseModel
from pydantic import BaseModel, RootModel, ValidationError
from urllib3._collections import HTTPHeaderDict
from urllib.parse import urlparse, quote
from urllib3.fields import RequestField as RequestFieldBase
Expand Down Expand Up @@ -68,32 +68,47 @@ def __eq__(self, other):
return self.__dict__ == other.__dict__


T = typing.TypeVar('T', bound=BaseModel)
T = typing.TypeVar('T')


def construct_model_instance(model: typing.Type[T], data: dict) -> T:
def construct_model_instance(model: typing.Type[T], data: typing.Any) -> T:
"""
Recursively construct an instance of a Pydantic model along with its nested models.
"""
for field_name, field_type in model.__annotations__.items():
if field_name in data:
if typing_extensions.get_origin(field_type) is list:
list_item_type = typing_extensions.get_args(field_type)[0]
if issubclass(list_item_type, BaseModel):
data[field_name] = [construct_model_instance(list_item_type, item) for item in data[field_name]]
elif issubclass(field_type, BaseModel):
data[field_name] = construct_model_instance(field_type, data[field_name])

return model.model_construct(**data)


def construct_model_list(model: typing.Type[typing.List[T]], data_list: typing.List[dict]) -> typing.List[T]:
"""
Construct a list of Pydantic model instances from a list of dictionaries.
"""
# Extract the inner model type from Type[List[T]]
inner_model = typing_extensions.get_args(model)[0]
return [construct_model_instance(inner_model, data) for data in data_list]
# if model is Union,
if typing_extensions.get_origin(model) is typing.Union:
closest = []
# iterate over all union types and determine which one is closest to the data
for union_type in model.__args__:
matches = isinstance(data, union_type)
if matches:
# found match, return using construct_model_instance
return construct_model_instance(union_type, data)
# if no match, just use the first union_type
return construct_model_instance(model.__args__[0], data)
# if model is scalar value like str, number, etc., use RootModel to construct
elif isinstance(model, type):
model = RootModel[model]
# try to coerce value to model type
try:
return model(data).root
except ValidationError as e:
pass
# if not possible, give up
return model.model_construct(data).root
# if model is list, iterate over list and recursively call
elif typing_extensions.get_origin(model) is list:
item_model = typing_extensions.get_args(model)[0]
return [construct_model_instance(item_model, item) for item in data]
# if model is BaseModel, iterate over fields and recursively call
elif issubclass(model, BaseModel):
new_data = {}
for field_name, field_type in model.__annotations__.items():
if field_name in data:
new_data[field_name] = construct_model_instance(field_type, data[field_name])
return model.model_construct(**data)
raise ApiTypeError(f"Unable to construct model instance of type {model}")


class Dictionary(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ async def afetch(
)
if validate:
return RootModel[TestFetchResponsePydantic](raw_response.body).root
return api_client.construct_model_list(TestFetchResponsePydantic, raw_response.body)
return api_client.construct_model_instance(TestFetchResponsePydantic, raw_response.body)


def fetch(
Expand All @@ -288,7 +288,7 @@ def fetch(
)
if validate:
return RootModel[TestFetchResponsePydantic](raw_response.body).root
return api_client.construct_model_list(TestFetchResponsePydantic, raw_response.body)
return api_client.construct_model_instance(TestFetchResponsePydantic, raw_response.body)


class ApiForget(BaseApi):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ class Fetch(BaseApi):
)
if validate:
return RootModel[TestFetchResponsePydantic](raw_response.body).root
return api_client.construct_model_list(TestFetchResponsePydantic, raw_response.body)
return api_client.construct_model_instance(TestFetchResponsePydantic, raw_response.body)


def fetch(
Expand All @@ -280,7 +280,7 @@ class Fetch(BaseApi):
)
if validate:
return RootModel[TestFetchResponsePydantic](raw_response.body).root
return api_client.construct_model_list(TestFetchResponsePydantic, raw_response.body)
return api_client.construct_model_instance(TestFetchResponsePydantic, raw_response.body)


class ApiForget(BaseApi):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import typing_extensions
import aiohttp
import urllib3
from pydantic import BaseModel
from pydantic import BaseModel, RootModel, ValidationError
from urllib3._collections import HTTPHeaderDict
from urllib.parse import urlparse, quote
from urllib3.fields import RequestField as RequestFieldBase
Expand Down Expand Up @@ -68,32 +68,47 @@ def __eq__(self, other):
return self.__dict__ == other.__dict__


T = typing.TypeVar('T', bound=BaseModel)
T = typing.TypeVar('T')


def construct_model_instance(model: typing.Type[T], data: dict) -> T:
def construct_model_instance(model: typing.Type[T], data: typing.Any) -> T:
"""
Recursively construct an instance of a Pydantic model along with its nested models.
"""
for field_name, field_type in model.__annotations__.items():
if field_name in data:
if typing_extensions.get_origin(field_type) is list:
list_item_type = typing_extensions.get_args(field_type)[0]
if issubclass(list_item_type, BaseModel):
data[field_name] = [construct_model_instance(list_item_type, item) for item in data[field_name]]
elif issubclass(field_type, BaseModel):
data[field_name] = construct_model_instance(field_type, data[field_name])

return model.model_construct(**data)


def construct_model_list(model: typing.Type[typing.List[T]], data_list: typing.List[dict]) -> typing.List[T]:
"""
Construct a list of Pydantic model instances from a list of dictionaries.
"""
# Extract the inner model type from Type[List[T]]
inner_model = typing_extensions.get_args(model)[0]
return [construct_model_instance(inner_model, data) for data in data_list]
# if model is Union,
if typing_extensions.get_origin(model) is typing.Union:
closest = []
# iterate over all union types and determine which one is closest to the data
for union_type in model.__args__:
matches = isinstance(data, union_type)
if matches:
# found match, return using construct_model_instance
return construct_model_instance(union_type, data)
# if no match, just use the first union_type
return construct_model_instance(model.__args__[0], data)
# if model is scalar value like str, number, etc., use RootModel to construct
elif isinstance(model, type):
model = RootModel[model]
# try to coerce value to model type
try:
return model(data).root
except ValidationError as e:
pass
# if not possible, give up
return model.model_construct(data).root
# if model is list, iterate over list and recursively call
elif typing_extensions.get_origin(model) is list:
item_model = typing_extensions.get_args(model)[0]
return [construct_model_instance(item_model, item) for item in data]
# if model is BaseModel, iterate over fields and recursively call
elif issubclass(model, BaseModel):
new_data = {}
for field_name, field_type in model.__annotations__.items():
if field_name in data:
new_data[field_name] = construct_model_instance(field_type, data[field_name])
return model.model_construct(**data)
raise ApiTypeError(f"Unable to construct model instance of type {model}")


class Dictionary(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ async def alist(
)
if validate:
return RootModel[ListInnerPydantic](raw_response.body).root
return api_client.construct_model_list(ListInnerPydantic, raw_response.body)
return api_client.construct_model_instance(ListInnerPydantic, raw_response.body)


def list(
Expand All @@ -288,7 +288,7 @@ def list(
)
if validate:
return RootModel[ListInnerPydantic](raw_response.body).root
return api_client.construct_model_list(ListInnerPydantic, raw_response.body)
return api_client.construct_model_instance(ListInnerPydantic, raw_response.body)


class ApiForget(BaseApi):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ class List(BaseApi):
)
if validate:
return RootModel[ListInnerPydantic](raw_response.body).root
return api_client.construct_model_list(ListInnerPydantic, raw_response.body)
return api_client.construct_model_instance(ListInnerPydantic, raw_response.body)


def list(
Expand All @@ -280,7 +280,7 @@ class List(BaseApi):
)
if validate:
return RootModel[ListInnerPydantic](raw_response.body).root
return api_client.construct_model_list(ListInnerPydantic, raw_response.body)
return api_client.construct_model_instance(ListInnerPydantic, raw_response.body)


class ApiForget(BaseApi):
Expand Down
Loading

0 comments on commit 46da09e

Please sign in to comment.