diff --git a/src/django_pg_hll/fields.py b/src/django_pg_hll/fields.py index 4d18fea..866f932 100644 --- a/src/django_pg_hll/fields.py +++ b/src/django_pg_hll/fields.py @@ -27,10 +27,10 @@ class HllField(BinaryField): } def __init__(self, *args, **kwargs): - self._log2m = kwargs.get('log2m', self.custom_params['log2m']) - self._regwidth = kwargs.get('regwidth', self.custom_params['regwidth']) - self._expthresh = kwargs.get('expthresh', self.custom_params['expthresh']) - self._sparseon = kwargs.get('sparseon', self.custom_params['sparseon']) + self._log2m = kwargs.pop('log2m', self.custom_params['log2m']) + self._regwidth = kwargs.pop('regwidth', self.custom_params['regwidth']) + self._expthresh = kwargs.pop('expthresh', self.custom_params['expthresh']) + self._sparseon = kwargs.pop('sparseon', self.custom_params['sparseon']) super(HllField, self).__init__(*args, **kwargs) @@ -40,7 +40,7 @@ def deconstruct(self): # Only include kwarg if it's not the default for param_name, default in self.custom_params.items(): if getattr(self, '_%s' % param_name) != default: - kwargs[name] = getattr(self, '_%s' % param_name) + kwargs[param_name] = getattr(self, '_%s' % param_name) return name, path, args, kwargs diff --git a/tests/migrations/0001_initial.py b/tests/migrations/0001_initial.py index 8ee5193..393fbaa 100644 --- a/tests/migrations/0001_initial.py +++ b/tests/migrations/0001_initial.py @@ -28,5 +28,15 @@ class Migration(migrations.Migration): options={ 'abstract': False, } + ), + migrations.CreateModel( + name='TestConfiguredModel', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('hll_field', HllField(log2m=13, regwidth=2, expthresh=1, sparseon=0)), + ], + options={ + 'abstract': False, + } ) ] diff --git a/tests/models.py b/tests/models.py index 596ec3c..cf541db 100644 --- a/tests/models.py +++ b/tests/models.py @@ -11,3 +11,7 @@ class TestModel(models.Model): hll_field = HllField() fk = models.ForeignKey(FKModel, null=True, blank=True, on_delete=models.CASCADE) + +class TestConfiguredModel(models.Model): + hll_field = HllField(log2m=13, regwidth=2, expthresh=1, sparseon=0) + diff --git a/tests/test_hll_field.py b/tests/test_hll_field.py index 2fc7465..c614e09 100644 --- a/tests/test_hll_field.py +++ b/tests/test_hll_field.py @@ -12,7 +12,7 @@ from django_pg_hll.bulk_update import HllConcatFunction from django_pg_hll.compatibility import django_pg_bulk_update_available -from tests.models import TestModel, FKModel +from tests.models import TestConfiguredModel, TestModel, FKModel class HllFieldTest(TestCase): @@ -40,6 +40,14 @@ def test_combine_auto_parse(self): def test_create(self): TestModel.objects.create(hll_field=HllEmpty()) + def test_create_custom_params(self): + with connection.cursor() as cursor: + cursor.execute('select hll_set_defaults(13,2,1,0);') + try: + TestConfiguredModel.objects.create(hll_field=HllEmpty()) + finally: + cursor.execute('select hll_set_defaults(11,5,-1,1);') + def test_migration(self): query = "SELECT hll_cardinality(hll_field) FROM tests_testmodel;" self.cursor.execute(query)