Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Alternative implementation for complex number support #5614

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions aiida/orm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
'CodeEntityLoader',
'Collection',
'Comment',
'Complex',
'Computer',
'ComputerEntityLoader',
'ContainerizedCode',
Expand Down
21 changes: 17 additions & 4 deletions aiida/orm/extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import copy
from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Tuple, Union

from .implementation.utils import clean_value, deserialize_value

if TYPE_CHECKING:
from .groups import Group
from .nodes.node import Node
Expand Down Expand Up @@ -57,7 +59,7 @@ def all(self) -> Dict[str, Any]:
extras = self._backend_entity.extras

if self._entity.is_stored:
extras = copy.deepcopy(extras)
extras = deserialize_value(copy.deepcopy(extras))

return extras

Expand All @@ -83,7 +85,7 @@ def get(self, key: str, default: Any = _NO_DEFAULT) -> Any:
extra = default

if self._entity.is_stored:
extra = copy.deepcopy(extra)
extra = deserialize_value(copy.deepcopy(extra))

return extra

Expand All @@ -105,7 +107,7 @@ def get_many(self, keys: List[str]) -> List[Any]:
extras = self._backend_entity.get_extra_many(keys)

if self._entity.is_stored:
extras = copy.deepcopy(extras)
extras = deserialize_value(copy.deepcopy(extras))

return extras

Expand All @@ -116,6 +118,8 @@ def set(self, key: str, value: Any) -> None:
:param value: value of the extra
:raise aiida.common.ValidationError: if the key is invalid, i.e. contains periods
"""
if self._entity.is_stored:
value = clean_value(value)
self._backend_entity.set_extra(key, value)

def set_many(self, extras: Dict[str, Any]) -> None:
Expand All @@ -126,6 +130,8 @@ def set_many(self, extras: Dict[str, Any]) -> None:
:param extras: a dictionary with the extras to set
:raise aiida.common.ValidationError: if any of the keys are invalid, i.e. contain periods
"""
if self._entity.is_stored:
extras = clean_value(extras)
self._backend_entity.set_extra_many(extras)

def reset(self, extras: Dict[str, Any]) -> None:
Expand All @@ -136,6 +142,8 @@ def reset(self, extras: Dict[str, Any]) -> None:
:param extras: a dictionary with the extras to set
:raise aiida.common.ValidationError: if any of the keys are invalid, i.e. contain periods
"""
if self._entity.is_stored:
extras = clean_value(extras)
self._backend_entity.reset_extras(extras)

def delete(self, key: str) -> None:
Expand Down Expand Up @@ -163,7 +171,12 @@ def items(self) -> Iterator[Tuple[str, Any]]:

:return: an iterator with extra key value pairs
"""
return self._backend_entity.extras_items()

def deserialize_values(items):
for key, value in items:
yield key, deserialize_value(value)

return deserialize_values(self._backend_entity.extras_items())

def keys(self) -> Iterable[str]:
"""Return an iterator over the extra keys.
Expand Down
45 changes: 45 additions & 0 deletions aiida/orm/implementation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# For further information please visit http://www.aiida.net #
###########################################################################
"""Utility methods for backend non-specific implementations."""
import cmath
from collections.abc import Iterable, Mapping
from decimal import Decimal
import math
Expand Down Expand Up @@ -93,6 +94,15 @@ def clean_builtin(val):
return int(new_val)
return new_val

if isinstance(val, numbers.Complex) and (cmath.isnan(val) or cmath.isinf(val)):
# see https://www.postgresql.org/docs/current/static/datatype-json.html#JSON-TYPE-MAPPING-TABLE
raise exceptions.ValidationError('nan and inf/-inf can not be serialized to the database')

if isinstance(val, numbers.Complex):
string_representation = f'{{:.{AIIDA_FLOAT_PRECISION}g}}'.format(val)
new_val = complex(string_representation)
return {'__complex__': True, 'real': new_val.real, 'imag': new_val.imag}

# Anything else we do not understand and we refuse
raise exceptions.ValidationError(f'type `{type(val)}` is not supported as it is not json-serializable')

Expand All @@ -101,6 +111,11 @@ def clean_builtin(val):

if isinstance(value, Mapping):
# Check dictionary before iterables
if '__complex__' in value and (\
not isinstance(value.get('real'),(numbers.Integral, numbers.Real, Decimal)) or \
not isinstance(value.get('real'),(numbers.Integral, numbers.Real, Decimal))):
#A dict with a __complex__ key will be deserialized as a complex number
raise exceptions.ValidationError('The key __complex__ is reserved for internal use')
return {k: clean_value(v) for k, v in value.items()}

if (isinstance(value, Iterable) and not isinstance(value, str)):
Expand All @@ -115,3 +130,33 @@ def clean_builtin(val):
# but is not an integer, I still accept it)

return clean_builtin(value)


def deserialize_value(value):
"""
Get value from input and (recursively) deserialize all values
that were serialized to another type (e.g. complex -> dict) by clean_value

