Skip to content

Commit

Permalink
1091 Make get_related work multiple levels deep (#1092)
Browse files Browse the repository at this point in the history
* make `get_related` work multiple levels deep

* fix linter errors
  • Loading branch information
dantownsend authored Oct 4, 2024
1 parent a8d401e commit 2cae26b
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 45 deletions.
7 changes: 7 additions & 0 deletions docs/src/piccolo/query_types/objects.rst
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,13 @@ using ``get_related``.
>>> manager.name
'Guido'
It works multiple levels deep - for example:

.. code-block:: python
concert = await Concert.objects().first()
manager = await concert.get_related(Concert.band_1.manager)
Prefetching related objects
~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
50 changes: 49 additions & 1 deletion piccolo/query/methods/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import typing as t

from piccolo.columns.column_types import ForeignKey
from piccolo.columns.column_types import ForeignKey, ReferencedTable
from piccolo.columns.combination import And, Where
from piccolo.custom_types import Combinable, TableInstance
from piccolo.engine.base import BaseBatch
Expand Down Expand Up @@ -231,6 +231,54 @@ def run_sync(self, *args, **kwargs) -> None:
return run_sync(self.run(*args, **kwargs))


class GetRelated(t.Generic[ReferencedTable]):

def __init__(self, row: Table, foreign_key: ForeignKey[ReferencedTable]):
self.row = row
self.foreign_key = foreign_key

async def run(
self,
node: t.Optional[str] = None,
in_pool: bool = True,
) -> t.Optional[ReferencedTable]:
references = t.cast(
t.Type[ReferencedTable],
self.foreign_key._foreign_key_meta.resolved_references,
)

data = (
await self.row.__class__.select(
*[
i.as_alias(i._meta.name)
for i in self.foreign_key.all_columns()
]
)
.first()
.run(node=node, in_pool=in_pool)
)

# Make sure that some values were returned:
if data is None or not any(data.values()):
return None

referenced_object = references(**data)
referenced_object._exists_in_db = True
return referenced_object

def __await__(
self,
) -> t.Generator[None, None, t.Optional[ReferencedTable]]:
"""
If the user doesn't explicity call .run(), proxy to it as a
convenience.
"""
return self.run().__await__()

def run_sync(self, *args, **kwargs) -> t.Optional[ReferencedTable]:
return run_sync(self.run(*args, **kwargs))


###############################################################################


Expand Down
25 changes: 7 additions & 18 deletions piccolo/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
)
from piccolo.query.methods.create_index import CreateIndex
from piccolo.query.methods.indexes import Indexes
from piccolo.query.methods.objects import First, UpdateSelf
from piccolo.query.methods.objects import GetRelated, UpdateSelf
from piccolo.query.methods.refresh import Refresh
from piccolo.querystring import QueryString
from piccolo.utils import _camel_to_snake
Expand Down Expand Up @@ -612,14 +612,14 @@ def refresh(
@t.overload
def get_related(
self, foreign_key: ForeignKey[ReferencedTable]
) -> First[ReferencedTable]: ...
) -> GetRelated[ReferencedTable]: ...

@t.overload
def get_related(self, foreign_key: str) -> First[Table]: ...
def get_related(self, foreign_key: str) -> GetRelated[Table]: ...

def get_related(
self, foreign_key: t.Union[str, ForeignKey[ReferencedTable]]
) -> t.Union[First[Table], First[ReferencedTable]]:
) -> GetRelated[ReferencedTable]:
"""
Used to fetch a ``Table`` instance, for the target of a foreign key.
Expand All @@ -630,8 +630,8 @@ def get_related(
>>> print(manager.name)
'Guido'
It can only follow foreign keys one level currently.
i.e. ``Band.manager``, but not ``Band.manager.x.y.z``.
It can only follow foreign keys multiple levels deep. For example,
``Concert.band_1.manager``.
"""
if isinstance(foreign_key, str):
Expand All @@ -645,18 +645,7 @@ def get_related(
"ForeignKey column."
)

column_name = foreign_key._meta.name

references = foreign_key._foreign_key_meta.resolved_references

return (
references.objects()
.where(
foreign_key._foreign_key_meta.resolved_target_column
== getattr(self, column_name)
)
.first()
)
return GetRelated(foreign_key=foreign_key, row=self)

def get_m2m(self, m2m: M2M) -> M2MGetRelated:
"""
Expand Down
72 changes: 46 additions & 26 deletions tests/table/instance/test_get_related.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,62 @@
import typing as t
from unittest import TestCase

from tests.example_apps.music.tables import Band, Manager
from piccolo.testing.test_case import AsyncTableTest
from tests.example_apps.music.tables import Band, Concert, Manager, Venue

TABLES = [Manager, Band]

class TestGetRelated(AsyncTableTest):
tables = [Manager, Band, Concert, Venue]

class TestGetRelated(TestCase):
def setUp(self):
for table in TABLES:
table.create_table().run_sync()
async def asyncSetUp(self):
await super().asyncSetUp()

def tearDown(self):
for table in reversed(TABLES):
table.alter().drop_table().run_sync()
self.manager = Manager(name="Guido")
await self.manager.save()

def test_get_related(self) -> None:
self.band = Band(
name="Pythonistas", manager=self.manager.id, popularity=100
)
await self.band.save()

async def test_foreign_key(self) -> None:
"""
Make sure you can get a related object from another object instance.
"""
manager = Manager(name="Guido")
manager.save().run_sync()
manager = await self.band.get_related(Band.manager)
assert manager is not None
self.assertTrue(manager.name == "Guido")

band = Band(name="Pythonistas", manager=manager.id, popularity=100)
band.save().run_sync()
async def test_non_foreign_key(self):
"""
Make sure that non-ForeignKey raise an exception.
"""
with self.assertRaises(ValueError):
self.band.get_related(Band.name) # type: ignore

_manager = band.get_related(Band.manager).run_sync()
assert _manager is not None
self.assertTrue(_manager.name == "Guido")
async def test_string(self):
"""
Make sure it also works using a string representation of a foreign key.
"""
manager = t.cast(Manager, await self.band.get_related("manager"))
self.assertTrue(manager.name == "Guido")

# Test non-ForeignKey
async def test_invalid_string(self):
"""
Make sure an exception is raised if the foreign key string is invalid.
"""
with self.assertRaises(ValueError):
band.get_related(Band.name) # type: ignore
self.band.get_related("abc123")

async def test_multiple_levels(self):
"""
Make sure ``get_related`` works multiple levels deep.
"""
concert = Concert(band_1=self.band)
await concert.save()

# Make sure it also works using a string
_manager_2 = t.cast(Manager, band.get_related("manager").run_sync())
self.assertTrue(_manager_2.name == "Guido")
manager = await concert.get_related(Concert.band_1._.manager)
assert manager is not None
self.assertTrue(manager.name == "Guido")

# Test an invalid string
with self.assertRaises(ValueError):
band.get_related("abc123")
band_2_manager = await concert.get_related(Concert.band_2._.manager)
assert band_2_manager is None
6 changes: 6 additions & 0 deletions tests/type_checking.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ async def get_related() -> None:
manager = await band.get_related(Band.manager)
assert_type(manager, t.Optional[Manager])

async def get_related_multiple_levels() -> None:
concert = await Concert.objects().first()
assert concert is not None
manager = await concert.get_related(Concert.band_1._.manager)
assert_type(manager, t.Optional[Manager])

async def get_or_create() -> None:
query = Band.objects().get_or_create(Band.name == "Pythonistas")
assert_type(await query, Band)
Expand Down

0 comments on commit 2cae26b

Please sign in to comment.