Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(sqla_factory): added __set_association_proxy__ attribute #624

Closed
wants to merge 10 commits into from
Closed
37 changes: 33 additions & 4 deletions polyfactory/factories/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
Callable,
ClassVar,
Collection,
Coroutine,
Generic,
Hashable,
Iterable,
Expand Down Expand Up @@ -1068,9 +1069,25 @@ def process_kwargs_coverage(cls, **kwargs: Any) -> abc.Iterable[dict[str, Any]]:
resolved[field_name] = post_generator.to_value(field_name, resolved)
yield resolved

@classmethod
async def build_async(cls, **kwargs: Any) -> T:
"""Asynchronously build an instance of the factory's __model__

:param kwargs: Any kwargs. If field names are set in kwargs, their values will be used.

:returns: An instance of type T.

"""
data: dict[str, Any] = cls.process_kwargs(**kwargs)
for k, v in data.items():
if isinstance(v, Coroutine):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this behaviour generally used for other async methods or more specific to SQLA relationships? If specific to SQLA then should live within that factory

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I should check a bit, but could suppose this is more specific for SQLAlchemy proxy. If this behaviour is only for the Alchemy, I'll try to think something and leave it within SQLAlchemy factory.

data[k] = await v

return cast("T", cls.__model__(**data))

@classmethod
def build(cls, **kwargs: Any) -> T:
"""Build an instance of the factory's __model__
"""Synchronously build an instance of the factory's __model__

:param kwargs: Any kwargs. If field names are set in kwargs, their values will be used.

Expand All @@ -1081,7 +1098,7 @@ def build(cls, **kwargs: Any) -> T:

@classmethod
def batch(cls, size: int, **kwargs: Any) -> list[T]:
"""Build a batch of size n of the factory's Meta.model.
"""Synchronously build a batch of size n of the factory's Meta.model.

:param size: Size of the batch.
:param kwargs: Any kwargs. If field_meta names are set in kwargs, their values will be used.
Expand All @@ -1091,6 +1108,18 @@ def batch(cls, size: int, **kwargs: Any) -> list[T]:
"""
return [cls.build(**kwargs) for _ in range(size)]

@classmethod
async def batch_async(cls, size: int, **kwargs: Any) -> list[T]:
"""Asynchronously build a batch of size n of the factory's Meta.model.

:param size: Size of the batch.
:param kwargs: Any kwargs. If field_meta names are set in kwargs, their values will be used.

:returns: A list of instances of type T.

"""
return [await cls.build_async(**kwargs) for _ in range(size)]

@classmethod
def coverage(cls, **kwargs: Any) -> abc.Iterator[T]:
"""Build a batch of the factory's Meta.model will full coverage of the sub-types of the model.
Expand Down Expand Up @@ -1135,7 +1164,7 @@ async def create_async(cls, **kwargs: Any) -> T:

:returns: An instance of type T.
"""
return await cls._get_async_persistence().save(data=cls.build(**kwargs))
return await cls._get_async_persistence().save(data=await cls.build_async(**kwargs))

@classmethod
async def create_batch_async(cls, size: int, **kwargs: Any) -> list[T]:
Expand All @@ -1147,7 +1176,7 @@ async def create_batch_async(cls, size: int, **kwargs: Any) -> list[T]:

:returns: A list of instances of type T.
"""
return await cls._get_async_persistence().save_many(data=cls.batch(size, **kwargs))
return await cls._get_async_persistence().save_many(data=await cls.batch_async(size, **kwargs))


def _register_builtin_factories() -> None:
Expand Down
36 changes: 29 additions & 7 deletions polyfactory/factories/sqlalchemy_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from sqlalchemy import ARRAY, Column, Numeric, String, inspect, types
from sqlalchemy.dialects import mssql, mysql, postgresql, sqlite
from sqlalchemy.exc import NoInspectionAvailable
from sqlalchemy.ext.associationproxy import AssociationProxy
from sqlalchemy.orm import InstanceState, Mapper
except ImportError as e:
msg = "sqlalchemy is not installed"
Expand Down Expand Up @@ -52,16 +53,18 @@ def __init__(self, session: AsyncSession) -> None:
self.session = session

async def save(self, data: T) -> T:
self.session.add(data)
await self.session.commit()
await self.session.refresh(data)
async with self.session as session:
session.add(data)
await session.commit()
await session.refresh(data)
return data

async def save_many(self, data: list[T]) -> list[T]:
self.session.add_all(data)
await self.session.commit()
for batch_item in data:
await self.session.refresh(batch_item)
async with self.session as session:
session.add_all(data)
await session.commit()
for batch_item in data:
await session.refresh(batch_item)
return data


Expand All @@ -76,6 +79,8 @@ class SQLAlchemyFactory(Generic[T], BaseFactory[T]):
"""Configuration to consider columns with foreign keys as a field or not."""
__set_relationships__: ClassVar[bool] = False
"""Configuration to consider relationships property as a model field or not."""
__set_association_proxy__: ClassVar[bool] = False
"""Configuration to consider AssociationProxy property as a model field or not."""

__session__: ClassVar[Session | Callable[[], Session] | None] = None
__async_session__: ClassVar[AsyncSession | Callable[[], AsyncSession] | None] = None
Expand Down Expand Up @@ -213,6 +218,23 @@ def get_model_fields(cls) -> list[FieldMeta]:
random=cls.__random__,
),
)
if cls.__set_association_proxy__:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice. Adding as a feature makes sense here. Can appropriate tests and documentation be added for this feature please?

Copy link
Contributor Author

@nisemenov nisemenov Dec 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Ok, I'd try to do that.

for name, attr in table.all_orm_descriptors.items():
if isinstance(attr, AssociationProxy):
target_collection = table.relationships.get(attr.target_collection)
if target_collection:
target_class = target_collection.entity.class_
target_attr = getattr(target_class, attr.value_attr)
if target_attr:
class_ = target_attr.entity.class_
annotation = class_ if not target_collection.uselist else List[class_] # type: ignore[valid-type]
fields_meta.append(
FieldMeta.from_type(
name=name,
annotation=annotation,
random=cls.__random__,
)
)

return fields_meta

Expand Down
14 changes: 9 additions & 5 deletions tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
func,
inspect,
orm,
select,
text,
types,
)
Expand Down Expand Up @@ -343,13 +344,15 @@ class Factory(SQLAlchemyFactory[AsyncModel]):
__async_session__ = session_config(session)
__model__ = AsyncModel

result = await Factory.create_async()
assert inspect(result).persistent # type: ignore[union-attr]
instance = await Factory.create_async()
result = await session.scalar(select(AsyncModel).where(AsyncModel.id == instance.id))
assert result

batch_result = await Factory.create_batch_async(size=2)
assert len(batch_result) == 2
for batch_item in batch_result:
assert inspect(batch_item).persistent # type: ignore[union-attr]
result = await session.scalar(select(AsyncModel).where(AsyncModel.id == batch_item.id))
assert result


@pytest.mark.parametrize(
Expand Down Expand Up @@ -392,8 +395,9 @@ class Factory(SQLAlchemyFactory[AsyncRefreshModel]):
test_int = Ignore()
test_bool = Ignore()

result = await Factory.create_async()
assert inspect(result).persistent # type: ignore[union-attr]
instance = await Factory.create_async()
result = await session.scalar(select(AsyncRefreshModel).where(AsyncRefreshModel.id == instance.id))
assert result
assert result.test_datetime is not None
assert isinstance(result.test_datetime, datetime)
assert result.test_str == "test_str"
Expand Down
Loading