From 01a864cf9b87fb294fbed29ae0912ca5c9b85811 Mon Sep 17 00:00:00 2001 From: instanceofmel Date: Tue, 9 Apr 2024 11:36:29 +0200 Subject: [PATCH] Allow functions as FactoryOptions.model Facilitates the use of a factory function, that may create several related objects at once. --- docs/changelog.rst | 1 + factory/base.py | 2 ++ tests/test_base.py | 16 ++++++++++++++++ 3 files changed, 19 insertions(+) diff --git a/docs/changelog.rst b/docs/changelog.rst index c5fc0767..81f7da27 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -16,6 +16,7 @@ ChangeLog - :issue:`1031`: Do not require :attr:`~factory.alchemy.SQLAlchemyOptions.sqlalchemy_session` when :attr:`~factory.alchemy.SQLAlchemyOptions.sqlalchemy_session_factory` is provided. +- :issue:`1072`: Allow using functions as models for a :attr:`~factory.FactoryOptions.model`. *Removed:* diff --git a/factory/base.py b/factory/base.py index 454513be..26e4199a 100644 --- a/factory/base.py +++ b/factory/base.py @@ -2,6 +2,7 @@ import collections +import inspect import logging import warnings from typing import Generic, List, Type, TypeVar @@ -243,6 +244,7 @@ def _get_counter_reference(self): if (self.model is not None and self.base_factory is not None and self.base_factory._meta.model is not None + and inspect.isclass(self.model) and issubclass(self.model, self.base_factory._meta.model)): return self.base_factory._meta.counter_reference else: diff --git a/tests/test_base.py b/tests/test_base.py index d3b32570..b19e9eac 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -281,6 +281,22 @@ class Meta: ones = {x.one for x in (parent, alt_parent, sub, alt_sub)} self.assertEqual(4, len(ones)) + def test_inheritance_with_function_as_meta_model(self): + def make_test_object(**kwargs): + return TestObject(**kwargs) + + class TestObjectFactory(base.Factory): + class Meta: + model = make_test_object + + one = "foo" + + class TestSubFactory(TestObjectFactory): + one = "bar" + + sub = TestSubFactory.build() + self.assertEqual(sub.one, "bar") + class FactorySequenceTestCase(unittest.TestCase): def setUp(self):