Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ORM: Switch to pydantic for code schema definition #6190

Merged
merged 1 commit into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/nitpick-exceptions
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ py:meth click.Option.get_default
py:meth fail

py:class ComputedFieldInfo
py:class pydantic.fields.Field
py:class pydantic.main.BaseModel

py:class requests.models.Response
Expand Down
4 changes: 2 additions & 2 deletions src/aiida/cmdline/commands/cmd_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def show(code):
table.append(['PK', code.pk])
table.append(['UUID', code.uuid])
table.append(['Type', code.entry_point.name])
for key in code.get_cli_options().keys():
for key in code.Model.model_fields.keys():
try:
table.append([key.capitalize().replace('_', ' '), getattr(code, key)])
except AttributeError:
Expand All @@ -242,7 +242,7 @@ def export(code, output_file):

code_data = {}

for key in code.get_cli_options().keys():
for key in code.Model.model_fields.keys():
if key == 'computer':
value = getattr(code, key).label
else:
Expand Down
109 changes: 72 additions & 37 deletions src/aiida/cmdline/groups/dynamic.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""Subclass of :class:`click.Group` that loads subcommands dynamically from entry points."""
from __future__ import annotations

import copy
import functools
import re
import typing as t
import warnings

import click

Expand Down Expand Up @@ -88,10 +88,35 @@ def get_command(self, ctx: click.Context, cmd_name: str) -> click.Command | None
command = super().get_command(ctx, cmd_name)
return command

def call_command(self, ctx, cls, **kwargs):
"""Call the ``command`` after validating the provided inputs."""
from pydantic import ValidationError

if hasattr(cls, 'Model'):
# The plugin defines a pydantic model: use it to validate the provided arguments
try:
model = cls.Model(**kwargs)
except ValidationError as exception:
param_hint = [
f'--{loc.replace("_", "-")}' # type: ignore[union-attr]
for loc in exception.errors()[0]['loc']
]
message = '\n'.join([str(e['ctx']['error']) for e in exception.errors()])
raise click.BadParameter(
message,
param_hint=param_hint or 'multiple parameters', # type: ignore[arg-type]
) from exception

# Update the arguments with the dictionary representation of the model. This will include any type coercions
# that may have been applied with validators defined for the model.
kwargs.update(**model.model_dump())

return self._command(ctx, cls, **kwargs)

def create_command(self, ctx: click.Context, entry_point: str) -> click.Command:
"""Create a subcommand for the given ``entry_point``."""
cls = self.factory(entry_point)
command = functools.partial(self._command, ctx, cls)
command = functools.partial(self.call_command, ctx, cls)
command.__doc__ = cls.__doc__
return click.command(entry_point)(self.create_options(entry_point)(command))

Expand Down Expand Up @@ -131,61 +156,71 @@ def list_options(self, entry_point: str) -> list:

cls = self.factory(entry_point)

if not hasattr(cls, 'Configuration'):
# This should be enabled once the ``Code`` classes are migrated to using pydantic to define their model.
# See https://github.com/aiidateam/aiida-core/pull/6190
# from aiida.common.warnings import warn_deprecation
# warn_deprecation(
# 'Relying on `_get_cli_options` is deprecated. The options should be defined through a '
# '`pydantic.BaseModel` that should be assigned to the `Config` class attribute.',
# version=3
# )
if not hasattr(cls, 'Model'):
from aiida.common.warnings import warn_deprecation

warn_deprecation(
'Relying on `_get_cli_options` is deprecated. The options should be defined through a '
'`pydantic.BaseModel` that should be assigned to the `Config` class attribute.',
'`pydantic.BaseModel` that should be assigned to the `Model` class attribute.',
version=3,
)
options_spec = self.factory(entry_point).get_cli_options() # type: ignore[union-attr]
else:
options_spec = {}

for key, field_info in cls.Configuration.model_fields.items():
default = field_info.default_factory if field_info.default is PydanticUndefined else field_info.default

# The ``field_info.annotation`` property returns the annotation of the field. This can be a plain type
# or a type from ``typing``, e.g., ``Union[int, float]`` or ``Optional[str]``. In these cases, the type
# that needs to be passed to ``click`` is the arguments of the type, which can be obtained using the
# ``typing.get_args()`` method. If it is not a compound type, this returns an empty tuplem so in that
# case, the type is simply the ``field_info.annotation``.
options_spec[key] = {
'required': field_info.is_required(),
'type': t.get_args(field_info.annotation) or field_info.annotation,
'prompt': field_info.title,
'default': default,
'help': field_info.description,
}

return [self.create_option(*item) for item in options_spec.items()]
return [self.create_option(*item) for item in options_spec]

options_spec = {}

for key, field_info in cls.Model.model_fields.items():
default = field_info.default_factory if field_info.default is PydanticUndefined else field_info.default

# If the annotation has the ``__args__`` attribute it is an instance of a type from ``typing`` and the real
# type can be gotten from the arguments. For example it could be ``typing.Union[str, None]`` calling
# ``typing.Union[str, None].__args__`` will return the tuple ``(str, NoneType)``. So to get the real type,
# we simply remove all ``NoneType`` and the remaining type should be the type of the option.
if hasattr(field_info.annotation, '__args__'):
args = list(filter(lambda e: e != type(None), field_info.annotation.__args__))
if len(args) > 1:
warnings.warn(
f'field `{key}` defines multiple types, but can take only one, taking the first: `{args[0]}`',
UserWarning,
)
field_type = args[0]
else:
field_type = field_info.annotation

options_spec[key] = {
'required': field_info.is_required(),
'type': field_type,
'is_flag': field_type is bool,
'prompt': field_info.title,
'default': default,
'help': field_info.description,
}
for metadata in field_info.metadata:
for metadata_key, metadata_value in metadata.items():
options_spec[key][metadata_key] = metadata_value

options_ordered = []

for name, spec in sorted(options_spec.items(), key=lambda x: x[1].get('priority', 0), reverse=True):
spec.pop('priority', None)
options_ordered.append(self.create_option(name, spec))

return options_ordered

@staticmethod
def create_option(name, spec: dict) -> t.Callable[[t.Any], t.Any]:
"""Create a click option from a name and a specification."""
spec = copy.deepcopy(spec)

is_flag = spec.pop('is_flag', False)
default = spec.get('default')
name_dashed = name.replace('_', '-')
option_name = f'--{name_dashed}/--no-{name_dashed}' if is_flag else f'--{name_dashed}'
option_short_name = spec.pop('short_name', None)
option_names = (option_short_name, option_name) if option_short_name else (option_name,)

kwargs = {'cls': spec.pop('cls', InteractiveOption), 'show_default': True, 'is_flag': is_flag, **spec}
kwargs = {'cls': spec.pop('option_cls', InteractiveOption), 'show_default': True, 'is_flag': is_flag, **spec}

# If the option is a flag with no default, make sure it is not prompted for, as that will force the user to
# specify it to be on or off, but cannot let it unspecified.
if kwargs['cls'] is InteractiveOption and is_flag and default is None:
if kwargs['cls'] is InteractiveOption and is_flag and spec.get('default') is None:
kwargs['cls'] = functools.partial(InteractiveOption, prompt_fn=lambda ctx: False)

return click.option(*(option_names), **kwargs)
46 changes: 46 additions & 0 deletions src/aiida/common/pydantic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""Utilities related to ``pydantic``."""
from __future__ import annotations

