From f0de61ea194830cf70528b1775e1889a9404a33e Mon Sep 17 00:00:00 2001 From: Bastien Gerard Date: Wed, 2 Oct 2024 22:19:07 +0200 Subject: [PATCH] refactor _document_registry + log a warning when user register multiple Document classes with the same name (only flagging when this happens in different module) --- docs/changelog.rst | 3 ++ mongoengine/base/__init__.py | 3 +- mongoengine/base/common.py | 77 +++++++++++++++++++++++---------- mongoengine/base/document.py | 6 +-- mongoengine/base/metaclasses.py | 4 +- mongoengine/dereference.py | 20 ++++----- mongoengine/document.py | 6 +-- mongoengine/fields.py | 20 ++++----- mongoengine/queryset/base.py | 6 ++- tests/document/test_instance.py | 8 ++-- tests/fields/test_fields.py | 4 +- 11 files changed, 96 insertions(+), 61 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 01f6c3236..887a67773 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -13,6 +13,8 @@ Development - make sure to read https://www.mongodb.com/docs/manual/core/transactions-in-applications/#callback-api-vs-core-api - run_in_transaction context manager relies on Pymongo coreAPI, it will retry automatically in case of `UnknownTransactionCommitResult` but not `TransientTransactionError` exceptions - Using .count() in a transaction will always use Collection.count_document (as estimated_document_count is not supported in transactions) +- BREAKING CHANGE: wrap _document_registry (normally not used by end users) with _DocumentRegistry which acts as a singleton to access the registry +- Log a warning in case users creates multiple Document classes with the same name as it can lead to unexpected behavior #1778 - Fix use of $geoNear or $collStats in aggregate #2493 - BREAKING CHANGE: Further to the deprecation warning, remove ability to use an unpacked list to `Queryset.aggregate(*pipeline)`, a plain list must be provided instead `Queryset.aggregate(pipeline)`, as it's closer to pymongo interface - BREAKING CHANGE: Further to the deprecation warning, remove `full_response` from `QuerySet.modify` as it wasn't supported with Pymongo 3+ @@ -21,6 +23,7 @@ Development - BREAKING CHANGE: Remove LongField as it's equivalent to IntField since we drop support to Python2 long time ago (User should simply switch to IntField) #2309 - BugFix - Calling .clear on a ListField wasn't being marked as changed (and flushed to db upon .save()) #2858 + Changes in 0.29.0 ================= - Fix weakref in EmbeddedDocumentListField (causing brief mem leak in certain circumstances) #2827 diff --git a/mongoengine/base/__init__.py b/mongoengine/base/__init__.py index dca0c4bb7..a2c88aae6 100644 --- a/mongoengine/base/__init__.py +++ b/mongoengine/base/__init__.py @@ -13,8 +13,7 @@ __all__ = ( # common "UPDATE_OPERATORS", - "_document_registry", - "get_document", + "_DocumentRegistry", # datastructures "BaseDict", "BaseList", diff --git a/mongoengine/base/common.py b/mongoengine/base/common.py index 85897324f..fe631a40e 100644 --- a/mongoengine/base/common.py +++ b/mongoengine/base/common.py @@ -1,6 +1,8 @@ +import warnings + from mongoengine.errors import NotRegistered -__all__ = ("UPDATE_OPERATORS", "get_document", "_document_registry") +__all__ = ("UPDATE_OPERATORS", "_DocumentRegistry") UPDATE_OPERATORS = { @@ -25,28 +27,57 @@ _document_registry = {} -def get_document(name): - """Get a registered Document class by name.""" - doc = _document_registry.get(name, None) - if not doc: - # Possible old style name - single_end = name.split(".")[-1] - compound_end = ".%s" % single_end - possible_match = [ - k for k in _document_registry if k.endswith(compound_end) or k == single_end - ] - if len(possible_match) == 1: - doc = _document_registry.get(possible_match.pop(), None) - if not doc: - raise NotRegistered( - """ - `%s` has not been registered in the document registry. - Importing the document class automatically registers it, has it - been imported? - """.strip() - % name - ) - return doc +class _DocumentRegistry: + """Wrapper for the document registry (providing a singleton pattern). + This is part of MongoEngine's internals, not meant to be used directly by end-users + """ + + @staticmethod + def get(name): + doc = _document_registry.get(name, None) + if not doc: + # Possible old style name + single_end = name.split(".")[-1] + compound_end = ".%s" % single_end + possible_match = [ + k + for k in _document_registry + if k.endswith(compound_end) or k == single_end + ] + if len(possible_match) == 1: + doc = _document_registry.get(possible_match.pop(), None) + if not doc: + raise NotRegistered( + """ + `%s` has not been registered in the document registry. + Importing the document class automatically registers it, has it + been imported? + """.strip() + % name + ) + return doc + + @staticmethod + def register(DocCls): + ExistingDocCls = _document_registry.get(DocCls._class_name) + if ( + ExistingDocCls is not None + and ExistingDocCls.__module__ != DocCls.__module__ + ): + # A sign that a codebase may have named two different classes with the same name accidentally, + # this could cause issues with dereferencing because MongoEngine makes the assumption that a Document + # class name is unique. + warnings.warn( + f"Multiple Document classes named `{DocCls._class_name}` were registered, " + f"first from: `{ExistingDocCls.__module__}`, then from: `{DocCls.__module__}`. " + "this may lead to unexpected behavior during dereferencing.", + stacklevel=4, + ) + _document_registry[DocCls._class_name] = DocCls + + @staticmethod + def unregister(doc_cls_name): + _document_registry.pop(doc_cls_name) def _get_documents_by_db(connection_alias, default_connection_alias): diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py index 00db2c218..ea3962ad7 100644 --- a/mongoengine/base/document.py +++ b/mongoengine/base/document.py @@ -7,7 +7,7 @@ from bson import SON, DBRef, ObjectId, json_util from mongoengine import signals -from mongoengine.base.common import get_document +from mongoengine.base.common import _DocumentRegistry from mongoengine.base.datastructures import ( BaseDict, BaseList, @@ -500,7 +500,7 @@ def __expand_dynamic_values(self, name, value): # If the value is a dict with '_cls' in it, turn it into a document is_dict = isinstance(value, dict) if is_dict and "_cls" in value: - cls = get_document(value["_cls"]) + cls = _DocumentRegistry.get(value["_cls"]) return cls(**value) if is_dict: @@ -802,7 +802,7 @@ def _from_son(cls, son, _auto_dereference=True, created=False): # Return correct subclass for document type if class_name != cls._class_name: - cls = get_document(class_name) + cls = _DocumentRegistry.get(class_name) errors_dict = {} diff --git a/mongoengine/base/metaclasses.py b/mongoengine/base/metaclasses.py index a48bde959..a311aa167 100644 --- a/mongoengine/base/metaclasses.py +++ b/mongoengine/base/metaclasses.py @@ -1,7 +1,7 @@ import itertools import warnings -from mongoengine.base.common import _document_registry +from mongoengine.base.common import _DocumentRegistry from mongoengine.base.fields import ( BaseField, ComplexBaseField, @@ -169,7 +169,7 @@ def __new__(mcs, name, bases, attrs): new_class._collection = None # Add class to the _document_registry - _document_registry[new_class._class_name] = new_class + _DocumentRegistry.register(new_class) # Handle delete rules for field in new_class._fields.values(): diff --git a/mongoengine/dereference.py b/mongoengine/dereference.py index 1f4f7594c..38da2e873 100644 --- a/mongoengine/dereference.py +++ b/mongoengine/dereference.py @@ -5,7 +5,7 @@ BaseList, EmbeddedDocumentList, TopLevelDocumentMetaclass, - get_document, + _DocumentRegistry, ) from mongoengine.base.datastructures import LazyReference from mongoengine.connection import _get_session, get_db @@ -131,9 +131,9 @@ def _find_references(self, items, depth=0): elif isinstance(v, DBRef): reference_map.setdefault(field.document_type, set()).add(v.id) elif isinstance(v, (dict, SON)) and "_ref" in v: - reference_map.setdefault(get_document(v["_cls"]), set()).add( - v["_ref"].id - ) + reference_map.setdefault( + _DocumentRegistry.get(v["_cls"]), set() + ).add(v["_ref"].id) elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth: field_cls = getattr( getattr(field, "field", None), "document_type", None @@ -151,9 +151,9 @@ def _find_references(self, items, depth=0): elif isinstance(item, DBRef): reference_map.setdefault(item.collection, set()).add(item.id) elif isinstance(item, (dict, SON)) and "_ref" in item: - reference_map.setdefault(get_document(item["_cls"]), set()).add( - item["_ref"].id - ) + reference_map.setdefault( + _DocumentRegistry.get(item["_cls"]), set() + ).add(item["_ref"].id) elif isinstance(item, (dict, list, tuple)) and depth - 1 <= self.max_depth: references = self._find_references(item, depth - 1) for key, refs in references.items(): @@ -198,9 +198,9 @@ def _fetch_objects(self, doc_type=None): ) for ref in references: if "_cls" in ref: - doc = get_document(ref["_cls"])._from_son(ref) + doc = _DocumentRegistry.get(ref["_cls"])._from_son(ref) elif doc_type is None: - doc = get_document( + doc = _DocumentRegistry.get( "".join(x.capitalize() for x in collection.split("_")) )._from_son(ref) else: @@ -235,7 +235,7 @@ def _attach_objects(self, items, depth=0, instance=None, name=None): (items["_ref"].collection, items["_ref"].id), items ) elif "_cls" in items: - doc = get_document(items["_cls"])._from_son(items) + doc = _DocumentRegistry.get(items["_cls"])._from_son(items) _cls = doc._data.pop("_cls", None) del items["_cls"] doc._data = self._attach_objects(doc._data, depth, doc, None) diff --git a/mongoengine/document.py b/mongoengine/document.py index 4907589e9..829c07135 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -12,7 +12,7 @@ DocumentMetaclass, EmbeddedDocumentList, TopLevelDocumentMetaclass, - get_document, + _DocumentRegistry, ) from mongoengine.base.utils import NonOrderedList from mongoengine.common import _import_class @@ -851,12 +851,12 @@ def register_delete_rule(cls, document_cls, field_name, rule): object. """ classes = [ - get_document(class_name) + _DocumentRegistry.get(class_name) for class_name in cls._subclasses if class_name != cls.__name__ ] + [cls] documents = [ - get_document(class_name) + _DocumentRegistry.get(class_name) for class_name in document_cls._subclasses if class_name != document_cls.__name__ ] + [document_cls] diff --git a/mongoengine/fields.py b/mongoengine/fields.py index c74539691..e9cf5b817 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -30,7 +30,7 @@ GeoJsonBaseField, LazyReference, ObjectIdField, - get_document, + _DocumentRegistry, ) from mongoengine.base.utils import LazyRegexCompiler from mongoengine.common import _import_class @@ -725,7 +725,7 @@ def document_type(self): if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT: resolved_document_type = self.owner_document else: - resolved_document_type = get_document(self.document_type_obj) + resolved_document_type = _DocumentRegistry.get(self.document_type_obj) if not issubclass(resolved_document_type, EmbeddedDocument): # Due to the late resolution of the document_type @@ -801,7 +801,7 @@ def prepare_query_value(self, op, value): def to_python(self, value): if isinstance(value, dict): - doc_cls = get_document(value["_cls"]) + doc_cls = _DocumentRegistry.get(value["_cls"]) value = doc_cls._from_son(value) return value @@ -879,7 +879,7 @@ def to_mongo(self, value, use_db_field=True, fields=None): def to_python(self, value): if isinstance(value, dict) and "_cls" in value: - doc_cls = get_document(value["_cls"]) + doc_cls = _DocumentRegistry.get(value["_cls"]) if "_ref" in value: value = doc_cls._get_db().dereference( value["_ref"], session=_get_session() @@ -1171,7 +1171,7 @@ def document_type(self): if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT: self.document_type_obj = self.owner_document else: - self.document_type_obj = get_document(self.document_type_obj) + self.document_type_obj = _DocumentRegistry.get(self.document_type_obj) return self.document_type_obj @staticmethod @@ -1195,7 +1195,7 @@ def __get__(self, instance, owner): if auto_dereference and isinstance(ref_value, DBRef): if hasattr(ref_value, "cls"): # Dereference using the class type specified in the reference - cls = get_document(ref_value.cls) + cls = _DocumentRegistry.get(ref_value.cls) else: cls = self.document_type @@ -1335,7 +1335,7 @@ def document_type(self): if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT: self.document_type_obj = self.owner_document else: - self.document_type_obj = get_document(self.document_type_obj) + self.document_type_obj = _DocumentRegistry.get(self.document_type_obj) return self.document_type_obj @staticmethod @@ -1498,7 +1498,7 @@ def __get__(self, instance, owner): auto_dereference = instance._fields[self.name]._auto_dereference if auto_dereference and isinstance(value, dict): - doc_cls = get_document(value["_cls"]) + doc_cls = _DocumentRegistry.get(value["_cls"]) instance._data[self.name] = self._lazy_load_ref(doc_cls, value["_ref"]) return super().__get__(instance, owner) @@ -2443,7 +2443,7 @@ def document_type(self): if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT: self.document_type_obj = self.owner_document else: - self.document_type_obj = get_document(self.document_type_obj) + self.document_type_obj = _DocumentRegistry.get(self.document_type_obj) return self.document_type_obj def build_lazyref(self, value): @@ -2584,7 +2584,7 @@ def build_lazyref(self, value): elif value is not None: if isinstance(value, (dict, SON)): value = LazyReference( - get_document(value["_cls"]), + _DocumentRegistry.get(value["_cls"]), value["_ref"].id, passthrough=self.passthrough, ) diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py index aef996996..f04ef06c5 100644 --- a/mongoengine/queryset/base.py +++ b/mongoengine/queryset/base.py @@ -13,7 +13,7 @@ from pymongo.read_concern import ReadConcern from mongoengine import signals -from mongoengine.base import get_document +from mongoengine.base import _DocumentRegistry from mongoengine.common import _import_class from mongoengine.connection import _get_session, get_db from mongoengine.context_managers import ( @@ -1956,7 +1956,9 @@ def _fields_to_dbfields(self, fields): """Translate fields' paths to their db equivalents.""" subclasses = [] if self._document._meta["allow_inheritance"]: - subclasses = [get_document(x) for x in self._document._subclasses][1:] + subclasses = [_DocumentRegistry.get(x) for x in self._document._subclasses][ + 1: + ] db_field_paths = [] for field in fields: diff --git a/tests/document/test_instance.py b/tests/document/test_instance.py index 89645e3ef..c5970420f 100644 --- a/tests/document/test_instance.py +++ b/tests/document/test_instance.py @@ -13,7 +13,7 @@ from mongoengine import * from mongoengine import signals -from mongoengine.base import _document_registry, get_document +from mongoengine.base import _DocumentRegistry from mongoengine.connection import get_db from mongoengine.context_managers import query_counter, switch_db from mongoengine.errors import ( @@ -392,7 +392,7 @@ class NicePlace(Place): # Mimic Place and NicePlace definitions being in a different file # and the NicePlace model not being imported in at query time. - del _document_registry["Place.NicePlace"] + _DocumentRegistry.unregister("Place.NicePlace") with pytest.raises(NotRegistered): list(Place.objects.all()) @@ -407,8 +407,8 @@ class Area(Location): Location.drop_collection() - assert Area == get_document("Area") - assert Area == get_document("Location.Area") + assert Area == _DocumentRegistry.get("Area") + assert Area == _DocumentRegistry.get("Location.Area") def test_creation(self): """Ensure that document may be created using keyword arguments.""" diff --git a/tests/fields/test_fields.py b/tests/fields/test_fields.py index 5655f12e2..8980b6037 100644 --- a/tests/fields/test_fields.py +++ b/tests/fields/test_fields.py @@ -37,7 +37,7 @@ from mongoengine.base import ( BaseField, EmbeddedDocumentList, - _document_registry, + _DocumentRegistry, ) from mongoengine.base.fields import _no_dereference_for_fields from mongoengine.errors import DeprecatedError @@ -1678,7 +1678,7 @@ class User(Document): # Mimic User and Link definitions being in a different file # and the Link model not being imported in the User file. - del _document_registry["Link"] + _DocumentRegistry.unregister("Link") user = User.objects.first() try: