Skip to content

Commit

Permalink
Fix/360 (#361)
Browse files Browse the repository at this point in the history
* Unwrap forward reference to preserve sqlalchemt 2.0.32 behavior

* Update tests for additional type information
  • Loading branch information
PGijsbers authored Oct 16, 2024
1 parent 69c1033 commit 5b40190
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 4 deletions.
11 changes: 7 additions & 4 deletions src/database/model/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""
import inspect
from collections import ChainMap
from typing import Type
from typing import Type, ForwardRef

import typing_inspect
from sqlalchemy.orm.util import _is_mapped_annotation, _extract_mapped_subtype
Expand All @@ -21,15 +21,16 @@ def all_annotations(cls) -> ChainMap:
return ChainMap(*(inspect.get_annotations(c) for c in cls.mro()))


def datatype_of_field(clazz: Type[SQLModel], field_name: str) -> Type:
def datatype_of_field(clazz: Type[SQLModel], field_name: str) -> Type | str:
"""
Returns the datatype of a field, based on the annotations. It returns the inner type in case
of a list, or an optional.
of a list, or an optional. Returns a str in case a forward reference was used.
Examples:
- name: str returns str
- issn: str | None returns str
- funder: list["AgentTable"] returns AgentTable
- funder: list[AgentTable] returns AgentTable
- funder: list["AgentTable"] returns "AgentTable"
"""
annotation = inspect.get_annotations(clazz)[field_name]

Expand All @@ -56,4 +57,6 @@ def datatype_of_field(clazz: Type[SQLModel], field_name: str) -> Type:
]
if typing_inspect.is_generic_type(annotation): # e.g. List[Dataset]
(annotation,) = typing_inspect.get_args(annotation)
if isinstance(annotation, ForwardRef):
annotation = annotation.__forward_arg__
return annotation
10 changes: 10 additions & 0 deletions src/database/model/relationships.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,11 @@ class OneToOne(_ResourceRelationshipSingle):
def create_triggers(self, parent_class: Type[SQLModel], field_name: str):
if self.on_delete_trigger_deletion_by is not None:
to_delete = datatype_of_field(parent_class, field_name)
if isinstance(to_delete, str):
raise ValueError(
"Deletion trigger is configured wrongly: field cannot use a forward reference "
f"`{parent_class}.{field_name}`"
)
if not issubclass(to_delete, SQLModel):
raise ValueError(
"The deletion trigger is configured wrongly: the field doesn't "
Expand Down Expand Up @@ -199,6 +204,11 @@ def create_triggers(self, parent_class: Type[SQLModel], field_name: str):
if self.on_delete_trigger_orphan_deletion is not None:
link = parent_class.__sqlmodel_relationships__[field_name].link_model
to_delete = datatype_of_field(parent_class, field_name)
if isinstance(to_delete, str):
raise ValueError(
"Deletion trigger is configured wrongly: field cannot use a forward reference "
f"`{parent_class}.{field_name}`"
)
if not issubclass(to_delete, SQLModel):
raise ValueError(
"The deletion trigger is configured wrongly: the field doesn't "
Expand Down
3 changes: 3 additions & 0 deletions src/tests/database/model/ai_asset/test_ai_asset_delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
def test_happy_path(client: TestClient):
dataset_distribution = datatype_of_field(Dataset, "distribution")
publication_distribution = datatype_of_field(Publication, "distribution")
assert not isinstance(dataset_distribution, str)
assert not isinstance(publication_distribution, str)

dataset_1 = Dataset(
name="dataset 1",
distribution=[
Expand Down
2 changes: 2 additions & 0 deletions src/tests/database/model/resource/test_resource_delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
def test_happy_path(client: TestClient):
dataset_media = datatype_of_field(Dataset, "media")
dataset_note = datatype_of_field(Dataset, "note")
assert not isinstance(dataset_media, str)
assert not isinstance(dataset_note, str)

alternate_name_a = AlternateName(name="a")
alternate_name_b = AlternateName(name="b")
Expand Down

0 comments on commit 5b40190

Please sign in to comment.