From 18326acd9a274068ae197b4edc4f2060943a14a0 Mon Sep 17 00:00:00 2001 From: Dan Fuchs <330402+fajpunk@users.noreply.github.com> Date: Mon, 18 Nov 2024 14:43:06 -0600 Subject: [PATCH] fix(TimedeltaField): fix timedelta fastavro serialization so it doesn't break in union types (#797) Fastavro [custom logical type encoders](https://fastavro.readthedocs.io/en/latest/logical_types.html#custom-logical-types) should explicitly check that they are operating on the expected python type before encoding it, and return the unchanged data otherwise. If a custom logical type is used as a member of a union type, then this encoder function will be called on any value that is provided for a field of that type, even if the value's type is a different member of the union type. --- dataclasses_avroschema/fields/fields.py | 23 +++++++++++++++---- .../test_logical_types_serialization.py | 3 +++ 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/dataclasses_avroschema/fields/fields.py b/dataclasses_avroschema/fields/fields.py index 223332d4..a0b0c89b 100644 --- a/dataclasses_avroschema/fields/fields.py +++ b/dataclasses_avroschema/fields/fields.py @@ -679,12 +679,12 @@ def default_to_avro(self, value: datetime.timedelta) -> float: return self.to_avro(value) @classmethod - def from_avro(cls, value: float, *_) -> datetime.timedelta: + def from_avro(cls, value: float) -> datetime.timedelta: """Convert from a fastavro-supported type to a timedelta.""" return datetime.timedelta(seconds=value) @classmethod - def to_avro(cls, value: datetime.timedelta, *_) -> float: + def to_avro(cls, value: datetime.timedelta) -> float: """Convert from a timedelta to a fastavro-supported type.""" return value.total_seconds() @@ -692,8 +692,23 @@ def fake(self) -> datetime.timedelta: return fake.time_delta(end_datetime=fake.date_time()) -fastavro.write.LOGICAL_WRITERS["double-dataclasses-avroschema-timedelta"] = TimedeltaField.to_avro -fastavro.read.LOGICAL_READERS["double-dataclasses-avroschema-timedelta"] = TimedeltaField.from_avro +def _fastavro_serialize_timedelta(data: typing.Any, *_) -> typing.Any: + """Serialize a timedelta for fastavro.""" + # Fastavro will call this function on any value that is provided for a + # union type where one of the union members is a timedelta, so we have to + # explicitly check if the value is a timedelta before trying to convert it. + if isinstance(data, datetime.timedelta): + return TimedeltaField.to_avro(data) + return data + + +def _fastavro_deserialize_timedelta(data: float, *_) -> datetime.timedelta: + """Deserialize a timedelta for fastavro.""" + return TimedeltaField.from_avro(data) + + +fastavro.write.LOGICAL_WRITERS["double-dataclasses-avroschema-timedelta"] = _fastavro_serialize_timedelta +fastavro.read.LOGICAL_READERS["double-dataclasses-avroschema-timedelta"] = _fastavro_deserialize_timedelta @dataclasses.dataclass diff --git a/tests/serialization/test_logical_types_serialization.py b/tests/serialization/test_logical_types_serialization.py index 6894c31a..d0e29cb5 100644 --- a/tests/serialization/test_logical_types_serialization.py +++ b/tests/serialization/test_logical_types_serialization.py @@ -82,13 +82,16 @@ class UnionSchema(model_class): "Some Unions" logical_union: typing.Union[datetime.datetime, datetime.date, uuid.UUID] + logical_union_timedelta: typing.Union[datetime.timedelta, None] data = { "logical_union": a_datetime.date(), + "logical_union_timedelta": None, } data_json = { "logical_union": serialization.date_to_str(a_datetime.date()), + "logical_union_timedelta": None, } logical_types = UnionSchema(**data)