import typing as t

from pydantic import Field


def MetadataField( # noqa: N802
default: t.Any | None = None,
*,
priority: int = 0,
short_name: str | None = None,
option_cls: t.Any | None = None,
**kwargs,
):
"""Return a :class:`pydantic.fields.Field` instance with additional metadata.

.. code-block:: python

class Model(BaseModel):

attribute: MetadataField('default', priority=1000, short_name='-A')

This is a utility function that constructs a ``Field`` instance with an easy interface to add additional metadata.
It is possible to add metadata using ``Annotated``::

class Model(BaseModel):

attribute: Annotated[str, {'metadata': 'value'}] = Field(...)

However, when requiring multiple metadata, this notation can make the model difficult to read. Since this utility
is only used to automatically build command line interfaces from the model definition, it is possible to restrict
which metadata are accepted.

:param priority: Used to order the list of all fields in the model. Ordering is done from small to large priority.
:param short_name: Optional short name to use for an option on a command line interface.
:param option_cls: The :class:`click.Option` class to use to construct the option.
"""
field_info = Field(default, **kwargs)

for key, value in (('priority', priority), ('short_name', short_name), ('option_cls', option_cls)):
if value is not None:
field_info.metadata.append({key: value})

return field_info
2 changes: 1 addition & 1 deletion src/aiida/manage/configuration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def create_profile(
from aiida.manage import get_manager
from aiida.orm import User

storage_config = storage_cls.Configuration(**{k: v for k, v in kwargs.items() if v is not None}).model_dump()
storage_config = storage_cls.Model(**{k: v for k, v in kwargs.items() if v is not None}).model_dump()
profile: Profile = config.create_profile(name=name, storage_cls=storage_cls, storage_config=storage_config)

with profile_context(profile.name, allow_switch=True):
Expand Down
Loading
Loading