- Mappings containg the key `__complex__` are converted into a complex number

Note however that there is no logic to avoid infinite loops when the
user passes some perverse recursive dictionary or list. However,
these should already fail when calling clean_value

:param value: A value to be set as an attribute or an extra
:return: a "deserialized" value, potentially identical to value, but with
values replaced where needed.
"""
if isinstance(value, Mapping):
if '__complex__' in value:
return complex(value['real'], value['imag'])
# Check dictionary before iterables
return {k: deserialize_value(v) for k, v in value.items()}

if (isinstance(value, Iterable) and not isinstance(value, str)):
# list, tuple, ... but not a string
# This should also properly take care of dealing with the
# basedatatypes.List object
return [deserialize_value(v) for v in value]

return value
1 change: 1 addition & 0 deletions aiida/orm/nodes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
'CalculationNode',
'CifData',
'Code',
'Complex',
'ContainerizedCode',
'Data',
'Dict',
Expand Down
21 changes: 17 additions & 4 deletions aiida/orm/nodes/attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import copy
from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Tuple

from ..implementation.utils import clean_value, deserialize_value

if TYPE_CHECKING:
from .node import Node

Expand Down Expand Up @@ -55,7 +57,7 @@ def all(self) -> Dict[str, Any]:
attributes = self._backend_node.attributes

if self._node.is_stored:
attributes = copy.deepcopy(attributes)
attributes = deserialize_value(copy.deepcopy(attributes))

return attributes

Expand All @@ -81,7 +83,7 @@ def get(self, key: str, default=_NO_DEFAULT) -> Any:
attribute = default

if self._node.is_stored:
attribute = copy.deepcopy(attribute)
attribute = deserialize_value(copy.deepcopy(attribute))

return attribute

Expand All @@ -103,7 +105,7 @@ def get_many(self, keys: List[str]) -> List[Any]:
attributes = self._backend_node.get_attribute_many(keys)

if self._node.is_stored:
attributes = copy.deepcopy(attributes)
attributes = deserialize_value(copy.deepcopy(attributes))

return attributes

Expand All @@ -116,6 +118,8 @@ def set(self, key: str, value: Any) -> None:
:raise aiida.common.ModificationNotAllowed: if the entity is stored
"""
self._node._check_mutability_attributes([key]) # pylint: disable=protected-access
if self._node.is_stored:
value = clean_value(value)
self._backend_node.set_attribute(key, value)

def set_many(self, attributes: Dict[str, Any]) -> None:
Expand All @@ -128,6 +132,8 @@ def set_many(self, attributes: Dict[str, Any]) -> None:
:raise aiida.common.ModificationNotAllowed: if the entity is stored
"""
self._node._check_mutability_attributes(list(attributes)) # pylint: disable=protected-access
if self._node.is_stored:
attributes = clean_value(attributes)
self._backend_node.set_attribute_many(attributes)

def reset(self, attributes: Dict[str, Any]) -> None:
Expand All @@ -140,6 +146,8 @@ def reset(self, attributes: Dict[str, Any]) -> None:
:raise aiida.common.ModificationNotAllowed: if the entity is stored
"""
self._node._check_mutability_attributes() # pylint: disable=protected-access
if self._node.is_stored:
attributes = clean_value(attributes)
self._backend_node.reset_attributes(attributes)

def delete(self, key: str) -> None:
Expand Down Expand Up @@ -172,7 +180,12 @@ def items(self) -> Iterator[Tuple[str, Any]]:

:return: an iterator with attribute key value pairs
"""
return self._backend_node.attributes_items()

def deserialize_values(items):
for key, value in items:
yield key, deserialize_value(value)

return deserialize_values(self._backend_node.attributes_items())

def keys(self) -> Iterable[str]:
"""Return an iterator over the attribute keys.
Expand Down
2 changes: 2 additions & 0 deletions aiida/orm/nodes/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .bool import *
from .cif import *
from .code import *
from .complex import *
from .data import *
from .dict import *
from .enum import *
Expand All @@ -43,6 +44,7 @@
'Bool',
'CifData',
'Code',
'Complex',
'ContainerizedCode',
'Data',
'Dict',
Expand Down
28 changes: 28 additions & 0 deletions aiida/orm/nodes/data/complex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# -*- coding: utf-8 -*-
###########################################################################
# Copyright (c), The AiiDA team. All rights reserved. #
# This file is part of the AiiDA code. #
# #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
"""`Data` sub class to represent a complex float value."""

import numbers

from .base import to_aiida_type
from .numeric import NumericType

__all__ = ('Complex',)


class Complex(NumericType):
"""`Data` sub class to represent a float value."""

_type = complex


@to_aiida_type.register(numbers.Complex)
def _(value):
return Complex(value)
34 changes: 19 additions & 15 deletions aiida/storage/psql_dos/orm/querybuilder/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from aiida.common.exceptions import NotExistent
from aiida.orm.entities import EntityTypes
from aiida.orm.implementation.querybuilder import QUERYBUILD_LOGGER, BackendQueryBuilder, QueryDictType
from aiida.orm.implementation.utils import clean_value

