Skip to content

Commit

Permalink
Levi's pyd v2 changes to Datum, serialization, dft_info, cpu_info
Browse files Browse the repository at this point in the history
  • Loading branch information
Lnaden authored and loriab committed Sep 10, 2024
1 parent c219aa2 commit fe963b9
Show file tree
Hide file tree
Showing 5 changed files with 223 additions and 44 deletions.
105 changes: 90 additions & 15 deletions qcelemental/datum.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,58 @@
from typing import Any, Dict, Optional

import numpy as np
from pydantic import (
BaseModel,
ConfigDict,
SerializationInfo,
SerializerFunctionWrapHandler,
WrapSerializer,
field_validator,
model_serializer,
)
from typing_extensions import Annotated


def reduce_complex(data):
# Reduce Complex
if isinstance(data, complex):
return [data.real, data.imag]
# Fallback
return data


def keep_decimal_cast_ndarray_complex(
v: Any, nxt: SerializerFunctionWrapHandler, info: SerializationInfo
) -> Union[list, Decimal, float]:
"""
Ensure Decimal types are preserved on the way out
This arose because Decimal was serialized to string and "dump" is equal to "serialize" in v2 pydantic
https://docs.pydantic.dev/latest/migration/#changes-to-json-schema-generation
This also checks against NumPy Arrays and complex numbers in the instance of being in JSON mode
"""
if isinstance(v, Decimal):
return v
if info.mode == "json":
if isinstance(v, complex):
return nxt(reduce_complex(v))
if isinstance(v, np.ndarray):
# Handle NDArray and complex NDArray
flat_list = v.flatten().tolist()
reduced_list = list(map(reduce_complex, flat_list))
return nxt(reduced_list)
try:
# Cast NumPy scalar data types to native Python data type
v = v.item()
except (AttributeError, ValueError):
pass
return nxt(v)


try:
from pydantic.v1 import BaseModel, validator
except ImportError: # Will also trap ModuleNotFoundError
from pydantic import BaseModel, validator
# Only 1 serializer is allowed. You can't chain wrap serializers.
AnyArrayComplex = Annotated[Any, WrapSerializer(keep_decimal_cast_ndarray_complex)]


class Datum(BaseModel):
Expand Down Expand Up @@ -38,15 +85,15 @@ class Datum(BaseModel):
numeric: bool
label: str
units: str
data: Any
data: AnyArrayComplex
comment: str = ""
doi: Optional[str] = None
glossary: str = ""

class Config:
extra = "forbid"
allow_mutation = False
json_encoders = {np.ndarray: lambda v: v.flatten().tolist(), complex: lambda v: (v.real, v.imag)}
model_config = ConfigDict(
extra="forbid",
frozen=True,
)

