From e45b43f514119df97c4036e68a38739a87d1027f Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Tue, 21 Feb 2023 12:22:24 +0100 Subject: [PATCH 1/3] Process functions: Replace `getfullargspec` with `signature` The `inspect.getfullargspec` method is used to analyze the signature of the function wrapped by the process function decorator. This method is outdated and really just kept around for backwards-compatibility. Internally it is using the `inspect.signature` method which has a more modern interface, so here we change the code to use that method instead. This change also prepares the next move that will allow to parse any type annotations for signature parameters to dynamically determine the `valid_type` attribute of the dynamically determined input ports. --- aiida/engine/processes/functions.py | 73 +++++++++++++++++------------ 1 file changed, 43 insertions(+), 30 deletions(-) diff --git a/aiida/engine/processes/functions.py b/aiida/engine/processes/functions.py index 051b5252b1..2a2dd7f610 100644 --- a/aiida/engine/processes/functions.py +++ b/aiida/engine/processes/functions.py @@ -222,10 +222,22 @@ def build(func: Callable[..., Any], node_class: Type['ProcessNode']) -> Type['Fu if not issubclass(node_class, ProcessNode) or not issubclass(node_class, FunctionCalculationMixin): raise TypeError('the node_class should be a sub class of `ProcessNode` and `FunctionCalculationMixin`') - args, varargs, keywords, defaults, _, _, _ = inspect.getfullargspec(func) - nargs = len(args) - ndefaults = len(defaults) if defaults else 0 - first_default_pos = nargs - ndefaults + signature = inspect.signature(func) + + args: list[str] = [] + varargs: str | None = None + keywords: str | None = None + + for key, parameter in signature.parameters.items(): + + if parameter.kind in [parameter.POSITIONAL_ONLY, parameter.POSITIONAL_OR_KEYWORD, parameter.KEYWORD_ONLY]: + args.append(key) + + if parameter.kind is parameter.VAR_POSITIONAL: + varargs = key + + if parameter.kind is parameter.VAR_KEYWORD: + varargs = key def _define(cls, spec): # pylint: disable=unused-argument """Define the spec dynamically""" @@ -233,37 +245,38 @@ def _define(cls, spec): # pylint: disable=unused-argument super().define(spec) - for i, arg in enumerate(args): + for parameter in signature.parameters.values(): - default = UNSPECIFIED + if parameter.kind in [parameter.VAR_POSITIONAL, parameter.VAR_KEYWORD]: + continue - if defaults and i >= first_default_pos: - default = defaults[i - first_default_pos] + default = parameter.default if parameter.default is not parameter.empty else UNSPECIFIED # If the keyword was already specified, simply override the default - if spec.has_input(arg): - spec.inputs[arg].default = default + if spec.has_input(parameter.name): + spec.inputs[parameter.name].default = default + continue + + # If the default is ``None`` make sure that the port also accepts a ``NoneType``. Note that we cannot + # use ``None`` because the validation will call ``isinstance`` which does not work when passing ``None`` + # but it does work with ``NoneType`` which is returned by calling ``type(None)``. + if default is None: + valid_type = (Data, type(None)) else: - # If the default is `None` make sure that the port also accepts a `NoneType` - # Note that we cannot use `None` because the validation will call `isinstance` which does not work - # when passing `None`, but it does work with `NoneType` which is returned by calling `type(None)` - if default is None: - valid_type = (Data, type(None)) - else: - valid_type = (Data,) - - # If a default is defined and it is not a ``Data`` instance it should be serialized, but this should - # be done lazily using a lambda, just as any port defaults should not define node instances directly - # as is also checked by the ``spec.input`` call. - if ( - default is not None and default != UNSPECIFIED and not isinstance(default, Data) and - not callable(default) - ): - indirect_default = lambda value=default: to_aiida_type(value) - else: - indirect_default = default # type: ignore[assignment] - - spec.input(arg, valid_type=valid_type, default=indirect_default, serializer=to_aiida_type) + valid_type = (Data,) + + # If a default is defined and it is not a ``Data`` instance it should be serialized, but this should be + # done lazily using a lambda, just as any port defaults should not define node instances directly as is + # also checked by the ``spec.input`` call. + if ( + default is not None and default != UNSPECIFIED and not isinstance(default, Data) and + not callable(default) + ): + indirect_default = lambda value=default: to_aiida_type(value) + else: + indirect_default = default # type: ignore[assignment] + + spec.input(parameter.name, valid_type=valid_type, default=indirect_default, serializer=to_aiida_type) # Set defaults for label and description based on function name and docstring, if not explicitly defined port_label = spec.inputs['metadata']['label'] From ef3d7723b037dd14d509ab9adc244d104ea0122e Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Wed, 22 Feb 2023 01:12:29 +0100 Subject: [PATCH 2/3] Typing: Use modern syntax for `aiida.engine.processes.functions` Use the new union syntax from PEP 604. --- .pre-commit-config.yaml | 1 - aiida/engine/processes/functions.py | 35 +++++++++++++++-------------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index be5aaa3cc9..4c9f688f07 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -110,7 +110,6 @@ repos: aiida/engine/processes/calcjobs/monitors.py| aiida/engine/processes/calcjobs/tasks.py| aiida/engine/processes/control.py| - aiida/engine/processes/functions.py| aiida/engine/processes/ports.py| aiida/manage/configuration/__init__.py| aiida/manage/configuration/config.py| diff --git a/aiida/engine/processes/functions.py b/aiida/engine/processes/functions.py index 2a2dd7f610..69dbeebb9a 100644 --- a/aiida/engine/processes/functions.py +++ b/aiida/engine/processes/functions.py @@ -15,7 +15,8 @@ import inspect import logging import signal -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence, Tuple, Type, TypeVar +import typing as t +from typing import TYPE_CHECKING from aiida.common.lang import override from aiida.manage import get_manager @@ -31,7 +32,7 @@ LOGGER = logging.getLogger(__name__) -FunctionType = TypeVar('FunctionType', bound=Callable[..., Any]) +FunctionType = t.TypeVar('FunctionType', bound=t.Callable[..., t.Any]) def calcfunction(function: FunctionType) -> FunctionType: @@ -88,14 +89,14 @@ def workfunction(function: FunctionType) -> FunctionType: return process_function(node_class=WorkFunctionNode)(function) -def process_function(node_class: Type['ProcessNode']) -> Callable[[Callable[..., Any]], Callable[..., Any]]: +def process_function(node_class: t.Type['ProcessNode']) -> t.Callable[[FunctionType], FunctionType]: """ The base function decorator to create a FunctionProcess out of a normal python function. :param node_class: the ORM class to be used as the Node record for the FunctionProcess """ - def decorator(function: Callable[..., Any]) -> Callable[..., Any]: + def decorator(function: FunctionType) -> FunctionType: """ Turn the decorated function into a FunctionProcess. @@ -104,7 +105,7 @@ def decorator(function: Callable[..., Any]) -> Callable[..., Any]: """ process_class = FunctionProcess.build(function, node_class=node_class) - def run_get_node(*args, **kwargs) -> Tuple[Optional[Dict[str, Any]], 'ProcessNode']: + def run_get_node(*args, **kwargs) -> tuple[dict[str, t.Any] | None, 'ProcessNode']: """ Run the FunctionProcess with the supplied inputs in a local runner. @@ -159,7 +160,7 @@ def kill_process(_num, _frame): return result, process.node - def run_get_pk(*args, **kwargs) -> Tuple[Optional[Dict[str, Any]], int]: + def run_get_pk(*args, **kwargs) -> tuple[dict[str, t.Any] | None, int]: """Recreate the `run_get_pk` utility launcher. :param args: input arguments to construct the FunctionProcess @@ -185,7 +186,7 @@ def decorated_function(*args, **kwargs): decorated_function.recreate_from = process_class.recreate_from # type: ignore[attr-defined] decorated_function.spec = process_class.spec # type: ignore[attr-defined] - return decorated_function + return decorated_function # type: ignore[return-value] return decorator @@ -193,7 +194,7 @@ def decorated_function(*args, **kwargs): class FunctionProcess(Process): """Function process class used for turning functions into a Process""" - _func_args: Sequence[str] = () + _func_args: t.Sequence[str] = () _varargs: str | None = None @staticmethod @@ -205,7 +206,7 @@ def _func(*_args, **_kwargs) -> dict: return {} @staticmethod - def build(func: Callable[..., Any], node_class: Type['ProcessNode']) -> Type['FunctionProcess']: + def build(func: FunctionType, node_class: t.Type['ProcessNode']) -> t.Type['FunctionProcess']: """ Build a Process from the given function. @@ -274,7 +275,7 @@ def _define(cls, spec): # pylint: disable=unused-argument ): indirect_default = lambda value=default: to_aiida_type(value) else: - indirect_default = default # type: ignore[assignment] + indirect_default = default spec.input(parameter.name, valid_type=valid_type, default=indirect_default, serializer=to_aiida_type) @@ -306,7 +307,7 @@ def _define(cls, spec): # pylint: disable=unused-argument ) @classmethod - def validate_inputs(cls, *args: Any, **kwargs: Any) -> None: # pylint: disable=unused-argument + def validate_inputs(cls, *args: t.Any, **kwargs: t.Any) -> None: # pylint: disable=unused-argument """ Validate the positional and keyword arguments passed in the function call. @@ -327,7 +328,7 @@ def validate_inputs(cls, *args: Any, **kwargs: Any) -> None: # pylint: disable= raise TypeError(f'{name}() takes {nparameters} positional arguments but {nargs} were given') @classmethod - def create_inputs(cls, *args: Any, **kwargs: Any) -> Dict[str, Any]: + def create_inputs(cls, *args: t.Any, **kwargs: t.Any) -> dict[str, t.Any]: """Create the input args for the FunctionProcess.""" cls.validate_inputs(*args, **kwargs) @@ -339,7 +340,7 @@ def create_inputs(cls, *args: Any, **kwargs: Any) -> Dict[str, Any]: return ins @classmethod - def args_to_dict(cls, *args: Any) -> Dict[str, Any]: + def args_to_dict(cls, *args: t.Any) -> dict[str, t.Any]: """ Create an input dictionary (of form label -> value) from supplied args. @@ -388,7 +389,7 @@ def __init__(self, *args, **kwargs) -> None: super().__init__(enable_persistence=False, *args, **kwargs) # type: ignore @property - def process_class(self) -> Callable[..., Any]: + def process_class(self) -> t.Callable[..., t.Any]: """ Return the class that represents this Process, for the FunctionProcess this is the function itself. @@ -401,7 +402,7 @@ class that really represents what was being executed. """ return self._func - def execute(self) -> Optional[Dict[str, Any]]: + def execute(self) -> dict[str, t.Any] | None: """Execute the process.""" result = super().execute() @@ -418,7 +419,7 @@ def _setup_db_record(self) -> None: self.node.store_source_info(self._func) @override - def run(self) -> Optional['ExitCode']: + def run(self) -> 'ExitCode' | None: """Run the process.""" from .exit_code import ExitCode @@ -427,7 +428,7 @@ def run(self) -> Optional['ExitCode']: # been overridden by the engine to `Running` so we cannot check that, but if the `exit_status` is anything other # than `None`, it should mean this node was taken from the cache, so the process should not be rerun. if self.node.exit_status is not None: - return self.node.exit_status + return ExitCode(self.node.exit_status, self.node.exit_message) # Split the inputs into positional and keyword arguments args = [None] * len(self._func_args) From b39fafbb1bb933e99aa7a579d1b79ab40d7268e7 Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Wed, 22 Feb 2023 01:25:09 +0100 Subject: [PATCH 3/3] Process functions: Infer argument `valid_type` from type hints An advantage of process functions over calculation jobs and work chains are that they are light weight and very easy to start using. But there are also disadvantages; due to the process specification being inferred from the function signature, the user cannot benefit from all the features of the process specification. For example, it is not possible to explicitly define a help string or valid type for the arguments of the process function. With PEP 484 type hints were introduced in Python 3.5 that made it possible to indicate the expected types for function arguments. Here we make use of this functionality to infer the valid type of function arguments based on the provided type hint, if any. As an example, a user can now define the following calcfunction: @calcfunction def add(a: int, b: int): return a + b and when it is called with anything other than an `int` or `Int` for either of the two arguments, a validation error is raised reporting that an argument with an invalid type was provided. The new functionality is fully optional and process functions without typing will continue to work as before. If incorrect typing is provided, a warning is logged and the typing is ignored. The annotation of the wrapped process function is parsed using the `inspect.get_annotation` method. This was added in Python 3.10 and so to provide support in older versions we install the `get-annotations` backport package. Since we have to use `eval_str=True` in the call to get unstringized versions of the types. This will fail for the backport implementation if the type uses union syntax of PEP 604, e.g `str | int` instead of `typing.Union[str, int]`, even if this functionality is enabled using `from __future__ import annotations`. --- aiida/engine/processes/functions.py | 79 ++++++++++++++- docs/source/topics/calculations/concepts.rst | 2 + docs/source/topics/processes/functions.rst | 51 ++++++++++ .../snippets/functions/typing_call_raise.py | 10 ++ .../include/snippets/functions/typing_none.py | 16 +++ .../snippets/functions/typing_pep_563.py | 11 +++ .../snippets/functions/typing_pep_604.py | 16 +++ .../snippets/functions/typing_union.py | 12 +++ environment.yml | 1 + pyproject.toml | 2 + requirements/requirements-py-3.8.txt | 1 + requirements/requirements-py-3.9.txt | 1 + tests/engine/test_process_function.py | 97 +++++++++++++++++++ 13 files changed, 295 insertions(+), 4 deletions(-) create mode 100644 docs/source/topics/processes/include/snippets/functions/typing_call_raise.py create mode 100644 docs/source/topics/processes/include/snippets/functions/typing_none.py create mode 100644 docs/source/topics/processes/include/snippets/functions/typing_pep_563.py create mode 100644 docs/source/topics/processes/include/snippets/functions/typing_pep_604.py create mode 100644 docs/source/topics/processes/include/snippets/functions/typing_union.py diff --git a/aiida/engine/processes/functions.py b/aiida/engine/processes/functions.py index 69dbeebb9a..23d8a87522 100644 --- a/aiida/engine/processes/functions.py +++ b/aiida/engine/processes/functions.py @@ -15,16 +15,41 @@ import inspect import logging import signal +import types import typing as t from typing import TYPE_CHECKING from aiida.common.lang import override from aiida.manage import get_manager -from aiida.orm import CalcFunctionNode, Data, ProcessNode, WorkFunctionNode, to_aiida_type +from aiida.orm import ( + Bool, + CalcFunctionNode, + Data, + Dict, + Float, + Int, + List, + ProcessNode, + Str, + WorkFunctionNode, + to_aiida_type, +) from aiida.orm.utils.mixins import FunctionCalculationMixin from .process import Process +try: + UnionType = types.UnionType # type: ignore[attr-defined] +except AttributeError: + # This type is not available for Python 3.9 and older + UnionType = None # pylint: disable=invalid-name + +try: + get_annotations = inspect.get_annotations # type: ignore[attr-defined] +except AttributeError: + # This is the backport for Python 3.9 and older + from get_annotations import get_annotations # type: ignore[no-redef] + if TYPE_CHECKING: from .exit_code import ExitCode @@ -191,6 +216,43 @@ def decorated_function(*args, **kwargs): return decorator +def infer_valid_type_from_type_annotation(annotation: t.Any) -> tuple[t.Any, ...]: + """Infer the value for the ``valid_type`` of an input port from the given function argument annotation. + + :param annotation: The annotation of a function argument as returned by ``inspect.get_annotation``. + :returns: A tuple of valid types. If no valid types were defined or they could not be successfully parsed, an empty + tuple is returned. + """ + + def get_type_from_annotation(annotation): + valid_type_map = { + bool: Bool, + dict: Dict, + t.Dict: Dict, + float: Float, + int: Int, + list: List, + t.List: List, + str: Str, + } + + if inspect.isclass(annotation) and issubclass(annotation, Data): + return annotation + + return valid_type_map.get(annotation) + + inferred_valid_type: tuple[t.Any, ...] = () + + if inspect.isclass(annotation): + inferred_valid_type = (get_type_from_annotation(annotation),) + elif t.get_origin(annotation) is t.Union or t.get_origin(annotation) is UnionType: + inferred_valid_type = tuple(get_type_from_annotation(valid_type) for valid_type in t.get_args(annotation)) + elif t.get_origin(annotation) is t.Optional: + inferred_valid_type = (t.get_args(annotation),) + + return tuple(valid_type for valid_type in inferred_valid_type if valid_type is not None) + + class FunctionProcess(Process): """Function process class used for turning functions into a Process""" @@ -229,6 +291,14 @@ def build(func: FunctionType, node_class: t.Type['ProcessNode']) -> t.Type['Func varargs: str | None = None keywords: str | None = None + try: + annotations = get_annotations(func, eval_str=True) + except Exception as exception: # pylint: disable=broad-except + # Since we are running with ``eval_str=True`` to unstringize the annotations, the call can except if the + # annotations are incorrect. In this case we simply want to log a warning and continue with type inference. + LOGGER.warning(f'function `{func.__name__}` has invalid type hints: {exception}') + annotations = {} + for key, parameter in signature.parameters.items(): if parameter.kind in [parameter.POSITIONAL_ONLY, parameter.POSITIONAL_OR_KEYWORD, parameter.KEYWORD_ONLY]: @@ -251,6 +321,9 @@ def _define(cls, spec): # pylint: disable=unused-argument if parameter.kind in [parameter.VAR_POSITIONAL, parameter.VAR_KEYWORD]: continue + annotation = annotations.get(parameter.name) + valid_type = infer_valid_type_from_type_annotation(annotation) or (Data,) + default = parameter.default if parameter.default is not parameter.empty else UNSPECIFIED # If the keyword was already specified, simply override the default @@ -262,9 +335,7 @@ def _define(cls, spec): # pylint: disable=unused-argument # use ``None`` because the validation will call ``isinstance`` which does not work when passing ``None`` # but it does work with ``NoneType`` which is returned by calling ``type(None)``. if default is None: - valid_type = (Data, type(None)) - else: - valid_type = (Data,) + valid_type += (type(None),) # If a default is defined and it is not a ``Data`` instance it should be serialized, but this should be # done lazily using a lambda, just as any port defaults should not define node instances directly as is diff --git a/docs/source/topics/calculations/concepts.rst b/docs/source/topics/calculations/concepts.rst index c6a00a9fea..afb178e9ae 100644 --- a/docs/source/topics/calculations/concepts.rst +++ b/docs/source/topics/calculations/concepts.rst @@ -55,6 +55,8 @@ To solve this, one only has to wrap them in the :py:class:`~aiida.orm.nodes.data The only difference with the previous snippet is that all inputs have been wrapped in the :py:class:`~aiida.orm.nodes.data.int.Int` class. The result that is returned by the function, is now also an :py:class:`~aiida.orm.nodes.data.int.Int` node that can be stored in the provenance graph, and contains the result of the computation. +.. _topics:calculations:concepts:calcfunctions:automatic-serialization: + .. versionadded:: 2.1 If a function argument is a Python base type (i.e. a value of type ``bool``, ``dict``, ``Enum``, ``float``, ``int``, ``list`` or ``str``), it can be passed straight away to the function, without first having to wrap it in the corresponding AiiDA data type. diff --git a/docs/source/topics/processes/functions.rst b/docs/source/topics/processes/functions.rst index 23b067f9e9..6a9d6ae8c0 100644 --- a/docs/source/topics/processes/functions.rst +++ b/docs/source/topics/processes/functions.rst @@ -116,6 +116,57 @@ The link labels for the example above will therefore be ``args_0``, ``args_1`` a If any of these labels were to overlap with the label of a positional or keyword argument, a ``RuntimeError`` will be raised. In this case, the conflicting argument name needs to be changed to something that does not overlap with the automatically generated labels for the variadic arguments. +Type validation +=============== + +.. versionadded:: 2.3 + +Type hints (introduced with `PEP 484 `_ in Python 3.5) can be used to add automatic type validation of process function arguments. +For example, the following will raise a ``ValueError`` exception: + +.. include:: include/snippets/functions/typing_call_raise.py + :code: python + +When the process function is declared, the process specification (``ProcessSpec``) is built dynamically. +For each function argument, if a correct type hint is provided, it is set as the ``valid_type`` attribute of the corresponding input port. +In the example above, the ``x`` and ``y`` inputs have ``Int`` as type hint, which is why the call that passes a ``Float`` raises a ``ValueError``. + +.. note:: + + Type hints for return values are currently not parsed. + +If an argument accepts multiple types, the ``typing.Union`` class can be used as normal: + +.. include:: include/snippets/functions/typing_union.py + :code: python + +The call with an ``Int`` and a ``Float`` will now finish correctly. +Similarly, optional arguments, with ``None`` as a default, can be declared using ``typing.Optional``: + +.. include:: include/snippets/functions/typing_none.py + :code: python + +The `postponed evaluation of annotations introduced by PEP 563 `_ is also supported. +This means it is possible to use Python base types for the type hint instead of AiiDA's ``Data`` node equivalent: + +.. include:: include/snippets/functions/typing_pep_563.py + :code: python + +The type hints are automatically serialized just as the actual inputs are when the function is called, :ref:`as introduced in v2.1`. + +The alternative syntax for union types ``X | Y`` `as introduced by PEP 604 `_ is also supported: + +.. include:: include/snippets/functions/typing_pep_604.py + :code: python + +.. warning:: + + The usage of notation as defined by PEP 563 and PEP 604 are not supported for Python versions older than 3.10, even if the ``from __future__ import annotations`` statement is added. + The reason is that the type inference uses the `inspect.get_annotations `_ method, which was introduced in Python 3.10. + For older Python versions, the `get-annotations `_ backport is used, but that does not work with PEP 563 and PEP 604, so the constructs from the ``typing`` module have to be used instead. + +If a process function has invalid type hints, they will simply be ignored and a warning message is logged: ``function 'function_name' has invalid type hints``. +This ensures backwards compatibility in the case existing process functions had invalid type hints. Return values ============= diff --git a/docs/source/topics/processes/include/snippets/functions/typing_call_raise.py b/docs/source/topics/processes/include/snippets/functions/typing_call_raise.py new file mode 100644 index 0000000000..c62cf149e8 --- /dev/null +++ b/docs/source/topics/processes/include/snippets/functions/typing_call_raise.py @@ -0,0 +1,10 @@ +# -*- coding: utf-8 -*- +from aiida.engine import calcfunction +from aiida.orm import Float, Int + + +@calcfunction +def add(x: Int, y: Int): + return x + y + +add(Int(1), Float(1.0)) diff --git a/docs/source/topics/processes/include/snippets/functions/typing_none.py b/docs/source/topics/processes/include/snippets/functions/typing_none.py new file mode 100644 index 0000000000..00405051f9 --- /dev/null +++ b/docs/source/topics/processes/include/snippets/functions/typing_none.py @@ -0,0 +1,16 @@ +# -*- coding: utf-8 -*- +import typing as t + +from aiida.engine import calcfunction +from aiida.orm import Int + + +@calcfunction +def add_multiply(x: Int, y: Int, z: typing.Optional[Int] = None): + if z is None: + z = Int(3) + + return (x + y) * z + +result = add_multiply(Int(1), Int(2)) +result = add_multiply(Int(1), Int(2), Int(3)) diff --git a/docs/source/topics/processes/include/snippets/functions/typing_pep_563.py b/docs/source/topics/processes/include/snippets/functions/typing_pep_563.py new file mode 100644 index 0000000000..cc52ca3ddf --- /dev/null +++ b/docs/source/topics/processes/include/snippets/functions/typing_pep_563.py @@ -0,0 +1,11 @@ +# -*- coding: utf-8 -*- +from __future__ import annotations + +from aiida.engine import calcfunction + + +@calcfunction +def add(x: int, y: int): + return x + y + +add(1, 2) diff --git a/docs/source/topics/processes/include/snippets/functions/typing_pep_604.py b/docs/source/topics/processes/include/snippets/functions/typing_pep_604.py new file mode 100644 index 0000000000..64f9c30290 --- /dev/null +++ b/docs/source/topics/processes/include/snippets/functions/typing_pep_604.py @@ -0,0 +1,16 @@ +# -*- coding: utf-8 -*- +from __future__ import annotations + +from aiida.engine import calcfunction +from aiida.orm import Int + + +@calcfunction +def add_multiply(x: int, y: int, z: int | None = None): + if z is None: + z = Int(3) + + return (x + y) * z + +result = add_multiply(1, 2) +result = add_multiply(1, 2, 3) diff --git a/docs/source/topics/processes/include/snippets/functions/typing_union.py b/docs/source/topics/processes/include/snippets/functions/typing_union.py new file mode 100644 index 0000000000..3811cad719 --- /dev/null +++ b/docs/source/topics/processes/include/snippets/functions/typing_union.py @@ -0,0 +1,12 @@ +# -*- coding: utf-8 -*- +import typing as t + +from aiida.engine import calcfunction +from aiida.orm import Float, Int + + +@calcfunction +def add(x: t.Union[Int, Float], y: t.Union[Int, Float]): + return x + y + +add(Int(1), Float(1.0)) diff --git a/environment.yml b/environment.yml index cffb828ef4..50a4acde10 100644 --- a/environment.yml +++ b/environment.yml @@ -14,6 +14,7 @@ dependencies: - click-spinner~=0.1.8 - click~=8.1 - disk-objectstore~=0.6.0 +- get-annotations~=0.1 - python-graphviz~=0.13 - ipython<9,>=7 - jinja2~=3.0 diff --git a/pyproject.toml b/pyproject.toml index e387fc424e..4e05598fe1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ dependencies = [ "click-spinner~=0.1.8", "click~=8.1", "disk-objectstore~=0.6.0", + "get-annotations~=0.1;python_version<'3.10'", "graphviz~=0.13", "ipython>=7,<9", "jinja2~=3.0", @@ -393,6 +394,7 @@ module = [ 'docutils.*', 'flask_cors.*', 'flask_restful.*', + 'get_annotations.*', 'graphviz.*', 'importlib._bootstrap.*', 'IPython.*', diff --git a/requirements/requirements-py-3.8.txt b/requirements/requirements-py-3.8.txt index 2f71868b15..b3be3fadbc 100644 --- a/requirements/requirements-py-3.8.txt +++ b/requirements/requirements-py-3.8.txt @@ -38,6 +38,7 @@ Flask-Cors==3.0.10 Flask-RESTful==0.3.9 fonttools==4.28.2 future==0.18.3 +get-annotations==0.1.2 graphviz==0.19 greenlet==1.1.2 idna==3.3 diff --git a/requirements/requirements-py-3.9.txt b/requirements/requirements-py-3.9.txt index 95162db0a7..f5d0477403 100644 --- a/requirements/requirements-py-3.9.txt +++ b/requirements/requirements-py-3.9.txt @@ -38,6 +38,7 @@ Flask-Cors==3.0.10 Flask-RESTful==0.3.9 fonttools==4.28.2 future==0.18.3 +get-annotations==0.1.2 graphviz==0.19 greenlet==1.1.2 idna==3.3 diff --git a/tests/engine/test_process_function.py b/tests/engine/test_process_function.py index 9bf6d90fc3..11767c4783 100644 --- a/tests/engine/test_process_function.py +++ b/tests/engine/test_process_function.py @@ -16,8 +16,12 @@ fly, but then anytime inputs or outputs would be attached to it in the tests, the ``validate_link`` function would complain as the dummy node class is not recognized as a valid process node. """ +from __future__ import annotations + import enum import re +import sys +import typing as t import pytest @@ -616,3 +620,96 @@ def function(**kwargs): } with pytest.raises(ValueError): function.run_get_node(**inputs) + + +def test_type_hinting_spec_inference(): + """Test the parsing of type hinting to define the valid types of the dynamically generated input ports.""" + + @calcfunction # type: ignore[misc] + def function( + a, + b: str, + c: bool, + d: orm.Str, + e: t.Union[orm.Str, orm.Int], + f: t.Union[str, int], + g: t.Optional[t.Dict] = None, + ): + # pylint: disable=invalid-name,unused-argument + pass + + input_namespace = function.spec().inputs + + expected = ( + ('a', (orm.Data,)), + ('b', (orm.Str,)), + ('c', (orm.Bool,)), + ('d', (orm.Str,)), + ('e', (orm.Str, orm.Int)), + ('f', (orm.Str, orm.Int)), + ('g', (orm.Dict, type(None))), + ) + + for key, valid_types in expected: + assert key in input_namespace + assert input_namespace[key].valid_type == valid_types, key + + +def test_type_hinting_spec_inference_pep_604(aiida_caplog): + """Test the parsing of type hinting that uses union typing of PEP 604 which is only available to Python 3.10 and up. + + Even though adding ``from __future__ import annotations`` should backport this functionality to Python 3.9 and older + the ``get_annotations`` method (which was also added to the ``inspect`` package in Python 3.10) as provided by the + ``get-annotations`` backport package fails for this new syntax when called with ``eval_str=True``. Therefore type + inference using this syntax only works on Python 3.10 and up. + + See https://peps.python.org/pep-0604 + """ + + @calcfunction # type: ignore[misc] + def function( + a: str | int, + b: orm.Str | orm.Int, + c: dict | None = None, + ): + # pylint: disable=invalid-name,unused-argument + pass + + input_namespace = function.spec().inputs + + # Since the PEP 604 union syntax is only available starting from Python 3.10 the type inference will not be + # available for older versions, and so the valid type will be the default ``(orm.Data,)``. + if sys.version_info[:2] >= (3, 10): + expected = ( + ('a', (orm.Str, orm.Int)), + ('b', (orm.Str, orm.Int)), + ('c', (orm.Dict, type(None))), + ) + else: + assert 'function `function` has invalid type hints: unsupported operand type' in aiida_caplog.records[0].message + expected = ( + ('a', (orm.Data,)), + ('b', (orm.Data,)), + ('c', (orm.Data, type(None))), + ) + + for key, valid_types in expected: + assert key in input_namespace + assert input_namespace[key].valid_type == valid_types, key + + +def test_type_hinting_validation(): + """Test that type hints are converted to automatic type checking through the process specification.""" + + @calcfunction # type: ignore[misc] + def function_type_hinting(a: t.Union[int, float]): + # pylint: disable=invalid-name + return a + 1 + + with pytest.raises(ValueError, match=r'.*value \'a\' is not of the right type.*'): + function_type_hinting('string') + + assert function_type_hinting(1) == 2 + assert function_type_hinting(orm.Int(1)) == 2 + assert function_type_hinting(1.0) == 2.0 + assert function_type_hinting(orm.Float(1)) == 2.0