Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Csse pyd2 510 Part 2 #348

Merged
merged 2 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ Enhancements
* The ``models.v2`` have had their `schema_version` bumped for ``BasisSet``, ``AtomicInput``, ``OptimizationInput`` (implicit for ``AtomicResult`` and ``OptimizationResult``), ``TorsionDriveInput`` , and ``TorsionDriveResult``.
* The ``models.v2`` ``AtomicResultProperties`` has been given a ``schema_name`` and ``schema_version`` (2) for the first time.
* Note that ``models.v2`` ``QCInputSpecification`` and ``OptimizationSpecification`` have *not* had schema_version bumped.
* All of ``Datum``, ``DFTFunctional``, and ``CPUInfo`` models, none of which are mixed with QCSchema models, are translated to Pydantic v2 API syntax.

Bug Fixes
+++++++++
Expand Down
107 changes: 91 additions & 16 deletions qcelemental/datum.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,61 @@
"""

from decimal import Decimal
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union

import numpy as np
from pydantic import (
BaseModel,
ConfigDict,
SerializationInfo,
SerializerFunctionWrapHandler,
WrapSerializer,
field_validator,
model_serializer,
)
from typing_extensions import Annotated


def reduce_complex(data):
# Reduce Complex
if isinstance(data, complex):
return [data.real, data.imag]
# Fallback
return data


def keep_decimal_cast_ndarray_complex(
v: Any, nxt: SerializerFunctionWrapHandler, info: SerializationInfo
) -> Union[list, Decimal, float]:
"""
Ensure Decimal types are preserved on the way out

This arose because Decimal was serialized to string and "dump" is equal to "serialize" in v2 pydantic
https://docs.pydantic.dev/latest/migration/#changes-to-json-schema-generation


This also checks against NumPy Arrays and complex numbers in the instance of being in JSON mode
"""
if isinstance(v, Decimal):
return v
if info.mode == "json":
if isinstance(v, complex):
return nxt(reduce_complex(v))
if isinstance(v, np.ndarray):
# Handle NDArray and complex NDArray
flat_list = v.flatten().tolist()
reduced_list = list(map(reduce_complex, flat_list))
return nxt(reduced_list)
try:
# Cast NumPy scalar data types to native Python data type
v = v.item()
except (AttributeError, ValueError):
pass
return nxt(v)


try:
from pydantic.v1 import BaseModel, validator
except ImportError: # Will also trap ModuleNotFoundError
from pydantic import BaseModel, validator
# Only 1 serializer is allowed. You can't chain wrap serializers.
AnyArrayComplex = Annotated[Any, WrapSerializer(keep_decimal_cast_ndarray_complex)]


class Datum(BaseModel):
Expand Down Expand Up @@ -38,15 +85,15 @@ class Datum(BaseModel):
numeric: bool
label: str
units: str
data: Any
data: AnyArrayComplex
comment: str = ""
doi: Optional[str] = None
glossary: str = ""

class Config:
extra = "forbid"
allow_mutation = False
json_encoders = {np.ndarray: lambda v: v.flatten().tolist(), complex: lambda v: (v.real, v.imag)}
model_config = ConfigDict(
extra="forbid",
frozen=True,
)

def __init__(self, label, units, data, *, comment=None, doi=None, glossary=None, numeric=True):
kwargs = {"label": label, "units": units, "data": data, "numeric": numeric}
Expand All @@ -59,20 +106,21 @@ def __init__(self, label, units, data, *, comment=None, doi=None, glossary=None,

super().__init__(**kwargs)

@validator("data")
def must_be_numerical(cls, v, values, **kwargs):
@field_validator("data")
@classmethod
def must_be_numerical(cls, v, info):
try:
1.0 * v
except TypeError:
try:
Decimal("1.0") * v
except TypeError:
if values["numeric"]:
if info.data["numeric"]:
raise ValueError(f"Datum data should be float, Decimal, or np.ndarray, not {type(v)}.")
else:
values["numeric"] = True
info.data["numeric"] = True
else:
values["numeric"] = True
info.data["numeric"] = True

return v

Expand All @@ -90,8 +138,35 @@ def __str__(self, label=""):
text.append("-" * width)
return "\n".join(text)

@model_serializer(mode="wrap")
def _serialize_model(self, handler) -> Dict[str, Any]:
"""
Customize the serialization output. Does duplicate with some code in model_dump, but handles the case of nested
models and any model config options.

