diff --git a/factory/alchemy.py b/factory/alchemy.py index f934ce5d..42bd2bbf 100644 --- a/factory/alchemy.py +++ b/factory/alchemy.py @@ -1,10 +1,14 @@ # Copyright: See the LICENSE file. +from typing import TypeVar + from sqlalchemy.exc import IntegrityError from sqlalchemy.orm.exc import NoResultFound from . import base, errors +T = TypeVar("T") + SESSION_PERSISTENCE_COMMIT = 'commit' SESSION_PERSISTENCE_FLUSH = 'flush' VALID_SESSION_PERSISTENCE_TYPES = [ @@ -46,7 +50,7 @@ def _build_default_options(self): ] -class SQLAlchemyModelFactory(base.Factory): +class SQLAlchemyModelFactory(base.Factory[T]): """Factory for SQLAlchemy models. """ _options_class = SQLAlchemyOptions diff --git a/factory/base.py b/factory/base.py index 8d499501..46a43b9e 100644 --- a/factory/base.py +++ b/factory/base.py @@ -656,7 +656,7 @@ def __init__(self, **kwargs): setattr(self, field, value) -class StubFactory(Factory): +class StubFactory(Factory[StubObject]): class Meta: strategy = enums.STUB_STRATEGY @@ -671,7 +671,7 @@ def create(cls, **kwargs): raise errors.UnsupportedStrategy() -class BaseDictFactory(Factory): +class BaseDictFactory(Factory[T]): """Factory for dictionary-like classes.""" class Meta: abstract = True @@ -688,12 +688,12 @@ def _create(cls, model_class, *args, **kwargs): return cls._build(model_class, *args, **kwargs) -class DictFactory(BaseDictFactory): +class DictFactory(BaseDictFactory[dict]): class Meta: model = dict -class BaseListFactory(Factory): +class BaseListFactory(Factory[T]): """Factory for list-like classes.""" class Meta: abstract = True @@ -714,7 +714,7 @@ def _create(cls, model_class, *args, **kwargs): return cls._build(model_class, *args, **kwargs) -class ListFactory(BaseListFactory): +class ListFactory(BaseListFactory[list]): class Meta: model = list diff --git a/factory/mogo.py b/factory/mogo.py index f886ae14..35089953 100644 --- a/factory/mogo.py +++ b/factory/mogo.py @@ -2,12 +2,14 @@ """factory_boy extensions for use with the mogo library (pymongo wrapper).""" - +from typing import TypeVar from . import base +T = TypeVar("T") + -class MogoFactory(base.Factory): +class MogoFactory(base.Factory[T]): """Factory for mogo objects.""" class Meta: abstract = True diff --git a/factory/mongoengine.py b/factory/mongoengine.py index eb4a8dc5..9767d14b 100644 --- a/factory/mongoengine.py +++ b/factory/mongoengine.py @@ -2,12 +2,14 @@ """factory_boy extensions for use with the mongoengine library (pymongo wrapper).""" - +from typing import TypeVar from . import base +T = TypeVar("T") + -class MongoEngineFactory(base.Factory): +class MongoEngineFactory(base.Factory[T]): """Factory for mongoengine objects.""" class Meta: diff --git a/tests/test_typing.py b/tests/test_typing.py index c2f8b564..103a598d 100644 --- a/tests/test_typing.py +++ b/tests/test_typing.py @@ -2,6 +2,9 @@ import dataclasses import unittest +from typing import List + +from typing_extensions import assert_type import factory @@ -25,7 +28,21 @@ class UserFactory(factory.Factory[User]): class Meta: model = User - result: User - result = UserFactory.build() - result = UserFactory.create() - self.assertEqual(result.name, "John Doe") + assert_type(UserFactory.build(), User) + assert_type(UserFactory.create(), User) + assert_type(UserFactory.build_batch(2), List[User]) + assert_type(UserFactory.create_batch(2), List[User]) + self.assertEqual(UserFactory.create().name, "John Doe") + + def test_dict_factory(self) -> None: + + class Pet(factory.DictFactory): + species = "dog" + name = "rover" + + assert_type(Pet.build(), dict) + assert_type(Pet.create(), dict) + + def test_list_factory(self) -> None: + assert_type(factory.ListFactory.build(), list) + assert_type(factory.ListFactory.create(), list) diff --git a/tox.ini b/tox.ini index d842c759..79e76f26 100644 --- a/tox.ini +++ b/tox.ini @@ -36,6 +36,7 @@ passenv = POSTGRES_DATABASE deps = mypy + typing_extensions alchemy: SQLAlchemy alchemy: sqlalchemy_utils mongo: mongoengine