Skip to content

Commit

Permalink
feat: support timedelta serialization to double seconds
Browse files Browse the repository at this point in the history
Serialize and deserialize Python `datetime.timedelta` fields to/from an
Avro `double` number of seconds. This uses `fastavro`'s [Custom Logical
Types](https://fastavro.readthedocs.io/en/latest/logical_types.html#custom-logical-types)
functionality.
  • Loading branch information
fajpunk committed Nov 4, 2024
1 parent 67586d5 commit 5eedd8a
Show file tree
Hide file tree
Showing 26 changed files with 176 additions and 6 deletions.
3 changes: 3 additions & 0 deletions dataclasses_avroschema/fields/field_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"LOGICAL_TIME_MICROS",
"LOGICAL_DATETIME_MILIS",
"LOGICAL_DATETIME_MICROS",
"LOGICAL_TIMEDELTA",
"LOGICAL_UUID",
"PYTHON_TYPE_TO_AVRO",
]
Expand All @@ -38,6 +39,7 @@
TIME_MICROS = "time-micros"
TIMESTAMP_MILLIS = "timestamp-millis"
TIMESTAMP_MICROS = "timestamp-micros"
TIMEDELTA = "dataclasses-avroschema-timedelta"

BOOLEAN = "boolean"
NULL = "null"
Expand All @@ -60,6 +62,7 @@
LOGICAL_TIME_MICROS = {"type": LONG, "logicalType": TIME_MICROS}
LOGICAL_DATETIME_MILIS = {"type": LONG, "logicalType": TIMESTAMP_MILLIS}
LOGICAL_DATETIME_MICROS = {"type": LONG, "logicalType": TIMESTAMP_MICROS}
LOGICAL_TIMEDELTA = {"type": DOUBLE, "logicalType": TIMEDELTA}
LOGICAL_UUID = {"type": STRING, "logicalType": UUID}

