diff --git a/qcelemental/datum.py b/qcelemental/datum.py index 02087a53..13ea1c4b 100644 --- a/qcelemental/datum.py +++ b/qcelemental/datum.py @@ -6,11 +6,58 @@ from typing import Any, Dict, Optional import numpy as np +from pydantic import ( + BaseModel, + ConfigDict, + SerializationInfo, + SerializerFunctionWrapHandler, + WrapSerializer, + field_validator, + model_serializer, +) +from typing_extensions import Annotated + + +def reduce_complex(data): + # Reduce Complex + if isinstance(data, complex): + return [data.real, data.imag] + # Fallback + return data + + +def keep_decimal_cast_ndarray_complex( + v: Any, nxt: SerializerFunctionWrapHandler, info: SerializationInfo +) -> Union[list, Decimal, float]: + """ + Ensure Decimal types are preserved on the way out + + This arose because Decimal was serialized to string and "dump" is equal to "serialize" in v2 pydantic + https://docs.pydantic.dev/latest/migration/#changes-to-json-schema-generation + + + This also checks against NumPy Arrays and complex numbers in the instance of being in JSON mode + """ + if isinstance(v, Decimal): + return v + if info.mode == "json": + if isinstance(v, complex): + return nxt(reduce_complex(v)) + if isinstance(v, np.ndarray): + # Handle NDArray and complex NDArray + flat_list = v.flatten().tolist() + reduced_list = list(map(reduce_complex, flat_list)) + return nxt(reduced_list) + try: + # Cast NumPy scalar data types to native Python data type + v = v.item() + except (AttributeError, ValueError): + pass + return nxt(v) + -try: - from pydantic.v1 import BaseModel, validator -except ImportError: # Will also trap ModuleNotFoundError - from pydantic import BaseModel, validator +# Only 1 serializer is allowed. You can't chain wrap serializers. +AnyArrayComplex = Annotated[Any, WrapSerializer(keep_decimal_cast_ndarray_complex)] class Datum(BaseModel): @@ -38,15 +85,15 @@ class Datum(BaseModel): numeric: bool label: str units: str - data: Any + data: AnyArrayComplex comment: str = "" doi: Optional[str] = None glossary: str = "" - class Config: - extra = "forbid" - allow_mutation = False - json_encoders = {np.ndarray: lambda v: v.flatten().tolist(), complex: lambda v: (v.real, v.imag)} + model_config = ConfigDict( + extra="forbid", + frozen=True, + ) def __init__(self, label, units, data, *, comment=None, doi=None, glossary=None, numeric=True): kwargs = {"label": label, "units": units, "data": data, "numeric": numeric} @@ -59,20 +106,21 @@ def __init__(self, label, units, data, *, comment=None, doi=None, glossary=None, super().__init__(**kwargs) - @validator("data") - def must_be_numerical(cls, v, values, **kwargs): + @field_validator("data") + @classmethod + def must_be_numerical(cls, v, info): try: 1.0 * v except TypeError: try: Decimal("1.0") * v except TypeError: - if values["numeric"]: + if info.data["numeric"]: raise ValueError(f"Datum data should be float, Decimal, or np.ndarray, not {type(v)}.") else: - values["numeric"] = True + info.data["numeric"] = True else: - values["numeric"] = True + info.data["numeric"] = True return v @@ -90,8 +138,35 @@ def __str__(self, label=""): text.append("-" * width) return "\n".join(text) + @model_serializer(mode="wrap") + def _serialize_model(self, handler) -> Dict[str, Any]: + """ + Customize the serialization output. Does duplicate with some code in model_dump, but handles the case of nested + models and any model config options. + + Encoding is handled at the `model_dump` level and not here as that should happen only after EVERYTHING has been + dumped/de-pydantic-ized. + """ + + # Get the default return, let the model_dump handle kwarg + default_result = handler(self) + # Exclude unset always + output_dict = {key: value for key, value in default_result.items() if key in self.model_fields_set} + return output_dict + def dict(self, *args, **kwargs): - return super().dict(*args, **{**kwargs, **{"exclude_unset": True}}) + """ + Passthrough to model_dump without deprecation warning + exclude_unset is forced through the model_serializer + """ + return super().model_dump(*args, **kwargs) + + def json(self, *args, **kwargs): + """ + Passthrough to model_dump_sjon without deprecation warning + exclude_unset is forced through the model_serializer + """ + return super().model_dump_json(*args, **kwargs) def to_units(self, units=None): from .physical_constants import constants diff --git a/qcelemental/info/cpu_info.py b/qcelemental/info/cpu_info.py index 4fe35689..cb7ad7a0 100644 --- a/qcelemental/info/cpu_info.py +++ b/qcelemental/info/cpu_info.py @@ -8,7 +8,8 @@ from functools import lru_cache from typing import List, Optional -from pydantic.v1 import Field +from pydantic import BeforeValidator, Field +from typing_extensions import Annotated from ..models import ProtoModel @@ -22,6 +23,13 @@ class VendorEnum(str, Enum): arm = "arm" +def stringify(v) -> str: + return str(v) + + +Stringify = Annotated[str, BeforeValidator(stringify)] + + class InstructionSetEnum(int, Enum): """Allowed instruction sets for CPUs in an ordinal enum.""" @@ -37,13 +45,13 @@ class ProcessorInfo(ProtoModel): ncores: int = Field(..., description="The number of physical cores on the chip.") nthreads: Optional[int] = Field(..., description="The maximum number of concurrent threads.") base_clock: float = Field(..., description="The base clock frequency (GHz).") - boost_clock: Optional[float] = Field(..., description="The boost clock frequency (GHz).") - model: str = Field(..., description="The model number of the chip.") + boost_clock: Optional[float] = Field(None, description="The boost clock frequency (GHz).") + model: Stringify = Field(..., description="The model number of the chip.") family: str = Field(..., description="The family of the chip.") - launch_date: Optional[int] = Field(..., description="The launch year of the chip.") + launch_date: Optional[int] = Field(None, description="The launch year of the chip.") target_use: str = Field(..., description="Target use case (Desktop, Server, etc).") vendor: VendorEnum = Field(..., description="The vendor the chip is produced by.") - microarchitecture: Optional[str] = Field(..., description="The microarchitecture the chip follows.") + microarchitecture: Optional[str] = Field(None, description="The microarchitecture the chip follows.") instructions: InstructionSetEnum = Field(..., description="The maximum vectorized instruction set available.") type: str = Field(..., description="The type of chip (cpu, gpu, etc).") diff --git a/qcelemental/info/dft_info.py b/qcelemental/info/dft_info.py index 073e40d3..742c587a 100644 --- a/qcelemental/info/dft_info.py +++ b/qcelemental/info/dft_info.py @@ -4,7 +4,8 @@ from typing import Dict -from pydantic.v1 import Field +from pydantic import Field +from typing_extensions import Annotated from ..models import ProtoModel @@ -68,4 +69,4 @@ def get(name: str) -> DFTFunctionalInfo: name = name.replace(x, "") break - return dftfunctionalinfo.functionals[name].copy() + return dftfunctionalinfo.functionals[name].model_copy() diff --git a/qcelemental/tests/test_utils.py b/qcelemental/tests/test_utils.py index 43bbb285..8373d894 100644 --- a/qcelemental/tests/test_utils.py +++ b/qcelemental/tests/test_utils.py @@ -1,8 +1,8 @@ from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np +import pydantic import pytest -from pydantic.v1 import BaseModel, Field import qcelemental as qcel from qcelemental.testing import compare_recursive, compare_values @@ -14,22 +14,22 @@ def doc_fixture(): # associated with AutoDoc, so leaving at Pydantic v1 syntax - class Nest(BaseModel): + class Nest(pydantic.v1.BaseModel): """A nested model""" n: float = 56 - class X(BaseModel): + class X(pydantic.v1.BaseModel): """A Pydantic model made up of many, many different combinations of ways of mapping types in Pydantic""" x: int - y: str = Field(...) + y: str = pydantic.v1.Field(...) n: Nest - n2: Nest = Field(Nest(), description="A detailed description") + n2: Nest = pydantic.v1.Field(Nest(), description="A detailed description") z: float = 5 z2: float = None z3: Optional[float] - z4: Optional[float] = Field(5, description="Some number I just made up") + z4: Optional[float] = pydantic.v1.Field(5, description="Some number I just made up") z5: Optional[Union[float, int]] z6: Optional[List[int]] l: List[int] @@ -38,11 +38,13 @@ class X(BaseModel): t2: Tuple[List[int]] t3: Tuple[Any] d: Dict[str, Any] - dlu: Dict[Union[int, str], List[Union[int, str, float]]] = Field(..., description="this is complicated") + dlu: Dict[Union[int, str], List[Union[int, str, float]]] = pydantic.v1.Field( + ..., description="this is complicated" + ) dlu2: Dict[Any, List[Union[int, str, float]]] dlu3: Dict[str, Any] - si: int = Field(..., description="A level of constraint", gt=0) - sf: float = Field(None, description="Optional Constrained Number", le=100.3) + si: int = pydantic.v1.Field(..., description="A level of constraint", gt=0) + sf: float = pydantic.v1.Field(None, description="Optional Constrained Number", le=100.3) yield X @@ -308,3 +310,85 @@ def test_auto_gen_doc_delete(doc_fixture): def test_serialization(obj, encoding): new_obj = qcel.util.deserialize(qcel.util.serialize(obj, encoding=encoding), encoding=encoding) assert compare_recursive(obj, new_obj) + + +@pytest.fixture +def atomic_result(): + """Mock AtomicResult output which can be tested against for complex serialization methods""" + + data = { + "id": None, + "schema_name": "qcschema_output", + "schema_version": 1, + "molecule": { + "schema_name": "qcschema_molecule", + "schema_version": 2, + "validated": True, + "symbols": np.array(["O", "H", "H"], dtype=" Any: # First try pydantic base objects try: - return pydantic_encoder(obj) - except TypeError: + return to_jsonable_python(obj) + except ValueError: pass if isinstance(obj, np.ndarray): @@ -123,8 +125,8 @@ def msgpackext_loads(data: bytes) -> Any: class JSONExtArrayEncoder(json.JSONEncoder): def default(self, obj: Any) -> Any: try: - return pydantic_encoder(obj) - except TypeError: + return to_jsonable_python(obj) + except ValueError: pass if isinstance(obj, np.ndarray): @@ -192,11 +194,20 @@ def jsonext_loads(data: Union[str, bytes]) -> Any: class JSONArrayEncoder(json.JSONEncoder): def default(self, obj: Any) -> Any: + # See if pydantic can do this on its own. + # Note: This calls DIFFERENT logic on BaseModels than BaseModel.model_dump_json, for somoe reason try: - return pydantic_encoder(obj) - except TypeError: + return to_jsonable_python(obj) + except ValueError: pass + # See if pydantic model can be just serialized if the above couldn't be dumped + if isinstance(obj, pydantic.BaseModel): + try: + return obj.model_dump_json() + except PydanticSerializationError: + pass + if isinstance(obj, np.ndarray): if obj.shape: return obj.ravel().tolist() @@ -260,8 +271,8 @@ def msgpack_encode(obj: Any) -> Any: """ try: - return pydantic_encoder(obj) - except TypeError: + return to_jsonable_python(obj) + except ValueError: pass if isinstance(obj, np.ndarray):