Skip to content

Commit

Permalink
Merge pull request #348 from loriab/csse_pyd2_510_pt2_more
Browse files Browse the repository at this point in the history
Csse pyd2 510 Part 2
  • Loading branch information
loriab authored Sep 11, 2024
2 parents c219aa2 + a1f7dcc commit d54ebf2
Show file tree
Hide file tree
Showing 9 changed files with 270 additions and 72 deletions.
1 change: 1 addition & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ Enhancements
* The ``models.v2`` have had their `schema_version` bumped for ``BasisSet``, ``AtomicInput``, ``OptimizationInput`` (implicit for ``AtomicResult`` and ``OptimizationResult``), ``TorsionDriveInput`` , and ``TorsionDriveResult``.
* The ``models.v2`` ``AtomicResultProperties`` has been given a ``schema_name`` and ``schema_version`` (2) for the first time.
* Note that ``models.v2`` ``QCInputSpecification`` and ``OptimizationSpecification`` have *not* had schema_version bumped.
* All of ``Datum``, ``DFTFunctional``, and ``CPUInfo`` models, none of which are mixed with QCSchema models, are translated to Pydantic v2 API syntax.

Bug Fixes
+++++++++
Expand Down
107 changes: 91 additions & 16 deletions qcelemental/datum.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,61 @@
"""

from decimal import Decimal
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union

import numpy as np
from pydantic import (
BaseModel,
ConfigDict,
SerializationInfo,
SerializerFunctionWrapHandler,
WrapSerializer,
field_validator,
model_serializer,
)
from typing_extensions import Annotated


def reduce_complex(data):
# Reduce Complex
if isinstance(data, complex):
return [data.real, data.imag]
# Fallback
return data


def keep_decimal_cast_ndarray_complex(
v: Any, nxt: SerializerFunctionWrapHandler, info: SerializationInfo
) -> Union[list, Decimal, float]:
"""
Ensure Decimal types are preserved on the way out
This arose because Decimal was serialized to string and "dump" is equal to "serialize" in v2 pydantic
https://docs.pydantic.dev/latest/migration/#changes-to-json-schema-generation
This also checks against NumPy Arrays and complex numbers in the instance of being in JSON mode
"""
if isinstance(v, Decimal):
return v
if info.mode == "json":
if isinstance(v, complex):
return nxt(reduce_complex(v))
if isinstance(v, np.ndarray):
# Handle NDArray and complex NDArray
flat_list = v.flatten().tolist()
reduced_list = list(map(reduce_complex, flat_list))
return nxt(reduced_list)
try:
# Cast NumPy scalar data types to native Python data type
v = v.item()
except (AttributeError, ValueError):
pass
return nxt(v)


try:
from pydantic.v1 import BaseModel, validator
except ImportError: # Will also trap ModuleNotFoundError
from pydantic import BaseModel, validator
# Only 1 serializer is allowed. You can't chain wrap serializers.
AnyArrayComplex = Annotated[Any, WrapSerializer(keep_decimal_cast_ndarray_complex)]


class Datum(BaseModel):
Expand Down Expand Up @@ -38,15 +85,15 @@ class Datum(BaseModel):
numeric: bool
label: str
units: str
data: Any
data: AnyArrayComplex
comment: str = ""
doi: Optional[str] = None
glossary: str = ""

class Config:
extra = "forbid"
allow_mutation = False
json_encoders = {np.ndarray: lambda v: v.flatten().tolist(), complex: lambda v: (v.real, v.imag)}
model_config = ConfigDict(
extra="forbid",
frozen=True,
)

def __init__(self, label, units, data, *, comment=None, doi=None, glossary=None, numeric=True):
kwargs = {"label": label, "units": units, "data": data, "numeric": numeric}
Expand All @@ -59,20 +106,21 @@ def __init__(self, label, units, data, *, comment=None, doi=None, glossary=None,

super().__init__(**kwargs)

@validator("data")
def must_be_numerical(cls, v, values, **kwargs):
@field_validator("data")
@classmethod
def must_be_numerical(cls, v, info):
try:
1.0 * v
except TypeError:
try:
Decimal("1.0") * v
except TypeError:
if values["numeric"]:
if info.data["numeric"]:
raise ValueError(f"Datum data should be float, Decimal, or np.ndarray, not {type(v)}.")
else:
values["numeric"] = True
info.data["numeric"] = True
else:
values["numeric"] = True
info.data["numeric"] = True

return v

Expand All @@ -90,8 +138,35 @@ def __str__(self, label=""):
text.append("-" * width)
return "\n".join(text)

@model_serializer(mode="wrap")
def _serialize_model(self, handler) -> Dict[str, Any]:
"""
Customize the serialization output. Does duplicate with some code in model_dump, but handles the case of nested
models and any model config options.
Encoding is handled at the `model_dump` level and not here as that should happen only after EVERYTHING has been
dumped/de-pydantic-ized.
"""

# Get the default return, let the model_dump handle kwarg
default_result = handler(self)
# Exclude unset always
output_dict = {key: value for key, value in default_result.items() if key in self.model_fields_set}
return output_dict

def dict(self, *args, **kwargs):
return super().dict(*args, **{**kwargs, **{"exclude_unset": True}})
"""
Passthrough to model_dump without deprecation warning
exclude_unset is forced through the model_serializer
"""
return super().model_dump(*args, **kwargs)

def json(self, *args, **kwargs):
"""
Passthrough to model_dump_sjon without deprecation warning
exclude_unset is forced through the model_serializer
"""
return super().model_dump_json(*args, **kwargs)

def to_units(self, units=None):
from .physical_constants import constants
Expand Down
22 changes: 16 additions & 6 deletions qcelemental/info/cpu_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@
from functools import lru_cache
from typing import List, Optional

from pydantic.v1 import Field
from pydantic import BeforeValidator, Field
from typing_extensions import Annotated

from ..models import ProtoModel
from ..models.v2 import ProtoModel

# ProcessorInfo models don't become parts of QCSchema models afaik, so pure pydantic v2 API


class VendorEnum(str, Enum):
Expand All @@ -22,6 +25,13 @@ class VendorEnum(str, Enum):
arm = "arm"


def stringify(v) -> str:
return str(v)


Stringify = Annotated[str, BeforeValidator(stringify)]


class InstructionSetEnum(int, Enum):
"""Allowed instruction sets for CPUs in an ordinal enum."""

Expand All @@ -37,13 +47,13 @@ class ProcessorInfo(ProtoModel):
ncores: int = Field(..., description="The number of physical cores on the chip.")
nthreads: Optional[int] = Field(..., description="The maximum number of concurrent threads.")
base_clock: float = Field(..., description="The base clock frequency (GHz).")
boost_clock: Optional[float] = Field(..., description="The boost clock frequency (GHz).")
model: str = Field(..., description="The model number of the chip.")
boost_clock: Optional[float] = Field(None, description="The boost clock frequency (GHz).")
model: Stringify = Field(..., description="The model number of the chip.")
family: str = Field(..., description="The family of the chip.")
launch_date: Optional[int] = Field(..., description="The launch year of the chip.")
launch_date: Optional[int] = Field(None, description="The launch year of the chip.")
target_use: str = Field(..., description="Target use case (Desktop, Server, etc).")
vendor: VendorEnum = Field(..., description="The vendor the chip is produced by.")
microarchitecture: Optional[str] = Field(..., description="The microarchitecture the chip follows.")
microarchitecture: Optional[str] = Field(None, description="The microarchitecture the chip follows.")
instructions: InstructionSetEnum = Field(..., description="The maximum vectorized instruction set available.")
type: str = Field(..., description="The type of chip (cpu, gpu, etc).")

Expand Down
9 changes: 6 additions & 3 deletions qcelemental/info/dft_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@

from typing import Dict

from pydantic.v1 import Field
from pydantic import Field
from typing_extensions import Annotated

from ..models import ProtoModel
from ..models.v2 import ProtoModel

# DFTFunctional models don't become parts of QCSchema models afaik, so pure pydantic v2 API


class DFTFunctionalInfo(ProtoModel):
Expand Down Expand Up @@ -68,4 +71,4 @@ def get(name: str) -> DFTFunctionalInfo:
name = name.replace(x, "")
break

return dftfunctionalinfo.functionals[name].copy()
return dftfunctionalinfo.functionals[name].model_copy()
20 changes: 11 additions & 9 deletions qcelemental/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,7 @@
from typing import TYPE_CHECKING, Callable, Dict, List, Tuple, Union

import numpy as np

try:
from pydantic.v1 import BaseModel
except ImportError: # Will also trap ModuleNotFoundError
from pydantic import BaseModel
import pydantic

if TYPE_CHECKING:
from qcelemental.models import ProtoModel # TODO: recheck if .v1 needed
Expand Down Expand Up @@ -313,10 +309,16 @@ def _compare_recursive(expected, computed, atol, rtol, _prefix=False, equal_phas
prefix = name + "."

# Initial conversions if required
if isinstance(expected, BaseModel):
if isinstance(expected, pydantic.BaseModel):
expected = expected.model_dump()

if isinstance(computed, pydantic.BaseModel):
computed = computed.model_dump()

if isinstance(expected, pydantic.v1.BaseModel):
expected = expected.dict()

if isinstance(computed, BaseModel):
if isinstance(computed, pydantic.v1.BaseModel):
computed = computed.dict()

if isinstance(expected, (str, int, bool, complex)):
Expand Down Expand Up @@ -381,8 +383,8 @@ def _compare_recursive(expected, computed, atol, rtol, _prefix=False, equal_phas


def compare_recursive(
expected: Union[Dict, BaseModel, "ProtoModel"], # type: ignore
computed: Union[Dict, BaseModel, "ProtoModel"], # type: ignore
expected: Union[Dict, pydantic.BaseModel, pydantic.v1.BaseModel, "ProtoModel"], # type: ignore
computed: Union[Dict, pydantic.BaseModel, pydantic.v1.BaseModel, "ProtoModel"], # type: ignore
label: str = None,
*,
atol: float = 1.0e-6,
Expand Down
10 changes: 3 additions & 7 deletions qcelemental/tests/test_datum.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
from decimal import Decimal

import numpy as np

try:
import pydantic.v1 as pydantic
except ImportError: # Will also trap ModuleNotFoundError
import pydantic
import pydantic
import pytest

import qcelemental as qcel
Expand Down Expand Up @@ -46,10 +42,10 @@ def test_creation_nonnum(dataset):


def test_creation_error():
with pytest.raises(pydantic.ValidationError):
with pytest.raises(pydantic.ValidationError) as e:
qcel.Datum("ze lbl", "ze unit", "ze data")

# assert 'Datum data should be float' in str(e)
assert "Datum data should be float" in str(e.value)


@pytest.mark.parametrize(
Expand Down
Loading

0 comments on commit d54ebf2

Please sign in to comment.