From 5b401908820b1610d275c814120011b52397545b Mon Sep 17 00:00:00 2001 From: Pieter Gijsbers Date: Wed, 16 Oct 2024 17:30:14 +0200 Subject: [PATCH] Fix/360 (#361) * Unwrap forward reference to preserve sqlalchemt 2.0.32 behavior * Update tests for additional type information --- src/database/model/annotations.py | 11 +++++++---- src/database/model/relationships.py | 10 ++++++++++ .../database/model/ai_asset/test_ai_asset_delete.py | 3 +++ .../database/model/resource/test_resource_delete.py | 2 ++ 4 files changed, 22 insertions(+), 4 deletions(-) diff --git a/src/database/model/annotations.py b/src/database/model/annotations.py index 9ec59f3f..a3841220 100644 --- a/src/database/model/annotations.py +++ b/src/database/model/annotations.py @@ -4,7 +4,7 @@ """ import inspect from collections import ChainMap -from typing import Type +from typing import Type, ForwardRef import typing_inspect from sqlalchemy.orm.util import _is_mapped_annotation, _extract_mapped_subtype @@ -21,15 +21,16 @@ def all_annotations(cls) -> ChainMap: return ChainMap(*(inspect.get_annotations(c) for c in cls.mro())) -def datatype_of_field(clazz: Type[SQLModel], field_name: str) -> Type: +def datatype_of_field(clazz: Type[SQLModel], field_name: str) -> Type | str: """ Returns the datatype of a field, based on the annotations. It returns the inner type in case - of a list, or an optional. + of a list, or an optional. Returns a str in case a forward reference was used. Examples: - name: str returns str - issn: str | None returns str - - funder: list["AgentTable"] returns AgentTable + - funder: list[AgentTable] returns AgentTable + - funder: list["AgentTable"] returns "AgentTable" """ annotation = inspect.get_annotations(clazz)[field_name] @@ -56,4 +57,6 @@ def datatype_of_field(clazz: Type[SQLModel], field_name: str) -> Type: ] if typing_inspect.is_generic_type(annotation): # e.g. List[Dataset] (annotation,) = typing_inspect.get_args(annotation) + if isinstance(annotation, ForwardRef): + annotation = annotation.__forward_arg__ return annotation diff --git a/src/database/model/relationships.py b/src/database/model/relationships.py index 927bea86..80097bca 100644 --- a/src/database/model/relationships.py +++ b/src/database/model/relationships.py @@ -136,6 +136,11 @@ class OneToOne(_ResourceRelationshipSingle): def create_triggers(self, parent_class: Type[SQLModel], field_name: str): if self.on_delete_trigger_deletion_by is not None: to_delete = datatype_of_field(parent_class, field_name) + if isinstance(to_delete, str): + raise ValueError( + "Deletion trigger is configured wrongly: field cannot use a forward reference " + f"`{parent_class}.{field_name}`" + ) if not issubclass(to_delete, SQLModel): raise ValueError( "The deletion trigger is configured wrongly: the field doesn't " @@ -199,6 +204,11 @@ def create_triggers(self, parent_class: Type[SQLModel], field_name: str): if self.on_delete_trigger_orphan_deletion is not None: link = parent_class.__sqlmodel_relationships__[field_name].link_model to_delete = datatype_of_field(parent_class, field_name) + if isinstance(to_delete, str): + raise ValueError( + "Deletion trigger is configured wrongly: field cannot use a forward reference " + f"`{parent_class}.{field_name}`" + ) if not issubclass(to_delete, SQLModel): raise ValueError( "The deletion trigger is configured wrongly: the field doesn't " diff --git a/src/tests/database/model/ai_asset/test_ai_asset_delete.py b/src/tests/database/model/ai_asset/test_ai_asset_delete.py index 474bc37f..d8a390d6 100644 --- a/src/tests/database/model/ai_asset/test_ai_asset_delete.py +++ b/src/tests/database/model/ai_asset/test_ai_asset_delete.py @@ -11,6 +11,9 @@ def test_happy_path(client: TestClient): dataset_distribution = datatype_of_field(Dataset, "distribution") publication_distribution = datatype_of_field(Publication, "distribution") + assert not isinstance(dataset_distribution, str) + assert not isinstance(publication_distribution, str) + dataset_1 = Dataset( name="dataset 1", distribution=[ diff --git a/src/tests/database/model/resource/test_resource_delete.py b/src/tests/database/model/resource/test_resource_delete.py index dc49ea4c..2428966d 100644 --- a/src/tests/database/model/resource/test_resource_delete.py +++ b/src/tests/database/model/resource/test_resource_delete.py @@ -14,6 +14,8 @@ def test_happy_path(client: TestClient): dataset_media = datatype_of_field(Dataset, "media") dataset_note = datatype_of_field(Dataset, "note") + assert not isinstance(dataset_media, str) + assert not isinstance(dataset_note, str) alternate_name_a = AlternateName(name="a") alternate_name_b = AlternateName(name="b")