diff --git a/pynamodb_attributes/__init__.py b/pynamodb_attributes/__init__.py index 9f9008c..e4edb15 100644 --- a/pynamodb_attributes/__init__.py +++ b/pynamodb_attributes/__init__.py @@ -1,3 +1,4 @@ +from .float import FloatAttribute from .integer import IntegerAttribute from .integer_date import IntegerDateAttribute from .integer_enum import IntegerEnumAttribute @@ -8,6 +9,7 @@ from .unicode_enum import UnicodeEnumAttribute __all__ = [ + 'FloatAttribute', 'IntegerAttribute', 'IntegerDateAttribute', 'IntegerEnumAttribute', diff --git a/pynamodb_attributes/float.py b/pynamodb_attributes/float.py new file mode 100644 index 0000000..a3b5bd1 --- /dev/null +++ b/pynamodb_attributes/float.py @@ -0,0 +1,11 @@ +from pynamodb.attributes import Attribute +from pynamodb.attributes import NumberAttribute + + +class FloatAttribute(Attribute): + """ + Unlike NumberAttribute, this attribute has its type hinted as 'float'. + """ + attr_type = NumberAttribute.attr_type + serialize = NumberAttribute.serialize + deserialize = NumberAttribute.deserialize diff --git a/pynamodb_attributes/float.pyi b/pynamodb_attributes/float.pyi new file mode 100644 index 0000000..e41e106 --- /dev/null +++ b/pynamodb_attributes/float.pyi @@ -0,0 +1,5 @@ +from ._typing import Attribute + + +class FloatAttribute(Attribute[float]): + ... diff --git a/tests/float_attribute_test.py b/tests/float_attribute_test.py new file mode 100644 index 0000000..1cc80db --- /dev/null +++ b/tests/float_attribute_test.py @@ -0,0 +1,48 @@ +import pytest +from pynamodb.attributes import UnicodeAttribute +from pynamodb.models import Model + +from pynamodb_attributes import FloatAttribute +from tests.meta import dynamodb_table_meta + + +class MyModel(Model): + Meta = dynamodb_table_meta(__name__) + + key = UnicodeAttribute(hash_key=True) + value = FloatAttribute(null=True) + + +@pytest.fixture(scope='module', autouse=True) +def create_table(): + MyModel.create_table() + + +def test_serialization_non_null(uuid_key): + model = MyModel() + model.key = uuid_key + model.value = 45.6 + model.save() + + # verify underlying storage + item = MyModel._get_connection().get_item(uuid_key) + assert item['Item']['value'] == {'N': '45.6'} + + # verify deserialization + model = MyModel.get(uuid_key) + assert model.value == 45.6 + + +def test_serialization_null(uuid_key): + model = MyModel() + model.key = uuid_key + model.value = None + model.save() + + # verify underlying storage + item = MyModel._get_connection().get_item(uuid_key) + assert 'value' not in item['Item'] + + # verify deserialization + model = MyModel.get(uuid_key) + assert model.value is None