From 88f05cdca1dc66060525c95fc9af2f10e8773cd1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mart=C3=ADn=20Gait=C3=A1n?= Date: Wed, 22 Jun 2022 08:52:59 -0300 Subject: [PATCH] support fully qualify path string as Meta.model --- Makefile | 1 - docs/reference.rst | 4 ++++ factory/base.py | 2 +- factory/declarations.py | 30 ++++++++++-------------------- factory/utils.py | 9 +++++++++ setup.cfg | 2 +- tests/test_base.py | 11 +++++++++++ tests/test_declarations.py | 4 ++-- tests/test_utils.py | 20 ++++++++++++++++++++ 9 files changed, 58 insertions(+), 25 deletions(-) diff --git a/Makefile b/Makefile index 93e36997..475d8deb 100644 --- a/Makefile +++ b/Makefile @@ -55,7 +55,6 @@ testall: test: python \ -b \ - -X dev \ -Werror \ -Wdefault:"the imp module is deprecated in favour of importlib; see the module's documentation for alternative uses":DeprecationWarning:distutils: \ -Wdefault:"Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working":DeprecationWarning:: \ diff --git a/docs/reference.rst b/docs/reference.rst index 76f913fd..869e65d0 100644 --- a/docs/reference.rst +++ b/docs/reference.rst @@ -32,11 +32,15 @@ Meta options .. attribute:: model This optional attribute describes the class of objects to generate. + It could be a class or the fully qualified import path to it. If unset, it will be inherited from parent :class:`Factory` subclasses. .. versionadded:: 2.4.0 + .. versionadded:: 3.3 + Support to be a a fully qualified import path to the class + .. method:: get_model_class() Returns the actual model class (:attr:`FactoryOptions.model` might be the diff --git a/factory/base.py b/factory/base.py index 36b2359a..cef66804 100644 --- a/factory/base.py +++ b/factory/base.py @@ -373,7 +373,7 @@ def get_model_class(self): This can be overridden in framework-specific subclasses to hook into existing model repositories, for instance. """ - return self.model + return utils.resolve_type(self.model) if isinstance(self.model, str) else self.model def __str__(self): return "<%s for %s>" % (self.__class__.__name__, self.factory.__name__) diff --git a/factory/declarations.py b/factory/declarations.py index fe2e34d9..52678286 100644 --- a/factory/declarations.py +++ b/factory/declarations.py @@ -343,31 +343,21 @@ class _FactoryWrapper: path for that subclass (e.g 'myapp.factories.MyFactory'). """ def __init__(self, factory_or_path): - self.factory = None - self.module = self.name = '' - if isinstance(factory_or_path, type): - self.factory = factory_or_path - else: - if not (isinstance(factory_or_path, str) and '.' in factory_or_path): - raise ValueError( - "A factory= argument must receive either a class " - "or the fully qualified path to a Factory subclass; got " - "%r instead." % factory_or_path) - self.module, self.name = factory_or_path.rsplit('.', 1) + + if not (isinstance(factory_or_path, type) or (isinstance(factory_or_path, str) and '.' in factory_or_path)): + raise ValueError( + "A factory= argument must receive either a class " + "or the fully qualified path to a Factory subclass; got " + "%r instead." % factory_or_path) + self.factory = factory_or_path def get(self): - if self.factory is None: - self.factory = utils.import_object( - self.module, - self.name, - ) + if isinstance(self.factory, str): + self.factory = utils.resolve_type(self.factory) return self.factory def __repr__(self): - if self.factory is None: - return f'<_FactoryImport: {self.module}.{self.name}>' - else: - return f'<_FactoryImport: {self.factory.__class__}>' + return f'<_FactoryImport: {self.factory}>' class SubFactory(BaseDeclaration): diff --git a/factory/utils.py b/factory/utils.py index a74e0b35..1f76dff9 100644 --- a/factory/utils.py +++ b/factory/utils.py @@ -16,6 +16,15 @@ def import_object(module_name, attribute_name): return getattr(module, attribute_name) +def resolve_type(type_or_path): + if isinstance(type_or_path, type): + return type_or_path + + if not (isinstance(type_or_path, str) and '.' in type_or_path): + raise ValueError("Must receive either an object or the fully qualified path") + return import_object(*type_or_path.rsplit('.', 1)) + + class log_pprint: """Helper for properly printing args / kwargs passed to an object. diff --git a/setup.cfg b/setup.cfg index 3ae6d65f..df64343f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -36,7 +36,7 @@ classifiers = [options] zip_safe = false packages = factory -python_requires = >=3.7 +python_requires = >=3.6 install_requires = Faker>=0.7.0 [options.extras_require] diff --git a/tests/test_base.py b/tests/test_base.py index 0b9ffa15..61a12de5 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -215,6 +215,17 @@ class Meta: with self.assertRaises(TypeError): type("SecondFactory", (base.Factory,), {"Meta": Meta}) + def test_meta_model_as_path(self): + class MailboxFactory(base.Factory): + + class Meta: + model = "mailbox.Mailbox" + path = "/tmp/mail" + + import mailbox + box = MailboxFactory() + assert isinstance(box, mailbox.Mailbox) + class DeclarationParsingTests(unittest.TestCase): def test_classmethod(self): diff --git a/tests/test_declarations.py b/tests/test_declarations.py index c9458ffe..b379bab0 100644 --- a/tests/test_declarations.py +++ b/tests/test_declarations.py @@ -196,7 +196,7 @@ def test_path(self): def test_lazyness(self): f = declarations._FactoryWrapper('factory.declarations.Sequence') - self.assertEqual(None, f.factory) + self.assertEqual('factory.declarations.Sequence', f.factory) factory_class = f.get() self.assertEqual(declarations.Sequence, factory_class) @@ -205,7 +205,7 @@ def test_cache(self): """Ensure that _FactoryWrapper tries to import only once.""" orig_date = datetime.date w = declarations._FactoryWrapper('datetime.date') - self.assertEqual(None, w.factory) + self.assertEqual('datetime.date', w.factory) factory_class = w.get() self.assertEqual(orig_date, factory_class) diff --git a/tests/test_utils.py b/tests/test_utils.py index 1d54eefe..6442fa85 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -23,6 +23,26 @@ def test_invalid_module(self): utils.import_object('this-is-an-invalid-module', '__name__') +class ResolveTypeTestCase(unittest.TestCase): + def test_datetime(self): + imported = utils.resolve_type('datetime.date') + import datetime + d = datetime.date + self.assertEqual(d, imported) + + def test_unknown_attribute(self): + with self.assertRaises(AttributeError): + utils.resolve_type('datetime.foo') + + def test_invalid_module(self): + with self.assertRaises(ImportError): + utils.resolve_type('this-is-an-invalid-module.__name__') + + def test_is_a_class(self): + import datetime + return utils.resolve_type(datetime.date) is datetime.date + + class LogPPrintTestCase(unittest.TestCase): def test_nothing(self): txt = str(utils.log_pprint())