Skip to content

Commit

Permalink
fix(TimedeltaField): fix timedelta fastavro serialization so it doesn…
Browse files Browse the repository at this point in the history
…'t break in union types (#797)

Fastavro [custom logical type encoders](https://fastavro.readthedocs.io/en/latest/logical_types.html#custom-logical-types) should explicitly check that they are operating on the expected python type before encoding it, and return the unchanged data otherwise. If a custom logical type is used as a member of a union type, then this encoder function will be called on any value that is provided for a field of that type, even if the value's type is a different member of the union type.
  • Loading branch information
fajpunk authored Nov 18, 2024
1 parent 5e65d6c commit 18326ac
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 4 deletions.
23 changes: 19 additions & 4 deletions dataclasses_avroschema/fields/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,21 +679,36 @@ def default_to_avro(self, value: datetime.timedelta) -> float:
return self.to_avro(value)

@classmethod
def from_avro(cls, value: float, *_) -> datetime.timedelta:
def from_avro(cls, value: float) -> datetime.timedelta:
"""Convert from a fastavro-supported type to a timedelta."""
return datetime.timedelta(seconds=value)

@classmethod
def to_avro(cls, value: datetime.timedelta, *_) -> float:
def to_avro(cls, value: datetime.timedelta) -> float:
"""Convert from a timedelta to a fastavro-supported type."""
return value.total_seconds()

def fake(self) -> datetime.timedelta:
return fake.time_delta(end_datetime=fake.date_time())


fastavro.write.LOGICAL_WRITERS["double-dataclasses-avroschema-timedelta"] = TimedeltaField.to_avro
fastavro.read.LOGICAL_READERS["double-dataclasses-avroschema-timedelta"] = TimedeltaField.from_avro
def _fastavro_serialize_timedelta(data: typing.Any, *_) -> typing.Any:
"""Serialize a timedelta for fastavro."""
# Fastavro will call this function on any value that is provided for a
# union type where one of the union members is a timedelta, so we have to
# explicitly check if the value is a timedelta before trying to convert it.
if isinstance(data, datetime.timedelta):
return TimedeltaField.to_avro(data)
return data


def _fastavro_deserialize_timedelta(data: float, *_) -> datetime.timedelta:
"""Deserialize a timedelta for fastavro."""
return TimedeltaField.from_avro(data)


fastavro.write.LOGICAL_WRITERS["double-dataclasses-avroschema-timedelta"] = _fastavro_serialize_timedelta
fastavro.read.LOGICAL_READERS["double-dataclasses-avroschema-timedelta"] = _fastavro_deserialize_timedelta


@dataclasses.dataclass
Expand Down
3 changes: 3 additions & 0 deletions tests/serialization/test_logical_types_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,16 @@ class UnionSchema(model_class):
"Some Unions"

logical_union: typing.Union[datetime.datetime, datetime.date, uuid.UUID]
logical_union_timedelta: typing.Union[datetime.timedelta, None]

data = {
"logical_union": a_datetime.date(),
"logical_union_timedelta": None,
}

data_json = {
"logical_union": serialization.date_to_str(a_datetime.date()),
"logical_union_timedelta": None,
}

logical_types = UnionSchema(**data)
Expand Down

0 comments on commit 18326ac

Please sign in to comment.