Encoding is handled at the `model_dump` level and not here as that should happen only after EVERYTHING has been
dumped/de-pydantic-ized.
"""

# Get the default return, let the model_dump handle kwarg
default_result = handler(self)
# Exclude unset always
output_dict = {key: value for key, value in default_result.items() if key in self.model_fields_set}
return output_dict

def dict(self, *args, **kwargs):
return super().dict(*args, **{**kwargs, **{"exclude_unset": True}})
"""
Passthrough to model_dump without deprecation warning
exclude_unset is forced through the model_serializer
"""
return super().model_dump(*args, **kwargs)

def json(self, *args, **kwargs):
"""
Passthrough to model_dump_sjon without deprecation warning
exclude_unset is forced through the model_serializer
"""
return super().model_dump_json(*args, **kwargs)

def to_units(self, units=None):
from .physical_constants import constants
Expand Down
22 changes: 16 additions & 6 deletions qcelemental/info/cpu_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@
from functools import lru_cache
from typing import List, Optional

from pydantic.v1 import Field
from pydantic import BeforeValidator, Field
from typing_extensions import Annotated

from ..models import ProtoModel
from ..models.v2 import ProtoModel

# ProcessorInfo models don't become parts of QCSchema models afaik, so pure pydantic v2 API


class VendorEnum(str, Enum):
Expand All @@ -22,6 +25,13 @@ class VendorEnum(str, Enum):
arm = "arm"


def stringify(v) -> str:
return str(v)


Stringify = Annotated[str, BeforeValidator(stringify)]


class InstructionSetEnum(int, Enum):
"""Allowed instruction sets for CPUs in an ordinal enum."""

Expand All @@ -37,13 +47,13 @@ class ProcessorInfo(ProtoModel):
ncores: int = Field(..., description="The number of physical cores on the chip.")
nthreads: Optional[int] = Field(..., description="The maximum number of concurrent threads.")
base_clock: float = Field(..., description="The base clock frequency (GHz).")
boost_clock: Optional[float] = Field(..., description="The boost clock frequency (GHz).")
model: str = Field(..., description="The model number of the chip.")
boost_clock: Optional[float] = Field(None, description="The boost clock frequency (GHz).")
model: Stringify = Field(..., description="The model number of the chip.")
family: str = Field(..., description="The family of the chip.")
launch_date: Optional[int] = Field(..., description="The launch year of the chip.")
launch_date: Optional[int] = Field(None, description="The launch year of the chip.")
target_use: str = Field(..., description="Target use case (Desktop, Server, etc).")
vendor: VendorEnum = Field(..., description="The vendor the chip is produced by.")
microarchitecture: Optional[str] = Field(..., description="The microarchitecture the chip follows.")
microarchitecture: Optional[str] = Field(None, description="The microarchitecture the chip follows.")
instructions: InstructionSetEnum = Field(..., description="The maximum vectorized instruction set available.")
type: str = Field(..., description="The type of chip (cpu, gpu, etc).")

Expand Down
9 changes: 6 additions & 3 deletions qcelemental/info/dft_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@

from typing import Dict

from pydantic.v1 import Field
from pydantic import Field
from typing_extensions import Annotated

from ..models import ProtoModel
from ..models.v2 import ProtoModel

# DFTFunctional models don't become parts of QCSchema models afaik, so pure pydantic v2 API


class DFTFunctionalInfo(ProtoModel):
Expand Down Expand Up @@ -68,4 +71,4 @@ def get(name: str) -> DFTFunctionalInfo:
name = name.replace(x, "")
break

return dftfunctionalinfo.functionals[name].copy()
return dftfunctionalinfo.functionals[name].model_copy()
20 changes: 11 additions & 9 deletions qcelemental/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,7 @@
from typing import TYPE_CHECKING, Callable, Dict, List, Tuple, Union

import numpy as np

try:
from pydantic.v1 import BaseModel
except ImportError: # Will also trap ModuleNotFoundError
from pydantic import BaseModel
import pydantic

if TYPE_CHECKING:
from qcelemental.models import ProtoModel # TODO: recheck if .v1 needed
Expand Down Expand Up @@ -313,10 +309,16 @@ def _compare_recursive(expected, computed, atol, rtol, _prefix=False, equal_phas
prefix = name + "."

# Initial conversions if required
if isinstance(expected, BaseModel):
if isinstance(expected, pydantic.BaseModel):
expected = expected.model_dump()

if isinstance(computed, pydantic.BaseModel):
computed = computed.model_dump()

if isinstance(expected, pydantic.v1.BaseModel):
expected = expected.dict()

if isinstance(computed, BaseModel):
if isinstance(computed, pydantic.v1.BaseModel):
computed = computed.dict()

if isinstance(expected, (str, int, bool, complex)):
Expand Down Expand Up @@ -381,8 +383,8 @@ def _compare_recursive(expected, computed, atol, rtol, _prefix=False, equal_phas


def compare_recursive(
expected: Union[Dict, BaseModel, "ProtoModel"], # type: ignore
computed: Union[Dict, BaseModel, "ProtoModel"], # type: ignore
expected: Union[Dict, pydantic.BaseModel, pydantic.v1.BaseModel, "ProtoModel"], # type: ignore
computed: Union[Dict, pydantic.BaseModel, pydantic.v1.BaseModel, "ProtoModel"], # type: ignore
label: str = None,
*,
atol: float = 1.0e-6,
Expand Down
10 changes: 3 additions & 7 deletions qcelemental/tests/test_datum.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
from decimal import Decimal

import numpy as np

try:
import pydantic.v1 as pydantic
except ImportError: # Will also trap ModuleNotFoundError
import pydantic
import pydantic
import pytest

import qcelemental as qcel
Expand Down Expand Up @@ -46,10 +42,10 @@ def test_creation_nonnum(dataset):


def test_creation_error():
with pytest.raises(pydantic.ValidationError):
with pytest.raises(pydantic.ValidationError) as e:
qcel.Datum("ze lbl", "ze unit", "ze data")

# assert 'Datum data should be float' in str(e)
assert "Datum data should be float" in str(e.value)


@pytest.mark.parametrize(
Expand Down
Loading
Loading