Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add missing generic on Factory subclasses #1060

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion factory/alchemy.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
# Copyright: See the LICENSE file.

from typing import TypeVar

from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm.exc import NoResultFound

from . import base, errors

T = TypeVar("T")
Copy link
Contributor

@foarsitter foarsitter Sep 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it required to create a new TypeVar here? Can't we import it from base.py or user base.T? The same applies todjango.py and mogo.py.


SESSION_PERSISTENCE_COMMIT = 'commit'
SESSION_PERSISTENCE_FLUSH = 'flush'
VALID_SESSION_PERSISTENCE_TYPES = [
Expand Down Expand Up @@ -46,7 +50,7 @@ def _build_default_options(self):
]


class SQLAlchemyModelFactory(base.Factory):
class SQLAlchemyModelFactory(base.Factory[T]):
"""Factory for SQLAlchemy models. """

_options_class = SQLAlchemyOptions
Expand Down
10 changes: 5 additions & 5 deletions factory/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,7 +656,7 @@ def __init__(self, **kwargs):
setattr(self, field, value)


class StubFactory(Factory):
class StubFactory(Factory[StubObject]):

class Meta:
strategy = enums.STUB_STRATEGY
Expand All @@ -671,7 +671,7 @@ def create(cls, **kwargs):
raise errors.UnsupportedStrategy()


class BaseDictFactory(Factory):
class BaseDictFactory(Factory[T]):
"""Factory for dictionary-like classes."""
class Meta:
abstract = True
Expand All @@ -688,12 +688,12 @@ def _create(cls, model_class, *args, **kwargs):
return cls._build(model_class, *args, **kwargs)


class DictFactory(BaseDictFactory):
class DictFactory(BaseDictFactory[dict]):
class Meta:
model = dict


class BaseListFactory(Factory):
class BaseListFactory(Factory[T]):
"""Factory for list-like classes."""
class Meta:
abstract = True
Expand All @@ -714,7 +714,7 @@ def _create(cls, model_class, *args, **kwargs):
return cls._build(model_class, *args, **kwargs)


class ListFactory(BaseListFactory):
class ListFactory(BaseListFactory[list]):
class Meta:
model = list

Expand Down
6 changes: 4 additions & 2 deletions factory/mogo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@


"""factory_boy extensions for use with the mogo library (pymongo wrapper)."""

from typing import TypeVar

from . import base

T = TypeVar("T")


class MogoFactory(base.Factory):
class MogoFactory(base.Factory[T]):
"""Factory for mogo objects."""
class Meta:
abstract = True
Expand Down
6 changes: 4 additions & 2 deletions factory/mongoengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@


"""factory_boy extensions for use with the mongoengine library (pymongo wrapper)."""

from typing import TypeVar

from . import base

T = TypeVar("T")


class MongoEngineFactory(base.Factory):
class MongoEngineFactory(base.Factory[T]):
"""Factory for mongoengine objects."""

class Meta:
Expand Down
25 changes: 21 additions & 4 deletions tests/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

import dataclasses
import unittest
from typing import List

from typing_extensions import assert_type

import factory

Expand All @@ -25,7 +28,21 @@ class UserFactory(factory.Factory[User]):
class Meta:
model = User

result: User
result = UserFactory.build()
result = UserFactory.create()
self.assertEqual(result.name, "John Doe")
assert_type(UserFactory.build(), User)
assert_type(UserFactory.create(), User)
assert_type(UserFactory.build_batch(2), List[User])
assert_type(UserFactory.create_batch(2), List[User])
self.assertEqual(UserFactory.create().name, "John Doe")

def test_dict_factory(self) -> None:

class Pet(factory.DictFactory):
species = "dog"
name = "rover"

assert_type(Pet.build(), dict)
assert_type(Pet.create(), dict)

def test_list_factory(self) -> None:
assert_type(factory.ListFactory.build(), list)
assert_type(factory.ListFactory.create(), list)
1 change: 1 addition & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ passenv =
POSTGRES_DATABASE
deps =
mypy
typing_extensions
alchemy: SQLAlchemy
alchemy: sqlalchemy_utils
mongo: mongoengine
Expand Down
Loading