Skip to content

Commit

Permalink
Process functions: Infer argument valid_type from type hints
Browse files Browse the repository at this point in the history
An advantage of process functions over calculation jobs and work chains
are that they are light weight and very easy to start using. But there
are also disadvantages; due to the process specification being inferred
from the function signature, the user cannot benefit from all the
features of the process specification. For example, it is not possible
to explicitly define a help string or valid type for the arguments of
the process function.

With PEP 484 type hints were introduced in Python 3.5 that made it
possible to indicate the expected types for function arguments. Here we
make use of this functionality to infer the valid type of function
arguments based on the provided type hint, if any.

As an example, a user can now define the following calcfunction:

    @calcfunction
    def add(a: int, b: int):
        return a + b

and when it is called with anything other than an `int` or `Int` for
either of the two arguments, a validation error is raised reporting that
an argument with an invalid type was provided.

The new functionality is fully optional and process functions without
typing will continue to work as before. If incorrect typing is provided,
a warning is logged and the typing is ignored.

The annotation of the wrapped process function is parsed using the
`inspect.get_annotation` method. This was added in Python 3.10 and so to
provide support in older versions we install the `get-annotations`
backport package. Since we have to use `eval_str=True` in the call to
get unstringized versions of the types. This will fail for the
backport implementation if the type uses union syntax of PEP 604, e.g
`str | int` instead of `typing.Union[str, int]`, even if this
functionality is enabled using `from __future__ import annotations`.
  • Loading branch information
sphuber committed Mar 16, 2023
1 parent ec8cb73 commit fc6d734
Show file tree
Hide file tree
Showing 13 changed files with 295 additions and 4 deletions.
79 changes: 75 additions & 4 deletions aiida/engine/processes/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,41 @@
import inspect
import logging
import signal
import types
import typing as t
from typing import TYPE_CHECKING

from aiida.common.lang import override
from aiida.manage import get_manager
from aiida.orm import CalcFunctionNode, Data, ProcessNode, WorkFunctionNode, to_aiida_type
from aiida.orm import (
Bool,
CalcFunctionNode,
Data,
Dict,
Float,
Int,
List,
ProcessNode,
Str,
WorkFunctionNode,
to_aiida_type,
)
from aiida.orm.utils.mixins import FunctionCalculationMixin

from .process import Process

try:
UnionType = types.UnionType # type: ignore[attr-defined]
except AttributeError:
# This type is not available for Python 3.9 and older
UnionType = None # pylint: disable=invalid-name

try:
get_annotations = inspect.get_annotations # type: ignore[attr-defined]
except AttributeError:
# This is the backport for Python 3.9 and older
from get_annotations import get_annotations # type: ignore[no-redef]

if TYPE_CHECKING:
from .exit_code import ExitCode

Expand Down Expand Up @@ -191,6 +216,43 @@ def decorated_function(*args, **kwargs):
return decorator


def infer_valid_type_from_type_annotation(annotation: t.Any) -> tuple[t.Any, ...]:
"""Infer the value for the ``valid_type`` of an input port from the given function argument annotation.
:param annotation: The annotation of a function argument as returned by ``inspect.get_annotation``.
:returns: A tuple of valid types. If no valid types were defined or they could not be successfully parsed, an empty
tuple is returned.
"""

def get_type_from_annotation(annotation):
valid_type_map = {
bool: Bool,
dict: Dict,
t.Dict: Dict,
float: Float,
int: Int,
list: List,
t.List: List,
str: Str,
}

if inspect.isclass(annotation) and issubclass(annotation, Data):
return annotation

return valid_type_map.get(annotation)

inferred_valid_type: tuple[t.Any, ...] = ()

if inspect.isclass(annotation):
inferred_valid_type = (get_type_from_annotation(annotation),)
elif t.get_origin(annotation) is t.Union or t.get_origin(annotation) is UnionType:
inferred_valid_type = tuple(get_type_from_annotation(valid_type) for valid_type in t.get_args(annotation))
elif t.get_origin(annotation) is t.Optional:
inferred_valid_type = (t.get_args(annotation),)

return tuple(valid_type for valid_type in inferred_valid_type if valid_type is not None)


class FunctionProcess(Process):
"""Function process class used for turning functions into a Process"""

