Skip to content

Commit

Permalink
Fix serialization, make custom serializers in pydantic work with csp.…
Browse files Browse the repository at this point in the history
…Structs

Signed-off-by: Nijat Khanbabayev <[email protected]>
  • Loading branch information
NeejWeej committed Jan 20, 2025
1 parent c252851 commit 71844f5
Show file tree
Hide file tree
Showing 4 changed files with 263 additions and 7 deletions.
2 changes: 1 addition & 1 deletion csp/impl/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def _validate(cls, v) -> "Enum":
raise ValueError(f"Cannot convert value to enum: {v}")

@staticmethod
def _serialize(value: "Enum") -> str:
def _serialize(value: typing.Union[str, "Enum"]) -> str:
return value.name

@classmethod
Expand Down
20 changes: 16 additions & 4 deletions csp/impl/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def layout(self, num_cols=8):

@staticmethod
def _get_pydantic_core_schema(cls, _source_type, handler):
"""Tell Pydantic how to validate this Struct class."""
"""Tell Pydantic how to validate and serialize this Struct class."""
from pydantic import PydanticSchemaGenerationError
from pydantic_core import core_schema

Expand Down Expand Up @@ -131,12 +131,24 @@ def create_instance(validated_data):
data_dict = validated_data[0] if isinstance(validated_data, tuple) else validated_data
return cls(**data_dict)

def serializer(val, handler):
# We don't use 'to_dict' since that works recursively
# Avoid unneeded hasattr call
# From testing, this speeds up dict creation by ~30%
new_val = {}
for k in val.__full_metadata_typed__:
try:
attr_val = getattr(val, k)
except AttributeError:
continue
new_val[k] = attr_val
return handler(new_val)

