From fe963b9be339e464cc3035a867fe5535d2762fce Mon Sep 17 00:00:00 2001 From: Levi Naden Date: Tue, 10 Sep 2024 12:13:23 -0400 Subject: [PATCH 1/2] Levi's pyd v2 changes to Datum, serialization, dft_info, cpu_info --- qcelemental/datum.py | 105 +++++++++++++++++++++++++----- qcelemental/info/cpu_info.py | 18 +++-- qcelemental/info/dft_info.py | 5 +- qcelemental/tests/test_utils.py | 102 ++++++++++++++++++++++++++--- qcelemental/util/serialization.py | 37 +++++++---- 5 files changed, 223 insertions(+), 44 deletions(-) diff --git a/qcelemental/datum.py b/qcelemental/datum.py index 02087a53..13ea1c4b 100644 --- a/qcelemental/datum.py +++ b/qcelemental/datum.py @@ -6,11 +6,58 @@ from typing import Any, Dict, Optional import numpy as np +from pydantic import ( + BaseModel, + ConfigDict, + SerializationInfo, + SerializerFunctionWrapHandler, + WrapSerializer, + field_validator, + model_serializer, +) +from typing_extensions import Annotated + + +def reduce_complex(data): + # Reduce Complex + if isinstance(data, complex): + return [data.real, data.imag] + # Fallback + return data + + +def keep_decimal_cast_ndarray_complex( + v: Any, nxt: SerializerFunctionWrapHandler, info: SerializationInfo +) -> Union[list, Decimal, float]: + """ + Ensure Decimal types are preserved on the way out + + This arose because Decimal was serialized to string and "dump" is equal to "serialize" in v2 pydantic + https://docs.pydantic.dev/latest/migration/#changes-to-json-schema-generation + + + This also checks against NumPy Arrays and complex numbers in the instance of being in JSON mode + """ + if isinstance(v, Decimal): + return v + if info.mode == "json": + if isinstance(v, complex): + return nxt(reduce_complex(v)) + if isinstance(v, np.ndarray): + # Handle NDArray and complex NDArray + flat_list = v.flatten().tolist() + reduced_list = list(map(reduce_complex, flat_list)) + return nxt(reduced_list) + try: + # Cast NumPy scalar data types to native Python data type + v = v.item() + except (AttributeError, ValueError): + pass + return nxt(v) + -try: - from pydantic.v1 import BaseModel, validator -except ImportError: # Will also trap ModuleNotFoundError - from pydantic import BaseModel, validator +# Only 1 serializer is allowed. You can't chain wrap serializers. +AnyArrayComplex = Annotated[Any, WrapSerializer(keep_decimal_cast_ndarray_complex)] class Datum(BaseModel): @@ -38,15 +85,15 @@ class Datum(BaseModel): numeric: bool label: str units: str - data: Any + data: AnyArrayComplex comment: str = "" doi: Optional[str] = None glossary: str = "" - class Config: - extra = "forbid" - allow_mutation = False - json_encoders = {np.ndarray: lambda v: v.flatten().tolist(), complex: lambda v: (v.real, v.imag)} + model_config = ConfigDict( + extra="forbid", + frozen=True, + ) def __init__(self, label, units, data, *, comment=None, doi=None, glossary=None, numeric=True): kwargs = {"label": label, "units": units, "data": data, "numeric": numeric} @@ -59,20 +106,21 @@ def __init__(self, label, units, data, *, comment=None, doi=None, glossary=None, super().__init__(**kwargs) - @validator("data") - def must_be_numerical(cls, v, values, **kwargs): + @field_validator("data") + @classmethod + def must_be_numerical(cls, v, info): try: 1.0 * v except TypeError: try: Decimal("1.0") * v except TypeError: - if values["numeric"]: + if info.data["numeric"]: raise ValueError(f"Datum data should be float, Decimal, or np.ndarray, not {type(v)}.") else: - values["numeric"] = True + info.data["numeric"] = True else: - values["numeric"] = True + info.data["numeric"] = True return v @@ -90,8 +138,35 @@ def __str__(self, label=""): text.append("-" * width) return "\n".join(text) + @model_serializer(mode="wrap") + def _serialize_model(self, handler) -> Dict[str, Any]: + """ + Customize the serialization output. Does duplicate with some code in model_dump, but handles the case of nested + models and any model config options. + + Encoding is handled at the `model_dump` level and not here as that should happen only after EVERYTHING has been + dumped/de-pydantic-ized. + """ + + # Get the default return, let the model_dump handle kwarg + default_result = handler(self) + # Exclude unset always + output_dict = {key: value for key, value in default_result.items() if key in self.model_fields_set} + return output_dict + def dict(self, *args, **kwargs): - return super().dict(*args, **{**kwargs, **{"exclude_unset": True}}) + """ + Passthrough to model_dump without deprecation warning + exclude_unset is forced through the model_serializer + """ + return super().model_dump(*args, **kwargs) + + def json(self, *args, **kwargs): + """ + Passthrough to model_dump_sjon without deprecation warning + exclude_unset is forced through the model_serializer + """ + return super().model_dump_json(*args, **kwargs) def to_units(self, units=None): from .physical_constants import constants diff --git a/qcelemental/info/cpu_info.py b/qcelemental/info/cpu_info.py index 4fe35689..cb7ad7a0 100644 --- a/qcelemental/info/cpu_info.py +++ b/qcelemental/info/cpu_info.py @@ -8,7 +8,8 @@ from functools import lru_cache from typing import List, Optional -from pydantic.v1 import Field +from pydantic import BeforeValidator, Field +from typing_extensions import Annotated from ..models import ProtoModel @@ -22,6 +23,13 @@ class VendorEnum(str, Enum): arm = "arm" +def stringify(v) -> str: + return str(v) + + +Stringify = Annotated[str, BeforeValidator(stringify)] + + class InstructionSetEnum(int, Enum): """Allowed instruction sets for CPUs in an ordinal enum.""" @@ -37,13 +45,13 @@ class ProcessorInfo(ProtoModel): ncores: int = Field(..., description="The number of physical cores on the chip.") nthreads: Optional[int] = Field(..., description="The maximum number of concurrent threads.") base_clock: float = Field(..., description="The base clock frequency (GHz).") - boost_clock: Optional[float] = Field(..., description="The boost clock frequency (GHz).") - model: str = Field(..., description="The model number of the chip.") + boost_clock: Optional[float] = Field(None, description="The boost clock frequency (GHz).") + model: Stringify = Field(..., description="The model number of the chip.") family: str = Field(..., description="The family of the chip.") - launch_date: Optional[int] = Field(..., description="The launch year of the chip.") + launch_date: Optional[int] = Field(None, description="The launch year of the chip.") target_use: str = Field(..., description="Target use case (Desktop, Server, etc).") vendor: VendorEnum = Field(..., description="The vendor the chip is produced by.") - microarchitecture: Optional[str] = Field(..., description="The microarchitecture the chip follows.") + microarchitecture: Optional[str] = Field(None, description="The microarchitecture the chip follows.") instructions: InstructionSetEnum = Field(..., description="The maximum vectorized instruction set available.") type: str = Field(..., description="The type of chip (cpu, gpu, etc).") diff --git a/qcelemental/info/dft_info.py b/qcelemental/info/dft_info.py index 073e40d3..742c587a 100644 --- a/qcelemental/info/dft_info.py +++ b/qcelemental/info/dft_info.py @@ -4,7 +4,8 @@ from typing import Dict -from pydantic.v1 import Field +from pydantic import Field +from typing_extensions import Annotated from ..models import ProtoModel @@ -68,4 +69,4 @@ def get(name: str) -> DFTFunctionalInfo: name = name.replace(x, "") break - return dftfunctionalinfo.functionals[name].copy() + return dftfunctionalinfo.functionals[name].model_copy() diff --git a/qcelemental/tests/test_utils.py b/qcelemental/tests/test_utils.py index 43bbb285..8373d894 100644 --- a/qcelemental/tests/test_utils.py +++ b/qcelemental/tests/test_utils.py @@ -1,8 +1,8 @@ from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np +import pydantic import pytest -from pydantic.v1 import BaseModel, Field import qcelemental as qcel from qcelemental.testing import compare_recursive, compare_values @@ -14,22 +14,22 @@ def doc_fixture(): # associated with AutoDoc, so leaving at Pydantic v1 syntax - class Nest(BaseModel): + class Nest(pydantic.v1.BaseModel): """A nested model""" n: float = 56 - class X(BaseModel): + class X(pydantic.v1.BaseModel): """A Pydantic model made up of many, many different combinations of ways of mapping types in Pydantic""" x: int - y: str = Field(...) + y: str = pydantic.v1.Field(...) n: Nest - n2: Nest = Field(Nest(), description="A detailed description") + n2: Nest = pydantic.v1.Field(Nest(), description="A detailed description") z: float = 5 z2: float = None z3: Optional[float] - z4: Optional[float] = Field(5, description="Some number I just made up") + z4: Optional[float] = pydantic.v1.Field(5, description="Some number I just made up") z5: Optional[Union[float, int]] z6: Optional[List[int]] l: List[int] @@ -38,11 +38,13 @@ class X(BaseModel): t2: Tuple[List[int]] t3: Tuple[Any] d: Dict[str, Any] - dlu: Dict[Union[int, str], List[Union[int, str, float]]] = Field(..., description="this is complicated") + dlu: Dict[Union[int, str], List[Union[int, str, float]]] = pydantic.v1.Field( + ..., description="this is complicated" + ) dlu2: Dict[Any, List[Union[int, str, float]]] dlu3: Dict[str, Any] - si: int = Field(..., description="A level of constraint", gt=0) - sf: float = Field(None, description="Optional Constrained Number", le=100.3) + si: int = pydantic.v1.Field(..., description="A level of constraint", gt=0) + sf: float = pydantic.v1.Field(None, description="Optional Constrained Number", le=100.3) yield X @@ -308,3 +310,85 @@ def test_auto_gen_doc_delete(doc_fixture): def test_serialization(obj, encoding): new_obj = qcel.util.deserialize(qcel.util.serialize(obj, encoding=encoding), encoding=encoding) assert compare_recursive(obj, new_obj) + + +@pytest.fixture +def atomic_result(): + """Mock AtomicResult output which can be tested against for complex serialization methods""" + + data = { + "id": None, + "schema_name": "qcschema_output", + "schema_version": 1, + "molecule": { + "schema_name": "qcschema_molecule", + "schema_version": 2, + "validated": True, + "symbols": np.array(["O", "H", "H"], dtype=" Any: # First try pydantic base objects try: - return pydantic_encoder(obj) - except TypeError: + return to_jsonable_python(obj) + except ValueError: pass if isinstance(obj, np.ndarray): @@ -123,8 +125,8 @@ def msgpackext_loads(data: bytes) -> Any: class JSONExtArrayEncoder(json.JSONEncoder): def default(self, obj: Any) -> Any: try: - return pydantic_encoder(obj) - except TypeError: + return to_jsonable_python(obj) + except ValueError: pass if isinstance(obj, np.ndarray): @@ -192,11 +194,20 @@ def jsonext_loads(data: Union[str, bytes]) -> Any: class JSONArrayEncoder(json.JSONEncoder): def default(self, obj: Any) -> Any: + # See if pydantic can do this on its own. + # Note: This calls DIFFERENT logic on BaseModels than BaseModel.model_dump_json, for somoe reason try: - return pydantic_encoder(obj) - except TypeError: + return to_jsonable_python(obj) + except ValueError: pass + # See if pydantic model can be just serialized if the above couldn't be dumped + if isinstance(obj, pydantic.BaseModel): + try: + return obj.model_dump_json() + except PydanticSerializationError: + pass + if isinstance(obj, np.ndarray): if obj.shape: return obj.ravel().tolist() @@ -260,8 +271,8 @@ def msgpack_encode(obj: Any) -> Any: """ try: - return pydantic_encoder(obj) - except TypeError: + return to_jsonable_python(obj) + except ValueError: pass if isinstance(obj, np.ndarray): From a1f7dccaa16ecb4d4efe4a34bd8006ba0af2af3f Mon Sep 17 00:00:00 2001 From: "Lori A. Burns" Date: Tue, 10 Sep 2024 14:25:10 -0400 Subject: [PATCH 2/2] fix up Datum, DFTFunctional, CPUInfo, and serialization --- docs/changelog.rst | 1 + qcelemental/datum.py | 2 +- qcelemental/info/cpu_info.py | 4 +++- qcelemental/info/dft_info.py | 4 +++- qcelemental/testing.py | 20 +++++++++++--------- qcelemental/tests/test_datum.py | 10 +++------- qcelemental/tests/test_utils.py | 10 ++++++---- qcelemental/util/autodocs.py | 5 +---- qcelemental/util/serialization.py | 25 +++++++++++++++++++++---- 9 files changed, 50 insertions(+), 31 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 1148f544..c7ca857c 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -38,6 +38,7 @@ Enhancements * The ``models.v2`` have had their `schema_version` bumped for ``BasisSet``, ``AtomicInput``, ``OptimizationInput`` (implicit for ``AtomicResult`` and ``OptimizationResult``), ``TorsionDriveInput`` , and ``TorsionDriveResult``. * The ``models.v2`` ``AtomicResultProperties`` has been given a ``schema_name`` and ``schema_version`` (2) for the first time. * Note that ``models.v2`` ``QCInputSpecification`` and ``OptimizationSpecification`` have *not* had schema_version bumped. +* All of ``Datum``, ``DFTFunctional``, and ``CPUInfo`` models, none of which are mixed with QCSchema models, are translated to Pydantic v2 API syntax. Bug Fixes +++++++++ diff --git a/qcelemental/datum.py b/qcelemental/datum.py index 13ea1c4b..5ca348e3 100644 --- a/qcelemental/datum.py +++ b/qcelemental/datum.py @@ -3,7 +3,7 @@ """ from decimal import Decimal -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Union import numpy as np from pydantic import ( diff --git a/qcelemental/info/cpu_info.py b/qcelemental/info/cpu_info.py index cb7ad7a0..d65a3b57 100644 --- a/qcelemental/info/cpu_info.py +++ b/qcelemental/info/cpu_info.py @@ -11,7 +11,9 @@ from pydantic import BeforeValidator, Field from typing_extensions import Annotated -from ..models import ProtoModel +from ..models.v2 import ProtoModel + +# ProcessorInfo models don't become parts of QCSchema models afaik, so pure pydantic v2 API class VendorEnum(str, Enum): diff --git a/qcelemental/info/dft_info.py b/qcelemental/info/dft_info.py index 742c587a..1c76eb00 100644 --- a/qcelemental/info/dft_info.py +++ b/qcelemental/info/dft_info.py @@ -7,7 +7,9 @@ from pydantic import Field from typing_extensions import Annotated -from ..models import ProtoModel +from ..models.v2 import ProtoModel + +# DFTFunctional models don't become parts of QCSchema models afaik, so pure pydantic v2 API class DFTFunctionalInfo(ProtoModel): diff --git a/qcelemental/testing.py b/qcelemental/testing.py index 6db0ada6..d911a78d 100644 --- a/qcelemental/testing.py +++ b/qcelemental/testing.py @@ -5,11 +5,7 @@ from typing import TYPE_CHECKING, Callable, Dict, List, Tuple, Union import numpy as np - -try: - from pydantic.v1 import BaseModel -except ImportError: # Will also trap ModuleNotFoundError - from pydantic import BaseModel +import pydantic if TYPE_CHECKING: from qcelemental.models import ProtoModel # TODO: recheck if .v1 needed @@ -313,10 +309,16 @@ def _compare_recursive(expected, computed, atol, rtol, _prefix=False, equal_phas prefix = name + "." # Initial conversions if required - if isinstance(expected, BaseModel): + if isinstance(expected, pydantic.BaseModel): + expected = expected.model_dump() + + if isinstance(computed, pydantic.BaseModel): + computed = computed.model_dump() + + if isinstance(expected, pydantic.v1.BaseModel): expected = expected.dict() - if isinstance(computed, BaseModel): + if isinstance(computed, pydantic.v1.BaseModel): computed = computed.dict() if isinstance(expected, (str, int, bool, complex)): @@ -381,8 +383,8 @@ def _compare_recursive(expected, computed, atol, rtol, _prefix=False, equal_phas def compare_recursive( - expected: Union[Dict, BaseModel, "ProtoModel"], # type: ignore - computed: Union[Dict, BaseModel, "ProtoModel"], # type: ignore + expected: Union[Dict, pydantic.BaseModel, pydantic.v1.BaseModel, "ProtoModel"], # type: ignore + computed: Union[Dict, pydantic.BaseModel, pydantic.v1.BaseModel, "ProtoModel"], # type: ignore label: str = None, *, atol: float = 1.0e-6, diff --git a/qcelemental/tests/test_datum.py b/qcelemental/tests/test_datum.py index 018040e4..bda69f6c 100644 --- a/qcelemental/tests/test_datum.py +++ b/qcelemental/tests/test_datum.py @@ -1,11 +1,7 @@ from decimal import Decimal import numpy as np - -try: - import pydantic.v1 as pydantic -except ImportError: # Will also trap ModuleNotFoundError - import pydantic +import pydantic import pytest import qcelemental as qcel @@ -46,10 +42,10 @@ def test_creation_nonnum(dataset): def test_creation_error(): - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pydantic.ValidationError) as e: qcel.Datum("ze lbl", "ze unit", "ze data") - # assert 'Datum data should be float' in str(e) + assert "Datum data should be float" in str(e.value) @pytest.mark.parametrize( diff --git a/qcelemental/tests/test_utils.py b/qcelemental/tests/test_utils.py index 8373d894..65613981 100644 --- a/qcelemental/tests/test_utils.py +++ b/qcelemental/tests/test_utils.py @@ -7,7 +7,7 @@ import qcelemental as qcel from qcelemental.testing import compare_recursive, compare_values -from .addons import serialize_extensions +from .addons import schema_versions, serialize_extensions @pytest.fixture(scope="function") @@ -313,7 +313,7 @@ def test_serialization(obj, encoding): @pytest.fixture -def atomic_result(): +def atomic_result_data(): """Mock AtomicResult output which can be tested against for complex serialization methods""" data = { @@ -385,10 +385,12 @@ def atomic_result(): "success": True, "error": None, } + return data - yield qcel.models.results.AtomicResult(**data) +def test_json_dumps(atomic_result_data, schema_versions): + AtomicResult = schema_versions.AtomicResult -def test_json_dumps(atomic_result): + atomic_result = AtomicResult(**atomic_result_data) ret = qcel.util.json_dumps(atomic_result) assert isinstance(ret, str) diff --git a/qcelemental/util/autodocs.py b/qcelemental/util/autodocs.py index ac57b50d..e0bd964a 100644 --- a/qcelemental/util/autodocs.py +++ b/qcelemental/util/autodocs.py @@ -41,10 +41,7 @@ def is_pydantic(test_object): def parse_type_str(prop) -> str: # Import here to minimize issues - try: - from pydantic.v1 import fields - except ImportError: # Will also trap ModuleNotFoundError - from pydantic import fields + from pydantic.v1 import fields typing_map = { fields.SHAPE_TUPLE: "Tuple", diff --git a/qcelemental/util/serialization.py b/qcelemental/util/serialization.py index 6a6c6625..171e21b7 100644 --- a/qcelemental/util/serialization.py +++ b/qcelemental/util/serialization.py @@ -3,6 +3,7 @@ import numpy as np import pydantic +from pydantic.v1.json import pydantic_encoder from pydantic_core import PydanticSerializationError, to_jsonable_python from .importing import which_import @@ -41,7 +42,14 @@ def msgpackext_encode(obj: Any) -> Any: try: return to_jsonable_python(obj) except ValueError: - pass + # above to_jsonable_python is for Pydantic v2 API models + # below pydatnic_encoder is for Pydantic v1 API models + # tentative whether handling both together will work beyond tests + # or if separate files called by models.v1 and .v2 will be req'd + try: + return pydantic_encoder(obj) + except TypeError: + pass if isinstance(obj, np.ndarray): if obj.shape: @@ -127,7 +135,10 @@ def default(self, obj: Any) -> Any: try: return to_jsonable_python(obj) except ValueError: - pass + try: + return pydantic_encoder(obj) + except TypeError: + pass if isinstance(obj, np.ndarray): if obj.shape: @@ -199,7 +210,10 @@ def default(self, obj: Any) -> Any: try: return to_jsonable_python(obj) except ValueError: - pass + try: + return pydantic_encoder(obj) + except TypeError: + pass # See if pydantic model can be just serialized if the above couldn't be dumped if isinstance(obj, pydantic.BaseModel): @@ -273,7 +287,10 @@ def msgpack_encode(obj: Any) -> Any: try: return to_jsonable_python(obj) except ValueError: - pass + try: + return pydantic_encoder(obj) + except TypeError: + pass if isinstance(obj, np.ndarray): if obj.shape: