From 5b401908820b1610d275c814120011b52397545b Mon Sep 17 00:00:00 2001
From: Pieter Gijsbers
Date: Wed, 16 Oct 2024 17:30:14 +0200
Subject: [PATCH] Fix/360 (#361)
* Unwrap forward reference to preserve sqlalchemt 2.0.32 behavior
* Update tests for additional type information
---
src/database/model/annotations.py | 11 +++++++----
src/database/model/relationships.py | 10 ++++++++++
.../database/model/ai_asset/test_ai_asset_delete.py | 3 +++
.../database/model/resource/test_resource_delete.py | 2 ++
4 files changed, 22 insertions(+), 4 deletions(-)
diff --git a/src/database/model/annotations.py b/src/database/model/annotations.py
index 9ec59f3f..a3841220 100644
--- a/src/database/model/annotations.py
+++ b/src/database/model/annotations.py
@@ -4,7 +4,7 @@
"""
import inspect
from collections import ChainMap
-from typing import Type
+from typing import Type, ForwardRef
import typing_inspect
from sqlalchemy.orm.util import _is_mapped_annotation, _extract_mapped_subtype
@@ -21,15 +21,16 @@ def all_annotations(cls) -> ChainMap:
return ChainMap(*(inspect.get_annotations(c) for c in cls.mro()))
-def datatype_of_field(clazz: Type[SQLModel], field_name: str) -> Type:
+def datatype_of_field(clazz: Type[SQLModel], field_name: str) -> Type | str:
"""
Returns the datatype of a field, based on the annotations. It returns the inner type in case
- of a list, or an optional.
+ of a list, or an optional. Returns a str in case a forward reference was used.
Examples:
- name: str returns str
- issn: str | None returns str
- - funder: list["AgentTable"] returns AgentTable
+ - funder: list[AgentTable] returns AgentTable
+ - funder: list["AgentTable"] returns "AgentTable"
"""
annotation = inspect.get_annotations(clazz)[field_name]
@@ -56,4 +57,6 @@ def datatype_of_field(clazz: Type[SQLModel], field_name: str) -> Type:
]
if typing_inspect.is_generic_type(annotation): # e.g. List[Dataset]
(annotation,) = typing_inspect.get_args(annotation)
+ if isinstance(annotation, ForwardRef):
+ annotation = annotation.__forward_arg__
return annotation
diff --git a/src/database/model/relationships.py b/src/database/model/relationships.py
index 927bea86..80097bca 100644
--- a/src/database/model/relationships.py
+++ b/src/database/model/relationships.py
@@ -136,6 +136,11 @@ class OneToOne(_ResourceRelationshipSingle):
def create_triggers(self, parent_class: Type[SQLModel], field_name: str):
if self.on_delete_trigger_deletion_by is not None:
to_delete = datatype_of_field(parent_class, field_name)
+ if isinstance(to_delete, str):
+ raise ValueError(
+ "Deletion trigger is configured wrongly: field cannot use a forward reference "
+ f"`{parent_class}.{field_name}`"
+ )
if not issubclass(to_delete, SQLModel):
raise ValueError(
"The deletion trigger is configured wrongly: the field doesn't "
@@ -199,6 +204,11 @@ def create_triggers(self, parent_class: Type[SQLModel], field_name: str):
if self.on_delete_trigger_orphan_deletion is not None:
link = parent_class.__sqlmodel_relationships__[field_name].link_model
to_delete = datatype_of_field(parent_class, field_name)
+ if isinstance(to_delete, str):
+ raise ValueError(
+ "Deletion trigger is configured wrongly: field cannot use a forward reference "
+ f"`{parent_class}.{field_name}`"
+ )
if not issubclass(to_delete, SQLModel):
raise ValueError(
"The deletion trigger is configured wrongly: the field doesn't "
diff --git a/src/tests/database/model/ai_asset/test_ai_asset_delete.py b/src/tests/database/model/ai_asset/test_ai_asset_delete.py
index 474bc37f..d8a390d6 100644
--- a/src/tests/database/model/ai_asset/test_ai_asset_delete.py
+++ b/src/tests/database/model/ai_asset/test_ai_asset_delete.py
@@ -11,6 +11,9 @@
def test_happy_path(client: TestClient):
dataset_distribution = datatype_of_field(Dataset, "distribution")
publication_distribution = datatype_of_field(Publication, "distribution")
+ assert not isinstance(dataset_distribution, str)
+ assert not isinstance(publication_distribution, str)
+
dataset_1 = Dataset(
name="dataset 1",
distribution=[
diff --git a/src/tests/database/model/resource/test_resource_delete.py b/src/tests/database/model/resource/test_resource_delete.py
index dc49ea4c..2428966d 100644
--- a/src/tests/database/model/resource/test_resource_delete.py
+++ b/src/tests/database/model/resource/test_resource_delete.py
@@ -14,6 +14,8 @@
def test_happy_path(client: TestClient):
dataset_media = datatype_of_field(Dataset, "media")
dataset_note = datatype_of_field(Dataset, "note")
+ assert not isinstance(dataset_media, str)
+ assert not isinstance(dataset_note, str)
alternate_name_a = AlternateName(name="a")
alternate_name_b = AlternateName(name="b")