From c5fb6e979c7ce9de5ef80bb9f424a4d872c79d81 Mon Sep 17 00:00:00 2001 From: Davide Date: Tue, 12 Dec 2023 14:56:39 +0100 Subject: [PATCH] Add get_placeholder to handle automatic check constraint evaluation --- netfields/fields.py | 3 +++ test/models.py | 18 ++++++++++++++++++ test/tests/test_sql_fields.py | 13 +++++++++++++ 3 files changed, 34 insertions(+) diff --git a/netfields/fields.py b/netfields/fields.py index 42800eb..6a48d3f 100644 --- a/netfields/fields.py +++ b/netfields/fields.py @@ -120,6 +120,9 @@ def get_db_prep_lookup(self, lookup_type, value, connection, return super(_NetAddressField, self).get_db_prep_lookup( lookup_type, value, connection=connection, prepared=prepared) + def get_placeholder(self, value, compiler, connection): + return "%s::{}".format(self.db_type(connection)) + def formfield(self, **kwargs): defaults = {'form_class': self.form_class()} defaults.update(kwargs) diff --git a/test/models.py b/test/models.py index 4829a11..f78e95c 100644 --- a/test/models.py +++ b/test/models.py @@ -1,3 +1,4 @@ +from django import VERSION from django.contrib.postgres.fields import ArrayField from django.db.models import CASCADE, ForeignKey, Model @@ -117,3 +118,20 @@ class AggregateTestChildModel(Model): ) network = CidrAddressField() inet = InetAddressField() + + +if VERSION >= (4, 1): + from django.db.models import F, Q, CheckConstraint + + + class ConstraintModel(Model): + network = CidrAddressField() + inet = InetAddressField() + + class Meta: + constraints = ( + CheckConstraint( + check=Q(network__net_contains=F('inet')), + name='inet_contained', + ), + ) diff --git a/test/tests/test_sql_fields.py b/test/tests/test_sql_fields.py index 72d0740..e7fc0ec 100644 --- a/test/tests/test_sql_fields.py +++ b/test/tests/test_sql_fields.py @@ -772,3 +772,16 @@ def test_aggregate_network(self): self.assertEqual(network_qs[0].agg_network, [None]) AggregateTestChildModel.objects.create(parent=parent, network=network, inet=inet) self.assertEqual(network_qs[0].agg_network, [network]) + + +class TestConstraints(TestCase): + + @skipIf(VERSION < (4, 1), 'Check constraint validation is supported from django 4.1 onwards') + def test_check_constraint(self): + from test.models import ConstraintModel + + inet = IPv4Interface('10.10.10.20/32') + network = IPv4Network('10.10.10.0/24') + model = ConstraintModel(inet=inet, network=network) + model.full_clean() + model.save()