diff --git a/dataclasses_avroschema/fields/fields.py b/dataclasses_avroschema/fields/fields.py index 223332d4..a0b0c89b 100644 --- a/dataclasses_avroschema/fields/fields.py +++ b/dataclasses_avroschema/fields/fields.py @@ -679,12 +679,12 @@ def default_to_avro(self, value: datetime.timedelta) -> float: return self.to_avro(value) @classmethod - def from_avro(cls, value: float, *_) -> datetime.timedelta: + def from_avro(cls, value: float) -> datetime.timedelta: """Convert from a fastavro-supported type to a timedelta.""" return datetime.timedelta(seconds=value) @classmethod - def to_avro(cls, value: datetime.timedelta, *_) -> float: + def to_avro(cls, value: datetime.timedelta) -> float: """Convert from a timedelta to a fastavro-supported type.""" return value.total_seconds() @@ -692,8 +692,23 @@ def fake(self) -> datetime.timedelta: return fake.time_delta(end_datetime=fake.date_time()) -fastavro.write.LOGICAL_WRITERS["double-dataclasses-avroschema-timedelta"] = TimedeltaField.to_avro -fastavro.read.LOGICAL_READERS["double-dataclasses-avroschema-timedelta"] = TimedeltaField.from_avro +def _fastavro_serialize_timedelta(data: typing.Any, *_) -> typing.Any: + """Serialize a timedelta for fastavro.""" + # Fastavro will call this function on any value that is provided for a + # union type where one of the union members is a timedelta, so we have to + # explicitly check if the value is a timedelta before trying to convert it. + if isinstance(data, datetime.timedelta): + return TimedeltaField.to_avro(data) + return data + + +def _fastavro_deserialize_timedelta(data: float, *_) -> datetime.timedelta: + """Deserialize a timedelta for fastavro.""" + return TimedeltaField.from_avro(data) + + +fastavro.write.LOGICAL_WRITERS["double-dataclasses-avroschema-timedelta"] = _fastavro_serialize_timedelta +fastavro.read.LOGICAL_READERS["double-dataclasses-avroschema-timedelta"] = _fastavro_deserialize_timedelta @dataclasses.dataclass diff --git a/tests/serialization/test_logical_types_serialization.py b/tests/serialization/test_logical_types_serialization.py index 6894c31a..d0e29cb5 100644 --- a/tests/serialization/test_logical_types_serialization.py +++ b/tests/serialization/test_logical_types_serialization.py @@ -82,13 +82,16 @@ class UnionSchema(model_class): "Some Unions" logical_union: typing.Union[datetime.datetime, datetime.date, uuid.UUID] + logical_union_timedelta: typing.Union[datetime.timedelta, None] data = { "logical_union": a_datetime.date(), + "logical_union_timedelta": None, } data_json = { "logical_union": serialization.date_to_str(a_datetime.date()), + "logical_union_timedelta": None, } logical_types = UnionSchema(**data)