AVRO_TYPES = (
Expand Down
39 changes: 39 additions & 0 deletions dataclasses_avroschema/fields/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import typing
import uuid

import fastavro
from typing_extensions import get_args, get_origin

from dataclasses_avroschema import (
Expand Down Expand Up @@ -60,6 +61,7 @@
"DateField",
"DatetimeField",
"DatetimeMicroField",
"TimedeltaField",
"TimeMilliField",
"TimeMicroField",
"UUIDField",
Expand Down Expand Up @@ -657,6 +659,43 @@ def fake(self) -> datetime.time:
return datetime_object.time()


@dataclasses.dataclass
class TimedeltaField(ImmutableField):
"""
The timedelta logical represents a absolute length of time.
It annotates an Avro `double`, which stores a number of seconds to microsecond precision.
Note that is different than an Avro `duration`, which could represent different lengths of time depending on
when it is measured from.
This is not an official Avro logical type, so consumers will need to know how to handle it.
"""

@property
def avro_type(self) -> typing.Dict:
return field_utils.LOGICAL_TIMEDELTA

def default_to_avro(self, value: datetime.timedelta) -> float:
return self.to_avro(value)

@classmethod
def from_avro(cls, value: float, *_) -> datetime.timedelta:
"""Convert from a fastavro-supported type to a timedelta."""
return datetime.timedelta(seconds=value)

@classmethod
def to_avro(cls, value: datetime.timedelta, *_) -> float:
"""Convert from a timedelta to a fastavro-supported type."""
return value.total_seconds()

def fake(self) -> datetime.timedelta:
return fake.time_delta(end_datetime=fake.date_time())


fastavro.write.LOGICAL_WRITERS["double-dataclasses-avroschema-timedelta"] = TimedeltaField.to_avro
fastavro.read.LOGICAL_READERS["double-dataclasses-avroschema-timedelta"] = TimedeltaField.from_avro


@dataclasses.dataclass
class DatetimeField(ImmutableField):
"""
Expand Down
1 change: 1 addition & 0 deletions dataclasses_avroschema/fields/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
datetime.time: fields.TimeMilliField,
types.TimeMicro: fields.TimeMicroField,
datetime.datetime: fields.DatetimeField,
datetime.timedelta: fields.TimedeltaField,
types.DateTimeMicro: fields.DatetimeMicroField,
uuid.uuid4: fields.UUIDField,
uuid.UUID: fields.UUIDField,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@
field_utils.TIME_MICROS: "types.TimeMicro",
field_utils.TIMESTAMP_MILLIS: "datetime.datetime",
field_utils.TIMESTAMP_MICROS: "types.DateTimeMicro",
field_utils.TIMEDELTA: "datetime.timedelta",
field_utils.UUID: "uuid.UUID",
}

LOGICAL_TYPES_IMPORTS: typing.Dict[str, str] = {
field_utils.DECIMAL: "import decimal",
field_utils.DATE: "import datetime",
field_utils.TIMEDELTA: "import datetime",
field_utils.TIME_MILLIS: "import datetime",
field_utils.TIME_MICROS: "from dataclasses_avroschema import types",
field_utils.TIMESTAMP_MILLIS: "import datetime",
Expand All @@ -41,6 +43,7 @@
field_utils.TIMESTAMP_MICROS: lambda value: datetime.datetime.fromtimestamp(
value / 1000000, tz=datetime.timezone.utc
),
field_utils.TIMEDELTA: lambda value: datetime.timedelta(seconds=value),
}

# Logical types objects to template
Expand Down Expand Up @@ -74,6 +77,9 @@
second=datetime_obj.second,
microsecond=datetime_obj.microsecond,
),
field_utils.TIMEDELTA: lambda timedelta_obj: templates.timedelta_template.safe_substitute(
seconds=timedelta_obj.total_seconds(),
),
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
DATETIME_MICROS_TEMPLATE = (
"datetime.datetime($year, $month, $day, $hour, $minute, $second, $microsecond, tzinfo=datetime.timezone.utc)"
)
TIMEDELTA_TEMPLATE = "datetime.timedelta(seconds=$seconds)"
DECIMAL_TEMPLATE = "decimal.Decimal('$value')"
DECIMAL_TYPE_TEMPLATE = "types.condecimal(max_digits=$precision, decimal_places=$scale)"

Expand Down Expand Up @@ -85,5 +86,6 @@ class Meta:
time_micros_template = Template(TIME_MICROS_TEMPLATE)
datetime_template = Template(DATETIME_TEMPLATE)
datetime_micros_template = Template(DATETIME_MICROS_TEMPLATE)
timedelta_template = Template(TIMEDELTA_TEMPLATE)
imports_template = Template(IMPORTS_TEMPLATE.strip())
module_template = Template(MODULE_TEMPLATE.strip())
2 changes: 2 additions & 0 deletions dataclasses_avroschema/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,8 @@ def serialize_value(*, value: typing.Any) -> typing.Any:
value = date_to_str(value)
elif isinstance(value, datetime.time):
value = time_to_str(value)
elif isinstance(value, datetime.timedelta):
value = value.total_seconds()
elif isinstance(value, (uuid.UUID, decimal.Decimal)):
value = str(value)
elif isinstance(value, dict):
Expand Down
40 changes: 40 additions & 0 deletions docs/logical_types.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ The following list represent the avro logical types mapped to python types:
| long | time-micros | types.TimeMicro |
| long | timestamp-millis | datetime.datetime |
| long | timestamp-micros | types.DateTimeMicro |
| double | timedelta | datetime.timedelta |
| string | uuid | uuid.uuid4 |
| string | uuid | uuid.UUID |
| bytes | decimal | types.condecimal |
Expand Down Expand Up @@ -171,6 +172,45 @@ DatetimeLogicalType.avro_schema()
!!! note
To use `timestamp-micros` in avro schemas you need to use `types.DateTimeMicro`

## Timedelta

`timedelta` fields are serialized to a `double` number of seconds.

```python title="Timedelta example"
import datetime
import dataclasses
import typing

from dataclasses_avroschema import AvroModel

delta = datetime.timedelta(weeks=1, days=2, hours=3, minutes=4, seconds=5, milliseconds=6, microseconds=7)

@dataclasses.dataclass
class TimedeltaLogicalType(AvroModel):
"Timedelta logical type"
time_elapsed: datetime.timedelta = delta

DatetimeLogicalType.avro_schema()

'{
"type": "record",
"name": "DatetimeLogicalType",
"fields": [
{
"name": "time_elapsed",
"type": {
"type": "double",
"logicalType": "dataclasses-avroschema-timedelta"
},
"default": 788645.006007
}
],
"doc": "Timedelta logical type"
}'
```

*(This script is complete, it should run "as is")*

## UUID

```python title="UUID example"
Expand Down
15 changes: 15 additions & 0 deletions models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from dataclasses_avroschema.pydantic import AvroBaseModel
import pydantic
import typing



class Infrastructure(AvroBaseModel):
email: pydantic.EmailStr
kafka_url: pydantic.KafkaDsn
total_nodes: pydantic.PositiveInt
event_id: pydantic.UUID1
landing_zone_nodes: typing.List[pydantic.PositiveInt]
total_nodes_in_aws: pydantic.PositiveInt = 10
optional_kafka_url: typing.Optional[pydantic.KafkaDsn] = None

1 change: 1 addition & 0 deletions tests/fake/test_fake.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class LogicalTypes(AvroModel):
meeting_time_micro: types.TimeMicro
release_datetime: datetime.datetime
release_datetime_micro: types.DateTimeMicro
time_elapsed: datetime.timedelta
event_uuid: uuid.UUID

assert isinstance(LogicalTypes.fake(), LogicalTypes)
Expand Down
1 change: 1 addition & 0 deletions tests/fake/test_fake_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class LogicalTypes(AvroBaseModel):
meeting_time_micro: types.TimeMicro
release_datetime: datetime.datetime
release_datetime_micro: types.DateTimeMicro
time_elapsed: datetime.timedelta
event_uuid: uuid.UUID

assert isinstance(LogicalTypes.fake(), LogicalTypes)
Expand Down
1 change: 1 addition & 0 deletions tests/fake/test_fake_pydantic_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class LogicalTypes(AvroBaseModel):
meeting_time_micro: types.TimeMicro
release_datetime: datetime.datetime
release_datetime_micro: types.DateTimeMicro
time_elapsed: datetime.timedelta
event_uuid: uuid.UUID

assert isinstance(LogicalTypes.fake(), LogicalTypes)
Expand Down
19 changes: 19 additions & 0 deletions tests/fields/test_logical_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,25 @@ def test_logical_type_datetime_with_default() -> None:
assert expected == field_with_default_factory.to_dict()


def test_logical_type_timedelta_with_default() -> None:
name = "a timedelta"
python_type = datetime.timedelta
delta = datetime.timedelta(seconds=1.234567)
seconds = delta.total_seconds()

field = AvroField(name, python_type, default=delta)
field_with_default_factory = AvroField(name, python_type, default_factory=lambda: delta)

expected = {
"name": name,
"type": {"type": field_utils.DOUBLE, "logicalType": field_utils.TIMEDELTA},
"default": seconds,
}

assert expected == field.to_dict()
assert expected == field_with_default_factory.to_dict()


@pytest.mark.parametrize(
"python_type,avro_type",
(
Expand Down
5 changes: 5 additions & 0 deletions tests/model_generator/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,6 +959,11 @@ def schema_with_logical_types() -> JsonDict:
"type": {"type": "long", "logicalType": "timestamp-micros"},
"default": 1570903062000000,
},
{
"name": "time_elapsed",
"type": {"type": "double", "logicalType": "dataclasses-avroschema-timedelta"},
"default": 788645.006007,
},
{
"name": "uuid_2",
"type": ["null", {"type": "string", "logicalType": "uuid"}],
Expand Down
1 change: 1 addition & 0 deletions tests/model_generator/test_model_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,7 @@ class LogicalTypes(AvroModel):
meeting_datetime: typing.Optional[datetime.datetime] = None
release_datetime: datetime.datetime = {release_datetime}
release_datetime_micro: types.DateTimeMicro = {release_datetime_micro}
time_elapsed: datetime.timedelta = datetime.timedelta(seconds=788645.006007)
uuid_2: typing.Optional[uuid.UUID] = None
event_uuid: uuid.UUID = "ad0677ab-bd1c-4383-9d45-e46c56bcc5c9"
explicit_with_default: types.condecimal(max_digits=3, decimal_places=2) = decimal.Decimal('3.14')
Expand Down
1 change: 1 addition & 0 deletions tests/model_generator/test_model_pydantic_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ class LogicalTypes(AvroBaseModel):
meeting_datetime: typing.Optional[datetime.datetime] = None
release_datetime: datetime.datetime = {release_datetime}
release_datetime_micro: types.DateTimeMicro = {release_datetime_micro}
time_elapsed: datetime.timedelta = datetime.timedelta(seconds=788645.006007)
uuid_2: typing.Optional[uuid.UUID] = None
event_uuid: uuid.UUID = "ad0677ab-bd1c-4383-9d45-e46c56bcc5c9"
explicit_with_default: types.condecimal(max_digits=3, decimal_places=2) = decimal.Decimal('3.14')
Expand Down
10 changes: 9 additions & 1 deletion tests/schemas/avro/logical_types.avsc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@
},
"default": 1570903062000
},
{
"name": "time_elapsed",
"type": {
"type": "double",
"logicalType": "dataclasses-avroschema-timedelta"
},
"default": 788645.006007
},
{
"name": "event_uuid",
"type": {
Expand All @@ -37,4 +45,4 @@
],
"doc": "Some logical types"
}

10 changes: 9 additions & 1 deletion tests/schemas/avro/logical_types_pydantic.avsc
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,14 @@
},
"default": 1570903062000
},
{
"name": "time_elapsed",
"type": {
"type": "double",
"logicalType": "dataclasses-avroschema-timedelta"
},
"default": 788645.006007
},
{
"name": "event_uuid",
"type": {
Expand All @@ -90,4 +98,4 @@
}
],
"doc": "Some logical types"
}
}
1 change: 1 addition & 0 deletions tests/schemas/avro/union_type.avsc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"type": [
{"type": "long", "logicalType": "timestamp-millis"},
{"type": "int", "logicalType": "date"},
{"type": "double", "logicalType": "dataclasses-avroschema-timedelta"},
{"type": "string", "logicalType": "uuid"}
]
},
Expand Down
4 changes: 3 additions & 1 deletion tests/schemas/pydantic/test_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def test_pydantic_record_schema_logical_types(logical_types_pydantic_schema):
a_past_datetime = datetime.datetime(2019, 10, 12, 17, 57, 42, tzinfo=datetime.timezone.utc)
a_future_datetime = datetime.datetime(9999, 12, 31, 23, 59, 59)
a_naive_datetime = datetime.datetime(2019, 10, 12, 17, 57, 42)
delta = datetime.timedelta(weeks=1, days=2, hours=3, minutes=4, seconds=5, milliseconds=6, microseconds=7)

class LogicalTypesPydantic(AvroBaseModel):
"Some logical types"
Expand All @@ -163,6 +164,7 @@ class LogicalTypesPydantic(AvroBaseModel):
future_datetime: FutureDatetime = a_future_datetime
aware_datetime: AwareDatetime = a_datetime
naive_datetime: NaiveDatetime = a_naive_datetime
time_elapsed: datetime.timedelta = delta
event_uuid: uuid.UUID = "09f00184-7721-4266-a955-21048a5cc235"

assert LogicalTypesPydantic.avro_schema() == json.dumps(logical_types_pydantic_schema)
Expand Down Expand Up @@ -337,7 +339,7 @@ class UnionSchema(AvroBaseModel):
"Some Unions"

first_union: typing.Union[str, int]
logical_union: typing.Union[datetime.datetime, datetime.date, uuid.UUID]
logical_union: typing.Union[datetime.datetime, datetime.date, datetime.timedelta, uuid.UUID]
lake_trip: typing.Union[Bus, Car]
river_trip: typing.Union[Bus, Car] = None
mountain_trip: typing.Union[Bus, Car] = Field(default_factory=lambda: Bus(engine_name="honda"))
Expand Down
4 changes: 3 additions & 1 deletion tests/schemas/pydantic/test_pydantic_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,15 @@ class Meta:

def test_pydantic_record_schema_logical_types(logical_types_schema):
a_datetime = datetime.datetime(2019, 10, 12, 17, 57, 42, tzinfo=datetime.timezone.utc)
delta = datetime.timedelta(weeks=1, days=2, hours=3, minutes=4, seconds=5, milliseconds=6, microseconds=7)

class LogicalTypes(AvroBaseModel):
"Some logical types"

birthday: datetime.date = a_datetime.date()
meeting_time: datetime.time = a_datetime.time()
release_datetime: datetime.datetime = a_datetime
time_elapsed: datetime.timedelta = delta
event_uuid: uuid.UUID = "09f00184-7721-4266-a955-21048a5cc235"

assert LogicalTypes.avro_schema() == json.dumps(logical_types_schema)
Expand Down Expand Up @@ -290,7 +292,7 @@ class UnionSchema(AvroBaseModel):
"Some Unions"

first_union: typing.Union[str, int]
logical_union: typing.Union[datetime.datetime, datetime.date, uuid.UUID]
logical_union: typing.Union[datetime.datetime, datetime.date, datetime.timedelta, uuid.UUID]
lake_trip: typing.Union[Bus, Car]
river_trip: typing.Union[Bus, Car] = None
mountain_trip: typing.Union[Bus, Car] = Field(default_factory=lambda: Bus(engine_name="honda"))
Expand Down
Loading

0 comments on commit 5eedd8a

Please sign in to comment.