Skip to content

Commit

Permalink
XspectraCrystalWorkChain: Correct Error Handling
Browse files Browse the repository at this point in the history
Changes requested for PR #1028 (2)

Corrects error handling behaviour for `validate_inputs` to use `return`
rather than `raise`.

Also fixes minor formatting errors in `symmetry_data` `help` entry and
removes unnecessary uses of `+`.
  • Loading branch information
PNOGillespie committed Jul 10, 2024
1 parent 29bbd3b commit e741275
Showing 1 changed file with 33 additions and 37 deletions.
70 changes: 33 additions & 37 deletions src/aiida_quantumespresso/workflows/xspectra/crystal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
Uses QuantumESPRESSO pw.x and xspectra.x.
"""
from aiida import orm
from aiida.common import AttributeDict, ValidationError
from aiida.common import AttributeDict
from aiida.engine import ToContext, WorkChain, if_
from aiida.orm import UpfData as aiida_core_upf
from aiida.plugins import CalculationFactory, DataFactory, WorkflowFactory
Expand Down Expand Up @@ -180,10 +180,10 @@ def define(cls, spec):
required=False,
help=(
'Input namespace to define equivalent sites and spacegroup number for the system. If defined, will '
+ 'skip symmetry analysis and structure standardization. Use *only* if symmetry data are known'
+ 'for certain. Requires ``spacegroup_number`` (Int) and ``equivalent_sites_data`` (Dict) to be'
+ 'defined separately. All keys in `equivalent_sites_data` must be formatted as "site_<site_index>".'
+ 'See docstring of `get_xspectra_structures` for more information about inputs.'
'skip symmetry analysis and structure standardization. Use *only* if symmetry data are known '
'for certain. Requires ``spacegroup_number`` (Int) and ``equivalent_sites_data`` (Dict) to be '
'defined separately. All keys in `equivalent_sites_data` must be formatted as "site_<site_index>". '
'See docstring of `get_xspectra_structures` for more information about inputs.'
)
)
spec.inputs.validator = cls.validate_inputs
Expand Down Expand Up @@ -382,9 +382,8 @@ def get_builder_from_protocol( # pylint: disable=too-many-statements
return builder


# pylint: disable=too-many-statements
@staticmethod
def validate_inputs(inputs, _):
def validate_inputs(inputs, _): # pylint: disable=too-many-return-statements
"""Validate the inputs before launching the WorkChain."""
structure = inputs['structure']
kinds_present = [kind.name for kind in structure.kinds]
Expand All @@ -396,58 +395,57 @@ def validate_inputs(inputs, _):
if element not in elements_present:
extra_elements.append(element)
if len(extra_elements) > 0:
raise ValidationError(
return (
f'Some elements in ``elements_list`` {extra_elements} do not exist in the'
f' structure provided {elements_present}.'
)

abs_atom_marker = inputs['abs_atom_marker'].value
if abs_atom_marker in kinds_present:
raise ValidationError(
return (
f'The marker given for the absorbing atom ("{abs_atom_marker}") matches an existing Kind in the '
f'input structure ({kinds_present}).'
)

if not inputs['core']['get_powder_spectrum'].value:
raise ValidationError(
return (
'The ``get_powder_spectrum`` input for the XspectraCoreWorkChain namespace must be ``True``.'
)

if 'upf2plotcore_code' not in inputs and 'core_wfc_data' not in inputs:
raise ValidationError(
return (
'Neither a ``Code`` node for upf2plotcore.sh or a set of ``core_wfc_data`` were provided.'
)

if 'core_wfc_data' in inputs:
core_wfc_data_list = sorted(inputs['core_wfc_data'].keys())
if core_wfc_data_list != absorbing_elements_list:
raise ValidationError(
return (
f'The ``core_wfc_data`` provided ({core_wfc_data_list}) does not match the list of'
f' absorbing elements ({absorbing_elements_list})'
)
else:
empty_core_wfc_data = []
for key, value in inputs['core_wfc_data'].items():
header_line = value.get_content()[:40]
try:
num_core_states = int(header_line.split(' ')[5])
except Exception as exc:
raise ValidationError(
'The core wavefunction data file is not of the correct format'
) from exc
if num_core_states == 0:
empty_core_wfc_data.append(key)
if len(empty_core_wfc_data) > 0:
raise ValidationError(
f'The ``core_wfc_data`` provided for elements {empty_core_wfc_data} do not contain '
'any wavefunction data.'
)
empty_core_wfc_data = []
for key, value in inputs['core_wfc_data'].items():
header_line = value.get_content()[:40]
try:
num_core_states = int(header_line.split(' ')[5])
except: # pylint: disable=bare-except
return (
'The core wavefunction data file is not of the correct format'
) # pylint: enable=bare-except
if num_core_states == 0:
empty_core_wfc_data.append(key)
if len(empty_core_wfc_data) > 0:
return (
f'The ``core_wfc_data`` provided for elements {empty_core_wfc_data} do not contain '
'any wavefunction data.'
)

if 'symmetry_data' in inputs:
spacegroup_number = inputs['symmetry_data']['spacegroup_number'].value
equivalent_sites_data = inputs['symmetry_data']['equivalent_sites_data'].get_dict()
if spacegroup_number <= 0 or spacegroup_number >= 231:
raise ValidationError(
return (
f'Input spacegroup number ({spacegroup_number}) outside of valid range (1-230).'
)

Expand All @@ -466,25 +464,23 @@ def validate_inputs(inputs, _):
elif value['symbol'] not in input_elements:
input_elements.append(value['symbol'])
if value['site_index'] < 0 or value['site_index'] >= len(structure.sites):
raise ValidationError(
return (
f'The site index for {site_label} ({value["site_index"]}) is outside the range of '
+ f'sites within the structure (0-{len(structure.sites) -1}).'
)

if len(invalid_entries) != 0:
raise ValidationError(
return (
f'The required keys ({required_keys}) were not found in the following entries: {invalid_entries}'
)

sorted_input_elements = sorted(input_elements)
if sorted_input_elements != absorbing_elements_list:
raise ValidationError(
f'Elements defined for sites in `equivalent_sites_data` ({sorted_input_elements}) do not match the'
+ f' list of absorbing elements ({absorbing_elements_list})'
)
return (f'Elements defined for sites in `equivalent_sites_data` ({sorted_input_elements}) '
f'do not match the list of absorbing elements ({absorbing_elements_list})')


# pylint: enable=too-many-statements
# pylint: enable=too-many-return-statements
def setup(self):
"""Set required context variables."""
if 'core_wfc_data' in self.inputs.keys():
Expand Down

0 comments on commit e741275

Please sign in to comment.