From 9e673752ef36d0e2f2246372b48b006490430794 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Mon, 14 Mar 2022 13:12:02 +0100 Subject: [PATCH 1/7] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20REFACTOR:=20`EntityAtt?= =?UTF-8?q?ributesMixin`=20->=20`NodeAttributes`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .molecule/default/files/polish/cli.py | 2 +- .pre-commit-config.yaml | 1 + aiida/cmdline/commands/cmd_node.py | 2 +- aiida/cmdline/utils/common.py | 5 +- aiida/engine/processes/calcjobs/calcjob.py | 2 +- aiida/engine/processes/process.py | 2 +- aiida/orm/__init__.py | 2 +- aiida/orm/entities.py | 162 +----------- aiida/orm/nodes/__init__.py | 2 + aiida/orm/nodes/attributes.py | 183 +++++++++++++ aiida/orm/nodes/data/array/array.py | 8 +- aiida/orm/nodes/data/array/bands.py | 14 +- aiida/orm/nodes/data/array/kpoints.py | 30 +-- aiida/orm/nodes/data/array/projection.py | 8 +- aiida/orm/nodes/data/array/trajectory.py | 10 +- aiida/orm/nodes/data/array/xy.py | 16 +- aiida/orm/nodes/data/base.py | 4 +- aiida/orm/nodes/data/cif.py | 22 +- aiida/orm/nodes/data/code.py | 32 +-- aiida/orm/nodes/data/data.py | 6 +- aiida/orm/nodes/data/dict.py | 18 +- aiida/orm/nodes/data/enum.py | 12 +- aiida/orm/nodes/data/jsonable.py | 4 +- aiida/orm/nodes/data/list.py | 6 +- aiida/orm/nodes/data/orbital.py | 6 +- aiida/orm/nodes/data/remote/base.py | 4 +- aiida/orm/nodes/data/remote/stash/base.py | 4 +- aiida/orm/nodes/data/remote/stash/folder.py | 8 +- aiida/orm/nodes/data/singlefile.py | 4 +- aiida/orm/nodes/data/structure.py | 34 +-- aiida/orm/nodes/data/upf.py | 16 +- aiida/orm/nodes/node.py | 53 +++- .../orm/nodes/process/calculation/calcjob.py | 46 ++-- aiida/orm/nodes/process/process.py | 38 +-- aiida/orm/nodes/process/workflow/workchain.py | 4 +- aiida/orm/utils/managers.py | 14 +- aiida/orm/utils/mixins.py | 87 +++---- aiida/restapi/translator/nodes/node.py | 6 +- aiida/tools/visualization/graph.py | 10 +- .../source/developer_guide/core/internals.rst | 15 +- docs/source/topics/data_types.rst | 16 +- tests/benchmark/test_nodes.py | 6 +- tests/cmdline/commands/test_node.py | 2 +- tests/cmdline/commands/test_process.py | 8 +- .../processes/calcjobs/test_calc_job.py | 20 +- .../engine/processes/workchains/test_utils.py | 18 +- tests/engine/test_process_function.py | 2 +- tests/orm/data/test_enum.py | 6 +- tests/orm/nodes/data/test_jsonable.py | 4 +- tests/orm/nodes/data/test_orbital.py | 2 +- tests/orm/nodes/data/test_remote.py | 2 +- tests/orm/nodes/data/test_trajectory.py | 20 +- tests/orm/nodes/test_calcjob.py | 2 +- tests/orm/nodes/test_node.py | 90 +++---- tests/orm/test_mixins.py | 2 +- tests/orm/test_querybuilder.py | 68 ++--- tests/restapi/test_routes.py | 4 +- tests/storage/psql_dos/test_query.py | 2 +- tests/test_calculation_node.py | 22 +- tests/test_dataclasses.py | 6 +- tests/test_dbimporters.py | 2 +- tests/test_nodes.py | 246 +++++++++--------- tests/tools/archive/orm/test_attributes.py | 6 +- tests/tools/archive/test_complex.py | 2 +- tests/tools/archive/test_simple.py | 4 +- tests/tools/archive/test_specific_import.py | 4 +- tests/tools/groups/test_paths.py | 2 +- 67 files changed, 757 insertions(+), 713 deletions(-) create mode 100644 aiida/orm/nodes/attributes.py diff --git a/.molecule/default/files/polish/cli.py b/.molecule/default/files/polish/cli.py index 9cb6bcf519..32fb454320 100755 --- a/.molecule/default/files/polish/cli.py +++ b/.molecule/default/files/polish/cli.py @@ -187,7 +187,7 @@ def run_via_daemon(workchains, inputs, sleep, timeout): except AttributeError: click.secho('Failed: ', fg='red', bold=True, nl=False) click.secho(f'the workchain<{workchain.pk}> did not return a result output node', bold=True) - click.echo(str(workchain.attributes)) + click.echo(str(workchain.base.attributes.all)) return None return result, workchain, total_time diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 24d9688979..9676733d24 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -95,6 +95,7 @@ repos: aiida/orm/users.py| aiida/orm/nodes/data/enum.py| aiida/orm/nodes/data/jsonable.py| + aiida/orm/nodes/attributes.py| aiida/orm/nodes/node.py| aiida/orm/nodes/process/.*py| aiida/orm/nodes/repository.py| diff --git a/aiida/cmdline/commands/cmd_node.py b/aiida/cmdline/commands/cmd_node.py index 28a049a5ac..f0c5b9c790 100644 --- a/aiida/cmdline/commands/cmd_node.py +++ b/aiida/cmdline/commands/cmd_node.py @@ -234,7 +234,7 @@ def echo_node_dict(nodes, keys, fmt, identifier, raw, use_attrs=True): id_value = node.uuid if use_attrs: - node_dict = node.attributes + node_dict = node.base.attributes.all dict_name = 'attributes' else: node_dict = node.extras diff --git a/aiida/cmdline/utils/common.py b/aiida/cmdline/utils/common.py index 1d8f7152c8..54ca85307b 100644 --- a/aiida/cmdline/utils/common.py +++ b/aiida/cmdline/utils/common.py @@ -195,7 +195,10 @@ def format_flat_links(links, headers): table = [] for link_triple in links: - table.append([link_triple.link_label, link_triple.node.pk, link_triple.node.get_attribute('process_label', '')]) + table.append([ + link_triple.link_label, link_triple.node.pk, + link_triple.node.base.attributes.get('process_label', '') + ]) result = f'\n{tabulate(table, headers=headers)}' diff --git a/aiida/engine/processes/calcjobs/calcjob.py b/aiida/engine/processes/calcjobs/calcjob.py index c6fa45cfb0..08e77ad3cb 100644 --- a/aiida/engine/processes/calcjobs/calcjob.py +++ b/aiida/engine/processes/calcjobs/calcjob.py @@ -440,7 +440,7 @@ def _perform_import(self): ) retrieve_calculation(self.node, transport, retrieved_temporary_folder.abspath) self.node.set_state(CalcJobState.PARSING) - self.node.set_attribute(orm.CalcJobNode.IMMIGRATED_KEY, True) + self.node.base.attributes.set(orm.CalcJobNode.IMMIGRATED_KEY, True) return self.parse(retrieved_temporary_folder.abspath) def parse(self, retrieved_temporary_folder: Optional[str] = None) -> ExitCode: diff --git a/aiida/engine/processes/process.py b/aiida/engine/processes/process.py index f9dbed327a..2324b54790 100644 --- a/aiida/engine/processes/process.py +++ b/aiida/engine/processes/process.py @@ -685,7 +685,7 @@ def _setup_db_record(self) -> None: def _setup_metadata(self) -> None: """Store the metadata on the ProcessNode.""" version_info = self.runner.plugin_version_provider.get_version_info(self.__class__) - self.node.set_attribute_many(version_info) + self.node.base.attributes.set_many(version_info) for name, metadata in self.metadata.items(): if name in ['store_provenance', 'dry_run', 'call_link_label']: diff --git a/aiida/orm/__init__.py b/aiida/orm/__init__.py index a20396ea7f..1358c5c9f4 100644 --- a/aiida/orm/__init__.py +++ b/aiida/orm/__init__.py @@ -51,7 +51,6 @@ 'Data', 'Dict', 'Entity', - 'EntityAttributesMixin', 'EntityExtrasMixin', 'EntityTypes', 'EnumData', @@ -70,6 +69,7 @@ 'List', 'Log', 'Node', + 'NodeAttributes', 'NodeEntityLoader', 'NodeLinksManager', 'NodeRepository', diff --git a/aiida/orm/entities.py b/aiida/orm/entities.py index ba56e5bbd0..1bb67e846c 100644 --- a/aiida/orm/entities.py +++ b/aiida/orm/entities.py @@ -16,7 +16,6 @@ from plumpy.base.utils import call_with_super_check, super_check -from aiida.common import exceptions from aiida.common.lang import classproperty, type_check from aiida.manage import get_manager @@ -24,7 +23,7 @@ from aiida.orm.implementation import BackendEntity, StorageBackend from aiida.orm.querybuilder import FilterType, OrderByType, QueryBuilder -__all__ = ('Entity', 'Collection', 'EntityAttributesMixin', 'EntityExtrasMixin', 'EntityTypes') +__all__ = ('Entity', 'Collection', 'EntityExtrasMixin', 'EntityTypes') CollectionType = TypeVar('CollectionType', bound='Collection') EntityType = TypeVar('EntityType', bound='Entity') @@ -260,165 +259,6 @@ def is_stored(self) -> bool: ... -class EntityAttributesMixin: - """Mixin class that adds all methods for the attributes column to an entity.""" - - @property - def attributes(self: EntityProtocol) -> Dict[str, Any]: - """Return the complete attributes dictionary. - - .. warning:: While the entity is unstored, this will return references of the attributes on the database model, - meaning that changes on the returned values (if they are mutable themselves, e.g. a list or dictionary) will - automatically be reflected on the database model as well. As soon as the entity is stored, the returned - attributes will be a deep copy and mutations of the database attributes will have to go through the - appropriate set methods. Therefore, once stored, retrieving a deep copy can be a heavy operation. If you - only need the keys or some values, use the iterators `attributes_keys` and `attributes_items`, or the - getters `get_attribute` and `get_attribute_many` instead. - - :return: the attributes as a dictionary - """ - attributes = self.backend_entity.attributes - - if self.is_stored: - attributes = copy.deepcopy(attributes) - - return attributes - - def get_attribute(self: EntityProtocol, key: str, default=_NO_DEFAULT) -> Any: - """Return the value of an attribute. - - .. warning:: While the entity is unstored, this will return a reference of the attribute on the database model, - meaning that changes on the returned value (if they are mutable themselves, e.g. a list or dictionary) will - automatically be reflected on the database model as well. As soon as the entity is stored, the returned - attribute will be a deep copy and mutations of the database attributes will have to go through the - appropriate set methods. - - :param key: name of the attribute - :param default: return this value instead of raising if the attribute does not exist - :return: the value of the attribute - :raises AttributeError: if the attribute does not exist and no default is specified - """ - try: - attribute = self.backend_entity.get_attribute(key) - except AttributeError: - if default is _NO_DEFAULT: - raise - attribute = default - - if self.is_stored: - attribute = copy.deepcopy(attribute) - - return attribute - - def get_attribute_many(self: EntityProtocol, keys: List[str]) -> List[Any]: - """Return the values of multiple attributes. - - .. warning:: While the entity is unstored, this will return references of the attributes on the database model, - meaning that changes on the returned values (if they are mutable themselves, e.g. a list or dictionary) will - automatically be reflected on the database model as well. As soon as the entity is stored, the returned - attributes will be a deep copy and mutations of the database attributes will have to go through the - appropriate set methods. Therefore, once stored, retrieving a deep copy can be a heavy operation. If you - only need the keys or some values, use the iterators `attributes_keys` and `attributes_items`, or the - getters `get_attribute` and `get_attribute_many` instead. - - :param keys: a list of attribute names - :return: a list of attribute values - :raises AttributeError: if at least one attribute does not exist - """ - attributes = self.backend_entity.get_attribute_many(keys) - - if self.is_stored: - attributes = copy.deepcopy(attributes) - - return attributes - - def set_attribute(self: EntityProtocol, key: str, value: Any) -> None: - """Set an attribute to the given value. - - :param key: name of the attribute - :param value: value of the attribute - :raise aiida.common.ValidationError: if the key is invalid, i.e. contains periods - :raise aiida.common.ModificationNotAllowed: if the entity is stored - """ - if self.is_stored: - raise exceptions.ModificationNotAllowed('the attributes of a stored entity are immutable') - - self.backend_entity.set_attribute(key, value) - - def set_attribute_many(self: EntityProtocol, attributes: Dict[str, Any]) -> None: - """Set multiple attributes. - - .. note:: This will override any existing attributes that are present in the new dictionary. - - :param attributes: a dictionary with the attributes to set - :raise aiida.common.ValidationError: if any of the keys are invalid, i.e. contain periods - :raise aiida.common.ModificationNotAllowed: if the entity is stored - """ - if self.is_stored: - raise exceptions.ModificationNotAllowed('the attributes of a stored entity are immutable') - - self.backend_entity.set_attribute_many(attributes) - - def reset_attributes(self: EntityProtocol, attributes: Dict[str, Any]) -> None: - """Reset the attributes. - - .. note:: This will completely clear any existing attributes and replace them with the new dictionary. - - :param attributes: a dictionary with the attributes to set - :raise aiida.common.ValidationError: if any of the keys are invalid, i.e. contain periods - :raise aiida.common.ModificationNotAllowed: if the entity is stored - """ - if self.is_stored: - raise exceptions.ModificationNotAllowed('the attributes of a stored entity are immutable') - - self.backend_entity.reset_attributes(attributes) - - def delete_attribute(self: EntityProtocol, key: str) -> None: - """Delete an attribute. - - :param key: name of the attribute - :raises AttributeError: if the attribute does not exist - :raise aiida.common.ModificationNotAllowed: if the entity is stored - """ - if self.is_stored: - raise exceptions.ModificationNotAllowed('the attributes of a stored entity are immutable') - - self.backend_entity.delete_attribute(key) - - def delete_attribute_many(self: EntityProtocol, keys: List[str]) -> None: - """Delete multiple attributes. - - :param keys: names of the attributes to delete - :raises AttributeError: if at least one of the attribute does not exist - :raise aiida.common.ModificationNotAllowed: if the entity is stored - """ - if self.is_stored: - raise exceptions.ModificationNotAllowed('the attributes of a stored entity are immutable') - - self.backend_entity.delete_attribute_many(keys) - - def clear_attributes(self: EntityProtocol) -> None: - """Delete all attributes.""" - if self.is_stored: - raise exceptions.ModificationNotAllowed('the attributes of a stored entity are immutable') - - self.backend_entity.clear_attributes() - - def attributes_items(self: EntityProtocol): - """Return an iterator over the attributes. - - :return: an iterator with attribute key value pairs - """ - return self.backend_entity.attributes_items() - - def attributes_keys(self: EntityProtocol): - """Return an iterator over the attribute keys. - - :return: an iterator with attribute keys - """ - return self.backend_entity.attributes_keys() - - class EntityExtrasMixin: """Mixin class that adds all methods for the extras column to an entity.""" diff --git a/aiida/orm/nodes/__init__.py b/aiida/orm/nodes/__init__.py index 9cd3c65371..f192d2fafd 100644 --- a/aiida/orm/nodes/__init__.py +++ b/aiida/orm/nodes/__init__.py @@ -14,6 +14,7 @@ # yapf: disable # pylint: disable=wildcard-import +from .attributes import * from .data import * from .node import * from .process import * @@ -40,6 +41,7 @@ 'KpointsData', 'List', 'Node', + 'NodeAttributes', 'NodeRepository', 'NumericType', 'OrbitalData', diff --git a/aiida/orm/nodes/attributes.py b/aiida/orm/nodes/attributes.py new file mode 100644 index 0000000000..780047dfe8 --- /dev/null +++ b/aiida/orm/nodes/attributes.py @@ -0,0 +1,183 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=too-many-lines,too-many-arguments +"""Interface to the attributes of a node instance.""" +import copy +from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Tuple + +if TYPE_CHECKING: + from .node import Node + +__all__ = ('NodeAttributes',) + +_NO_DEFAULT: Any = tuple() + + +class NodeAttributes: + """Interface to the attributes of a node instance. + + Attributes are a JSONable dictionary, stored on each node, + allowing for arbitrary data to be stored by node subclasses (and thus data plugins). + + Once the node is stored, the attributes are generally deemed immutable + (except for some updatable keys on process nodes, which can be mutated whilst the node is not "sealed"). + """ + + def __init__(self, node: 'Node') -> None: + """Initialize the interface.""" + self._entity = node + self._backend_entity = node.backend_entity + + def __contains__(self, key: str) -> bool: + """Check if the node contains an attribute with the given key.""" + return key in self._backend_entity.attributes + + @property + def all(self) -> Dict[str, Any]: + """Return the complete attributes dictionary. + + .. warning:: While the entity is unstored, this will return references of the attributes on the database model, + meaning that changes on the returned values (if they are mutable themselves, e.g. a list or dictionary) will + automatically be reflected on the database model as well. As soon as the entity is stored, the returned + attributes will be a deep copy and mutations of the database attributes will have to go through the + appropriate set methods. Therefore, once stored, retrieving a deep copy can be a heavy operation. If you + only need the keys or some values, use the iterators `keys` and `items`, or the + getters `get` and `get_many` instead. + + :return: the attributes as a dictionary + """ + attributes = self._backend_entity.attributes + + if self._entity.is_stored: + attributes = copy.deepcopy(attributes) + + return attributes + + def get(self, key: str, default=_NO_DEFAULT) -> Any: + """Return the value of an attribute. + + .. warning:: While the entity is unstored, this will return a reference of the attribute on the database model, + meaning that changes on the returned value (if they are mutable themselves, e.g. a list or dictionary) will + automatically be reflected on the database model as well. As soon as the entity is stored, the returned + attribute will be a deep copy and mutations of the database attributes will have to go through the + appropriate set methods. + + :param key: name of the attribute + :param default: return this value instead of raising if the attribute does not exist + :return: the value of the attribute + :raises AttributeError: if the attribute does not exist and no default is specified + """ + try: + attribute = self._backend_entity.get_attribute(key) + except AttributeError: + if default is _NO_DEFAULT: + raise + attribute = default + + if self._entity.is_stored: + attribute = copy.deepcopy(attribute) + + return attribute + + def get_many(self, keys: List[str]) -> List[Any]: + """Return the values of multiple attributes. + + .. warning:: While the entity is unstored, this will return references of the attributes on the database model, + meaning that changes on the returned values (if they are mutable themselves, e.g. a list or dictionary) will + automatically be reflected on the database model as well. As soon as the entity is stored, the returned + attributes will be a deep copy and mutations of the database attributes will have to go through the + appropriate set methods. Therefore, once stored, retrieving a deep copy can be a heavy operation. If you + only need the keys or some values, use the iterators `keys` and `items`, or the + getters `get` and `get_many` instead. + + :param keys: a list of attribute names + :return: a list of attribute values + :raises AttributeError: if at least one attribute does not exist + """ + attributes = self._backend_entity.get_attribute_many(keys) + + if self._entity.is_stored: + attributes = copy.deepcopy(attributes) + + return attributes + + def set(self, key: str, value: Any) -> None: + """Set an attribute to the given value. + + :param key: name of the attribute + :param value: value of the attribute + :raise aiida.common.ValidationError: if the key is invalid, i.e. contains periods + :raise aiida.common.ModificationNotAllowed: if the entity is stored + """ + self._entity._check_mutability_attributes([key]) # pylint: disable=protected-access + self._backend_entity.set_attribute(key, value) + + def set_many(self, attributes: Dict[str, Any]) -> None: + """Set multiple attributes. + + .. note:: This will override any existing attributes that are present in the new dictionary. + + :param attributes: a dictionary with the attributes to set + :raise aiida.common.ValidationError: if any of the keys are invalid, i.e. contain periods + :raise aiida.common.ModificationNotAllowed: if the entity is stored + """ + self._entity._check_mutability_attributes(list(attributes)) # pylint: disable=protected-access + self._backend_entity.set_attribute_many(attributes) + + def reset(self, attributes: Dict[str, Any]) -> None: + """Reset the attributes. + + .. note:: This will completely clear any existing attributes and replace them with the new dictionary. + + :param attributes: a dictionary with the attributes to set + :raise aiida.common.ValidationError: if any of the keys are invalid, i.e. contain periods + :raise aiida.common.ModificationNotAllowed: if the entity is stored + """ + self._entity._check_mutability_attributes() # pylint: disable=protected-access + self._backend_entity.reset_attributes(attributes) + + def delete(self, key: str) -> None: + """Delete an attribute. + + :param key: name of the attribute + :raises AttributeError: if the attribute does not exist + :raise aiida.common.ModificationNotAllowed: if the entity is stored + """ + self._entity._check_mutability_attributes([key]) # pylint: disable=protected-access + self._backend_entity.delete_attribute(key) + + def delete_many(self, keys: List[str]) -> None: + """Delete multiple attributes. + + :param keys: names of the attributes to delete + :raises AttributeError: if at least one of the attribute does not exist + :raise aiida.common.ModificationNotAllowed: if the entity is stored + """ + self._entity._check_mutability_attributes(keys) # pylint: disable=protected-access + self._backend_entity.delete_attribute_many(keys) + + def clear(self) -> None: + """Delete all attributes.""" + self._entity._check_mutability_attributes() # pylint: disable=protected-access + self._backend_entity.clear_attributes() + + def items(self) -> Iterator[Tuple[str, Any]]: + """Return an iterator over the attributes. + + :return: an iterator with attribute key value pairs + """ + return self._backend_entity.attributes_items() + + def keys(self) -> Iterable[str]: + """Return an iterator over the attribute keys. + + :return: an iterator with attribute keys + """ + return self._backend_entity.attributes_keys() diff --git a/aiida/orm/nodes/data/array/array.py b/aiida/orm/nodes/data/array/array.py index 43088209a4..166eca8e53 100644 --- a/aiida/orm/nodes/data/array/array.py +++ b/aiida/orm/nodes/data/array/array.py @@ -51,7 +51,7 @@ def delete_array(self, name): # remove both file and attribute self.base.repository.delete_object(fname) try: - self.delete_attribute(f'{self.array_prefix}{name}') + self.base.attributes.delete(f'{self.array_prefix}{name}') except (KeyError, AttributeError): # Should not happen, but do not crash if for some reason the property was not set. pass @@ -78,7 +78,7 @@ def _arraynames_from_properties(self): Return a list of all arrays stored in the node, listing the attributes starting with the correct prefix. """ - return [i[len(self.array_prefix):] for i in self.attributes.keys() if i.startswith(self.array_prefix)] + return [i[len(self.array_prefix):] for i in self.base.attributes.keys() if i.startswith(self.array_prefix)] def get_shape(self, name): """ @@ -87,7 +87,7 @@ def get_shape(self, name): :param name: The name of the array. """ - return tuple(self.get_attribute(f'{self.array_prefix}{name}')) + return tuple(self.base.attributes.get(f'{self.array_prefix}{name}')) def get_iterarrays(self): """ @@ -174,7 +174,7 @@ def set_array(self, name, array): self.base.repository.put_object_from_filelike(handle, f'{name}.npy') # Store the array name and shape for querying purposes - self.set_attribute(f'{self.array_prefix}{name}', list(array.shape)) + self.base.attributes.set(f'{self.array_prefix}{name}', list(array.shape)) def _validate(self): """ diff --git a/aiida/orm/nodes/data/array/bands.py b/aiida/orm/nodes/data/array/bands.py index c47f96afb8..40e451a28f 100644 --- a/aiida/orm/nodes/data/array/bands.py +++ b/aiida/orm/nodes/data/array/bands.py @@ -341,7 +341,7 @@ def set_bands(self, bands, units=None, occupations=None, labels=None): self.units = units if the_labels is not None: - self.set_attribute('array_labels', the_labels) + self.base.attributes.set('array_labels', the_labels) if the_occupations is not None: # set occupations @@ -352,7 +352,7 @@ def array_labels(self): """ Get the labels associated with the band arrays """ - return self.get_attribute('array_labels', None) + return self.base.attributes.get('array_labels', None) @property def units(self): @@ -360,7 +360,7 @@ def units(self): Units in which the data in bands were stored. A string """ # return copy.deepcopy(self._pbc) - return self.get_attribute('units') + return self.base.attributes.get('units') @units.setter def units(self, value): @@ -369,7 +369,7 @@ def units(self, value): cell is periodic in the 1,2,3 crystal direction """ the_str = str(value) - self.set_attribute('units', the_str) + self.base.attributes.set('units', the_str) def _set_pbc(self, value): """ @@ -381,9 +381,9 @@ def _set_pbc(self, value): if self.is_stored: raise ModificationNotAllowed('The KpointsData object cannot be modified, it has already been stored') the_pbc = get_valid_pbc(value) - self.set_attribute('pbc1', the_pbc[0]) - self.set_attribute('pbc2', the_pbc[1]) - self.set_attribute('pbc3', the_pbc[2]) + self.base.attributes.set('pbc1', the_pbc[0]) + self.base.attributes.set('pbc2', the_pbc[1]) + self.base.attributes.set('pbc3', the_pbc[2]) def get_bands(self, also_occupations=False, also_labels=False): """ diff --git a/aiida/orm/nodes/data/array/kpoints.py b/aiida/orm/nodes/data/array/kpoints.py index 71ecc8cdd7..1f7432d8f2 100644 --- a/aiida/orm/nodes/data/array/kpoints.py +++ b/aiida/orm/nodes/data/array/kpoints.py @@ -59,7 +59,7 @@ def cell(self): The crystal unit cell. Rows are the crystal vectors in Angstroms. :return: a 3x3 numpy.array """ - return numpy.array(self.get_attribute('cell')) + return numpy.array(self.base.attributes.get('cell')) @cell.setter def cell(self, value): @@ -82,7 +82,7 @@ def _set_cell(self, value): the_cell = _get_valid_cell(value) - self.set_attribute('cell', the_cell) + self.base.attributes.set('cell', the_cell) @property def pbc(self): @@ -93,7 +93,7 @@ def pbc(self): boundary conditions for the i-th real-space direction (i=1,2,3) """ # return copy.deepcopy(self._pbc) - return (self.get_attribute('pbc1'), self.get_attribute('pbc2'), self.get_attribute('pbc3')) + return (self.base.attributes.get('pbc1'), self.base.attributes.get('pbc2'), self.base.attributes.get('pbc3')) @pbc.setter def pbc(self, value): @@ -113,9 +113,9 @@ def _set_pbc(self, value): if self.is_stored: raise ModificationNotAllowed('The KpointsData object cannot be modified, it has already been stored') the_pbc = get_valid_pbc(value) - self.set_attribute('pbc1', the_pbc[0]) - self.set_attribute('pbc2', the_pbc[1]) - self.set_attribute('pbc3', the_pbc[2]) + self.base.attributes.set('pbc1', the_pbc[0]) + self.base.attributes.set('pbc2', the_pbc[1]) + self.base.attributes.set('pbc3', the_pbc[2]) @property def labels(self): @@ -123,8 +123,8 @@ def labels(self): Labels associated with the list of kpoints. List of tuples with kpoint index and kpoint name: ``[(0,'G'),(13,'M'),...]`` """ - label_numbers = self.get_attribute('label_numbers', None) - labels = self.get_attribute('labels', None) + label_numbers = self.base.attributes.get('label_numbers', None) + labels = self.base.attributes.get('labels', None) if labels is None or label_numbers is None: return None return list(zip(label_numbers, labels)) @@ -155,8 +155,8 @@ def _set_labels(self, value): if any(i > len(self.get_kpoints()) - 1 for i in label_numbers): raise ValueError('Index of label exceeding the list of kpoints') - self.set_attribute('label_numbers', label_numbers) - self.set_attribute('labels', labels) + self.base.attributes.set('label_numbers', label_numbers) + self.base.attributes.set('labels', labels) def _change_reference(self, kpoints, to_cartesian=True): """ @@ -267,8 +267,8 @@ def set_kpoints_mesh(self, mesh, offset=None): pass # store - self.set_attribute('mesh', the_mesh) - self.set_attribute('offset', the_offset) + self.base.attributes.set('mesh', the_mesh) + self.base.attributes.set('offset', the_offset) def get_kpoints_mesh(self, print_list=False): """ @@ -282,8 +282,8 @@ def get_kpoints_mesh(self, print_list=False): :return kpoints: (if print_list = True) an explicit list of kpoints coordinates, similar to what returned by get_kpoints() """ - mesh = self.get_attribute('mesh') - offset = self.get_attribute('offset') + mesh = self.base.attributes.get('mesh') + offset = self.base.attributes.get('offset') if not print_list: return mesh, offset @@ -467,7 +467,7 @@ def set_kpoints(self, kpoints, cartesian=False, labels=None, weights=None, fill_ the_kpoints = self._change_reference(the_kpoints, to_cartesian=False) # check that we did not saved a mesh already - if self.get_attribute('mesh', None) is not None: + if self.base.attributes.get('mesh', None) is not None: raise ModificationNotAllowed('KpointsData has already a mesh stored') # store diff --git a/aiida/orm/nodes/data/array/projection.py b/aiida/orm/nodes/data/array/projection.py index 86e7aa96ad..d2db2a0dda 100644 --- a/aiida/orm/nodes/data/array/projection.py +++ b/aiida/orm/nodes/data/array/projection.py @@ -79,7 +79,7 @@ def set_reference_bandsdata(self, value): 'The value passed to set_reference_bandsdata was not associated to any bandsdata' ) - self.set_attribute('reference_bandsdata_uuid', uuid) + self.base.attributes.set('reference_bandsdata_uuid', uuid) def get_reference_bandsdata(self): """ @@ -92,7 +92,7 @@ def get_reference_bandsdata(self): """ from aiida.orm import load_node try: - uuid = self.get_attribute('reference_bandsdata_uuid') + uuid = self.base.attributes.get('reference_bandsdata_uuid') except AttributeError: raise AttributeError('BandsData has not been set for this instance') try: @@ -248,7 +248,7 @@ def array_list_checker(array_list, array_name, orb_length): cls = OrbitalFactory(orbital_type) test_orbital = cls(**orbital_dict) list_of_orbital_dicts.append(test_orbital.get_orbital_dict()) - self.set_attribute('orbital_dicts', list_of_orbital_dicts) + self.base.attributes.set('orbital_dicts', list_of_orbital_dicts) # verifies and sets the projections if list_of_projections: @@ -286,7 +286,7 @@ def array_list_checker(array_list, array_name, orb_length): if not all(isinstance(_, str) for _ in tags): raise exceptions.ValidationError('Tags must set a list of strings') - self.set_attribute('tags', tags) + self.base.attributes.set('tags', tags) def set_orbitals(self, **kwargs): # pylint: disable=arguments-differ """ diff --git a/aiida/orm/nodes/data/array/trajectory.py b/aiida/orm/nodes/data/array/trajectory.py index 8193ab6529..70f29c252f 100644 --- a/aiida/orm/nodes/data/array/trajectory.py +++ b/aiida/orm/nodes/data/array/trajectory.py @@ -138,7 +138,7 @@ def set_trajectory(self, symbols, positions, stepids=None, cells=None, times=Non self._internal_validate(stepids, cells, symbols, positions, times, velocities) # set symbols as attribute for easier querying - self.set_attribute('symbols', list(symbols)) + self.base.attributes.set('symbols', list(symbols)) self.set_array('positions', positions) if stepids is not None: # use input stepids self.set_array('steps', stepids) @@ -271,7 +271,7 @@ def symbols(self): :raises KeyError: if the trajectory has not been set yet. """ - return self.get_attribute('symbols') + return self.base.attributes.get('symbols') def get_positions(self): """ @@ -610,11 +610,11 @@ def show_mpl_pos(self, **kwargs): # pylint: disable=too-many-locals # Try to get the units. try: - positions_unit = self.get_attribute('units|positions') + positions_unit = self.base.attributes.get('units|positions') except AttributeError: positions_unit = 'A' try: - times_unit = self.get_attribute('units|times') + times_unit = self.base.attributes.get('units|times') except AttributeError: times_unit = 'ps' @@ -731,7 +731,7 @@ def collapse_into_unit_cell(point, cell): positions = self.get_positions()[minindex:maxindex:stepsize] try: - if self.get_attribute('units|positions') in ('bohr', 'atomic'): + if self.base.attributes.get('units|positions') in ('bohr', 'atomic'): bohr_to_ang = 0.52917720859 positions *= bohr_to_ang except AttributeError: diff --git a/aiida/orm/nodes/data/array/xy.py b/aiida/orm/nodes/data/array/xy.py index c643fc4f36..7827fc050f 100644 --- a/aiida/orm/nodes/data/array/xy.py +++ b/aiida/orm/nodes/data/array/xy.py @@ -70,8 +70,8 @@ def set_x(self, x_array, x_name, x_units): :param x_units: the units of x """ self._arrayandname_validator(x_array, x_name, x_units) - self.set_attribute('x_name', x_name) - self.set_attribute('x_units', x_units) + self.base.attributes.set('x_name', x_name) + self.base.attributes.set('x_units', x_units) self.set_array('x_array', x_array) def set_y(self, y_arrays, y_names, y_units): @@ -107,8 +107,8 @@ def set_y(self, y_arrays, y_names, y_units): self.set_array(f'y_array_{num}', y_array) # if the y_arrays pass the initial validation, sets each - self.set_attribute('y_names', y_names) - self.set_attribute('y_units', y_units) + self.base.attributes.set('y_names', y_names) + self.base.attributes.set('y_units', y_units) def get_x(self): """ @@ -119,9 +119,9 @@ def get_x(self): :return x_units: the x units set earlier """ try: - x_name = self.get_attribute('x_name') + x_name = self.base.attributes.get('x_name') x_array = self.get_array('x_array') - x_units = self.get_attribute('x_units') + x_units = self.base.attributes.get('x_units') except (KeyError, AttributeError): raise NotExistent('No x array has been set yet!') return x_name, x_array, x_units @@ -136,11 +136,11 @@ def get_y(self): :return y_units: list of strings giving the units for the y_arrays """ try: - y_names = self.get_attribute('y_names') + y_names = self.base.attributes.get('y_names') except (KeyError, AttributeError): raise NotExistent('No y names has been set yet!') try: - y_units = self.get_attribute('y_units') + y_units = self.base.attributes.get('y_units') except (KeyError, AttributeError): raise NotExistent('No y units has been set yet!') y_arrays = [] diff --git a/aiida/orm/nodes/data/base.py b/aiida/orm/nodes/data/base.py index 176b1445d0..f95cacaa2e 100644 --- a/aiida/orm/nodes/data/base.py +++ b/aiida/orm/nodes/data/base.py @@ -36,11 +36,11 @@ def __init__(self, value=None, **kwargs): @property def value(self): - return self.get_attribute('value', None) + return self.base.attributes.get('value', None) @value.setter def value(self, value): - self.set_attribute('value', self._type(value)) # pylint: disable=no-member + self.base.attributes.set('value', self._type(value)) # pylint: disable=no-member def __str__(self): return f'{super().__str__()} value: {self.value}' diff --git a/aiida/orm/nodes/data/cif.py b/aiida/orm/nodes/data/cif.py index 35cf6db881..638de9a3a8 100644 --- a/aiida/orm/nodes/data/cif.py +++ b/aiida/orm/nodes/data/cif.py @@ -293,7 +293,7 @@ def __init__(self, ase=None, file=None, filename=None, values=None, scan_type=No if values is not None: self.set_values(values) - if not self.is_stored and file is not None and self.get_attribute('parse_policy') == 'eager': + if not self.is_stored and file is not None and self.base.attributes.get('parse_policy') == 'eager': self.parse() @staticmethod @@ -437,7 +437,7 @@ def values(self): from CifFile import CifBlock # pylint: disable=no-name-in-module with self.open() as handle: - c = CifFile.ReadCif(handle, scantype=self.get_attribute('scan_type', CifData._SCAN_TYPE_DEFAULT)) # pylint: disable=no-member + c = CifFile.ReadCif(handle, scantype=self.base.attributes.get('scan_type', CifData._SCAN_TYPE_DEFAULT)) # pylint: disable=no-member for k, v in c.items(): c.dictionary[k] = CifBlock(v) self._values = c @@ -477,15 +477,15 @@ def parse(self, scan_type=None): self.set_scan_type(scan_type) # Note: this causes parsing, if not already parsed - self.set_attribute('formulae', self.get_formulae()) - self.set_attribute('spacegroup_numbers', self.get_spacegroup_numbers()) + self.base.attributes.set('formulae', self.get_formulae()) + self.base.attributes.set('spacegroup_numbers', self.get_spacegroup_numbers()) def store(self, *args, **kwargs): # pylint: disable=signature-differs """ Store the node. """ if not self.is_stored: - self.set_attribute('md5', self.generate_md5()) + self.base.attributes.set('md5', self.generate_md5()) return super().store(*args, **kwargs) @@ -507,12 +507,12 @@ def set_file(self, file, filename=None): self.source.get('source_md5', None) is not None and \ self.source['source_md5'] != md5sum: self.source = {} - self.set_attribute('md5', md5sum) + self.base.attributes.set('md5', md5sum) self._values = None self._ase = None - self.set_attribute('formulae', None) - self.set_attribute('spacegroup_numbers', None) + self.base.attributes.set('formulae', None) + self.base.attributes.set('spacegroup_numbers', None) def set_scan_type(self, scan_type): """ @@ -525,7 +525,7 @@ def set_scan_type(self, scan_type): :param scan_type: Either 'standard' or 'flex' (see _scan_types) """ if scan_type in CifData._SCAN_TYPES: - self.set_attribute('scan_type', scan_type) + self.base.attributes.set('scan_type', scan_type) else: raise ValueError(f'Got unknown scan_type {scan_type}') @@ -537,7 +537,7 @@ def set_parse_policy(self, parse_policy): or 'lazy' (defer parsing until needed) """ if parse_policy in CifData._PARSE_POLICIES: - self.set_attribute('parse_policy', parse_policy) + self.base.attributes.set('parse_policy', parse_policy) else: raise ValueError(f'Got unknown parse_policy {parse_policy}') @@ -799,7 +799,7 @@ def _validate(self): super()._validate() try: - attr_md5 = self.get_attribute('md5') + attr_md5 = self.base.attributes.get('md5') except AttributeError: raise ValidationError("attribute 'md5' not set.") md5 = self.generate_md5() diff --git a/aiida/orm/nodes/data/code.py b/aiida/orm/nodes/data/code.py index 081d80445b..c3cf8a701b 100644 --- a/aiida/orm/nodes/data/code.py +++ b/aiida/orm/nodes/data/code.py @@ -325,14 +325,14 @@ def set_prepend_text(self, code): Pass a string of code that will be put in the scheduler script before the execution of the code. """ - self.set_attribute('prepend_text', str(code)) + self.base.attributes.set('prepend_text', str(code)) def get_prepend_text(self): """ Return the code that will be put in the scheduler script before the execution, or an empty string if no pre-exec code was defined. """ - return self.get_attribute('prepend_text', '') + return self.base.attributes.get('prepend_text', '') def set_input_plugin_name(self, input_plugin): """ @@ -340,29 +340,29 @@ def set_input_plugin_name(self, input_plugin): generation of a new calculation. """ if input_plugin is None: - self.set_attribute('input_plugin', None) + self.base.attributes.set('input_plugin', None) else: - self.set_attribute('input_plugin', str(input_plugin)) + self.base.attributes.set('input_plugin', str(input_plugin)) def get_input_plugin_name(self): """ Return the name of the default input plugin (or None if no input plugin was set. """ - return self.get_attribute('input_plugin', None) + return self.base.attributes.get('input_plugin', None) def set_append_text(self, code): """ Pass a string of code that will be put in the scheduler script after the execution of the code. """ - self.set_attribute('append_text', str(code)) + self.base.attributes.set('append_text', str(code)) def get_append_text(self): """ Return the postexec_code, or an empty string if no post-exec code was defined. """ - return self.get_attribute('append_text', '') + return self.base.attributes.get('append_text', '') def set_local_executable(self, exec_name): """ @@ -370,10 +370,10 @@ def set_local_executable(self, exec_name): Implicitly set the code as local. """ self._set_local() - self.set_attribute('local_executable', exec_name) + self.base.attributes.set('local_executable', exec_name) def get_local_executable(self): - return self.get_attribute('local_executable', '') + return self.base.attributes.get('local_executable', '') def set_remote_computer_exec(self, remote_computer_exec): """ @@ -400,12 +400,12 @@ def set_remote_computer_exec(self, remote_computer_exec): self._set_remote() self.computer = computer - self.set_attribute('remote_exec_path', remote_exec_path) + self.base.attributes.set('remote_exec_path', remote_exec_path) def get_remote_exec_path(self): if self.is_local(): raise ValueError('The code is local') - return self.get_attribute('remote_exec_path', '') + return self.base.attributes.get('remote_exec_path', '') def get_remote_computer(self): if self.is_local(): @@ -421,10 +421,10 @@ def _set_local(self): It also deletes the flags related to the local case (if any) """ - self.set_attribute('is_local', True) + self.base.attributes.set('is_local', True) self.computer = None try: - self.delete_attribute('remote_exec_path') + self.base.attributes.delete('remote_exec_path') except AttributeError: pass @@ -436,9 +436,9 @@ def _set_remote(self): It also deletes the flags related to the local case (if any) """ - self.set_attribute('is_local', False) + self.base.attributes.set('is_local', False) try: - self.delete_attribute('local_executable') + self.base.attributes.delete('local_executable') except AttributeError: pass @@ -447,7 +447,7 @@ def is_local(self): Return True if the code is 'local', False if it is 'remote' (see also documentation of the set_local and set_remote functions). """ - return self.get_attribute('is_local', None) + return self.base.attributes.get('is_local', None) def can_run_on(self, computer): """ diff --git a/aiida/orm/nodes/data/data.py b/aiida/orm/nodes/data/data.py index f204fc2587..d2b92dfe45 100644 --- a/aiida/orm/nodes/data/data.py +++ b/aiida/orm/nodes/data/data.py @@ -69,7 +69,7 @@ def clone(self): backend_clone = self.backend_entity.clone() clone = self.__class__.from_backend_entity(backend_clone) - clone.reset_attributes(copy.deepcopy(self.attributes)) + clone.base.attributes.reset(copy.deepcopy(self.base.attributes.all)) clone.base.repository._clone(self.base.repository) # pylint: disable=protected-access return clone @@ -93,7 +93,7 @@ def source(self): :return: dictionary describing the source of Data object. """ - return self.get_attribute('source', None) + return self.base.attributes.get('source', None) @source.setter def source(self, source): @@ -109,7 +109,7 @@ def source(self, source): if unknown_attrs: raise KeyError(f"Unknown source parameters: {', '.join(unknown_attrs)}") - self.set_attribute('source', source) + self.base.attributes.set('source', source) def set_source(self, source): """ diff --git a/aiida/orm/nodes/data/dict.py b/aiida/orm/nodes/data/dict.py index 03d3eb9d10..e0a975f916 100644 --- a/aiida/orm/nodes/data/dict.py +++ b/aiida/orm/nodes/data/dict.py @@ -63,12 +63,12 @@ def __init__(self, value=None, **kwargs): def __getitem__(self, key): try: - return self.get_attribute(key) + return self.base.attributes.get(key) except AttributeError as exc: raise KeyError from exc def __setitem__(self, key, value): - self.set_attribute(key, value) + self.base.attributes.set(key, value) def __eq__(self, other): if isinstance(other, Dict): @@ -77,7 +77,7 @@ def __eq__(self, other): def __contains__(self, key: str) -> bool: """Return whether the node contains a key.""" - return key in self.attributes + return key in self.base.attributes def set_dict(self, dictionary): """Replace the current dictionary with another one. @@ -88,14 +88,14 @@ def set_dict(self, dictionary): try: # Clear existing attributes and set the new dictionary - self.clear_attributes() + self.base.attributes.clear() self.update_dict(dictionary) except exceptions.ModificationNotAllowed: # pylint: disable=try-except-raise # I reraise here to avoid to go in the generic 'except' below that would raise the same exception again raise except Exception: # Try to restore the old data - self.clear_attributes() + self.base.attributes.clear() self.update_dict(dictionary_backup) raise @@ -107,26 +107,26 @@ def update_dict(self, dictionary): :param dictionary: a dictionary with the keys to substitute """ for key, value in dictionary.items(): - self.set_attribute(key, value) + self.base.attributes.set(key, value) def get_dict(self): """Return a dictionary with the parameters currently set. :return: dictionary """ - return dict(self.attributes) + return dict(self.base.attributes.all) def keys(self): """Iterator of valid keys stored in the Dict object. :return: iterator over the keys of the current dictionary """ - for key in self.attributes.keys(): + for key in self.base.attributes.keys(): yield key def items(self): """Iterator of all items stored in the Dict node.""" - for key, value in self.attributes_items(): + for key, value in self.base.attributes.items(): yield key, value @property diff --git a/aiida/orm/nodes/data/enum.py b/aiida/orm/nodes/data/enum.py index cc5fe3b71e..4b9d4b50e9 100644 --- a/aiida/orm/nodes/data/enum.py +++ b/aiida/orm/nodes/data/enum.py @@ -60,24 +60,24 @@ def __init__(self, member: Enum, *args, **kwargs): self.KEY_IDENTIFIER: get_object_loader().identify_object(member.__class__) } - self.set_attribute_many(data) + self.base.attributes.set_many(data) @property def name(self) -> str: """Return the name of the enum member.""" - return self.get_attribute(self.KEY_NAME) + return self.base.attributes.get(self.KEY_NAME) @property def value(self) -> t.Any: """Return the value of the enum member.""" - return self.get_attribute(self.KEY_VALUE) + return self.base.attributes.get(self.KEY_VALUE) def get_enum(self) -> t.Type[EnumType]: """Return the enum class reconstructed from the serialized identifier stored in the database. :raises `ImportError`: if the enum class represented by the stored identifier cannot be imported. """ - identifier = self.get_attribute(self.KEY_IDENTIFIER) + identifier = self.base.attributes.get(self.KEY_IDENTIFIER) try: return get_object_loader().load_object(identifier) except ValueError as exc: @@ -93,7 +93,7 @@ def get_member(self) -> EnumType: :raises `ImportError`: if the enum class represented by the stored identifier cannot be imported. :raises `ValueError`: if the stored enum member value is no longer valid for the imported enum class. """ - value = self.get_attribute(self.KEY_VALUE) + value = self.base.attributes.get(self.KEY_VALUE) enum: t.Type[EnumType] = self.get_enum() try: @@ -112,6 +112,6 @@ def __eq__(self, other: t.Any) -> bool: except (ImportError, ValueError): return False elif isinstance(other, EnumData): - return self.attributes == other.attributes + return self.base.attributes.all == other.base.attributes.all return False diff --git a/aiida/orm/nodes/data/jsonable.py b/aiida/orm/nodes/data/jsonable.py index f351b95d26..d16670a4e3 100644 --- a/aiida/orm/nodes/data/jsonable.py +++ b/aiida/orm/nodes/data/jsonable.py @@ -81,7 +81,7 @@ def __init__(self, obj: JsonSerializableProtocol, *args, **kwargs): except TypeError as exc: raise TypeError(f'the object `{obj}` is not JSON-serializable and therefore cannot be stored.') from exc - self.set_attribute_many(serialized) + self.base.attributes.set_many(serialized) @classmethod def _deserialize_float_constants(cls, data: typing.Any): @@ -114,7 +114,7 @@ def _get_object(self) -> JsonSerializableProtocol: try: return self._obj except AttributeError: - attributes = self.attributes + attributes = self.base.attributes.all class_name = attributes.pop('@class') module_name = attributes.pop('@module') diff --git a/aiida/orm/nodes/data/list.py b/aiida/orm/nodes/data/list.py index 36bb57ae39..1b5c8dc998 100644 --- a/aiida/orm/nodes/data/list.py +++ b/aiida/orm/nodes/data/list.py @@ -115,10 +115,10 @@ def get_list(self): :return: a list """ try: - return self.get_attribute(self._LIST_KEY) + return self.base.attributes.get(self._LIST_KEY) except AttributeError: self.set_list([]) - return self.get_attribute(self._LIST_KEY) + return self.base.attributes.get(self._LIST_KEY) def set_list(self, data): """Set the contents of this node. @@ -127,7 +127,7 @@ def set_list(self, data): """ if not isinstance(data, list): raise TypeError('Must supply list type') - self.set_attribute(self._LIST_KEY, data.copy()) + self.base.attributes.set(self._LIST_KEY, data.copy()) def _using_list_reference(self): """ diff --git a/aiida/orm/nodes/data/orbital.py b/aiida/orm/nodes/data/orbital.py index e3f4ce3c34..32f1640cac 100644 --- a/aiida/orm/nodes/data/orbital.py +++ b/aiida/orm/nodes/data/orbital.py @@ -29,7 +29,7 @@ def clear_orbitals(self): Remove all orbitals that were added to the class Cannot work if OrbitalData has been already stored """ - self.set_attribute('orbital_dicts', []) + self.base.attributes.set('orbital_dicts', []) def get_orbitals(self, **kwargs): """ @@ -43,7 +43,7 @@ def get_orbitals(self, **kwargs): :return list_of_outputs: a list of orbitals """ - orbital_dicts = copy.deepcopy(self.get_attribute('orbital_dicts', None)) + orbital_dicts = copy.deepcopy(self.base.attributes.get('orbital_dicts', None)) if orbital_dicts is None: raise AttributeError('Orbitals must be set before being retrieved') @@ -83,7 +83,7 @@ def set_orbitals(self, orbitals): except KeyError: raise ValueError(f'No _orbital_type found in: {orbital_dict}') orbital_dicts.append(orbital_dict) - self.set_attribute('orbital_dicts', orbital_dicts) + self.base.attributes.set('orbital_dicts', orbital_dicts) ########################################################################## diff --git a/aiida/orm/nodes/data/remote/base.py b/aiida/orm/nodes/data/remote/base.py index 4b2ac74268..81d30d952d 100644 --- a/aiida/orm/nodes/data/remote/base.py +++ b/aiida/orm/nodes/data/remote/base.py @@ -32,10 +32,10 @@ def __init__(self, remote_path=None, **kwargs): self.set_remote_path(remote_path) def get_remote_path(self): - return self.get_attribute('remote_path') + return self.base.attributes.get('remote_path') def set_remote_path(self, val): - self.set_attribute('remote_path', val) + self.base.attributes.set('remote_path', val) @property def is_empty(self): diff --git a/aiida/orm/nodes/data/remote/stash/base.py b/aiida/orm/nodes/data/remote/stash/base.py index 1fe4e315c3..c768505249 100644 --- a/aiida/orm/nodes/data/remote/stash/base.py +++ b/aiida/orm/nodes/data/remote/stash/base.py @@ -41,7 +41,7 @@ def stash_mode(self) -> StashMode: :return: the stash mode. """ - return StashMode(self.get_attribute('stash_mode')) + return StashMode(self.base.attributes.get('stash_mode')) @stash_mode.setter def stash_mode(self, value: StashMode): @@ -50,4 +50,4 @@ def stash_mode(self, value: StashMode): :param value: the stash mode. """ type_check(value, StashMode) - self.set_attribute('stash_mode', value.value) + self.base.attributes.set('stash_mode', value.value) diff --git a/aiida/orm/nodes/data/remote/stash/folder.py b/aiida/orm/nodes/data/remote/stash/folder.py index ebe097fd1f..bf182b7a5c 100644 --- a/aiida/orm/nodes/data/remote/stash/folder.py +++ b/aiida/orm/nodes/data/remote/stash/folder.py @@ -38,7 +38,7 @@ def target_basepath(self) -> str: :return: the target basepath. """ - return self.get_attribute('target_basepath') + return self.base.attributes.get('target_basepath') @target_basepath.setter def target_basepath(self, value: str): @@ -47,7 +47,7 @@ def target_basepath(self, value: str): :param value: the target basepath. """ type_check(value, str) - self.set_attribute('target_basepath', value) + self.base.attributes.set('target_basepath', value) @property def source_list(self) -> typing.Union[typing.List, typing.Tuple]: @@ -55,7 +55,7 @@ def source_list(self) -> typing.Union[typing.List, typing.Tuple]: :return: the list of source files. """ - return self.get_attribute('source_list') + return self.base.attributes.get('source_list') @source_list.setter def source_list(self, value: typing.Union[typing.List, typing.Tuple]): @@ -64,4 +64,4 @@ def source_list(self, value: typing.Union[typing.List, typing.Tuple]): :param value: the list of source files. """ type_check(value, (list, tuple)) - self.set_attribute('source_list', value) + self.base.attributes.set('source_list', value) diff --git a/aiida/orm/nodes/data/singlefile.py b/aiida/orm/nodes/data/singlefile.py index 32e6ae63ca..988c06a6f6 100644 --- a/aiida/orm/nodes/data/singlefile.py +++ b/aiida/orm/nodes/data/singlefile.py @@ -43,7 +43,7 @@ def filename(self): :return: the filename under which the file is stored in the repository """ - return self.get_attribute('filename') + return self.base.attributes.get('filename') @contextlib.contextmanager def open(self, path=None, mode='r'): @@ -111,7 +111,7 @@ def set_file(self, file, filename=None): for existing_key in existing_object_names: self.base.repository.delete_object(existing_key) - self.set_attribute('filename', key) + self.base.attributes.set('filename', key) def _validate(self): """Ensure that there is one object stored in the repository, whose key matches value set for `filename` attr.""" diff --git a/aiida/orm/nodes/data/structure.py b/aiida/orm/nodes/data/structure.py index 700cb5f4f6..bbe0a77b74 100644 --- a/aiida/orm/nodes/data/structure.py +++ b/aiida/orm/nodes/data/structure.py @@ -1020,8 +1020,8 @@ def _prepare_chemdoodle(self, main_file_name=''): # pylint: disable=unused-argu supercell_factors = [1, 1, 1] # Get cell vectors and atomic position - lattice_vectors = np.array(self.get_attribute('cell')) - base_sites = self.get_attribute('sites') + lattice_vectors = np.array(self.base.attributes.get('cell')) + base_sites = self.base.attributes.get('sites') start1 = -int(supercell_factors[0] / 2) start2 = -int(supercell_factors[1] / 2) @@ -1149,7 +1149,7 @@ def get_extremas_from_positions(positions): # Translate the structure to the origin, such that the minimal values in each dimension # amount to (0,0,0) positions -= position_min - for index, site in enumerate(self.get_attribute('sites')): + for index, site in enumerate(self.base.attributes.get('sites')): site['position'] = list(positions[index]) # The orthorhombic cell that (just) accomodates the whole structure is now given by the @@ -1349,12 +1349,12 @@ def append_kind(self, kind): raise ValueError(f'A kind with the same name ({kind.name}) already exists.') # If here, no exceptions have been raised, so I add the site. - self.attributes.setdefault('kinds', []).append(new_kind.get_raw()) + self.base.attributes.all.setdefault('kinds', []).append(new_kind.get_raw()) # Note, this is a dict (with integer keys) so it allows for empty spots! if self._internal_kind_tags is None: self._internal_kind_tags = {} - self._internal_kind_tags[len(self.get_attribute('kinds')) - 1] = kind._internal_tag # pylint: disable=protected-access + self._internal_kind_tags[len(self.base.attributes.get('kinds')) - 1] = kind._internal_tag # pylint: disable=protected-access def append_site(self, site): """ @@ -1377,7 +1377,7 @@ def append_site(self, site): ) # If here, no exceptions have been raised, so I add the site. - self.attributes.setdefault('sites', []).append(new_site.get_raw()) + self.base.attributes.all.setdefault('sites', []).append(new_site.get_raw()) def append_atom(self, **kwargs): """ @@ -1495,7 +1495,7 @@ def clear_kinds(self): if self.is_stored: raise ModificationNotAllowed('The StructureData object cannot be modified, it has already been stored') - self.set_attribute('kinds', []) + self.base.attributes.set('kinds', []) self._internal_kind_tags = {} self.clear_sites() @@ -1508,7 +1508,7 @@ def clear_sites(self): if self.is_stored: raise ModificationNotAllowed('The StructureData object cannot be modified, it has already been stored') - self.set_attribute('sites', []) + self.base.attributes.set('sites', []) @property def sites(self): @@ -1516,7 +1516,7 @@ def sites(self): Returns a list of sites. """ try: - raw_sites = self.get_attribute('sites') + raw_sites = self.base.attributes.get('sites') except AttributeError: raw_sites = [] return [Site(raw=i) for i in raw_sites] @@ -1527,7 +1527,7 @@ def kinds(self): Returns a list of kinds. """ try: - raw_kinds = self.get_attribute('kinds') + raw_kinds = self.base.attributes.get('kinds') except AttributeError: raw_kinds = [] return [Kind(raw=i) for i in raw_kinds] @@ -1578,7 +1578,7 @@ def cell(self): :return: a 3x3 list of lists. """ - return copy.deepcopy(self.get_attribute('cell')) + return copy.deepcopy(self.base.attributes.get('cell')) @cell.setter def cell(self, value): @@ -1593,7 +1593,7 @@ def set_cell(self, value): raise ModificationNotAllowed('The StructureData object cannot be modified, it has already been stored') the_cell = _get_valid_cell(value) - self.set_attribute('cell', the_cell) + self.base.attributes.set('cell', the_cell) def reset_cell(self, new_cell): """ @@ -1609,7 +1609,7 @@ def reset_cell(self, new_cell): if self.is_stored: raise ModificationNotAllowed() - self.set_attribute('cell', new_cell) + self.base.attributes.set('cell', new_cell) def reset_sites_positions(self, new_positions, conserve_particle=True): """ @@ -1670,7 +1670,7 @@ def pbc(self): boundary conditions for the i-th real-space direction (i=1,2,3) """ # return copy.deepcopy(self._pbc) - return (self.get_attribute('pbc1'), self.get_attribute('pbc2'), self.get_attribute('pbc3')) + return (self.base.attributes.get('pbc1'), self.base.attributes.get('pbc2'), self.base.attributes.get('pbc3')) @pbc.setter def pbc(self, value): @@ -1686,9 +1686,9 @@ def set_pbc(self, value): the_pbc = get_valid_pbc(value) # self._pbc = the_pbc - self.set_attribute('pbc1', the_pbc[0]) - self.set_attribute('pbc2', the_pbc[1]) - self.set_attribute('pbc3', the_pbc[2]) + self.base.attributes.set('pbc1', the_pbc[0]) + self.base.attributes.set('pbc2', the_pbc[1]) + self.base.attributes.set('pbc3', the_pbc[2]) @property def cell_lengths(self): diff --git a/aiida/orm/nodes/data/upf.py b/aiida/orm/nodes/data/upf.py index 6896ca5a19..d84c83267a 100644 --- a/aiida/orm/nodes/data/upf.py +++ b/aiida/orm/nodes/data/upf.py @@ -314,8 +314,8 @@ def store(self, *args, **kwargs): # pylint: disable=signature-differs except KeyError: raise ParsingError(f'Could not parse the element from the UPF file {self.filename}') - self.set_attribute('element', str(element)) - self.set_attribute('md5', md5) + self.base.attributes.set('element', str(element)) + self.base.attributes.set('md5', md5) return super().store(*args, **kwargs) @@ -358,8 +358,8 @@ def set_file(self, file, filename=None): super().set_file(file, filename=filename) - self.set_attribute('element', str(element)) - self.set_attribute('md5', md5sum) + self.base.attributes.set('element', str(element)) + self.base.attributes.set('md5', md5sum) def get_upf_family_names(self): """Get the list of all upf family names to which the pseudo belongs.""" @@ -376,7 +376,7 @@ def element(self): :return: the element """ - return self.get_attribute('element', None) + return self.base.attributes.get('element', None) @property def md5sum(self): @@ -384,7 +384,7 @@ def md5sum(self): :return: the md5 checksum """ - return self.get_attribute('md5', None) + return self.base.attributes.get('md5', None) def _validate(self): """Validate the UPF potential file stored for this node.""" @@ -411,12 +411,12 @@ def _validate(self): raise ValidationError(f"No 'element' could be parsed in the UPF {self.filename}") try: - attr_element = self.get_attribute('element') + attr_element = self.base.attributes.get('element') except AttributeError: raise ValidationError("attribute 'element' not set.") try: - attr_md5 = self.get_attribute('md5') + attr_md5 = self.base.attributes.get('md5') except AttributeError: raise ValidationError("attribute 'md5' not set.") diff --git a/aiida/orm/nodes/node.py b/aiida/orm/nodes/node.py index 19bbfeee13..d773a75314 100644 --- a/aiida/orm/nodes/node.py +++ b/aiida/orm/nodes/node.py @@ -43,9 +43,10 @@ from ..comments import Comment from ..computers import Computer from ..entities import Collection as EntityCollection -from ..entities import Entity, EntityAttributesMixin, EntityExtrasMixin +from ..entities import Entity, EntityExtrasMixin from ..querybuilder import QueryBuilder from ..users import User +from .attributes import NodeAttributes from .repository import NodeRepository if TYPE_CHECKING: @@ -114,8 +115,13 @@ def repository(self) -> 'NodeRepository': """Return the repository for this node.""" return NodeRepository(self._node) + @cached_property + def attributes(self) -> 'NodeAttributes': + """Return the attributes for this node.""" + return NodeAttributes(self._node) + -class Node(Entity['BackendNode'], EntityAttributesMixin, EntityExtrasMixin, metaclass=AbstractNodeMeta): +class Node(Entity['BackendNode'], EntityExtrasMixin, metaclass=AbstractNodeMeta): """ Base class for all nodes in AiiDA. @@ -194,6 +200,16 @@ def base(self) -> NodeBase: """Return the node base namespace.""" return NodeBase(self) + def _check_mutability_attributes(self, keys: Optional[List[str]] = None) -> None: # pylint: disable=unused-argument + """Check if the entity is mutable and raise an exception if not. + + This is called from `NodeAttributes` methods that modify the attributes. + + :param keys: the keys that will be mutated, or all if None + """ + if self.is_stored: + raise exceptions.ModificationNotAllowed('the attributes of a stored entity are immutable') + def __eq__(self, other: Any) -> bool: """Fallback equality comparison by uuid (can be overwritten by specific types)""" if isinstance(other, Node) and self.uuid == other.uuid: @@ -240,7 +256,7 @@ def _validate(self) -> bool: and should usually call ``super()._validate()`` first! This method is called automatically before storing the node in the DB. - Therefore, use :py:meth:`~aiida.orm.entities.EntityAttributesMixin.get_attribute()` and similar methods that + Therefore, use :py:meth:`~aiida.orm.nodes.attributes.NodeAttributes.get()` and similar methods that automatically read either from the DB or from the internal attribute cache. """ # pylint: disable=no-self-use @@ -774,9 +790,9 @@ def _store_from_cache(self, cache_node: 'Node', with_transaction: bool) -> None: # Make sure to reinitialize the repository instance of the clone to that of the source node. self.base.repository._copy(cache_node.base.repository) # pylint: disable=protected-access - for key, value in cache_node.attributes.items(): + for key, value in cache_node.base.attributes.all.items(): if key != Sealable.SEALED_KEY: - self.set_attribute(key, value) + self.base.attributes.set(key, value) self._store(with_transaction=with_transaction, clean=False) self._add_outputs_from_cache(cache_node) @@ -827,7 +843,7 @@ def _get_objects_to_hash(self) -> List[Any]: version, { key: val - for key, val in self.attributes_items() + for key, val in self.base.attributes.items() if key not in self._hash_ignored_attributes and key not in self._updatable_attributes # pylint: disable=unsupported-membership-test }, self.base.repository.hash(), @@ -951,12 +967,36 @@ def get_description(self) -> str: 'repository_metadata': 'metadata', } + _deprecated_attr_methods = { + 'attributes': 'all', + 'get_attribute': 'get', + 'get_attribute_many': 'get_many', + 'set_attribute': 'set', + 'set_attribute_many': 'set_many', + 'reset_attributes': 'reset', + 'delete_attribute': 'delete', + 'delete_attribute_many': 'delete_many', + 'clear_attributes': 'clear', + 'attributes_items': 'items', + 'attributes_keys': 'keys', + } + def __getattr__(self, name: str) -> Any: """ This method is called when an attribute is not found in the instance. It allows for the handling of deprecated mixin methods. """ + if name in self._deprecated_attr_methods: + new_name = self._deprecated_attr_methods[name] + kls = self.__class__.__name__ + warn_deprecation( + f'`{kls}.{name}` is deprecated, use `{kls}.base.attributes.{new_name}` instead.', + version=3, + stacklevel=3 + ) + return getattr(self.base.attributes, new_name) + if name in self._deprecated_repo_methods: new_name = self._deprecated_repo_methods[name] kls = self.__class__.__name__ @@ -966,4 +1006,5 @@ def __getattr__(self, name: str) -> Any: stacklevel=3 ) return getattr(self.base.repository, new_name) + raise AttributeError(name) diff --git a/aiida/orm/nodes/process/calculation/calcjob.py b/aiida/orm/nodes/process/calculation/calcjob.py index ea06c8871f..402ef3b42e 100644 --- a/aiida/orm/nodes/process/calculation/calcjob.py +++ b/aiida/orm/nodes/process/calculation/calcjob.py @@ -125,7 +125,7 @@ def _get_objects_to_hash(self) -> List[Any]: import_module(self.__module__.split('.', 1)[0]).__version__, { key: val - for key, val in self.attributes_items() + for key, val in self.base.attributes.items() if key not in self._hash_ignored_attributes and key not in self._updatable_attributes # pylint: disable=unsupported-membership-test }, self.computer.uuid if self.computer is not None else None, # pylint: disable=no-member @@ -155,7 +155,7 @@ def get_builder_restart(self) -> 'ProcessBuilder': @property def is_imported(self) -> bool: """Return whether the calculation job was imported instead of being an actual run.""" - return self.get_attribute(self.IMMIGRATED_KEY, None) is True + return self.base.attributes.get(self.IMMIGRATED_KEY, None) is True def get_option(self, name: str) -> Optional[Any]: """ @@ -165,7 +165,7 @@ def get_option(self, name: str) -> Optional[Any]: :return: the option value or None :raises: ValueError for unknown option """ - return self.get_attribute(name, None) + return self.base.attributes.get(name, None) def set_option(self, name: str, value: Any) -> None: """ @@ -176,7 +176,7 @@ def set_option(self, name: str, value: Any) -> None: :raises: ValueError for unknown option :raises: TypeError for values with invalid type """ - self.set_attribute(name, value) + self.base.attributes.set(name, value) def get_options(self) -> Dict[str, Any]: """ @@ -211,7 +211,7 @@ def get_state(self) -> Optional[CalcJobState]: :return: instance of `aiida.common.datastructures.CalcJobState` or `None` if invalid value, or not set """ - state = self.get_attribute(self.CALC_JOB_STATE_KEY, None) + state = self.base.attributes.get(self.CALC_JOB_STATE_KEY, None) try: state = CalcJobState(state) @@ -228,12 +228,12 @@ def set_state(self, state: CalcJobState) -> None: if not isinstance(state, CalcJobState): raise ValueError(f'{state} is not a valid CalcJobState') - self.set_attribute(self.CALC_JOB_STATE_KEY, state.value) + self.base.attributes.set(self.CALC_JOB_STATE_KEY, state.value) def delete_state(self) -> None: """Delete the calculation job state attribute if it exists.""" try: - self.delete_attribute(self.CALC_JOB_STATE_KEY) + self.base.attributes.delete(self.CALC_JOB_STATE_KEY) except AttributeError: pass @@ -242,14 +242,14 @@ def set_remote_workdir(self, remote_workdir: str) -> None: :param remote_workdir: absolute filepath to the remote working directory """ - self.set_attribute(self.REMOTE_WORKDIR_KEY, remote_workdir) + self.base.attributes.set(self.REMOTE_WORKDIR_KEY, remote_workdir) def get_remote_workdir(self) -> Optional[str]: """Return the path to the remote (on cluster) scratch folder of the calculation. :return: a string with the remote path """ - return self.get_attribute(self.REMOTE_WORKDIR_KEY, None) + return self.base.attributes.get(self.REMOTE_WORKDIR_KEY, None) @staticmethod def _validate_retrieval_directive(directives: Sequence[Union[str, Tuple[str, str, str]]]) -> None: @@ -289,14 +289,14 @@ def set_retrieve_list(self, retrieve_list: Sequence[Union[str, Tuple[str, str, s :param retrieve_list: list or tuple of with filepath directives """ self._validate_retrieval_directive(retrieve_list) - self.set_attribute(self.RETRIEVE_LIST_KEY, retrieve_list) + self.base.attributes.set(self.RETRIEVE_LIST_KEY, retrieve_list) def get_retrieve_list(self) -> Optional[Sequence[Union[str, Tuple[str, str, str]]]]: """Return the list of files/directories to be retrieved on the cluster after the calculation has completed. :return: a list of file directives """ - return self.get_attribute(self.RETRIEVE_LIST_KEY, None) + return self.base.attributes.get(self.RETRIEVE_LIST_KEY, None) def set_retrieve_temporary_list(self, retrieve_temporary_list: Sequence[Union[str, Tuple[str, str, str]]]) -> None: """Set the retrieve temporary list. @@ -307,14 +307,14 @@ def set_retrieve_temporary_list(self, retrieve_temporary_list: Sequence[Union[st :param retrieve_temporary_list: list or tuple of with filepath directives """ self._validate_retrieval_directive(retrieve_temporary_list) - self.set_attribute(self.RETRIEVE_TEMPORARY_LIST_KEY, retrieve_temporary_list) + self.base.attributes.set(self.RETRIEVE_TEMPORARY_LIST_KEY, retrieve_temporary_list) def get_retrieve_temporary_list(self) -> Optional[Sequence[Union[str, Tuple[str, str, str]]]]: """Return list of files to be retrieved from the cluster which will be available during parsing. :return: a list of file directives """ - return self.get_attribute(self.RETRIEVE_TEMPORARY_LIST_KEY, None) + return self.base.attributes.get(self.RETRIEVE_TEMPORARY_LIST_KEY, None) def set_job_id(self, job_id: Union[int, str]) -> None: """Set the job id that was assigned to the calculation by the scheduler. @@ -323,14 +323,14 @@ def set_job_id(self, job_id: Union[int, str]) -> None: :param job_id: the id assigned by the scheduler after submission """ - return self.set_attribute(self.SCHEDULER_JOB_ID_KEY, str(job_id)) + return self.base.attributes.set(self.SCHEDULER_JOB_ID_KEY, str(job_id)) def get_job_id(self) -> Optional[str]: """Return job id that was assigned to the calculation by the scheduler. :return: the string representation of the scheduler job id """ - return self.get_attribute(self.SCHEDULER_JOB_ID_KEY, None) + return self.base.attributes.get(self.SCHEDULER_JOB_ID_KEY, None) def set_scheduler_state(self, state: 'JobState') -> None: """Set the scheduler state. @@ -343,8 +343,8 @@ def set_scheduler_state(self, state: 'JobState') -> None: if not isinstance(state, JobState): raise ValueError(f'scheduler state should be an instance of JobState, got: {state}') - self.set_attribute(self.SCHEDULER_STATE_KEY, state.value) - self.set_attribute(self.SCHEDULER_LAST_CHECK_TIME_KEY, timezone.datetime_to_isoformat(timezone.now())) + self.base.attributes.set(self.SCHEDULER_STATE_KEY, state.value) + self.base.attributes.set(self.SCHEDULER_LAST_CHECK_TIME_KEY, timezone.datetime_to_isoformat(timezone.now())) def get_scheduler_state(self) -> Optional['JobState']: """Return the status of the calculation according to the cluster scheduler. @@ -353,7 +353,7 @@ def get_scheduler_state(self) -> Optional['JobState']: """ from aiida.schedulers.datastructures import JobState - state = self.get_attribute(self.SCHEDULER_STATE_KEY, None) + state = self.base.attributes.get(self.SCHEDULER_STATE_KEY, None) if state is None: return state @@ -366,7 +366,7 @@ def get_scheduler_lastchecktime(self) -> Optional[datetime.datetime]: :return: a datetime object or None """ from aiida.common import timezone - value = self.get_attribute(self.SCHEDULER_LAST_CHECK_TIME_KEY, None) + value = self.base.attributes.get(self.SCHEDULER_LAST_CHECK_TIME_KEY, None) if value is not None: value = timezone.isoformat_to_datetime(value) @@ -378,7 +378,7 @@ def set_detailed_job_info(self, detailed_job_info: Optional[dict]) -> None: :param detailed_job_info: a dictionary with metadata with the accounting of a completed job """ - self.set_attribute(self.SCHEDULER_DETAILED_JOB_INFO_KEY, detailed_job_info) + self.base.attributes.set(self.SCHEDULER_DETAILED_JOB_INFO_KEY, detailed_job_info) def get_detailed_job_info(self) -> Optional[dict]: """Return the detailed job info dictionary. @@ -387,14 +387,14 @@ def get_detailed_job_info(self) -> Optional[dict]: :return: the dictionary with detailed job info if defined or None """ - return self.get_attribute(self.SCHEDULER_DETAILED_JOB_INFO_KEY, None) + return self.base.attributes.get(self.SCHEDULER_DETAILED_JOB_INFO_KEY, None) def set_last_job_info(self, last_job_info: 'JobInfo') -> None: """Set the last job info. :param last_job_info: a `JobInfo` object """ - self.set_attribute(self.SCHEDULER_LAST_JOB_INFO_KEY, last_job_info.get_dict()) + self.base.attributes.set(self.SCHEDULER_LAST_JOB_INFO_KEY, last_job_info.get_dict()) def get_last_job_info(self) -> Optional['JobInfo']: """Return the last information asked to the scheduler about the status of the job. @@ -409,7 +409,7 @@ def get_last_job_info(self) -> Optional['JobInfo']: """ from aiida.schedulers.datastructures import JobInfo - last_job_info_dictserialized = self.get_attribute(self.SCHEDULER_LAST_JOB_INFO_KEY, None) + last_job_info_dictserialized = self.base.attributes.get(self.SCHEDULER_LAST_JOB_INFO_KEY, None) if last_job_info_dictserialized is not None: job_info = JobInfo.load_from_dict(last_job_info_dictserialized) diff --git a/aiida/orm/nodes/process/process.py b/aiida/orm/nodes/process/process.py index b02810a549..eb81bfd93b 100644 --- a/aiida/orm/nodes/process/process.py +++ b/aiida/orm/nodes/process/process.py @@ -146,7 +146,7 @@ def process_label(self) -> Optional[str]: :returns: the process label """ - return self.get_attribute(self.PROCESS_LABEL_KEY, None) + return self.base.attributes.get(self.PROCESS_LABEL_KEY, None) def set_process_label(self, label: str) -> None: """ @@ -154,7 +154,7 @@ def set_process_label(self, label: str) -> None: :param label: process label string """ - self.set_attribute(self.PROCESS_LABEL_KEY, label) + self.base.attributes.set(self.PROCESS_LABEL_KEY, label) @property def process_state(self) -> Optional[ProcessState]: @@ -163,7 +163,7 @@ def process_state(self) -> Optional[ProcessState]: :returns: the process state instance of ProcessState enum """ - state = self.get_attribute(self.PROCESS_STATE_KEY, None) + state = self.base.attributes.get(self.PROCESS_STATE_KEY, None) if state is None: return state @@ -178,7 +178,7 @@ def set_process_state(self, state: Union[str, ProcessState]): """ if isinstance(state, ProcessState): state = state.value - return self.set_attribute(self.PROCESS_STATE_KEY, state) + return self.base.attributes.set(self.PROCESS_STATE_KEY, state) @property def process_status(self) -> Optional[str]: @@ -189,7 +189,7 @@ def process_status(self) -> Optional[str]: :returns: the process status """ - return self.get_attribute(self.PROCESS_STATUS_KEY, None) + return self.base.attributes.get(self.PROCESS_STATUS_KEY, None) def set_process_status(self, status: Optional[str]) -> None: """ @@ -202,7 +202,7 @@ def set_process_status(self, status: Optional[str]) -> None: """ if status is None: try: - self.delete_attribute(self.PROCESS_STATUS_KEY) + self.base.attributes.delete(self.PROCESS_STATUS_KEY) except AttributeError: pass return None @@ -210,7 +210,7 @@ def set_process_status(self, status: Optional[str]) -> None: if not isinstance(status, str): raise TypeError('process status should be a string') - return self.set_attribute(self.PROCESS_STATUS_KEY, status) + return self.base.attributes.set(self.PROCESS_STATUS_KEY, status) @property def is_terminated(self) -> bool: @@ -292,7 +292,7 @@ def exit_status(self) -> Optional[int]: :returns: the exit status, an integer exit code or None """ - return self.get_attribute(self.EXIT_STATUS_KEY, None) + return self.base.attributes.get(self.EXIT_STATUS_KEY, None) def set_exit_status(self, status: Union[None, enum.Enum, int]) -> None: """ @@ -309,7 +309,7 @@ def set_exit_status(self, status: Union[None, enum.Enum, int]) -> None: if not isinstance(status, int): raise ValueError(f'exit status has to be an integer, got {status}') - return self.set_attribute(self.EXIT_STATUS_KEY, status) + return self.base.attributes.set(self.EXIT_STATUS_KEY, status) @property def exit_message(self) -> Optional[str]: @@ -318,7 +318,7 @@ def exit_message(self) -> Optional[str]: :returns: the exit message """ - return self.get_attribute(self.EXIT_MESSAGE_KEY, None) + return self.base.attributes.get(self.EXIT_MESSAGE_KEY, None) def set_exit_message(self, message: Optional[str]) -> None: """ @@ -332,7 +332,7 @@ def set_exit_message(self, message: Optional[str]) -> None: if not isinstance(message, str): raise ValueError(f'exit message has to be a string type, got {type(message)}') - return self.set_attribute(self.EXIT_MESSAGE_KEY, message) + return self.base.attributes.set(self.EXIT_MESSAGE_KEY, message) @property def exception(self) -> Optional[str]: @@ -344,7 +344,7 @@ def exception(self) -> Optional[str]: :returns: the exception message or None """ if self.is_excepted: - return self.get_attribute(self.EXCEPTION_KEY, '') + return self.base.attributes.get(self.EXCEPTION_KEY, '') return None @@ -357,7 +357,7 @@ def set_exception(self, exception: str) -> None: if not isinstance(exception, str): raise ValueError(f'exception message has to be a string type, got {type(exception)}') - return self.set_attribute(self.EXCEPTION_KEY, exception) + return self.base.attributes.set(self.EXCEPTION_KEY, exception) @property def checkpoint(self) -> Optional[Dict[str, Any]]: @@ -366,7 +366,7 @@ def checkpoint(self) -> Optional[Dict[str, Any]]: :returns: checkpoint bundle if it exists, None otherwise """ - return self.get_attribute(self.CHECKPOINT_KEY, None) + return self.base.attributes.get(self.CHECKPOINT_KEY, None) def set_checkpoint(self, checkpoint: Dict[str, Any]) -> None: """ @@ -374,14 +374,14 @@ def set_checkpoint(self, checkpoint: Dict[str, Any]) -> None: :param state: string representation of the stepper state info """ - return self.set_attribute(self.CHECKPOINT_KEY, checkpoint) + return self.base.attributes.set(self.CHECKPOINT_KEY, checkpoint) def delete_checkpoint(self) -> None: """ Delete the checkpoint bundle set for the process """ try: - self.delete_attribute(self.CHECKPOINT_KEY) + self.base.attributes.delete(self.CHECKPOINT_KEY) except AttributeError: pass @@ -392,7 +392,7 @@ def paused(self) -> bool: :returns: True if the Calculation is marked as paused, False otherwise """ - return self.get_attribute(self.PROCESS_PAUSED_KEY, False) + return self.base.attributes.get(self.PROCESS_PAUSED_KEY, False) def pause(self) -> None: """ @@ -401,7 +401,7 @@ def pause(self) -> None: This serves only to reflect that the corresponding Process is paused and so this method should not be called by anyone but the Process instance itself. """ - return self.set_attribute(self.PROCESS_PAUSED_KEY, True) + return self.base.attributes.set(self.PROCESS_PAUSED_KEY, True) def unpause(self) -> None: """ @@ -411,7 +411,7 @@ def unpause(self) -> None: by anyone but the Process instance itself. """ try: - self.delete_attribute(self.PROCESS_PAUSED_KEY) + self.base.attributes.delete(self.PROCESS_PAUSED_KEY) except AttributeError: pass diff --git a/aiida/orm/nodes/process/workflow/workchain.py b/aiida/orm/nodes/process/workflow/workchain.py index 07f0f8a0b3..ab8847d85f 100644 --- a/aiida/orm/nodes/process/workflow/workchain.py +++ b/aiida/orm/nodes/process/workflow/workchain.py @@ -34,7 +34,7 @@ def stepper_state_info(self) -> Optional[str]: :returns: string representation of the stepper state info """ - return self.get_attribute(self.STEPPER_STATE_INFO_KEY, None) + return self.base.attributes.get(self.STEPPER_STATE_INFO_KEY, None) def set_stepper_state_info(self, stepper_state_info: str) -> None: """ @@ -42,4 +42,4 @@ def set_stepper_state_info(self, stepper_state_info: str) -> None: :param state: string representation of the stepper state info """ - return self.set_attribute(self.STEPPER_STATE_INFO_KEY, stepper_state_info) + return self.base.attributes.set(self.STEPPER_STATE_INFO_KEY, stepper_state_info) diff --git a/aiida/orm/utils/managers.py b/aiida/orm/utils/managers.py index 1c56f112e4..9717784832 100644 --- a/aiida/orm/utils/managers.py +++ b/aiida/orm/utils/managers.py @@ -208,7 +208,7 @@ def __init__(self, node): """ # Possibly add checks here # We cannot set `self._node` because it would go through the __setattr__ method - # which uses said _node by calling `self._node.set_attribute(name, value)`. + # which uses said _node by calling `self._node.base.attributes.set(name, value)`. # Instead, we need to manually set it through the `self.__dict__` property. self.__dict__['_node'] = node @@ -216,20 +216,20 @@ def __dir__(self): """ Allow to list the keys of the dictionary """ - return sorted(self._node.attributes_keys()) + return sorted(self._node.base.attributes.keys()) def __iter__(self): """ Return the keys as an iterator """ - for k in self._node.attributes_keys(): + for k in self._node.base.attributes.keys(): yield k def _get_dict(self): """ Return the internal dictionary """ - return dict(self._node.attributes_items()) + return dict(self._node.base.attributes.items()) def __getattr__(self, name): """ @@ -240,10 +240,10 @@ def __getattr__(self, name): :param name: name of the key whose value is required. """ - return self._node.get_attribute(name) + return self._node.base.attributes.get(name) def __setattr__(self, name, value): - self._node.set_attribute(name, value) + self._node.base.attributes.set(name, value) def __getitem__(self, name): """ @@ -252,6 +252,6 @@ def __getitem__(self, name): :param name: name of the key whose value is required. """ try: - return self._node.get_attribute(name) + return self._node.base.attributes.get(name) except AttributeError as exception: raise KeyError(str(exception)) from exception diff --git a/aiida/orm/utils/mixins.py b/aiida/orm/utils/mixins.py index 95ddefad65..59f27e579f 100644 --- a/aiida/orm/utils/mixins.py +++ b/aiida/orm/utils/mixins.py @@ -9,6 +9,7 @@ ########################################################################### """Mixin classes for ORM classes.""" import inspect +from typing import List, Optional from aiida.common import exceptions from aiida.common.lang import classproperty, override @@ -64,14 +65,14 @@ def function_name(self): :returns: the function name or None """ - return self.get_attribute(self.FUNCTION_NAME_KEY, None) + return self.base.attributes.get(self.FUNCTION_NAME_KEY, None) def _set_function_name(self, function_name): """Set the function name of the wrapped function. :param function_name: the function name """ - self.set_attribute(self.FUNCTION_NAME_KEY, function_name) + self.base.attributes.set(self.FUNCTION_NAME_KEY, function_name) @property def function_namespace(self): @@ -79,14 +80,14 @@ def function_namespace(self): :returns: the function namespace or None """ - return self.get_attribute(self.FUNCTION_NAMESPACE_KEY, None) + return self.base.attributes.get(self.FUNCTION_NAMESPACE_KEY, None) def _set_function_namespace(self, function_namespace): """Set the function namespace of the wrapped function. :param function_namespace: the function namespace """ - self.set_attribute(self.FUNCTION_NAMESPACE_KEY, function_namespace) + self.base.attributes.set(self.FUNCTION_NAMESPACE_KEY, function_namespace) @property def function_starting_line_number(self): @@ -94,14 +95,14 @@ def function_starting_line_number(self): :returns: the starting line number or None """ - return self.get_attribute(self.FUNCTION_STARTING_LINE_KEY, None) + return self.base.attributes.get(self.FUNCTION_STARTING_LINE_KEY, None) def _set_function_starting_line_number(self, function_starting_line_number): """Set the starting line number of the wrapped function in its source file. :param function_starting_line_number: the starting line number """ - self.set_attribute(self.FUNCTION_STARTING_LINE_KEY, function_starting_line_number) + self.base.attributes.set(self.FUNCTION_STARTING_LINE_KEY, function_starting_line_number) def get_function_source_code(self): """Return the absolute path to the source file in the repository. @@ -121,6 +122,36 @@ class Sealable: def _updatable_attributes(cls): # pylint: disable=no-self-argument return (cls.SEALED_KEY,) + @property + def is_sealed(self): + """Returns whether the node is sealed, i.e. whether the sealed attribute has been set to True.""" + return self.base.attributes.get(self.SEALED_KEY, False) + + def seal(self): + """Seal the node by setting the sealed attribute to True.""" + if not self.is_sealed: + self.base.attributes.set(self.SEALED_KEY, True) + + @override + def _check_mutability_attributes(self, keys: Optional[List[str]] = None) -> None: # pylint: disable=unused-argument + """Check if the entity is mutable and raise an exception if not. + + This is called from `NodeAttributes` methods that modify the attributes. + + :param keys: the keys that will be mutated, or all if None + """ + if self.is_sealed: + raise exceptions.ModificationNotAllowed('attributes of a sealed node are immutable') + + if self.is_stored: + # here we are more lenient than the base class, since we allow the modification of some attributes + if keys is None: + raise exceptions.ModificationNotAllowed('Cannot bulk modify attributes of a stored+unsealed node') + elif any(key not in self._updatable_attributes for key in keys): + raise exceptions.ModificationNotAllowed( + f'Cannot modify non-updatable attributes of a stored+unsealed node: {keys}' + ) + def validate_incoming(self, source, link_type, link_label): """Validate adding a link of the given type from a given node to ourself. @@ -150,47 +181,3 @@ def validate_outgoing(self, target, link_type, link_label): raise exceptions.ModificationNotAllowed('Cannot add a link from a sealed node') super().validate_outgoing(target, link_type=link_type, link_label=link_label) - - @property - def is_sealed(self): - """Returns whether the node is sealed, i.e. whether the sealed attribute has been set to True.""" - return self.get_attribute(self.SEALED_KEY, False) - - def seal(self): - """Seal the node by setting the sealed attribute to True.""" - if not self.is_sealed: - self.set_attribute(self.SEALED_KEY, True) - - @override - def set_attribute(self, key, value): - """Set an attribute to the given value. - - :param key: name of the attribute - :param value: value of the attribute - :raise aiida.common.exceptions.ModificationNotAllowed: if the node is already sealed or if the node - is already stored and the attribute is not updatable. - """ - if self.is_sealed: - raise exceptions.ModificationNotAllowed('attributes of a sealed node are immutable') - - if self.is_stored and key not in self._updatable_attributes: # pylint: disable=unsupported-membership-test - raise exceptions.ModificationNotAllowed(f'`{key}` is not an updatable attribute') - - self.backend_entity.set_attribute(key, value) - - @override - def delete_attribute(self, key): - """Delete an attribute. - - :param key: name of the attribute - :raises AttributeError: if the attribute does not exist - :raise aiida.common.exceptions.ModificationNotAllowed: if the node is already sealed or if the node - is already stored and the attribute is not updatable. - """ - if self.is_sealed: - raise exceptions.ModificationNotAllowed('attributes of a sealed node are immutable') - - if self.is_stored and key not in self._updatable_attributes: # pylint: disable=unsupported-membership-test - raise exceptions.ModificationNotAllowed(f'`{key}` is not an updatable attribute') - - self.backend_entity.delete_attribute(key) diff --git a/aiida/restapi/translator/nodes/node.py b/aiida/restapi/translator/nodes/node.py index b550f10e0e..5576747208 100644 --- a/aiida/restapi/translator/nodes/node.py +++ b/aiida/restapi/translator/nodes/node.py @@ -260,13 +260,13 @@ def _get_content(self): if self._content_type == 'attributes': # Get all attrs if attributes_filter is None if self._attributes_filter is None: - data = {self._content_type: node.attributes} + data = {self._content_type: node.base.attributes.all} # Get all attrs contained in attributes_filter else: attrs = {} - for key in node.attributes.keys(): + for key in node.base.attributes.keys(): if key in self._attributes_filter: - attrs[key] = node.get_attribute(key) + attrs[key] = node.base.attributes.get(key) data = {self._content_type: attrs} # content/extras diff --git a/aiida/tools/visualization/graph.py b/aiida/tools/visualization/graph.py index 37b4fd0b04..de01e94df5 100644 --- a/aiida/tools/visualization/graph.py +++ b/aiida/tools/visualization/graph.py @@ -217,13 +217,13 @@ def default_node_sublabels(node): class_node_type = node.class_node_type if class_node_type == 'data.core.int.Int.': - sublabel = f"value: {node.get_attribute('value', '')}" + sublabel = f"value: {node.base.attributes.get('value', '')}" elif class_node_type == 'data.core.float.Float.': - sublabel = f"value: {node.get_attribute('value', '')}" + sublabel = f"value: {node.base.attributes.get('value', '')}" elif class_node_type == 'data.core.str.Str.': - sublabel = f"{node.get_attribute('value', '')}" + sublabel = f"{node.base.attributes.get('value', '')}" elif class_node_type == 'data.core.bool.Bool.': - sublabel = f"{node.get_attribute('value', '')}" + sublabel = f"{node.base.attributes.get('value', '')}" elif class_node_type == 'data.core.code.Code.': sublabel = f'{os.path.basename(node.get_execname())}@{node.computer.label}' elif class_node_type == 'data.core.singlefile.SinglefileData.': @@ -242,7 +242,7 @@ def default_node_sublabels(node): sublabel_lines.append(', '.join(sg_numbers)) sublabel = '; '.join(sublabel_lines) elif class_node_type == 'data.core.upf.UpfData.': - sublabel = f"{node.get_attribute('element', '')}" + sublabel = f"{node.base.attributes.get('element', '')}" elif isinstance(node, orm.ProcessNode): sublabel = [] if node.process_state is not None: diff --git a/docs/source/developer_guide/core/internals.rst b/docs/source/developer_guide/core/internals.rst index 80b6e79a48..bb13565f09 100644 --- a/docs/source/developer_guide/core/internals.rst +++ b/docs/source/developer_guide/core/internals.rst @@ -158,20 +158,7 @@ Attributes related methods ========================== Each :py:meth:`~aiida.orm.nodes.node.Node` object can have attributes which are properties that characterize the node. Such properties can be the energy, the atom symbols or the lattice vectors. -The following methods can be used for the management of the attributes. - -- :py:meth:`~aiida.orm.nodes.node.Node.set_attribute` and :py:meth:`~aiida.orm.nodes.node.Node.set_attribute_many` adds one or many new attributes to the node. - The key of the attribute is the property name (e.g. ``energy``, ``lattice_vectors`` etc) and the value of the attribute is the value of that property. - -- :py:meth:`~aiida.orm.nodes.node.Node.reset_attributes` will replace all existing attributes with a new set of attributes. - -- :py:meth:`~aiida.orm.nodes.node.Node.attributes` is a property that returns all attributes. - -- :py:meth:`~aiida.orm.nodes.node.Node.get_attribute` and :py:meth:`~aiida.orm.nodes.node.Node.get_attribute_many` can be used to return a single or many specific attributes. - -- :py:meth:`~aiida.orm.nodes.node.Node.delete_attribute` & :py:meth:`~aiida.orm.nodes.node.Node.delete_attribute_many` delete one or multiple specific attributes. - -- :py:meth:`~aiida.orm.nodes.node.Node.clear_attributes` will delete all existing attributes. +The methods for the management of the attributes are defined on the :py:class:`~aiida.orm.nodes.attributes.NodeAttributes` class. Extras related methods diff --git a/docs/source/topics/data_types.rst b/docs/source/topics/data_types.rst index 3ff39c4639..2beb07147d 100644 --- a/docs/source/topics/data_types.rst +++ b/docs/source/topics/data_types.rst @@ -916,7 +916,7 @@ Therefore, we have to override the constructor :meth:`~aiida.orm.nodes.node.Node def __init__(self, **kwargs): value = kwargs.pop('value') super().__init__(**kwargs) - self.set_attribute('value', value) + self.base.attributes.set('value', value) .. warning:: @@ -937,10 +937,10 @@ By adding the value to the node's attributes, they will be queryable in the data .. code-block:: python node = NewData(value=5) # Creating new node instance in memory - node.set_attribute('value', 6) # While in memory, node attributes can be changed + node.base.attributes.set('value', 6) # While in memory, node attributes can be changed node.store() # Storing node instance in the database -After storing the node instance in the database, its attributes are frozen, and ``node.set_attribute('value', 7)`` will fail. +After storing the node instance in the database, its attributes are frozen, and ``node.base.attributes.set('value', 7)`` will fail. By storing the ``value`` in the attributes of the node instance, we ensure that that ``value`` can be retrieved even when the node is reloaded at a later point in time. Besides making sure that the content of a data node is stored in the database or file repository, the data type class can also provide useful methods for users to retrieve that data. @@ -949,7 +949,7 @@ For example, with the current state of the ``NewData`` class, in order to retrie .. code-block:: python node = load_node() - node.get_attribute('value') + node.base.attributes.get('value') In other words, the user of the ``NewData`` class needs to know that the ``value`` is stored as an attribute with the name 'value'. This is not easy to remember and therefore not very user-friendly. @@ -968,7 +968,7 @@ Let's introduce one that will return the value that was stored for it: @property def value(self): """Return the value stored for this instance.""" - return self.get_attribute('value') + return self.base.attributes.get('value') The addition of the instance property ``value`` makes retrieving the value of a ``NewData`` node a lot easier: @@ -998,14 +998,14 @@ Here is an example for a custom data type that needs to wrap a single text file: filename = os.path.basename(filepath) # Get the filename from the absolute path self.put_object_from_file(filepath, filename) # Store the file in the repository under the given filename - self.set_attribute('filename', filename) # Store in the attributes what the filename is + self.base.attributes.set('filename', filename) # Store in the attributes what the filename is def get_content(self): """Return the content of the single file stored for this data node. :return: the content of the file as a string """ - filename = self.get_attribute('filename') + filename = self.base.attributes.get('filename') return self.get_object_content(filename) To create a new instance of this data type and get its content: @@ -1033,4 +1033,4 @@ However, storing large amounts of data within the database comes at the cost of Therefore, big data (think large files), whose content does not necessarily need to be queried for, is better stored in the file repository. A data type may safely use both the database and file repository in parallel for individual properties. Properties stored in the database are stored as *attributes* of the node. -The node class has various methods to set these attributes, such as :py:meth:`~aiida.orm.entities.EntityAttributesMixin.set_attribute` and :py:meth:`~aiida.orm.entities.EntityAttributesMixin.set_attribute_many`. +The node class has various methods to set these attributes, such as :py:meth:`~aiida.orm.nodes.attributes.NodeAttributes.set` and :py:meth:`~aiida.orm.nodes.attributes.NodeAttributes.set_many`. diff --git a/tests/benchmark/test_nodes.py b/tests/benchmark/test_nodes.py index cc45fcda52..2eed4963bf 100644 --- a/tests/benchmark/test_nodes.py +++ b/tests/benchmark/test_nodes.py @@ -26,7 +26,7 @@ def get_data_node(store=True): """A function to create a simple data node.""" data = Data() - data.set_attribute_many({str(i): i for i in range(10)}) + data.base.attributes.set_many({str(i): i for i in range(10)}) if store: data.store() return (), {'node': data} @@ -35,7 +35,7 @@ def get_data_node(store=True): def get_data_node_and_object(store=True): """A function to create a simple data node, with an object.""" data = Data() - data.set_attribute_many({str(i): i for i in range(10)}) + data.base.attributes.set_many({str(i): i for i in range(10)}) data.put_object_from_filelike(StringIO('a' * 10000), 'key') if store: data.store() @@ -51,7 +51,7 @@ def test_store_backend(benchmark): def _run(): data = Data() - data.set_attribute_many({str(i): i for i in range(10)}) + data.base.attributes.set_many({str(i): i for i in range(10)}) data._backend_entity.store(clean=False) return data diff --git a/tests/cmdline/commands/test_node.py b/tests/cmdline/commands/test_node.py index 57888a97c7..05967a9141 100644 --- a/tests/cmdline/commands/test_node.py +++ b/tests/cmdline/commands/test_node.py @@ -39,7 +39,7 @@ def init_profile(self, aiida_profile_clean, run_cli_command): # pylint: disable self.ATTR_KEY_TWO = 'b' self.ATTR_VAL_TWO = 'test' - node.set_attribute_many({self.ATTR_KEY_ONE: self.ATTR_VAL_ONE, self.ATTR_KEY_TWO: self.ATTR_VAL_TWO}) + node.base.attributes.set_many({self.ATTR_KEY_ONE: self.ATTR_VAL_ONE, self.ATTR_KEY_TWO: self.ATTR_VAL_TWO}) self.EXTRA_KEY_ONE = 'x' self.EXTRA_VAL_ONE = '2' diff --git a/tests/cmdline/commands/test_process.py b/tests/cmdline/commands/test_process.py index 868cc02fe6..119cada641 100644 --- a/tests/cmdline/commands/test_process.py +++ b/tests/cmdline/commands/test_process.py @@ -54,7 +54,7 @@ def init_profile(self, aiida_profile_clean, run_cli_command): # pylint: disable calc.set_exit_status(0) # Give a `process_label` to the `WorkFunctionNodes` so the `--process-label` option can be tested - calc.set_attribute('process_label', self.process_label) + calc.base.attributes.set('process_label', self.process_label) calc.store() self.calcs.append(calc) @@ -184,7 +184,7 @@ def test_process_show(self): workchain_two = WorkChainNode() workchains = [workchain_one, workchain_two] - workchain_two.set_attribute('process_label', 'workchain_one_caller') + workchain_two.base.attributes.set('process_label', 'workchain_one_caller') workchain_two.store() workchain_one.add_incoming(workchain_two, link_type=LinkType.CALL_WORK, link_label='called') workchain_one.store() @@ -192,8 +192,8 @@ def test_process_show(self): calcjob_one = CalcJobNode() calcjob_two = CalcJobNode() - calcjob_one.set_attribute('process_label', 'process_label_one') - calcjob_two.set_attribute('process_label', 'process_label_two') + calcjob_one.base.attributes.set('process_label', 'process_label_one') + calcjob_two.base.attributes.set('process_label', 'process_label_two') calcjob_one.add_incoming(workchain_one, link_type=LinkType.CALL_CALC, link_label='one') calcjob_two.add_incoming(workchain_one, link_type=LinkType.CALL_CALC, link_label='two') diff --git a/tests/engine/processes/calcjobs/test_calc_job.py b/tests/engine/processes/calcjobs/test_calc_job.py index 707ada511b..2d308a6709 100644 --- a/tests/engine/processes/calcjobs/test_calc_job.py +++ b/tests/engine/processes/calcjobs/test_calc_job.py @@ -71,7 +71,7 @@ def prepare_for_submission(self, folder): calcinfo = CalcInfo() calcinfo.codes_info = [codeinfo] - calcinfo.provenance_exclude_list = self.inputs.settings.get_attribute('provenance_exclude_list') + calcinfo.provenance_exclude_list = self.inputs.settings.base.attributes.get('provenance_exclude_list') return calcinfo @@ -551,7 +551,7 @@ def test_parse_non_zero_retval(generate_process): retrieved = orm.FolderData().store() retrieved.add_incoming(process.node, link_label='retrieved', link_type=LinkType.CREATE) - process.node.set_attribute('detailed_job_info', {'retval': 1, 'stderr': 'accounting disabled', 'stdout': ''}) + process.node.base.attributes.set('detailed_job_info', {'retval': 1, 'stderr': 'accounting disabled', 'stdout': ''}) process.parse() logs = [log.message for log in orm.Log.objects.get_logs_for(process.node)] @@ -570,7 +570,7 @@ def test_parse_not_implemented(generate_process): retrieved = orm.FolderData() retrieved.add_incoming(process.node, link_label='retrieved', link_type=LinkType.CREATE) - process.node.set_attribute('detailed_job_info', {}) + process.node.base.attributes.set('detailed_job_info', {}) filename_stderr = process.node.get_option('scheduler_stderr') filename_stdout = process.node.get_option('scheduler_stdout') @@ -604,7 +604,7 @@ def test_parse_scheduler_excepted(generate_process, monkeypatch): retrieved = orm.FolderData() retrieved.add_incoming(process.node, link_label='retrieved', link_type=LinkType.CREATE) - process.node.set_attribute('detailed_job_info', {}) + process.node.base.attributes.set('detailed_job_info', {}) filename_stderr = process.node.get_option('scheduler_stderr') filename_stdout = process.node.get_option('scheduler_stdout') @@ -702,33 +702,33 @@ def test_additional_retrieve_list(generate_process, fixture_sandbox): """Test the ``additional_retrieve_list`` option.""" process = generate_process() process.presubmit(fixture_sandbox) - retrieve_list = process.node.get_attribute('retrieve_list') + retrieve_list = process.node.base.attributes.get('retrieve_list') # Keep reference of the base contents of the retrieve list. base_retrieve_list = retrieve_list # Test that the code works if no explicit additional retrieve list is specified assert len(retrieve_list) != 0 - assert isinstance(process.node.get_attribute('retrieve_list'), list) + assert isinstance(process.node.base.attributes.get('retrieve_list'), list) # Defining explicit additional retrieve list that is disjoint with the base retrieve list additional_retrieve_list = ['file.txt', 'folder/file.txt'] process = generate_process({'metadata': {'options': {'additional_retrieve_list': additional_retrieve_list}}}) process.presubmit(fixture_sandbox) - retrieve_list = process.node.get_attribute('retrieve_list') + retrieve_list = process.node.base.attributes.get('retrieve_list') # Check that the `retrieve_list` is a list and contains the union of the base and additional retrieve list - assert isinstance(process.node.get_attribute('retrieve_list'), list) + assert isinstance(process.node.base.attributes.get('retrieve_list'), list) assert set(retrieve_list) == set(base_retrieve_list).union(set(additional_retrieve_list)) # Defining explicit additional retrieve list with elements that overlap with `base_retrieve_list additional_retrieve_list = ['file.txt', 'folder/file.txt'] + base_retrieve_list process = generate_process({'metadata': {'options': {'additional_retrieve_list': additional_retrieve_list}}}) process.presubmit(fixture_sandbox) - retrieve_list = process.node.get_attribute('retrieve_list') + retrieve_list = process.node.base.attributes.get('retrieve_list') # Check that the `retrieve_list` is a list and contains the union of the base and additional retrieve list - assert isinstance(process.node.get_attribute('retrieve_list'), list) + assert isinstance(process.node.base.attributes.get('retrieve_list'), list) assert set(retrieve_list) == set(base_retrieve_list).union(set(additional_retrieve_list)) # Test the validator diff --git a/tests/engine/processes/workchains/test_utils.py b/tests/engine/processes/workchains/test_utils.py index 63c2d12e7a..4f71940940 100644 --- a/tests/engine/processes/workchains/test_utils.py +++ b/tests/engine/processes/workchains/test_utils.py @@ -66,33 +66,33 @@ class ArithmeticAddBaseWorkChain(BaseRestartWorkChain): @process_handler(priority=100) def handler_01(self, node): """Example handler returing ExitCode 100.""" - handlers_called = node.get_attribute(attribute_key, default=[]) + handlers_called = node.base.attributes.get(attribute_key, default=[]) handlers_called.append('handler_01') - node.set_attribute(attribute_key, handlers_called) + node.base.attributes.set(attribute_key, handlers_called) return ProcessHandlerReport(False, ExitCode(100)) @process_handler(priority=300) def handler_03(self, node): """Example handler returing ExitCode 300.""" - handlers_called = node.get_attribute(attribute_key, default=[]) + handlers_called = node.base.attributes.get(attribute_key, default=[]) handlers_called.append('handler_03') - node.set_attribute(attribute_key, handlers_called) + node.base.attributes.set(attribute_key, handlers_called) return ProcessHandlerReport(False, ExitCode(300)) @process_handler(priority=200) def handler_02(self, node): """Example handler returing ExitCode 200.""" - handlers_called = node.get_attribute(attribute_key, default=[]) + handlers_called = node.base.attributes.get(attribute_key, default=[]) handlers_called.append('handler_02') - node.set_attribute(attribute_key, handlers_called) + node.base.attributes.set(attribute_key, handlers_called) return ProcessHandlerReport(False, ExitCode(200)) @process_handler(priority=400) def handler_04(self, node): """Example handler returing ExitCode 400.""" - handlers_called = node.get_attribute(attribute_key, default=[]) + handlers_called = node.base.attributes.get(attribute_key, default=[]) handlers_called.append('handler_04') - node.set_attribute(attribute_key, handlers_called) + node.base.attributes.set(attribute_key, handlers_called) return ProcessHandlerReport(False, ExitCode(400)) child = ProcessNode() @@ -105,7 +105,7 @@ def handler_04(self, node): # Last called handler should be `handler_01` which returned `ExitCode(100)` assert process.inspect_process() == ExitCode(100) - assert child.get_attribute(attribute_key, []) == ['handler_04', 'handler_03', 'handler_02', 'handler_01'] + assert child.base.attributes.get(attribute_key, []) == ['handler_04', 'handler_03', 'handler_02', 'handler_01'] def test_exit_codes_keyword_only(self): """The `exit_codes` should be keyword only.""" diff --git a/tests/engine/test_process_function.py b/tests/engine/test_process_function.py index ca25bc1216..21c5605bce 100644 --- a/tests/engine/test_process_function.py +++ b/tests/engine/test_process_function.py @@ -137,7 +137,7 @@ def test_plugin_version(self): # Since the "plugin" i.e. the process function is defined in `aiida-core` the `version.plugin` is the same as # the version of `aiida-core` itself - version_info = node.get_attribute('version') + version_info = node.base.attributes.get('version') assert version_info['core'] == version_core assert version_info['plugin'] == version_core diff --git a/tests/orm/data/test_enum.py b/tests/orm/data/test_enum.py index 08de37a8c5..6dcd47314c 100644 --- a/tests/orm/data/test_enum.py +++ b/tests/orm/data/test_enum.py @@ -99,7 +99,7 @@ def test_get_member_module_not_importable(): """Test the ``get_member`` property when the enum cannot be imported from the identifier.""" member = DummyEnum.OPTION_A node = EnumData(member) - node.set_attribute(EnumData.KEY_IDENTIFIER, 'aiida.common.links:NonExistingEnum') + node.base.attributes.set(EnumData.KEY_IDENTIFIER, 'aiida.common.links:NonExistingEnum') node.store() loaded = load_node(node.pk) @@ -139,9 +139,9 @@ def test_eq(): assert node_a != DummyEnum.OPTION_A.value # If the identifier cannot be resolved, the equality should not raise but simply return ``False``. - node_a.set_attribute(EnumData.KEY_IDENTIFIER, 'aiida.common.links:NonExistingEnum') + node_a.base.attributes.set(EnumData.KEY_IDENTIFIER, 'aiida.common.links:NonExistingEnum') assert node_a != DummyEnum.OPTION_A # If the value is incorrect for the resolved identifier, the equality should not raise but simply return ``False``. - node_b.set_attribute(EnumData.KEY_VALUE, 'c') + node_b.base.attributes.set(EnumData.KEY_VALUE, 'c') assert node_b != DummyEnum.OPTION_B diff --git a/tests/orm/nodes/data/test_jsonable.py b/tests/orm/nodes/data/test_jsonable.py index f4cc9bdb1c..da3b48c2fc 100644 --- a/tests/orm/nodes/data/test_jsonable.py +++ b/tests/orm/nodes/data/test_jsonable.py @@ -124,7 +124,7 @@ def test_unimportable_module(): node = JsonableData(obj) # Artificially change the ``@module`` in the attributes so it becomes unloadable - node.set_attribute('@module', 'not.existing') + node.base.attributes.set('@module', 'not.existing') node.store() loaded = load_node(node.pk) @@ -140,7 +140,7 @@ def test_unimportable_class(): node = JsonableData(obj) # Artificially change the ``@class`` in the attributes so it becomes unloadable - node.set_attribute('@class', 'NonExistingClass') + node.base.attributes.set('@class', 'NonExistingClass') node.store() loaded = load_node(node.pk) diff --git a/tests/orm/nodes/data/test_orbital.py b/tests/orm/nodes/data/test_orbital.py index 8239b5f206..ebe14ee21d 100644 --- a/tests/orm/nodes/data/test_orbital.py +++ b/tests/orm/nodes/data/test_orbital.py @@ -60,7 +60,7 @@ def test_real_hydrogen(self): #Check that a corrupted OribtalData fails on get_orbitals corrupted_orbitaldata = copy.deepcopy(orbitaldata) - del corrupted_orbitaldata.get_attribute('orbital_dicts')[0]['_orbital_type'] + del corrupted_orbitaldata.base.attributes.get('orbital_dicts')[0]['_orbital_type'] with pytest.raises(ValidationError): corrupted_orbitaldata.get_orbitals() diff --git a/tests/orm/nodes/data/test_remote.py b/tests/orm/nodes/data/test_remote.py index 8746deb1ae..1d3cec71ca 100644 --- a/tests/orm/nodes/data/test_remote.py +++ b/tests/orm/nodes/data/test_remote.py @@ -31,4 +31,4 @@ def test_clean(remote_data): remote_data._clean() # pylint: disable=protected-access assert remote_data.is_empty - assert remote_data.get_attribute(RemoteData.KEY_EXTRA_CLEANED, True) + assert remote_data.base.attributes.get(RemoteData.KEY_EXTRA_CLEANED, True) diff --git a/tests/orm/nodes/data/test_trajectory.py b/tests/orm/nodes/data/test_trajectory.py index d73835126e..7b55e93878 100644 --- a/tests/orm/nodes/data/test_trajectory.py +++ b/tests/orm/nodes/data/test_trajectory.py @@ -26,13 +26,13 @@ def test_get_attribute_tryexcept_default(self): Test whether the try_except statement on the get_attribute calls for units in the `show_mpl_*` functions except the correct exception type (for setting defaults). - Added for PR #5015 (behavior of BackendEntityAttributesMixin.get_attribute changed + Added for PR #5015 (behavior of BackendEntityAttributes.get changed to raise AttributeError instead of KeyError) """ tjd = TrajectoryData() try: - positions_unit = tjd.get_attribute('units|positions') + positions_unit = tjd.base.attributes.get('units|positions') except AttributeError: positions_unit = 'A' except KeyError: @@ -40,7 +40,7 @@ def test_get_attribute_tryexcept_default(self): assert positions_unit == 'A' try: - times_unit = tjd.get_attribute('units|times') + times_unit = tjd.base.attributes.get('units|times') except AttributeError: times_unit = 'ps' except KeyError: @@ -49,7 +49,7 @@ def test_get_attribute_tryexcept_default(self): positions = 1 try: - if self.get_attribute('units|positions') in ('bohr', 'atomic'): + if self.base.attributes.get('units|positions') in ('bohr', 'atomic'): bohr_to_ang = 0.52917720859 positions *= bohr_to_ang except AttributeError: @@ -62,15 +62,15 @@ def test_units(self): """Test the setting of units attributes.""" tjd = TrajectoryData() - tjd.set_attribute('units|positions', 'some_random_pos_unit') - assert tjd.get_attribute('units|positions') == 'some_random_pos_unit' + tjd.base.attributes.set('units|positions', 'some_random_pos_unit') + assert tjd.base.attributes.get('units|positions') == 'some_random_pos_unit' - tjd.set_attribute('units|times', 'some_random_time_unit') - assert tjd.get_attribute('units|times') == 'some_random_time_unit' + tjd.base.attributes.set('units|times', 'some_random_time_unit') + assert tjd.base.attributes.get('units|times') == 'some_random_time_unit' # Test after storing tjd.set_trajectory(self.symbols, self.positions) tjd.store() tjd2 = load_node(tjd.pk) - assert tjd2.get_attribute('units|positions') == 'some_random_pos_unit' - assert tjd2.get_attribute('units|times') == 'some_random_time_unit' + assert tjd2.base.attributes.get('units|positions') == 'some_random_pos_unit' + assert tjd2.base.attributes.get('units|times') == 'some_random_time_unit' diff --git a/tests/orm/nodes/test_calcjob.py b/tests/orm/nodes/test_calcjob.py index ae34376ce4..1eb7bde175 100644 --- a/tests/orm/nodes/test_calcjob.py +++ b/tests/orm/nodes/test_calcjob.py @@ -38,7 +38,7 @@ def test_get_set_state(self): assert node.get_state() == CalcJobState.UPLOADING # Setting an illegal calculation job state, the `get_state` should not fail but return `None` - node.set_attribute(node.CALC_JOB_STATE_KEY, 'INVALID') + node.base.attributes.set(node.CALC_JOB_STATE_KEY, 'INVALID') assert node.get_state() is None def test_get_scheduler_stdout(self): diff --git a/tests/orm/nodes/test_node.py b/tests/orm/nodes/test_node.py index 5b74267520..6e3e58e386 100644 --- a/tests/orm/nodes/test_node.py +++ b/tests/orm/nodes/test_node.py @@ -125,11 +125,11 @@ def setup_method(self): self.node = Data() def test_attributes(self): - """Test the `Node.attributes` property.""" + """Test the `Node.base.attributes.all` property.""" original_attribute = {'nested': {'a': 1}} - self.node.set_attribute('key', original_attribute) - node_attributes = self.node.attributes + self.node.base.attributes.set('key', original_attribute) + node_attributes = self.node.base.attributes.all assert node_attributes['key'] == original_attribute node_attributes['key']['nested']['a'] = 2 @@ -137,7 +137,7 @@ def test_attributes(self): # Now store the node and verify that `attributes` then returns a deep copy self.node.store() - node_attributes = self.node.attributes + node_attributes = self.node.base.attributes.all # We change the returned node attributes but the original attribute should remain unchanged node_attributes['key']['nested']['a'] = 3 @@ -147,37 +147,37 @@ def test_get_attribute(self): """Test the `Node.get_attribute` method.""" original_attribute = {'nested': {'a': 1}} - self.node.set_attribute('key', original_attribute) - node_attribute = self.node.get_attribute('key') + self.node.base.attributes.set('key', original_attribute) + node_attribute = self.node.base.attributes.get('key') assert node_attribute == original_attribute node_attribute['nested']['a'] = 2 assert original_attribute['nested']['a'] == 2 default = 'default' - assert self.node.get_attribute('not_existing', default=default) == default + assert self.node.base.attributes.get('not_existing', default=default) == default with pytest.raises(AttributeError): - self.node.get_attribute('not_existing') + self.node.base.attributes.get('not_existing') # Now store the node and verify that `get_attribute` then returns a deep copy self.node.store() - node_attribute = self.node.get_attribute('key') + node_attribute = self.node.base.attributes.get('key') # We change the returned node attributes but the original attribute should remain unchanged node_attribute['nested']['a'] = 3 assert original_attribute['nested']['a'] == 2 default = 'default' - assert self.node.get_attribute('not_existing', default=default) == default + assert self.node.base.attributes.get('not_existing', default=default) == default with pytest.raises(AttributeError): - self.node.get_attribute('not_existing') + self.node.base.attributes.get('not_existing') def test_get_attribute_many(self): """Test the `Node.get_attribute_many` method.""" original_attribute = {'nested': {'a': 1}} - self.node.set_attribute('key', original_attribute) - node_attribute = self.node.get_attribute_many(['key'])[0] + self.node.base.attributes.set('key', original_attribute) + node_attribute = self.node.base.attributes.get_many(['key'])[0] assert node_attribute == original_attribute node_attribute['nested']['a'] = 2 @@ -185,7 +185,7 @@ def test_get_attribute_many(self): # Now store the node and verify that `get_attribute` then returns a deep copy self.node.store() - node_attribute = self.node.get_attribute_many(['key'])[0] + node_attribute = self.node.base.attributes.get_many(['key'])[0] # We change the returned node attributes but the original attribute should remain unchanged node_attribute['nested']['a'] = 3 @@ -194,24 +194,24 @@ def test_get_attribute_many(self): def test_set_attribute(self): """Test the `Node.set_attribute` method.""" with pytest.raises(exceptions.ValidationError): - self.node.set_attribute('illegal.key', 'value') + self.node.base.attributes.set('illegal.key', 'value') - self.node.set_attribute('valid_key', 'value') + self.node.base.attributes.set('valid_key', 'value') self.node.store() with pytest.raises(exceptions.ModificationNotAllowed): - self.node.set_attribute('valid_key', 'value') + self.node.base.attributes.set('valid_key', 'value') def test_set_attribute_many(self): """Test the `Node.set_attribute` method.""" with pytest.raises(exceptions.ValidationError): - self.node.set_attribute_many({'illegal.key': 'value', 'valid_key': 'value'}) + self.node.base.attributes.set_many({'illegal.key': 'value', 'valid_key': 'value'}) - self.node.set_attribute_many({'valid_key': 'value'}) + self.node.base.attributes.set_many({'valid_key': 'value'}) self.node.store() with pytest.raises(exceptions.ModificationNotAllowed): - self.node.set_attribute_many({'valid_key': 'value'}) + self.node.base.attributes.set_many({'valid_key': 'value'}) def test_reset_attribute(self): """Test the `Node.reset_attribute` method.""" @@ -219,34 +219,34 @@ def test_reset_attribute(self): attributes_after = {'attribute_three': 'value', 'attribute_four': 'value'} attributes_illegal = {'attribute.illegal': 'value', 'attribute_four': 'value'} - self.node.set_attribute_many(attributes_before) - assert self.node.attributes == attributes_before - self.node.reset_attributes(attributes_after) - assert self.node.attributes == attributes_after + self.node.base.attributes.set_many(attributes_before) + assert self.node.base.attributes.all == attributes_before + self.node.base.attributes.reset(attributes_after) + assert self.node.base.attributes.all == attributes_after with pytest.raises(exceptions.ValidationError): - self.node.reset_attributes(attributes_illegal) + self.node.base.attributes.reset(attributes_illegal) self.node.store() with pytest.raises(exceptions.ModificationNotAllowed): - self.node.reset_attributes(attributes_after) + self.node.base.attributes.reset(attributes_after) def test_delete_attribute(self): """Test the `Node.delete_attribute` method.""" - self.node.set_attribute('valid_key', 'value') - assert self.node.get_attribute('valid_key') == 'value' - self.node.delete_attribute('valid_key') + self.node.base.attributes.set('valid_key', 'value') + assert self.node.base.attributes.get('valid_key') == 'value' + self.node.base.attributes.delete('valid_key') with pytest.raises(AttributeError): - self.node.delete_attribute('valid_key') + self.node.base.attributes.delete('valid_key') # Repeat with stored node - self.node.set_attribute('valid_key', 'value') + self.node.base.attributes.set('valid_key', 'value') self.node.store() with pytest.raises(exceptions.ModificationNotAllowed): - self.node.delete_attribute('valid_key') + self.node.base.attributes.delete('valid_key') def test_delete_attribute_many(self): """Test the `Node.delete_attribute_many` method.""" @@ -254,29 +254,29 @@ def test_delete_attribute_many(self): def test_clear_attributes(self): """Test the `Node.clear_attributes` method.""" attributes = {'attribute_one': 'value', 'attribute_two': 'value'} - self.node.set_attribute_many(attributes) - assert self.node.attributes == attributes + self.node.base.attributes.set_many(attributes) + assert self.node.base.attributes.all == attributes - self.node.clear_attributes() - assert self.node.attributes == {} + self.node.base.attributes.clear() + assert self.node.base.attributes.all == {} # Repeat for stored node self.node.store() with pytest.raises(exceptions.ModificationNotAllowed): - self.node.clear_attributes() + self.node.base.attributes.clear() def test_attributes_items(self): - """Test the `Node.attributes_items` generator.""" + """Test the `Node.base.attributes.items` generator.""" attributes = {'attribute_one': 'value', 'attribute_two': 'value'} - self.node.set_attribute_many(attributes) - assert dict(self.node.attributes_items()) == attributes + self.node.base.attributes.set_many(attributes) + assert dict(self.node.base.attributes.items()) == attributes def test_attributes_keys(self): - """Test the `Node.attributes_keys` generator.""" + """Test the `Node.base.attributes.keys` generator.""" attributes = {'attribute_one': 'value', 'attribute_two': 'value'} - self.node.set_attribute_many(attributes) - assert set(self.node.attributes_keys()) == set(attributes) + self.node.base.attributes.set_many(attributes) + assert set(self.node.base.attributes.keys()) == set(attributes) def test_extras(self): """Test the `Node.extras` property.""" @@ -449,10 +449,10 @@ def test_extras_keys(self): def test_attribute_decimal(self): """Test that the `Node.set_attribute` method supports Decimal.""" - self.node.set_attribute('a_val', Decimal('3.141')) + self.node.base.attributes.set('a_val', Decimal('3.141')) self.node.store() # ensure the returned node is a float - assert self.node.get_attribute('a_val') == 3.141 + assert self.node.base.attributes.get('a_val') == 3.141 @pytest.mark.usefixtures('aiida_profile_clean_class') diff --git a/tests/orm/test_mixins.py b/tests/orm/test_mixins.py index dd822be348..3f868880f3 100644 --- a/tests/orm/test_mixins.py +++ b/tests/orm/test_mixins.py @@ -30,7 +30,7 @@ def test_change_updatable_attrs_after_store(): for attr in CalculationNode._updatable_attributes: # pylint: disable=protected-access,not-an-iterable if attr != Sealable.SEALED_KEY: - node.set_attribute(attr, 'a') + node.base.attributes.set(attr, 'a') def test_validate_incoming_sealed(self): """Verify that trying to add a link to a sealed node will raise.""" diff --git a/tests/orm/test_querybuilder.py b/tests/orm/test_querybuilder.py index 2d950aacb6..c8dab0d11a 100644 --- a/tests/orm/test_querybuilder.py +++ b/tests/orm/test_querybuilder.py @@ -228,25 +228,25 @@ def test_simple_query_1(self): n1 = orm.Data() n1.label = 'node1' - n1.set_attribute('foo', ['hello', 'goodbye']) + n1.base.attributes.set('foo', ['hello', 'goodbye']) n1.store() n2 = orm.CalculationNode() n2.label = 'node2' - n2.set_attribute('foo', 1) + n2.base.attributes.set('foo', 1) n3 = orm.Data() n3.label = 'node3' - n3.set_attribute('foo', 1.0000) # Stored as fval + n3.base.attributes.set('foo', 1.0000) # Stored as fval n3.store() n4 = orm.CalculationNode() n4.label = 'node4' - n4.set_attribute('foo', 'bar') + n4.base.attributes.set('foo', 'bar') n5 = orm.Data() n5.label = 'node5' - n5.set_attribute('foo', None) + n5.base.attributes.set('foo', None) n5.store() n2.add_incoming(n1, link_type=LinkType.INPUT_CALC, link_label='link1') @@ -295,7 +295,7 @@ def test_simple_query_2(self): n0 = orm.Data() n0.label = 'hello' n0.description = '' - n0.set_attribute('foo', 'bar') + n0.base.attributes.set('foo', 'bar') n1 = orm.CalculationNode() n1.label = 'foo' @@ -405,14 +405,14 @@ def test_dict_multiple_projections(self): def test_operators_eq_lt_gt(self): nodes = [orm.Data() for _ in range(8)] - nodes[0].set_attribute('fa', 1) - nodes[1].set_attribute('fa', 1.0) - nodes[2].set_attribute('fa', 1.01) - nodes[3].set_attribute('fa', 1.02) - nodes[4].set_attribute('fa', 1.03) - nodes[5].set_attribute('fa', 1.04) - nodes[6].set_attribute('fa', 1.05) - nodes[7].set_attribute('fa', 1.06) + nodes[0].base.attributes.set('fa', 1) + nodes[1].base.attributes.set('fa', 1.0) + nodes[2].base.attributes.set('fa', 1.01) + nodes[3].base.attributes.set('fa', 1.02) + nodes[4].base.attributes.set('fa', 1.03) + nodes[5].base.attributes.set('fa', 1.04) + nodes[6].base.attributes.set('fa', 1.05) + nodes[7].base.attributes.set('fa', 1.06) for n in nodes: n.store() @@ -426,11 +426,11 @@ def test_operators_eq_lt_gt(self): def test_subclassing(self): s = orm.StructureData() - s.set_attribute('cat', 'miau') + s.base.attributes.set('cat', 'miau') s.store() d = orm.Data() - d.set_attribute('cat', 'miau') + d.base.attributes.set('cat', 'miau') d.store() p = orm.Dict(dict=dict(cat='miau')) @@ -794,7 +794,7 @@ def test_round_trip_append(self): g = orm.Group(label='helloworld').store() for cls in (orm.StructureData, orm.Dict, orm.Data): obj = cls() - obj.set_attribute('foo-qh2', 'bar') + obj.base.attributes.set('foo-qh2', 'bar') obj.store() g.add_nodes(obj) @@ -845,7 +845,7 @@ def test_computer_json(self): """ n1 = orm.CalculationNode() n1.label = 'node2' - n1.set_attribute('foo', 1) + n1.base.attributes.set('foo', 1) n1.store() # Checking the correct retrieval of _metadata which is @@ -872,8 +872,8 @@ def test_attribute_existence(self): val = 1. res_uuids = set() n1 = orm.Data() - n1.set_attribute('whatever', 3.) - n1.set_attribute('test_case', 'test_attribute_existence') + n1.base.attributes.set('whatever', 3.) + n1.base.attributes.set('test_case', 'test_attribute_existence') n1.store() # I want all the nodes where whatever is smaller than 1. or there is no such value: @@ -900,12 +900,12 @@ def test_attribute_existence(self): def test_attribute_type(self): key = 'value_test_attr_type' n_int, n_float, n_str, n_str2, n_bool, n_arr = [orm.Data() for _ in range(6)] - n_int.set_attribute(key, 1) - n_float.set_attribute(key, 1.0) - n_bool.set_attribute(key, True) - n_str.set_attribute(key, '1') - n_str2.set_attribute(key, 'one') - n_arr.set_attribute(key, [4, 3, 5]) + n_int.base.attributes.set(key, 1) + n_float.base.attributes.set(key, 1.0) + n_bool.base.attributes.set(key, True) + n_str.base.attributes.set(key, '1') + n_str2.base.attributes.set(key, 'one') + n_arr.base.attributes.set(key, [4, 3, 5]) for n in (n_str2, n_str, n_int, n_float, n_bool, n_arr): n.store() @@ -951,7 +951,7 @@ def test_ordering_limits_offsets_of_results_general(self): # Creating 10 nodes with an attribute that can be ordered for i in range(10): n = orm.Data() - n.set_attribute('foo', i) + n.base.attributes.set('foo', i) n.store() qb = orm.QueryBuilder().append(orm.Node, project='attributes.foo').order_by({orm.Node: 'ctime'}) @@ -1023,11 +1023,11 @@ def test_joins_node_incoming(self): good_child = orm.CalculationNode() good_child.label = 'good_child' - good_child.set_attribute('is_good', True) + good_child.base.attributes.set('is_good', True) bad_child = orm.CalculationNode() bad_child.label = 'bad_child' - bad_child.set_attribute('is_good', False) + bad_child.base.attributes.set('is_good', False) unrelated = orm.CalculationNode() unrelated.label = 'unrelated' @@ -1056,7 +1056,7 @@ def test_joins_node_incoming2(self): advisors = [orm.CalculationNode() for i in range(3)] for i, a in enumerate(advisors): a.label = f'advisor {i}' - a.set_attribute('advisor_id', i) + a.base.attributes.set('advisor_id', i) for n in advisors + students: n.store() @@ -1167,22 +1167,22 @@ def test_joins_group_node(self): # Create nodes and add them to the created group n1 = orm.Data() n1.label = 'node1' - n1.set_attribute('foo', ['hello', 'goodbye']) + n1.base.attributes.set('foo', ['hello', 'goodbye']) n1.store() n2 = orm.CalculationNode() n2.label = 'node2' - n2.set_attribute('foo', 1) + n2.base.attributes.set('foo', 1) n2.store() n3 = orm.Data() n3.label = 'node3' - n3.set_attribute('foo', 1.0000) # Stored as fval + n3.base.attributes.set('foo', 1.0000) # Stored as fval n3.store() n4 = orm.CalculationNode() n4.label = 'node4' - n4.set_attribute('foo', 'bar') + n4.base.attributes.set('foo', 'bar') n4.store() group.add_nodes([n1, n2, n3, n4]) diff --git a/tests/restapi/test_routes.py b/tests/restapi/test_routes.py index 36fb05b5f0..d86a5c5fcd 100644 --- a/tests/restapi/test_routes.py +++ b/tests/restapi/test_routes.py @@ -71,8 +71,8 @@ def init_profile(self, aiida_profile_clean, aiida_localhost): # pylint: disable calc = orm.CalcJobNode(computer=self.computer) calc.set_option('resources', resources) - calc.set_attribute('attr1', 'OK') - calc.set_attribute('attr2', 'OK') + calc.base.attributes.set('attr1', 'OK') + calc.base.attributes.set('attr2', 'OK') calc.set_extra('extra1', False) calc.set_extra('extra2', 'extra_info') diff --git a/tests/storage/psql_dos/test_query.py b/tests/storage/psql_dos/test_query.py index 27340d2021..dff9b78085 100644 --- a/tests/storage/psql_dos/test_query.py +++ b/tests/storage/psql_dos/test_query.py @@ -30,7 +30,7 @@ def test_qb_ordering_limits_offsets_sqla(): # Creating 10 nodes with an attribute that can be ordered for i in range(10): node = Data() - node.set_attribute('foo', i) + node.base.attributes.set('foo', i) node.store() q_b = QueryBuilder().append(Node, project='attributes.foo').order_by({Node: {'attributes.foo': {'cast': 'i'}}}) res = next(zip(*q_b.all())) diff --git a/tests/test_calculation_node.py b/tests/test_calculation_node.py index ea99b606fd..4f3cbaf20f 100644 --- a/tests/test_calculation_node.py +++ b/tests/test_calculation_node.py @@ -82,39 +82,39 @@ def test_process_node_updatable_attribute(self): } for key, value in attrs_to_set.items(): - node.set_attribute(key, value) + node.base.attributes.set(key, value) # Check before storing - node.set_attribute(CalculationNode.PROCESS_STATE_KEY, self.stateval) - assert node.get_attribute(CalculationNode.PROCESS_STATE_KEY) == self.stateval + node.base.attributes.set(CalculationNode.PROCESS_STATE_KEY, self.stateval) + assert node.base.attributes.get(CalculationNode.PROCESS_STATE_KEY) == self.stateval node.store() # Check after storing - assert node.get_attribute(CalculationNode.PROCESS_STATE_KEY) == self.stateval + assert node.base.attributes.get(CalculationNode.PROCESS_STATE_KEY) == self.stateval # I should be able to mutate the updatable attribute but not the others - node.set_attribute(CalculationNode.PROCESS_STATE_KEY, 'FINISHED') - node.delete_attribute(CalculationNode.PROCESS_STATE_KEY) + node.base.attributes.set(CalculationNode.PROCESS_STATE_KEY, 'FINISHED') + node.base.attributes.delete(CalculationNode.PROCESS_STATE_KEY) # Deleting non-existing attribute should raise attribute error with pytest.raises(AttributeError): - node.delete_attribute(CalculationNode.PROCESS_STATE_KEY) + node.base.attributes.delete(CalculationNode.PROCESS_STATE_KEY) with pytest.raises(ModificationNotAllowed): - node.set_attribute('bool', False) + node.base.attributes.set('bool', False) with pytest.raises(ModificationNotAllowed): - node.delete_attribute('bool') + node.base.attributes.delete('bool') node.seal() # After sealing, even updatable attributes should be immutable with pytest.raises(ModificationNotAllowed): - node.set_attribute(CalculationNode.PROCESS_STATE_KEY, 'FINISHED') + node.base.attributes.set(CalculationNode.PROCESS_STATE_KEY, 'FINISHED') with pytest.raises(ModificationNotAllowed): - node.delete_attribute(CalculationNode.PROCESS_STATE_KEY) + node.base.attributes.delete(CalculationNode.PROCESS_STATE_KEY) def test_get_description(self): assert self.calcjob.get_description() == '' diff --git a/tests/test_dataclasses.py b/tests/test_dataclasses.py index 000f7207f2..5de313eb4c 100644 --- a/tests/test_dataclasses.py +++ b/tests/test_dataclasses.py @@ -673,8 +673,8 @@ def test_set_file(self): # this should reset formulae and spacegroup_numbers a.set_file(tmpf.name) - assert a.get_attribute('formulae') is None - assert a.get_attribute('spacegroup_numbers') is None + assert a.base.attributes.get('formulae') is None + assert a.base.attributes.get('spacegroup_numbers') is None # this should populate formulae a.parse() @@ -3379,7 +3379,7 @@ def test_aiida_roundtrip(self): roundtrip_struc = spglib_tuple_to_structure(struc_tuple, kind_info, kinds) assert round(abs(np.sum(np.abs(np.array(struc.cell) - np.array(roundtrip_struc.cell))) - 0.), 7) == 0 - assert struc.get_attribute('kinds') == roundtrip_struc.get_attribute('kinds') + assert struc.base.attributes.get('kinds') == roundtrip_struc.base.attributes.get('kinds') assert [_.kind_name for _ in struc.sites] == [_.kind_name for _ in roundtrip_struc.sites] assert np.sum( np.abs(np.array([_.position for _ in struc.sites]) - np.array([_.position for _ in roundtrip_struc.sites])) diff --git a/tests/test_dbimporters.py b/tests/test_dbimporters.py index cd19edf7a3..aadc758a9e 100644 --- a/tests/test_dbimporters.py +++ b/tests/test_dbimporters.py @@ -144,7 +144,7 @@ def test_dbentry_to_cif_node(self): cif = entry.get_cif_node() assert isinstance(cif, CifData) is True - assert cif.get_attribute('md5') == '070711e8e99108aade31d20cd5c94c48' + assert cif.base.attributes.get('md5') == '070711e8e99108aade31d20cd5c94c48' assert cif.source == { 'db_name': 'Crystallography Open Database', 'db_uri': 'http://www.crystallography.net/cod', diff --git a/tests/test_nodes.py b/tests/test_nodes.py index 0d88e595fb..ca521724c4 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -74,9 +74,9 @@ class TestNodeHashing: @staticmethod def create_simple_node(a, b=0, c=0): n = orm.Data() - n.set_attribute('a', a) - n.set_attribute('b', b) - n.set_attribute('c', c) + n.base.attributes.set('a', a) + n.base.attributes.set('b', b) + n.base.attributes.set('c', c) return n def test_node_uuid_hashing_for_querybuidler(self): @@ -295,78 +295,78 @@ def test_attribute_mutability(self): disabled in the call """ a = orm.Data() - a.set_attribute('bool', self.boolval) - a.set_attribute('integer', self.intval) + a.base.attributes.set('bool', self.boolval) + a.base.attributes.set('integer', self.intval) a.store() # After storing attributes should now be immutable with pytest.raises(ModificationNotAllowed): - a.delete_attribute('bool') + a.base.attributes.delete('bool') with pytest.raises(ModificationNotAllowed): - a.set_attribute('integer', self.intval) + a.base.attributes.set('integer', self.intval) def test_attr_before_storing(self): a = orm.Data() - a.set_attribute('k1', self.boolval) - a.set_attribute('k2', self.intval) - a.set_attribute('k3', self.floatval) - a.set_attribute('k4', self.stringval) - a.set_attribute('k5', self.dictval) - a.set_attribute('k6', self.listval) - a.set_attribute('k7', self.emptydict) - a.set_attribute('k8', self.emptylist) - a.set_attribute('k9', None) + a.base.attributes.set('k1', self.boolval) + a.base.attributes.set('k2', self.intval) + a.base.attributes.set('k3', self.floatval) + a.base.attributes.set('k4', self.stringval) + a.base.attributes.set('k5', self.dictval) + a.base.attributes.set('k6', self.listval) + a.base.attributes.set('k7', self.emptydict) + a.base.attributes.set('k8', self.emptylist) + a.base.attributes.set('k9', None) # Now I check if I can retrieve them, before the storage - assert self.boolval == a.get_attribute('k1') - assert self.intval == a.get_attribute('k2') - assert self.floatval == a.get_attribute('k3') - assert self.stringval == a.get_attribute('k4') - assert self.dictval == a.get_attribute('k5') - assert self.listval == a.get_attribute('k6') - assert self.emptydict == a.get_attribute('k7') - assert self.emptylist == a.get_attribute('k8') - assert a.get_attribute('k9') is None + assert self.boolval == a.base.attributes.get('k1') + assert self.intval == a.base.attributes.get('k2') + assert self.floatval == a.base.attributes.get('k3') + assert self.stringval == a.base.attributes.get('k4') + assert self.dictval == a.base.attributes.get('k5') + assert self.listval == a.base.attributes.get('k6') + assert self.emptydict == a.base.attributes.get('k7') + assert self.emptylist == a.base.attributes.get('k8') + assert a.base.attributes.get('k9') is None # And now I try to delete the keys - a.delete_attribute('k1') - a.delete_attribute('k2') - a.delete_attribute('k3') - a.delete_attribute('k4') - a.delete_attribute('k5') - a.delete_attribute('k6') - a.delete_attribute('k7') - a.delete_attribute('k8') - a.delete_attribute('k9') + a.base.attributes.delete('k1') + a.base.attributes.delete('k2') + a.base.attributes.delete('k3') + a.base.attributes.delete('k4') + a.base.attributes.delete('k5') + a.base.attributes.delete('k6') + a.base.attributes.delete('k7') + a.base.attributes.delete('k8') + a.base.attributes.delete('k9') with pytest.raises(AttributeError): # I delete twice the same attribute - a.delete_attribute('k1') + a.base.attributes.delete('k1') with pytest.raises(AttributeError): # I delete a non-existing attribute - a.delete_attribute('nonexisting') + a.base.attributes.delete('nonexisting') with pytest.raises(AttributeError): # I get a deleted attribute - a.get_attribute('k1') + a.base.attributes.get('k1') with pytest.raises(AttributeError): # I get a non-existing attribute - a.get_attribute('nonexisting') + a.base.attributes.get('nonexisting') def test_get_attrs_before_storing(self): a = orm.Data() - a.set_attribute('k1', self.boolval) - a.set_attribute('k2', self.intval) - a.set_attribute('k3', self.floatval) - a.set_attribute('k4', self.stringval) - a.set_attribute('k5', self.dictval) - a.set_attribute('k6', self.listval) - a.set_attribute('k7', self.emptydict) - a.set_attribute('k8', self.emptylist) - a.set_attribute('k9', None) + a.base.attributes.set('k1', self.boolval) + a.base.attributes.set('k2', self.intval) + a.base.attributes.set('k3', self.floatval) + a.base.attributes.set('k4', self.stringval) + a.base.attributes.set('k5', self.dictval) + a.base.attributes.set('k6', self.listval) + a.base.attributes.set('k7', self.emptydict) + a.base.attributes.set('k8', self.emptylist) + a.base.attributes.set('k9', None) target_attrs = { 'k1': self.boolval, @@ -381,32 +381,32 @@ def test_get_attrs_before_storing(self): } # Now I check if I can retrieve them, before the storage - assert a.attributes == target_attrs + assert a.base.attributes.all == target_attrs # And now I try to delete the keys - a.delete_attribute('k1') - a.delete_attribute('k2') - a.delete_attribute('k3') - a.delete_attribute('k4') - a.delete_attribute('k5') - a.delete_attribute('k6') - a.delete_attribute('k7') - a.delete_attribute('k8') - a.delete_attribute('k9') - - assert a.attributes == {} + a.base.attributes.delete('k1') + a.base.attributes.delete('k2') + a.base.attributes.delete('k3') + a.base.attributes.delete('k4') + a.base.attributes.delete('k5') + a.base.attributes.delete('k6') + a.base.attributes.delete('k7') + a.base.attributes.delete('k8') + a.base.attributes.delete('k9') + + assert a.base.attributes.all == {} def test_get_attrs_after_storing(self): a = orm.Data() - a.set_attribute('k1', self.boolval) - a.set_attribute('k2', self.intval) - a.set_attribute('k3', self.floatval) - a.set_attribute('k4', self.stringval) - a.set_attribute('k5', self.dictval) - a.set_attribute('k6', self.listval) - a.set_attribute('k7', self.emptydict) - a.set_attribute('k8', self.emptylist) - a.set_attribute('k9', None) + a.base.attributes.set('k1', self.boolval) + a.base.attributes.set('k2', self.intval) + a.base.attributes.set('k3', self.floatval) + a.base.attributes.set('k4', self.stringval) + a.base.attributes.set('k5', self.dictval) + a.base.attributes.set('k6', self.listval) + a.base.attributes.set('k7', self.emptydict) + a.base.attributes.set('k8', self.emptylist) + a.base.attributes.set('k9', None) a.store() @@ -423,18 +423,18 @@ def test_get_attrs_after_storing(self): } # Now I check if I can retrieve them, before the storage - assert a.attributes == target_attrs + assert a.base.attributes.all == target_attrs def test_store_object(self): """Trying to set objects as attributes should fail, because they are not json-serializable.""" a = orm.Data() - a.set_attribute('object', object()) + a.base.attributes.set('object', object()) with pytest.raises(ValidationError): a.store() b = orm.Data() - b.set_attribute('object_list', [object(), object()]) + b.base.attributes.set('object_list', [object(), object()]) with pytest.raises(ValidationError): b.store() @@ -453,20 +453,20 @@ def test_attributes_on_clone(self): } for k, v in attrs_to_set.items(): - a.set_attribute(k, v) + a.base.attributes.set(k, v) # Create a copy b = copy.deepcopy(a) # I modify an attribute and add a new one; I mirror it in the dictionary # for later checking b_expected_attributes = copy.deepcopy(attrs_to_set) - b.set_attribute('integer', 489) + b.base.attributes.set('integer', 489) b_expected_attributes['integer'] = 489 - b.set_attribute('new', 'cvb') + b.base.attributes.set('new', 'cvb') b_expected_attributes['new'] = 'cvb' # I check before storing that the attributes are ok - assert b.attributes == b_expected_attributes + assert b.base.attributes.all == b_expected_attributes # Note that during copy, I do not copy the extras! assert b.extras == {} @@ -477,10 +477,10 @@ def test_attributes_on_clone(self): b_expected_extras = {'meta': 'textofext', '_aiida_hash': AnyValue()} # Now I check that the attributes of the original node have not changed - assert a.attributes == attrs_to_set + assert a.base.attributes.all == attrs_to_set # I check then on the 'b' copy - assert b.attributes == b_expected_attributes + assert b.base.attributes.all == b_expected_attributes assert b.extras == b_expected_extras def test_files(self): @@ -687,45 +687,45 @@ def test_folders(self): def test_attr_after_storing(self): a = orm.Data() - a.set_attribute('none', None) - a.set_attribute('bool', self.boolval) - a.set_attribute('integer', self.intval) - a.set_attribute('float', self.floatval) - a.set_attribute('string', self.stringval) - a.set_attribute('dict', self.dictval) - a.set_attribute('list', self.listval) + a.base.attributes.set('none', None) + a.base.attributes.set('bool', self.boolval) + a.base.attributes.set('integer', self.intval) + a.base.attributes.set('float', self.floatval) + a.base.attributes.set('string', self.stringval) + a.base.attributes.set('dict', self.dictval) + a.base.attributes.set('list', self.listval) a.store() # Now I check if I can retrieve them, before the storage - assert a.get_attribute('none') is None - assert self.boolval == a.get_attribute('bool') - assert self.intval == a.get_attribute('integer') - assert self.floatval == a.get_attribute('float') - assert self.stringval == a.get_attribute('string') - assert self.dictval == a.get_attribute('dict') - assert self.listval == a.get_attribute('list') + assert a.base.attributes.get('none') is None + assert self.boolval == a.base.attributes.get('bool') + assert self.intval == a.base.attributes.get('integer') + assert self.floatval == a.base.attributes.get('float') + assert self.stringval == a.base.attributes.get('string') + assert self.dictval == a.base.attributes.get('dict') + assert self.listval == a.base.attributes.get('list') def test_attr_with_reload(self): a = orm.Data() - a.set_attribute('none', None) - a.set_attribute('bool', self.boolval) - a.set_attribute('integer', self.intval) - a.set_attribute('float', self.floatval) - a.set_attribute('string', self.stringval) - a.set_attribute('dict', self.dictval) - a.set_attribute('list', self.listval) + a.base.attributes.set('none', None) + a.base.attributes.set('bool', self.boolval) + a.base.attributes.set('integer', self.intval) + a.base.attributes.set('float', self.floatval) + a.base.attributes.set('string', self.stringval) + a.base.attributes.set('dict', self.dictval) + a.base.attributes.set('list', self.listval) a.store() b = orm.load_node(uuid=a.uuid) - assert a.get_attribute('none') is None - assert self.boolval == b.get_attribute('bool') - assert self.intval == b.get_attribute('integer') - assert self.floatval == b.get_attribute('float') - assert self.stringval == b.get_attribute('string') - assert self.dictval == b.get_attribute('dict') - assert self.listval == b.get_attribute('list') + assert a.base.attributes.get('none') is None + assert self.boolval == b.base.attributes.get('bool') + assert self.intval == b.base.attributes.get('integer') + assert self.floatval == b.base.attributes.get('float') + assert self.stringval == b.base.attributes.get('string') + assert self.dictval == b.base.attributes.get('dict') + assert self.listval == b.base.attributes.get('list') def test_extra_with_reload(self): a = orm.Data() @@ -805,7 +805,7 @@ def test_attr_listing(self): } for k, v in attrs_to_set.items(): - a.set_attribute(k, v) + a.base.attributes.set(k, v) a.store() @@ -817,10 +817,10 @@ def test_attr_listing(self): all_extras = dict(_aiida_hash=AnyValue(), **extras_to_set) - assert set(list(a.attributes.keys())) == set(attrs_to_set.keys()) + assert set(list(a.base.attributes.keys())) == set(attrs_to_set.keys()) assert set(list(a.extras.keys())) == set(all_extras.keys()) - assert a.attributes == attrs_to_set + assert a.base.attributes.all == attrs_to_set assert a.extras == all_extras @@ -926,28 +926,28 @@ def test_basetype_as_attr(self): # Manages to store, and value is converted to its base type p = orm.Dict(dict={'b': orm.Str('sometext'), 'c': l1}) p.store() - assert p.get_attribute('b') == 'sometext' - assert isinstance(p.get_attribute('b'), str) - assert p.get_attribute('c') == ['b', [1, 2]] - assert isinstance(p.get_attribute('c'), (list, tuple)) + assert p.base.attributes.get('b') == 'sometext' + assert isinstance(p.base.attributes.get('b'), str) + assert p.base.attributes.get('c') == ['b', [1, 2]] + assert isinstance(p.base.attributes.get('c'), (list, tuple)) # Check also before storing n = orm.Data() - n.set_attribute('a', orm.Str('sometext2')) - n.set_attribute('b', l2) - assert n.get_attribute('a').value == 'sometext2' - assert isinstance(n.get_attribute('a'), orm.Str) - assert n.get_attribute('b').get_list() == ['f', True, {'gg': None}] - assert isinstance(n.get_attribute('b'), orm.List) + n.base.attributes.set('a', orm.Str('sometext2')) + n.base.attributes.set('b', l2) + assert n.base.attributes.get('a').value == 'sometext2' + assert isinstance(n.base.attributes.get('a'), orm.Str) + assert n.base.attributes.get('b').get_list() == ['f', True, {'gg': None}] + assert isinstance(n.base.attributes.get('b'), orm.List) # Check also deep in a dictionary/list n = orm.Data() - n.set_attribute('a', {'b': [orm.Str('sometext3')]}) - assert n.get_attribute('a')['b'][0].value == 'sometext3' - assert isinstance(n.get_attribute('a')['b'][0], orm.Str) + n.base.attributes.set('a', {'b': [orm.Str('sometext3')]}) + assert n.base.attributes.get('a')['b'][0].value == 'sometext3' + assert isinstance(n.base.attributes.get('a')['b'][0], orm.Str) n.store() - assert n.get_attribute('a')['b'][0] == 'sometext3' - assert isinstance(n.get_attribute('a')['b'][0], str) + assert n.base.attributes.get('a')['b'][0] == 'sometext3' + assert isinstance(n.base.attributes.get('a')['b'][0], str) def test_basetype_as_extra(self): """ diff --git a/tests/tools/archive/orm/test_attributes.py b/tests/tools/archive/orm/test_attributes.py index be5fe2451b..3e2b53c71e 100644 --- a/tests/tools/archive/orm/test_attributes.py +++ b/tests/tools/archive/orm/test_attributes.py @@ -20,7 +20,7 @@ def test_import_of_attributes(tmp_path, aiida_profile): # Create Data with attributes data = orm.Data() data.label = 'my_test_data_node' - data.set_attribute_many({'b': 2, 'c': 3}) + data.base.attributes.set_many({'b': 2, 'c': 3}) data.store() # Export @@ -35,8 +35,8 @@ def test_import_of_attributes(tmp_path, aiida_profile): assert builder.count() == 1 imported_node = builder.all(flat=True)[0] - assert imported_node.get_attribute('b') == 2 - assert imported_node.get_attribute('c') == 3 + assert imported_node.base.attributes.get('b') == 2 + assert imported_node.base.attributes.get('c') == 3 @pytest.mark.usefixtures('aiida_profile_clean') diff --git a/tests/tools/archive/test_complex.py b/tests/tools/archive/test_complex.py index 4b882e87c4..2e1860862a 100644 --- a/tests/tools/archive/test_complex.py +++ b/tests/tools/archive/test_complex.py @@ -128,7 +128,7 @@ def test_reexport(aiida_profile, tmp_path): calc = orm.CalculationNode() # setting also trial dict as attributes, but randomizing the keys) for key, value in trial_dict.items(): - calc.set_attribute(str(int(key) + np.random.randint(10)), value) + calc.base.attributes.set(str(int(key) + np.random.randint(10)), value) array = orm.ArrayData() array.set_array('array', nparr) array.store() diff --git a/tests/tools/archive/test_simple.py b/tests/tools/archive/test_simple.py index bee2df8468..624030fa03 100644 --- a/tests/tools/archive/test_simple.py +++ b/tests/tools/archive/test_simple.py @@ -64,7 +64,7 @@ def test_calc_of_structuredata(aiida_profile_clean, tmp_path, aiida_localhost): node = orm.load_node(pk) attrs[node.uuid] = {} for k in node.attributes.keys(): - attrs[node.uuid][k] = node.get_attribute(k) + attrs[node.uuid][k] = node.base.attributes.get(k) filename = str(tmp_path / 'export.aiida') @@ -76,7 +76,7 @@ def test_calc_of_structuredata(aiida_profile_clean, tmp_path, aiida_localhost): for uuid, value in attrs.items(): node = orm.load_node(uuid) for k in value.keys(): - assert value[k] == node.get_attribute(k) + assert value[k] == node.base.attributes.get(k) def test_check_for_export_format_version(aiida_profile_clean, tmp_path): diff --git a/tests/tools/archive/test_specific_import.py b/tests/tools/archive/test_specific_import.py index 732ecc78fd..9f7569c602 100644 --- a/tests/tools/archive/test_specific_import.py +++ b/tests/tools/archive/test_specific_import.py @@ -88,10 +88,10 @@ def test_cycle_structure_data(aiida_profile_clean, aiida_localhost, tmp_path): structure.store() parent_process = orm.CalculationNode() - parent_process.set_attribute('key', 'value') + parent_process.base.attributes.set('key', 'value') parent_process.store() child_calculation = orm.CalculationNode() - child_calculation.set_attribute('key', 'value') + child_calculation.base.attributes.set('key', 'value') remote_folder = orm.RemoteData(computer=aiida_localhost, remote_path='/').store() remote_folder.add_incoming(parent_process, link_type=LinkType.CREATE, link_label='link') diff --git a/tests/tools/groups/test_paths.py b/tests/tools/groups/test_paths.py index 5068d8c097..d0d8ede1d4 100644 --- a/tests/tools/groups/test_paths.py +++ b/tests/tools/groups/test_paths.py @@ -130,7 +130,7 @@ def test_walk_nodes(aiida_profile_clean): """Test the ``GroupPath.walk_nodes()`` function.""" group, _ = orm.Group.objects.get_or_create('a') node = orm.Data() - node.set_attribute_many({'i': 1, 'j': 2}) + node.base.attributes.set_many({'i': 1, 'j': 2}) node.store() group.add_nodes(node) group_path = GroupPath() From 68298d687ba652819f00552ab521dd54968edd9c Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Thu, 7 Apr 2022 11:40:47 +0200 Subject: [PATCH 2/7] Apply suggestions from code review Co-authored-by: Sebastiaan Huber --- aiida/orm/nodes/attributes.py | 1 - aiida/orm/nodes/node.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/aiida/orm/nodes/attributes.py b/aiida/orm/nodes/attributes.py index 780047dfe8..099d7e022d 100644 --- a/aiida/orm/nodes/attributes.py +++ b/aiida/orm/nodes/attributes.py @@ -7,7 +7,6 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=too-many-lines,too-many-arguments """Interface to the attributes of a node instance.""" import copy from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Tuple diff --git a/aiida/orm/nodes/node.py b/aiida/orm/nodes/node.py index d773a75314..dfced1e6f0 100644 --- a/aiida/orm/nodes/node.py +++ b/aiida/orm/nodes/node.py @@ -117,7 +117,7 @@ def repository(self) -> 'NodeRepository': @cached_property def attributes(self) -> 'NodeAttributes': - """Return the attributes for this node.""" + """Return an interface to interact with the attributes of this node.""" return NodeAttributes(self._node) From da5320ef06afbe0e28bbb62fd5d26eea4ef242e3 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Thu, 7 Apr 2022 11:56:49 +0200 Subject: [PATCH 3/7] Update attributes.py --- aiida/orm/nodes/attributes.py | 48 ++++++++++++++++------------------- 1 file changed, 22 insertions(+), 26 deletions(-) diff --git a/aiida/orm/nodes/attributes.py b/aiida/orm/nodes/attributes.py index 099d7e022d..90d54539e3 100644 --- a/aiida/orm/nodes/attributes.py +++ b/aiida/orm/nodes/attributes.py @@ -31,12 +31,8 @@ class NodeAttributes: def __init__(self, node: 'Node') -> None: """Initialize the interface.""" - self._entity = node - self._backend_entity = node.backend_entity - - def __contains__(self, key: str) -> bool: - """Check if the node contains an attribute with the given key.""" - return key in self._backend_entity.attributes + self._node = node + self._backend_node = node.backend_entity @property def all(self) -> Dict[str, Any]: @@ -52,9 +48,9 @@ def all(self) -> Dict[str, Any]: :return: the attributes as a dictionary """ - attributes = self._backend_entity.attributes + attributes = self._backend_node.attributes - if self._entity.is_stored: + if self._node.is_stored: attributes = copy.deepcopy(attributes) return attributes @@ -74,13 +70,13 @@ def get(self, key: str, default=_NO_DEFAULT) -> Any: :raises AttributeError: if the attribute does not exist and no default is specified """ try: - attribute = self._backend_entity.get_attribute(key) + attribute = self._backend_node.get_attribute(key) except AttributeError: if default is _NO_DEFAULT: raise attribute = default - if self._entity.is_stored: + if self._node.is_stored: attribute = copy.deepcopy(attribute) return attribute @@ -100,9 +96,9 @@ def get_many(self, keys: List[str]) -> List[Any]: :return: a list of attribute values :raises AttributeError: if at least one attribute does not exist """ - attributes = self._backend_entity.get_attribute_many(keys) + attributes = self._backend_node.get_attribute_many(keys) - if self._entity.is_stored: + if self._node.is_stored: attributes = copy.deepcopy(attributes) return attributes @@ -115,8 +111,8 @@ def set(self, key: str, value: Any) -> None: :raise aiida.common.ValidationError: if the key is invalid, i.e. contains periods :raise aiida.common.ModificationNotAllowed: if the entity is stored """ - self._entity._check_mutability_attributes([key]) # pylint: disable=protected-access - self._backend_entity.set_attribute(key, value) + self._node._check_mutability_attributes([key]) # pylint: disable=protected-access + self._backend_node.set_attribute(key, value) def set_many(self, attributes: Dict[str, Any]) -> None: """Set multiple attributes. @@ -127,8 +123,8 @@ def set_many(self, attributes: Dict[str, Any]) -> None: :raise aiida.common.ValidationError: if any of the keys are invalid, i.e. contain periods :raise aiida.common.ModificationNotAllowed: if the entity is stored """ - self._entity._check_mutability_attributes(list(attributes)) # pylint: disable=protected-access - self._backend_entity.set_attribute_many(attributes) + self._node._check_mutability_attributes(list(attributes)) # pylint: disable=protected-access + self._backend_node.set_attribute_many(attributes) def reset(self, attributes: Dict[str, Any]) -> None: """Reset the attributes. @@ -139,8 +135,8 @@ def reset(self, attributes: Dict[str, Any]) -> None: :raise aiida.common.ValidationError: if any of the keys are invalid, i.e. contain periods :raise aiida.common.ModificationNotAllowed: if the entity is stored """ - self._entity._check_mutability_attributes() # pylint: disable=protected-access - self._backend_entity.reset_attributes(attributes) + self._node._check_mutability_attributes() # pylint: disable=protected-access + self._backend_node.reset_attributes(attributes) def delete(self, key: str) -> None: """Delete an attribute. @@ -149,8 +145,8 @@ def delete(self, key: str) -> None: :raises AttributeError: if the attribute does not exist :raise aiida.common.ModificationNotAllowed: if the entity is stored """ - self._entity._check_mutability_attributes([key]) # pylint: disable=protected-access - self._backend_entity.delete_attribute(key) + self._node._check_mutability_attributes([key]) # pylint: disable=protected-access + self._backend_node.delete_attribute(key) def delete_many(self, keys: List[str]) -> None: """Delete multiple attributes. @@ -159,24 +155,24 @@ def delete_many(self, keys: List[str]) -> None: :raises AttributeError: if at least one of the attribute does not exist :raise aiida.common.ModificationNotAllowed: if the entity is stored """ - self._entity._check_mutability_attributes(keys) # pylint: disable=protected-access - self._backend_entity.delete_attribute_many(keys) + self._node._check_mutability_attributes(keys) # pylint: disable=protected-access + self._backend_node.delete_attribute_many(keys) def clear(self) -> None: """Delete all attributes.""" - self._entity._check_mutability_attributes() # pylint: disable=protected-access - self._backend_entity.clear_attributes() + self._node._check_mutability_attributes() # pylint: disable=protected-access + self._backend_node.clear_attributes() def items(self) -> Iterator[Tuple[str, Any]]: """Return an iterator over the attributes. :return: an iterator with attribute key value pairs """ - return self._backend_entity.attributes_items() + return self._backend_node.attributes_items() def keys(self) -> Iterable[str]: """Return an iterator over the attribute keys. :return: an iterator with attribute keys """ - return self._backend_entity.attributes_keys() + return self._backend_node.attributes_keys() From f77ca6a23152a84d08a566b0a9fea0a2283cda31 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Thu, 7 Apr 2022 14:54:00 +0200 Subject: [PATCH 4/7] Update aiida/orm/nodes/attributes.py --- aiida/orm/nodes/attributes.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/aiida/orm/nodes/attributes.py b/aiida/orm/nodes/attributes.py index 90d54539e3..bbc286357b 100644 --- a/aiida/orm/nodes/attributes.py +++ b/aiida/orm/nodes/attributes.py @@ -34,6 +34,10 @@ def __init__(self, node: 'Node') -> None: self._node = node self._backend_node = node.backend_entity + def __contains__(self, key: str) -> bool: + """Check if the node contains an attribute with the given key.""" + return key in self._backend_entity.attributes + @property def all(self) -> Dict[str, Any]: """Return the complete attributes dictionary. From 9ec19ba4ec96865d6020c93f708f91729666676a Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Thu, 7 Apr 2022 19:09:08 +0200 Subject: [PATCH 5/7] Update aiida/orm/nodes/attributes.py --- aiida/orm/nodes/attributes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aiida/orm/nodes/attributes.py b/aiida/orm/nodes/attributes.py index bbc286357b..5bd3ed48b8 100644 --- a/aiida/orm/nodes/attributes.py +++ b/aiida/orm/nodes/attributes.py @@ -36,7 +36,7 @@ def __init__(self, node: 'Node') -> None: def __contains__(self, key: str) -> bool: """Check if the node contains an attribute with the given key.""" - return key in self._backend_entity.attributes + return key in self._backend_node.attributes @property def all(self) -> Dict[str, Any]: From fb1b17b91fd0670706edb5cc750cec1dd79c5f43 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Fri, 8 Apr 2022 09:54:25 +0200 Subject: [PATCH 6/7] Fix deprecation warnings --- tests/tools/archive/test_complex.py | 6 +++--- tests/tools/archive/test_simple.py | 2 +- tests/tools/archive/test_specific_import.py | 4 ++-- tests/tools/groups/test_paths.py | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/tools/archive/test_complex.py b/tests/tools/archive/test_complex.py index 2e1860862a..016c3d8a20 100644 --- a/tests/tools/archive/test_complex.py +++ b/tests/tools/archive/test_complex.py @@ -178,13 +178,13 @@ def get_hash_from_db_content(grouplabel): # The hash is given from the preservable entries in an export-import cycle, # uuids, attributes, labels, descriptions, arrays, link-labels, link-types: hash_ = make_hash([( - item['param']['*'].attributes, + item['param']['*'].base.attributes.all, item['param']['*'].uuid, item['param']['*'].label, item['param']['*'].description, item['calc']['*'].uuid, - item['calc']['*'].attributes, - item['array']['*'].attributes, + item['calc']['*'].base.attributes.all, + item['array']['*'].base.attributes.all, [item['array']['*'].get_array(name).tolist() for name in item['array']['*'].get_arraynames()], item['array']['*'].uuid, item['group']['*'].uuid, diff --git a/tests/tools/archive/test_simple.py b/tests/tools/archive/test_simple.py index 624030fa03..cb18439412 100644 --- a/tests/tools/archive/test_simple.py +++ b/tests/tools/archive/test_simple.py @@ -63,7 +63,7 @@ def test_calc_of_structuredata(aiida_profile_clean, tmp_path, aiida_localhost): for pk in pks: node = orm.load_node(pk) attrs[node.uuid] = {} - for k in node.attributes.keys(): + for k in node.base.attributes.keys(): attrs[node.uuid][k] = node.base.attributes.get(k) filename = str(tmp_path / 'export.aiida') diff --git a/tests/tools/archive/test_specific_import.py b/tests/tools/archive/test_specific_import.py index 9f7569c602..af953f4e5f 100644 --- a/tests/tools/archive/test_specific_import.py +++ b/tests/tools/archive/test_specific_import.py @@ -121,8 +121,8 @@ def test_cycle_structure_data(aiida_profile_clean, aiida_localhost, tmp_path): # Verify that orm.CalculationNodes have non-empty attribute dictionaries builder = orm.QueryBuilder().append(orm.CalculationNode) for [calculation] in builder.iterall(): - assert isinstance(calculation.attributes, dict) - assert len(calculation.attributes) != 0 + assert isinstance(calculation.base.attributes.all, dict) + assert len(calculation.base.attributes.all) != 0 # Verify that the structure data maintained its label, cell and kinds builder = orm.QueryBuilder().append(orm.StructureData) diff --git a/tests/tools/groups/test_paths.py b/tests/tools/groups/test_paths.py index d0d8ede1d4..5a68c2c2a9 100644 --- a/tests/tools/groups/test_paths.py +++ b/tests/tools/groups/test_paths.py @@ -134,7 +134,7 @@ def test_walk_nodes(aiida_profile_clean): node.store() group.add_nodes(node) group_path = GroupPath() - assert [(r.group_path.path, r.node.attributes) for r in group_path.walk_nodes()] == [('a', {'i': 1, 'j': 2})] + assert [(r.group_path.path, r.node.base.attributes.all) for r in group_path.walk_nodes()] == [('a', {'i': 1, 'j': 2})] def test_cls(aiida_profile_clean): From c3f9ef90fe592bd1c03d34752fe237c28a4e54ed Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 8 Apr 2022 07:56:36 +0000 Subject: [PATCH 7/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/tools/groups/test_paths.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/tools/groups/test_paths.py b/tests/tools/groups/test_paths.py index 5a68c2c2a9..c1320fae76 100644 --- a/tests/tools/groups/test_paths.py +++ b/tests/tools/groups/test_paths.py @@ -134,7 +134,11 @@ def test_walk_nodes(aiida_profile_clean): node.store() group.add_nodes(node) group_path = GroupPath() - assert [(r.group_path.path, r.node.base.attributes.all) for r in group_path.walk_nodes()] == [('a', {'i': 1, 'j': 2})] + assert [(r.group_path.path, r.node.base.attributes.all) for r in group_path.walk_nodes() + ] == [('a', { + 'i': 1, + 'j': 2 + })] def test_cls(aiida_profile_clean):