from .joiner import JoinReturn, SqlaJoiner

Expand Down Expand Up @@ -587,7 +588,6 @@ def get_filter_expr_from_jsonb(
"""Return a filter expression"""

# pylint: disable=too-many-branches, too-many-arguments, too-many-statements

def cast_according_to_type(path_in_json, value):
"""Cast the value according to the type"""
if isinstance(value, bool):
Expand All @@ -602,6 +602,9 @@ def cast_according_to_type(path_in_json, value):
elif isinstance(value, dict):
type_filter = jsonb_typeof(path_in_json) == 'array'
casted_entity = path_in_json.astext.cast(JSONB) # BOOLEANS?
elif isinstance(value, complex):
type_filter = jsonb_typeof(path_in_json) == 'object'
casted_entity = path_in_json.astext.cast(JSONB) # BOOLEANS?
elif isinstance(value, str):
type_filter = jsonb_typeof(path_in_json) == 'string'
casted_entity = path_in_json.astext
Expand All @@ -615,48 +618,49 @@ def cast_according_to_type(path_in_json, value):
if column is None:
column = get_column(column_name, alias)

cleaned_value = clean_value(value)
database_entity = column[tuple(attr_key)]
expr: Any
if operator == '==':
type_filter, casted_entity = cast_according_to_type(database_entity, value)
expr = case((type_filter, casted_entity == value), else_=False)
expr = case((type_filter, casted_entity == cleaned_value), else_=False)
elif operator == '>':
type_filter, casted_entity = cast_according_to_type(database_entity, value)
expr = case((type_filter, casted_entity > value), else_=False)
expr = case((type_filter, casted_entity > cleaned_value), else_=False)
elif operator == '<':
type_filter, casted_entity = cast_according_to_type(database_entity, value)
expr = case((type_filter, casted_entity < value), else_=False)
expr = case((type_filter, casted_entity < cleaned_value), else_=False)
elif operator in ('>=', '=>'):
type_filter, casted_entity = cast_according_to_type(database_entity, value)
expr = case((type_filter, casted_entity >= value), else_=False)
expr = case((type_filter, casted_entity >= cleaned_value), else_=False)
elif operator in ('<=', '=<'):
type_filter, casted_entity = cast_according_to_type(database_entity, value)
expr = case((type_filter, casted_entity <= value), else_=False)
expr = case((type_filter, casted_entity <= cleaned_value), else_=False)
elif operator == 'of_type':
# http://www.postgresql.org/docs/9.5/static/functions-json.html
# Possible types are object, array, string, number, boolean, and null.
valid_types = ('object', 'array', 'string', 'number', 'boolean', 'null')
if value not in valid_types:
raise ValueError(f'value {value} for of_type is not among valid types\n{valid_types}')
expr = jsonb_typeof(database_entity) == value
expr = jsonb_typeof(database_entity) == cleaned_value
elif operator == 'like':
type_filter, casted_entity = cast_according_to_type(database_entity, value)
expr = case((type_filter, casted_entity.like(value)), else_=False)
expr = case((type_filter, casted_entity.like(cleaned_value)), else_=False)
elif operator == 'ilike':
type_filter, casted_entity = cast_according_to_type(database_entity, value)
expr = case((type_filter, casted_entity.ilike(value)), else_=False)
expr = case((type_filter, casted_entity.ilike(cleaned_value)), else_=False)
elif operator == 'in':
type_filter, casted_entity = cast_according_to_type(database_entity, value[0])
expr = case((type_filter, casted_entity.in_(value)), else_=False)
expr = case((type_filter, casted_entity.in_(cleaned_value)), else_=False)
elif operator == 'contains':
expr = database_entity.cast(JSONB).contains(value)
expr = database_entity.cast(JSONB).contains(cleaned_value)
elif operator == 'has_key':
expr = database_entity.cast(JSONB).has_key(value) # noqa
expr = database_entity.cast(JSONB).has_key(cleaned_value) # noqa
elif operator == 'of_length':
expr = case(
(
jsonb_typeof(database_entity) == 'array',
jsonb_array_length(database_entity.cast(JSONB)) == value,
jsonb_array_length(database_entity.cast(JSONB)) == cleaned_value,
),
else_=False,
)
Expand All @@ -665,15 +669,15 @@ def cast_according_to_type(path_in_json, value):
expr = case(
(
jsonb_typeof(database_entity) == 'array',
jsonb_array_length(database_entity.cast(JSONB)) > value,
jsonb_array_length(database_entity.cast(JSONB)) > cleaned_value,
),
else_=False,
)
elif operator == 'shorter':
expr = case(
(
jsonb_typeof(database_entity) == 'array',
jsonb_array_length(database_entity.cast(JSONB)) < value,
jsonb_array_length(database_entity.cast(JSONB)) < cleaned_value,
),
else_=False,
)
Expand Down
Loading