Skip to content

Commit

Permalink
948 Adding a self referencing foreign key to an existing table which …
Browse files Browse the repository at this point in the history
…has a custom primary key (#949)

* fix migration

* catch exception

* add test functions

* reformat with black

* fix typo in comment
  • Loading branch information
dantownsend authored Mar 12, 2024
1 parent bd5ef61 commit 39959b4
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 30 deletions.
59 changes: 54 additions & 5 deletions piccolo/apps/migrations/auto/migration_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import inspect
import logging
import typing as t
from dataclasses import dataclass, field

Expand All @@ -14,14 +15,16 @@
)
from piccolo.apps.migrations.auto.serialisation import deserialise_params
from piccolo.columns import Column, column_types
from piccolo.columns.column_types import Serial
from piccolo.columns.column_types import ForeignKey, Serial
from piccolo.engine import engine_finder
from piccolo.query import Query
from piccolo.query.base import DDL
from piccolo.schema import SchemaDDLBase
from piccolo.table import Table, create_table_class, sort_table_classes
from piccolo.utils.warnings import colored_warning

logger = logging.getLogger(__name__)


@dataclass
class AddColumnClass:
Expand Down Expand Up @@ -793,26 +796,72 @@ async def _run_add_columns(self, backwards: bool = False):
AddColumnClass
] = self.add_columns.for_table_class_name(table_class_name)

###############################################################
# Define the table, with the columns, so the metaclass
# sets up the columns correctly.

table_class_members = {
add_column.column._meta.name: add_column.column
for add_column in add_columns
}

# There's an extreme edge case, when we're adding a foreign
# key which references its own table, for example:
#
# fk = ForeignKey('self')
#
# And that table has a custom primary key, for example:
#
# id = UUID(primary_key=True)
#
# In this situation, we need to know the primary key of the
# table in order to correctly add this new foreign key.
for add_column in add_columns:
if (
isinstance(add_column.column, ForeignKey)
and add_column.column._meta.params.get("references")
== "self"
):
try:
existing_table = (
await self.get_table_from_snapshot(
table_class_name=table_class_name,
app_name=self.app_name,
offset=-1,
)
)
except ValueError:
logger.error(
"Unable to find primary key for the table - "
"assuming Serial."
)
else:
primary_key = existing_table._meta.primary_key

table_class_members[
primary_key._meta.name
] = primary_key

break

_Table = create_table_class(
class_name=add_columns[0].table_class_name,
class_kwargs={
"tablename": add_columns[0].tablename,
"schema": add_columns[0].schema,
},
class_members={
add_column.column._meta.name: add_column.column
for add_column in add_columns
},
class_members=table_class_members,
)

###############################################################

for add_column in add_columns:
# We fetch the column from the Table, as the metaclass
# copies and sets it up properly.
column = _Table._meta.get_column_by_name(
add_column.column._meta.name
)

await self._run_query(
_Table.alter().add_column(
name=column._meta.name, column=column
Expand Down
77 changes: 52 additions & 25 deletions tests/apps/migrations/auto/integration/test_migrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def _get_app_config(self) -> AppConfig:
def _test_migrations(
self,
table_snapshots: t.List[t.List[t.Type[Table]]],
test_function: t.Optional[t.Callable[[RowMeta], None]] = None,
test_function: t.Optional[t.Callable[[RowMeta], bool]] = None,
):
"""
Writes a migration file to disk and runs it.
Expand Down Expand Up @@ -1040,47 +1040,74 @@ def test_target_column(self):


@engines_only("postgres", "cockroach")
class TestTargetColumnString(MigrationTestCase):
class TestForeignKeySelf(MigrationTestCase):
def setUp(self):
class TableA(Table):
name = Varchar(unique=True)

class TableB(Table):
table_a = ForeignKey(TableA, target_column="name")
id = UUID(primary_key=True)
table_a = ForeignKey("self")

self.table_classes = [TableA, TableB]
self.table_classes: t.List[t.Type[Table]] = [TableA]

def tearDown(self):
drop_db_tables_sync(Migration, *self.table_classes)

def test_target_column(self):
def test_create_table(self):
"""
Make sure migrations still work when a foreign key references a column
other than the primary key.
Make sure migrations still work when:
* Creating a new table with a foreign key which references itself.
* The table has a custom primary key type (e.g. UUID).
"""
self._test_migrations(
table_snapshots=[self.table_classes],
test_function=lambda x: x.data_type == "uuid",
)

for table_class in self.table_classes:
self.assertTrue(table_class.table_exists().run_sync())

# Make sure the constraint was created correctly.
response = self.run_sync(
"""
SELECT EXISTS(
SELECT 1
FROM INFORMATION_SCHEMA.CONSTRAINT_COLUMN_USAGE CCU
JOIN INFORMATION_SCHEMA.TABLE_CONSTRAINTS TC ON
CCU.CONSTRAINT_NAME = TC.CONSTRAINT_NAME
WHERE CONSTRAINT_TYPE = 'FOREIGN KEY'
AND TC.TABLE_NAME = 'table_b'
AND CCU.TABLE_NAME = 'table_a'
AND CCU.COLUMN_NAME = 'name'
)
"""

@engines_only("postgres", "cockroach")
class TestAddForeignKeySelf(MigrationTestCase):
def setUp(self):
pass

def tearDown(self):
drop_db_tables_sync(create_table_class("MyTable"), Migration)

@patch("piccolo.conf.apps.Finder.get_app_config")
def test_add_column(self, get_app_config):
"""
Make sure migrations still work when:
* A foreign key is added to an existing table.
* The foreign key references its own table.
* The table has a custom primary key (e.g. UUID).
"""
get_app_config.return_value = self._get_app_config()

self._test_migrations(
table_snapshots=[
[
create_table_class(
class_name="MyTable",
class_members={"id": UUID(primary_key=True)},
)
],
[
create_table_class(
class_name="MyTable",
class_members={
"id": UUID(primary_key=True),
"fk": ForeignKey("self"),
},
)
],
],
test_function=lambda x: x.data_type == "uuid",
)
self.assertTrue(response[0]["exists"])


###############################################################################
Expand Down

0 comments on commit 39959b4

Please sign in to comment.