Expand Down Expand Up @@ -229,6 +291,14 @@ def build(func: FunctionType, node_class: t.Type['ProcessNode']) -> t.Type['Func
varargs: str | None = None
keywords: str | None = None

try:
annotations = get_annotations(func, eval_str=True)
except Exception as exception: # pylint: disable=broad-except
# Since we are running with ``eval_str=True`` to unstringize the annotations, the call can except if the
# annotations are incorrect. In this case we simply want to log a warning and continue with type inference.
LOGGER.warning(f'function `{func.__name__}` has invalid type hints: {exception}')
annotations = {}

for key, parameter in signature.parameters.items():

if parameter.kind in [parameter.POSITIONAL_ONLY, parameter.POSITIONAL_OR_KEYWORD, parameter.KEYWORD_ONLY]:
Expand All @@ -251,6 +321,9 @@ def _define(cls, spec): # pylint: disable=unused-argument
if parameter.kind in [parameter.VAR_POSITIONAL, parameter.VAR_KEYWORD]:
continue

annotation = annotations.get(parameter.name)
valid_type = infer_valid_type_from_type_annotation(annotation) or (Data,)

default = parameter.default if parameter.default is not parameter.empty else UNSPECIFIED

# If the keyword was already specified, simply override the default
Expand All @@ -262,9 +335,7 @@ def _define(cls, spec): # pylint: disable=unused-argument
# use ``None`` because the validation will call ``isinstance`` which does not work when passing ``None``
# but it does work with ``NoneType`` which is returned by calling ``type(None)``.
if default is None:
valid_type = (Data, type(None))
else:
valid_type = (Data,)
valid_type += (type(None),)

# If a default is defined and it is not a ``Data`` instance it should be serialized, but this should be
# done lazily using a lambda, just as any port defaults should not define node instances directly as is
Expand Down
2 changes: 2 additions & 0 deletions docs/source/topics/calculations/concepts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ To solve this, one only has to wrap them in the :py:class:`~aiida.orm.nodes.data
The only difference with the previous snippet is that all inputs have been wrapped in the :py:class:`~aiida.orm.nodes.data.int.Int` class.
The result that is returned by the function, is now also an :py:class:`~aiida.orm.nodes.data.int.Int` node that can be stored in the provenance graph, and contains the result of the computation.

.. _topics:calculations:concepts:calcfunctions:automatic-serialization:

.. versionadded:: 2.1

If a function argument is a Python base type (i.e. a value of type ``bool``, ``dict``, ``Enum``, ``float``, ``int``, ``list`` or ``str``), it can be passed straight away to the function, without first having to wrap it in the corresponding AiiDA data type.
Expand Down
51 changes: 51 additions & 0 deletions docs/source/topics/processes/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,57 @@ The link labels for the example above will therefore be ``args_0``, ``args_1`` a
If any of these labels were to overlap with the label of a positional or keyword argument, a ``RuntimeError`` will be raised.
In this case, the conflicting argument name needs to be changed to something that does not overlap with the automatically generated labels for the variadic arguments.

Type validation
===============

.. versionadded:: 2.3

Type hints (introduced with `PEP 484 <https://peps.python.org/pep-0484/>`_ in Python 3.5) can be used to add automatic type validation of process function arguments.
For example, the following will raise a ``ValueError`` exception:

.. include:: include/snippets/functions/typing_call_raise.py
:code: python

When the process function is declared, the process specification (``ProcessSpec``) is built dynamically.
For each function argument, if a correct type hint is provided, it is set as the ``valid_type`` attribute of the corresponding input port.
In the example above, the ``x`` and ``y`` inputs have ``Int`` as type hint, which is why the call that passes a ``Float`` raises a ``ValueError``.

.. note::

Type hints for return values are currently not parsed.

If an argument accepts multiple types, the ``typing.Union`` class can be used as normal:

.. include:: include/snippets/functions/typing_union.py
:code: python

The call with an ``Int`` and a ``Float`` will now finish correctly.
Similarly, optional arguments, with ``None`` as a default, can be declared using ``typing.Optional``:

.. include:: include/snippets/functions/typing_none.py
:code: python

The `postponed evaluation of annotations introduced by PEP 563 <https://peps.python.org/pep-0563/>`_ is also supported.
This means it is possible to use Python base types for the type hint instead of AiiDA's ``Data`` node equivalent:

.. include:: include/snippets/functions/typing_pep_563.py
:code: python

The type hints are automatically serialized just as the actual inputs are when the function is called, :ref:`as introduced in v2.1<topics:calculations:concepts:calcfunctions:automatic-serialization>`.

The alternative syntax for union types ``X | Y`` `as introduced by PEP 604 <https://peps.python.org/pep-0604/>`_ is also supported:

.. include:: include/snippets/functions/typing_pep_604.py
:code: python

.. warning::

The usage of notation as defined by PEP 563 and PEP 604 are not supported for Python versions older than 3.10, even if the ``from __future__ import annotations`` statement is added.
The reason is that the type inference uses the `inspect.get_annotations <https://docs.python.org/3/library/inspect.html#inspect.get_annotations>`_ method, which was introduced in Python 3.10.
For older Python versions, the `get-annotations <https://pypi.org/project/get-annotations>`_ backport is used, but that does not work with PEP 563 and PEP 604, so the constructs from the ``typing`` module have to be used instead.

If a process function has invalid type hints, they will simply be ignored and a warning message is logged: ``function 'function_name' has invalid type hints``.
This ensures backwards compatibility in the case existing process functions had invalid type hints.

Return values
=============
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# -*- coding: utf-8 -*-
from aiida.engine import calcfunction
from aiida.orm import Float, Int


@calcfunction
def add(x: Int, y: Int):
return x + y

add(Int(1), Float(1.0))
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# -*- coding: utf-8 -*-
import typing as t

from aiida.engine import calcfunction
from aiida.orm import Int


@calcfunction
def add_multiply(x: Int, y: Int, z: typing.Optional[Int] = None):
if z is None:
z = Int(3)

return (x + y) * z

result = add_multiply(Int(1), Int(2))
result = add_multiply(Int(1), Int(2), Int(3))
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# -*- coding: utf-8 -*-
from __future__ import annotations

from aiida.engine import calcfunction


@calcfunction
def add(x: int, y: int):
return x + y

add(1, 2)
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# -*- coding: utf-8 -*-
from __future__ import annotations

from aiida.engine import calcfunction
from aiida.orm import Int


@calcfunction
def add_multiply(x: int, y: int, z: int | None = None):
if z is None:
z = Int(3)

return (x + y) * z

result = add_multiply(1, 2)
result = add_multiply(1, 2, 3)
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# -*- coding: utf-8 -*-
import typing as t

from aiida.engine import calcfunction
from aiida.orm import Float, Int


@calcfunction
def add(x: t.Union[Int, Float], y: t.Union[Int, Float]):
return x + y

add(Int(1), Float(1.0))
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ dependencies:
- click-spinner~=0.1.8
- click~=8.1
- disk-objectstore~=0.6.0
- get-annotations~=0.1
- python-graphviz~=0.13
- ipython<9,>=7
- jinja2~=3.0
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ dependencies = [
"click-spinner~=0.1.8",
"click~=8.1",
"disk-objectstore~=0.6.0",
"get-annotations~=0.1;python_version<'3.10'",
"graphviz~=0.13",
"ipython>=7,<9",
"jinja2~=3.0",
Expand Down Expand Up @@ -393,6 +394,7 @@ module = [
'docutils.*',
'flask_cors.*',
'flask_restful.*',
'get_annotations.*',
'graphviz.*',
'importlib._bootstrap.*',
'IPython.*',
Expand Down
1 change: 1 addition & 0 deletions requirements/requirements-py-3.8.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ Flask-Cors==3.0.10
Flask-RESTful==0.3.9
fonttools==4.28.2
future==0.18.3
get-annotations==0.1.2
graphviz==0.19
greenlet==1.1.2
idna==3.3
Expand Down
1 change: 1 addition & 0 deletions requirements/requirements-py-3.9.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ Flask-Cors==3.0.10
Flask-RESTful==0.3.9
fonttools==4.28.2
future==0.18.3
get-annotations==0.1.2
graphviz==0.19
greenlet==1.1.2
idna==3.3
Expand Down
Loading

0 comments on commit fc6d734

Please sign in to comment.