return core_schema.no_info_after_validator_function(
function=create_instance,
schema=schema,
serialization=core_schema.plain_serializer_function_ser_schema(
function=lambda x: x.to_dict(), # Use the built-in to_dict method
return_schema=core_schema.dict_schema(),
serialization=core_schema.wrap_serializer_function_ser_schema(
function=serializer, schema=fields_schema, when_used="always"
),
)

Expand Down
87 changes: 87 additions & 0 deletions csp/tests/impl/test_enum.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import _pickle
import json
import pytest
import unittest
from datetime import datetime, timedelta
from pydantic import BaseModel, ConfigDict, RootModel
from typing import Dict, List

import csp
from csp import ts
Expand Down Expand Up @@ -36,6 +40,22 @@ def s1():
MyDEnum = csp.DynamicEnum("MyDEnum", ["A", "B", "C"])


class MyEnum3(csp.Enum):
FIELD1 = csp.Enum.auto()
FIELD2 = csp.Enum.auto()


class MyModel(BaseModel):
enum: MyEnum3
enum_default: MyEnum3 = MyEnum3.FIELD1


class MyDictModel(BaseModel):
model_config = ConfigDict(use_enum_values=True)

enum_dict: Dict[MyEnum3, int] = None


class TestCspEnum(unittest.TestCase):
def test_basic(self):
self.assertEqual(MyEnum("A"), MyEnum.A)
Expand Down Expand Up @@ -152,6 +172,73 @@ class B(A):

self.assertEqual("Cannot extend csp.Enum 'A': inheriting from an Enum is prohibited", str(cm.exception))

def test_pydantic_validation(self):
assert MyModel(enum="FIELD2").enum == MyEnum3.FIELD2
assert MyModel(enum=0).enum == MyEnum3.FIELD1
assert MyModel(enum=MyEnum3.FIELD1).enum == MyEnum3.FIELD1
with pytest.raises(ValueError):
MyModel(enum=3.14)

def test_pydantic_dict(self):
assert dict(MyModel(enum=MyEnum3.FIELD2)) == {"enum": MyEnum3.FIELD2, "enum_default": MyEnum3.FIELD1}
assert MyModel(enum=MyEnum3.FIELD2).model_dump(mode="python") == {
"enum": MyEnum3.FIELD2,
"enum_default": MyEnum3.FIELD1,
}
assert MyModel(enum=MyEnum3.FIELD2).model_dump(mode="json") == {"enum": "FIELD2", "enum_default": "FIELD1"}

def test_pydantic_serialization(self):
assert "enum" in MyModel.model_fields
assert "enum_default" in MyModel.model_fields
tm = MyModel(enum=MyEnum3.FIELD2)
assert json.loads(tm.model_dump_json()) == json.loads('{"enum": "FIELD2", "enum_default": "FIELD1"}')

def test_enum_as_dict_key_json_serialization(self):
class DictWrapper(RootModel[Dict[MyEnum3, int]]):
model_config = ConfigDict(use_enum_values=True)

def __getitem__(self, item):
return self.root[item]

class MyDictWrapperModel(BaseModel):
model_config = ConfigDict(use_enum_values=True)

enum_dict: DictWrapper

dict_model = MyDictModel(enum_dict={MyEnum3.FIELD1: 8, MyEnum3.FIELD2: 19})
assert dict_model.enum_dict[MyEnum3.FIELD1] == 8
assert dict_model.enum_dict[MyEnum3.FIELD2] == 19

assert json.loads(dict_model.model_dump_json()) == json.loads('{"enum_dict":{"FIELD1":8,"FIELD2":19}}')

dict_wrapper_model = MyDictWrapperModel(enum_dict=DictWrapper({MyEnum3.FIELD1: 8, MyEnum3.FIELD2: 19}))

assert dict_wrapper_model.enum_dict[MyEnum3.FIELD1] == 8
assert dict_wrapper_model.enum_dict[MyEnum3.FIELD2] == 19
assert json.loads(dict_wrapper_model.model_dump_json()) == json.loads('{"enum_dict":{"FIELD1":8,"FIELD2":19}}')

def test_json_schema_csp(self):
assert MyModel.model_json_schema() == {
"properties": {
"enum": {
"description": "An enumeration of MyEnum3",
"enum": ["FIELD1", "FIELD2"],
"title": "MyEnum3",
"type": "string",
},
"enum_default": {
"default": "FIELD1",
"description": "An enumeration of MyEnum3",
"enum": ["FIELD1", "FIELD2"],
"title": "MyEnum3",
"type": "string",
},
},
"required": ["enum"],
"title": "MyModel",
"type": "object",
}


if __name__ == "__main__":
unittest.main()
161 changes: 159 additions & 2 deletions csp/tests/impl/test_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -3567,10 +3567,167 @@ class DataPoint(csp.Struct):
self.assertIsInstance(result.history[2], BaseMetric) # Should be base

# Test serialization and deserialization preserves specific types
json_data = result.to_json()
restored = TypeAdapter(DataPoint).validate_json(json_data)
json_data_csp = result.to_json()
json_data_pydantic = TypeAdapter(DataPoint).dump_json(result).decode()
self.assertEqual(json.loads(json_data_csp), json.loads(json_data_pydantic))
restored = TypeAdapter(DataPoint).validate_json(json_data_csp)
self.assertEqual(restored, result)

def test_pydantic_custom_serialization(self):
"""Test that CustomStruct correctly serializes integers with comma formatting"""
from pydantic.functional_serializers import PlainSerializer

# Define the custom integer type with fancy formatting
FancyInt = Annotated[int, PlainSerializer(lambda x: f"{x:,}", return_type=str, when_used="always")]

# Simple struct with just the FancyInt
class CustomStruct(csp.Struct):
value: FancyInt

# Test different integer values
test_cases = [
(1234, "1,234"),
(1000000, "1,000,000"),
(42, "42"),
]

for input_value, expected_output in test_cases:
# Create and serialize the struct
s = CustomStruct(value=input_value)
serialized = json.loads(TypeAdapter(CustomStruct).dump_json(s))

# Verify the serialization
self.assertEqual(
serialized["value"],
expected_output,
)

def test_pydantic_serialization_with_enums(self):
"""Test serialization behavior with enums using both native and Pydantic approaches"""

class Color(csp.Enum):
RED = 1
GREEN = 2
BLUE = 3

class Shape(csp.Enum):
CIRCLE = 1
SQUARE = 2
TRIANGLE = 3

class DrawingStruct(csp.Struct):
color: Color
shape: Shape
colors: List[Color]
shapes: Dict[str, Shape]

drawing = DrawingStruct(
color=Color.RED,
shape=Shape.CIRCLE,
colors=[Color.RED, Color.GREEN, Color.BLUE],
shapes={"a": Shape.SQUARE, "b": Shape.TRIANGLE},
)

# Test native serialization
native_json = json.loads(drawing.to_json())
self.assertEqual(native_json["color"], "RED")
self.assertEqual(native_json["shape"], "CIRCLE")
self.assertEqual(native_json["colors"], ["RED", "GREEN", "BLUE"])
self.assertEqual(native_json["shapes"], {"a": "SQUARE", "b": "TRIANGLE"})

# Test Pydantic serialization
pydantic_json = json.loads(TypeAdapter(DrawingStruct).dump_json(drawing))
self.assertEqual(pydantic_json, native_json) # Should be identical for enums

# Test round-trip through both methods
native_restored = DrawingStruct.from_dict(json.loads(drawing.to_json()))
pydantic_restored = TypeAdapter(DrawingStruct).validate_json(TypeAdapter(DrawingStruct).dump_json(drawing))

self.assertEqual(native_restored, drawing)
self.assertEqual(pydantic_restored, drawing)

def test_pydantic_serialization_vs_native(self):
"""Test that Pydantic serialization matches CSP native serialization for basic types"""
from pydantic.functional_serializers import PlainSerializer

class MyEnum(csp.Enum):
OPTION1 = csp.Enum.auto()
OPTION2 = csp.Enum.auto()

# Define custom datetime serialization
# This is so that pydantic serializes datetime with the same precision as csp natively does
SimpleDatetime = Annotated[
datetime,
PlainSerializer(lambda dt: dt.strftime("%Y-%m-%dT%H:%M:%S.%f+00:00"), return_type=str, when_used="json"),
]

class SimpleStruct(csp.Struct):
i: int = 123
f: float = 3.14
s: str = "test"
b: bool = True
# dt: datetime = datetime(2023, 1, 1)
dt: SimpleDatetime = datetime(2023, 1, 1)
l: List[int] = [1, 2, 3]
d: Dict[str, float] = {"a": 1.1, "b": 2.2}
e: MyEnum

# Test with default values
s1 = SimpleStruct()
json_native = s1.to_json()
json_pydantic = TypeAdapter(SimpleStruct).dump_json(s1).decode()
self.assertEqual(json.loads(json_native), json.loads(json_pydantic))
python_native = s1.to_dict()
python_pydantic = TypeAdapter(SimpleStruct).dump_python(s1)
self.assertEqual(python_native, python_pydantic)
# unset variables with no default do not get encoded
self.assertTrue("e" not in python_native)
self.assertTrue("e" not in python_pydantic)

# Test with custom values
s2 = SimpleStruct(
i=456,
f=2.718,
s="custom",
b=False,
dt=datetime(2024, 1, 1, tzinfo=pytz.UTC),
l=[4, 5, 6],
d={"x": 9.9, "y": 8.8},
e=MyEnum.OPTION2,
)
python_native = s2.to_dict()
python_pydantic = TypeAdapter(SimpleStruct).dump_python(s2)
# NOTE: csp, when running 'to_dict'
# converts csp Enums to str
# The pydantic version maintains them as csp Enums, which is arguably more correct
enum_as_str = python_native.pop("e")
enum_as_enum = python_pydantic.pop("e")
self.assertEqual(python_native, python_pydantic)
self.assertEqual(enum_as_enum.name, enum_as_str)

json_native = s2.to_json()
json_pydantic = TypeAdapter(SimpleStruct).dump_json(s2).decode()
self.assertEqual(json.loads(json_native), json.loads(json_pydantic))

# Test with nested structs
class NestedStruct(csp.Struct):
name: str
simple: SimpleStruct
simples: List[SimpleStruct]

nested = NestedStruct(name="test", simple=s1, simples=[s1, s2])

python_native = nested.to_dict()
python_pydantic = TypeAdapter(NestedStruct).dump_python(nested)
enum_as_str = python_native["simples"][1].pop("e")
enum_as_enum = python_pydantic["simples"][1].pop("e")
self.assertEqual(python_native, python_pydantic)
self.assertEqual(enum_as_enum.name, enum_as_str)

json_native = nested.to_json()
json_pydantic = TypeAdapter(NestedStruct).dump_json(nested).decode()
self.assertEqual(json.loads(json_native), json.loads(json_pydantic))


if __name__ == "__main__":
unittest.main()

0 comments on commit 71844f5

Please sign in to comment.