diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 28f5cf50..24bc6c62 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -10,6 +10,7 @@ Changelog We believe that this should not cause a noticable performance change, and the number of queries involved should not change. * Add Django 5.0 support (no code changes were needed, but now we test this release). * Add Python 3.12 support +* Add support for dumpdata/loaddata using natural keys 5.0.1 (2023-10-26) ~~~~~~~~~~~~~~~~~~ diff --git a/docs/index.rst b/docs/index.rst index 611c28d3..db978980 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -13,6 +13,7 @@ tagging to your project easy and fun. forms admin serializers + testing api faq custom_tagging diff --git a/docs/testing.rst b/docs/testing.rst new file mode 100644 index 00000000..a18763e7 --- /dev/null +++ b/docs/testing.rst @@ -0,0 +1,14 @@ +Testing +======= + +Natural Key Support +------------------- +We have added `natural key support `_ to the Tag model in the Django taggit library. This allows you to identify objects by human-readable identifiers rather than by their database ID:: + + python manage.py dumpdata taggit.Tag --natural-foreign --natural-primary > tags.json + + python manage.py loaddata tags.json + +By default tags use the name field as the natural key. + +You can customize this in your own custom tag model by setting the ``natural_key_fields`` property on your model the required fields. diff --git a/taggit/models.py b/taggit/models.py index 8d7f60bd..091d733b 100644 --- a/taggit/models.py +++ b/taggit/models.py @@ -15,7 +15,25 @@ def unidecode(tag): return tag -class TagBase(models.Model): +class NaturalKeyManager(models.Manager): + def get_by_natural_key(self, *args): + if len(args) != len(self.model.natural_key_fields): + raise ValueError( + "Number of arguments does not match number of natural key fields." + ) + lookup_kwargs = dict(zip(self.model.natural_key_fields, args)) + return self.get(**lookup_kwargs) + + +class NaturalKeyModel(models.Model): + def natural_key(self): + return [getattr(self, field) for field in self.natural_key_fields] + + class Meta: + abstract = True + + +class TagBase(NaturalKeyModel): name = models.CharField( verbose_name=pgettext_lazy("A tag name", "name"), unique=True, max_length=100 ) @@ -26,6 +44,9 @@ class TagBase(models.Model): allow_unicode=True, ) + natural_key_fields = ["name"] + objects = NaturalKeyManager() + def __str__(self): return self.name @@ -91,13 +112,15 @@ class Meta: app_label = "taggit" -class ItemBase(models.Model): +class ItemBase(NaturalKeyModel): def __str__(self): return gettext("%(object)s tagged with %(tag)s") % { "object": self.content_object, "tag": self.tag, } + objects = NaturalKeyManager() + class Meta: abstract = True @@ -170,6 +193,7 @@ def tags_for(cls, model, instance=None, **extra_filters): class GenericTaggedItemBase(CommonGenericTaggedItemBase): object_id = models.IntegerField(verbose_name=_("object ID"), db_index=True) + natural_key_fields = ["object_id"] class Meta: abstract = True @@ -177,6 +201,7 @@ class Meta: class GenericUUIDTaggedItemBase(CommonGenericTaggedItemBase): object_id = models.UUIDField(verbose_name=_("object ID"), db_index=True) + natural_key_fields = ["object_id"] class Meta: abstract = True diff --git a/tests/tests.py b/tests/tests.py index 79489598..38a14e40 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -1,3 +1,4 @@ +import os from io import StringIO from unittest import mock @@ -1398,3 +1399,111 @@ def test_tests_have_no_pending_migrations(self): out = StringIO() call_command("makemigrations", "tests", dry_run=True, stdout=out) self.assertEqual(out.getvalue().strip(), "No changes detected in app 'tests'") + + +class NaturalKeyTests(TestCase): + @classmethod + def setUpTestData(cls): + cls.tag_names = ["circle", "square", "triangle", "rectangle", "pentagon"] + cls.filename = "test_data_dump.json" + cls.tag_count = len(cls.tag_names) + + def setUp(self): + self.tags = self._create_tags() + + def tearDown(self): + self._clear_existing_tags() + try: + os.remove(self.filename) + except FileNotFoundError: + pass + + @property + def _queryset(self): + return Tag.objects.filter(name__in=self.tag_names) + + def _create_tags(self): + return Tag.objects.bulk_create( + [Tag(name=shape, slug=shape) for shape in self.tag_names], + ignore_conflicts=True, + ) + + def _clear_existing_tags(self): + self._queryset.delete() + + def _dump_model(self, model): + model_label = model._meta.label + with open(self.filename, "w") as f: + call_command( + "dumpdata", + model_label, + natural_primary=True, + use_natural_foreign_keys=True, + stdout=f, + ) + + def _load_model(self): + call_command("loaddata", self.filename) + + def test_tag_natural_key(self): + """ + Test that tags can be dumped and loaded using natural keys. + """ + + # confirm count in the DB + self.assertEqual(self._queryset.count(), self.tag_count) + + # dump all tags to a file + self._dump_model(Tag) + + # Delete all tags + self._clear_existing_tags() + + # confirm all tags clear + self.assertEqual(self._queryset.count(), 0) + + # load the tags from the file + self._load_model() + + # confirm count in the DB + self.assertEqual(self._queryset.count(), self.tag_count) + + def test_tag_reloading_with_changed_pk(self): + """Test that tags are not reliant on the primary key of the tag model. + + Test that data is correctly loaded after database state has changed. + + """ + original_shape = self._queryset.first() + original_pk = original_shape.pk + original_shape_name = original_shape.name + new_shape_name = "hexagon" + + # dump all tags to a file + self._dump_model(Tag) + + # Delete the tag + self._clear_existing_tags() + + # create new tag with the same PK + Tag.objects.create(name=new_shape_name, slug=new_shape_name, pk=original_pk) + + # Load the tags from the file + self._load_model() + + # confirm that load did not overwrite the new_shape + self.assertEqual(Tag.objects.get(pk=original_pk).name, new_shape_name) + + # confirm that the original shape was reloaded with a different PK + self.assertNotEqual(Tag.objects.get(name=original_shape_name).pk, original_pk) + + def test_get_by_natural_key(self): + # Test retrieval of tags by their natural key + for name in self.tag_names: + tag = Tag.objects.get_by_natural_key(name) + self.assertEqual(tag.name, name) + + def test_wrong_number_of_args(self): + # Test that get_by_natural_key raises an error when the wrong number of args is passed + with self.assertRaises(ValueError): + Tag.objects.get_by_natural_key()