def __init__(self, label, units, data, *, comment=None, doi=None, glossary=None, numeric=True):
kwargs = {"label": label, "units": units, "data": data, "numeric": numeric}
Expand All @@ -59,20 +106,21 @@ def __init__(self, label, units, data, *, comment=None, doi=None, glossary=None,

super().__init__(**kwargs)

@validator("data")
def must_be_numerical(cls, v, values, **kwargs):
@field_validator("data")
@classmethod
def must_be_numerical(cls, v, info):
try:
1.0 * v
except TypeError:
try:
Decimal("1.0") * v
except TypeError:
if values["numeric"]:
if info.data["numeric"]:
raise ValueError(f"Datum data should be float, Decimal, or np.ndarray, not {type(v)}.")
else:
values["numeric"] = True
info.data["numeric"] = True
else:
values["numeric"] = True
info.data["numeric"] = True

return v

Expand All @@ -90,8 +138,35 @@ def __str__(self, label=""):
text.append("-" * width)
return "\n".join(text)

@model_serializer(mode="wrap")
def _serialize_model(self, handler) -> Dict[str, Any]:
"""
Customize the serialization output. Does duplicate with some code in model_dump, but handles the case of nested
models and any model config options.
Encoding is handled at the `model_dump` level and not here as that should happen only after EVERYTHING has been
dumped/de-pydantic-ized.
"""

# Get the default return, let the model_dump handle kwarg
default_result = handler(self)
# Exclude unset always
output_dict = {key: value for key, value in default_result.items() if key in self.model_fields_set}
return output_dict

def dict(self, *args, **kwargs):
return super().dict(*args, **{**kwargs, **{"exclude_unset": True}})
"""
Passthrough to model_dump without deprecation warning
exclude_unset is forced through the model_serializer
"""
return super().model_dump(*args, **kwargs)

def json(self, *args, **kwargs):
"""
Passthrough to model_dump_sjon without deprecation warning
exclude_unset is forced through the model_serializer
"""
return super().model_dump_json(*args, **kwargs)

def to_units(self, units=None):
from .physical_constants import constants
Expand Down
18 changes: 13 additions & 5 deletions qcelemental/info/cpu_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from functools import lru_cache
from typing import List, Optional

from pydantic.v1 import Field
from pydantic import BeforeValidator, Field
from typing_extensions import Annotated

from ..models import ProtoModel

Expand All @@ -22,6 +23,13 @@ class VendorEnum(str, Enum):
arm = "arm"


def stringify(v) -> str:
return str(v)


Stringify = Annotated[str, BeforeValidator(stringify)]


class InstructionSetEnum(int, Enum):
"""Allowed instruction sets for CPUs in an ordinal enum."""

Expand All @@ -37,13 +45,13 @@ class ProcessorInfo(ProtoModel):
ncores: int = Field(..., description="The number of physical cores on the chip.")
nthreads: Optional[int] = Field(..., description="The maximum number of concurrent threads.")
base_clock: float = Field(..., description="The base clock frequency (GHz).")
boost_clock: Optional[float] = Field(..., description="The boost clock frequency (GHz).")
model: str = Field(..., description="The model number of the chip.")
boost_clock: Optional[float] = Field(None, description="The boost clock frequency (GHz).")
model: Stringify = Field(..., description="The model number of the chip.")
family: str = Field(..., description="The family of the chip.")
launch_date: Optional[int] = Field(..., description="The launch year of the chip.")
launch_date: Optional[int] = Field(None, description="The launch year of the chip.")
target_use: str = Field(..., description="Target use case (Desktop, Server, etc).")
vendor: VendorEnum = Field(..., description="The vendor the chip is produced by.")
microarchitecture: Optional[str] = Field(..., description="The microarchitecture the chip follows.")
microarchitecture: Optional[str] = Field(None, description="The microarchitecture the chip follows.")
instructions: InstructionSetEnum = Field(..., description="The maximum vectorized instruction set available.")
type: str = Field(..., description="The type of chip (cpu, gpu, etc).")

Expand Down
5 changes: 3 additions & 2 deletions qcelemental/info/dft_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

from typing import Dict

from pydantic.v1 import Field
from pydantic import Field
from typing_extensions import Annotated

from ..models import ProtoModel

Expand Down Expand Up @@ -68,4 +69,4 @@ def get(name: str) -> DFTFunctionalInfo:
name = name.replace(x, "")
break

return dftfunctionalinfo.functionals[name].copy()
return dftfunctionalinfo.functionals[name].model_copy()
102 changes: 93 additions & 9 deletions qcelemental/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import pydantic
import pytest
from pydantic.v1 import BaseModel, Field

import qcelemental as qcel
from qcelemental.testing import compare_recursive, compare_values
Expand All @@ -14,22 +14,22 @@
def doc_fixture():
# associated with AutoDoc, so leaving at Pydantic v1 syntax

class Nest(BaseModel):
class Nest(pydantic.v1.BaseModel):
"""A nested model"""

n: float = 56

class X(BaseModel):
class X(pydantic.v1.BaseModel):
"""A Pydantic model made up of many, many different combinations of ways of mapping types in Pydantic"""

x: int
y: str = Field(...)
y: str = pydantic.v1.Field(...)
n: Nest
n2: Nest = Field(Nest(), description="A detailed description")
n2: Nest = pydantic.v1.Field(Nest(), description="A detailed description")
z: float = 5
z2: float = None
z3: Optional[float]
z4: Optional[float] = Field(5, description="Some number I just made up")
z4: Optional[float] = pydantic.v1.Field(5, description="Some number I just made up")
z5: Optional[Union[float, int]]
z6: Optional[List[int]]
l: List[int]
Expand All @@ -38,11 +38,13 @@ class X(BaseModel):
t2: Tuple[List[int]]
t3: Tuple[Any]
d: Dict[str, Any]
dlu: Dict[Union[int, str], List[Union[int, str, float]]] = Field(..., description="this is complicated")
dlu: Dict[Union[int, str], List[Union[int, str, float]]] = pydantic.v1.Field(
..., description="this is complicated"
)
dlu2: Dict[Any, List[Union[int, str, float]]]
dlu3: Dict[str, Any]
si: int = Field(..., description="A level of constraint", gt=0)
sf: float = Field(None, description="Optional Constrained Number", le=100.3)
si: int = pydantic.v1.Field(..., description="A level of constraint", gt=0)
sf: float = pydantic.v1.Field(None, description="Optional Constrained Number", le=100.3)

yield X

Expand Down Expand Up @@ -308,3 +310,85 @@ def test_auto_gen_doc_delete(doc_fixture):
def test_serialization(obj, encoding):
new_obj = qcel.util.deserialize(qcel.util.serialize(obj, encoding=encoding), encoding=encoding)
assert compare_recursive(obj, new_obj)


@pytest.fixture
def atomic_result():
"""Mock AtomicResult output which can be tested against for complex serialization methods"""

data = {
"id": None,
"schema_name": "qcschema_output",
"schema_version": 1,
"molecule": {
"schema_name": "qcschema_molecule",
"schema_version": 2,
"validated": True,
"symbols": np.array(["O", "H", "H"], dtype="<U1"),
"geometry": np.array(
[
[0.0000000000000000, 0.0000000000000000, -0.1242978140796278],
[0.0000000000000000, -1.4344192748456206, 0.9863482549166890],
[0.0000000000000000, 1.4344192748456206, 0.9863482549166890],
]
),
"name": "h2o",
"molecular_charge": 0.0,
"molecular_multiplicity": 1,
"masses": np.array([15.9949146195700003, 1.0078250322300000, 1.0078250322300000]),
"real": np.array([True, True, True]),
"atom_labels": np.array(["", "", ""], dtype="<U1"),
"atomic_numbers": np.array([8, 1, 1], dtype=np.int16),
"mass_numbers": np.array([16, 1, 1], dtype=np.int16),
"fragments": [np.array([0, 1, 2], dtype=np.int32)],
"fragment_charges": [0.0],
"fragment_multiplicities": [1],
"fix_com": True,
"fix_orientation": True,
"provenance": {
"creator": "QCElemental",
"version": "0.29.0.dev1",
"routine": "qcelemental.molparse.from_string",
},
"extras": {},
},
"driver": "gradient",
"model": {"method": "unknown", "basis": "unknown"},
"keywords": {},
"protocols": {},
"extras": {
"qcvars": {
"NUCLEAR REPULSION ENERGY": 9.168193296424349,
"CURRENT ENERGY": -76.02663273512756,
"CURRENT GRADIENT": np.array(
[
[-0.0000000000000000, 0.0000000000000000, -0.0176416299024253],
[0.0000000000000000, -0.0124384148528182, 0.0088208149511995],
[-0.0000000000000000, 0.0124384148528182, 0.0088208149511995],
]
),
}
},
"provenance": {"creator": "User", "version": "0.1", "routine": ""},
"properties": {"nuclear_repulsion_energy": 9.168193296424349, "return_energy": -76.02663273512756},
"wavefunction": None,
"return_result": np.array(
[
[-0.0000000000000000, 0.0000000000000000, -0.0176416299024253],
[0.0000000000000000, -0.0124384148528182, 0.0088208149511995],
[-0.0000000000000000, 0.0124384148528182, 0.0088208149511995],
]
),
"stdout": "User provided energy, gradient, or hessian is returned",
"stderr": None,
"native_files": {},
"success": True,
"error": None,
}

yield qcel.models.results.AtomicResult(**data)


def test_json_dumps(atomic_result):
ret = qcel.util.json_dumps(atomic_result)
assert isinstance(ret, str)
Loading

0 comments on commit fe963b9

Please sign in to comment.