diff --git a/ci_tools/gyselalib_static_analysis.py b/ci_tools/gyselalib_static_analysis.py index 7eb945b4a..df887aa4d 100644 --- a/ci_tools/gyselalib_static_analysis.py +++ b/ci_tools/gyselalib_static_analysis.py @@ -41,6 +41,8 @@ mirror_functions = {'create_mirror', 'create_mirror_and_copy', 'create_mirror_view', 'create_mirror_view_and_copy'} +parallel_functions = ['parallel_for', 'parallel_for_each', 'parallel_transform_reduce'] + HOME_DIR = Path(__file__).parent.parent.absolute() global_folders = [HOME_DIR / f for f in ('src', 'simulations', 'tests')] @@ -224,8 +226,8 @@ def search_for_unnecessary_auto(file): for v,t in zip(variables, variable_types) if t is not None and t.attrib['str'] == 'auto'] # Find the name and location of the first use of the auto variables - var_data_idx = [file.data.index(v.attrib) for v in auto_variables if v] - var_names = [v.attrib['str'] for v in auto_variables if v] + var_data_idx = [file.data.index(v.attrib) for v in auto_variables if v is not None] + var_names = [v.attrib['str'] for v in auto_variables if v is not None] for idx, var_name in enumerate(var_names): start = var_data_idx[idx] end = next(i for i,v in enumerate(file.data[start:], start) if v['str'] == ';') @@ -240,6 +242,27 @@ def search_for_unnecessary_auto(file): if not any(v['str'] in chain(mirror_functions, auto_functions) for v in file.data[start:end]): report_error(ERROR, file, file.data[start]['linenr'], f"Please use explicit types instead of auto ({var_name})") + # Find auto in arguments of lambda functions + for elem in file.data_xml.findall(".token[@str='[']"): + start_idx = file.data.index(elem.attrib)+1 + end_idx = next(j for j,a in enumerate(file.data[start_idx:], start_idx) if a['str'] == ']') + keys = ''.join(a['str'] for a in file.data[start_idx:end_idx]).split(',') + if not all(k[0] in ('&', '=') for k in keys): + continue + end_args_idx = next(j for j,a in enumerate(file.data[end_idx:], end_idx) if a['str'] == ')') + lambda_args = ' '.join(a['str'] for a in file.data[end_idx+2:end_args_idx]).split(',') + for a in lambda_args: + if 'auto' in a: + report_error(ERROR, file, file.data[start]['linenr'], f"Please use explicit types instead of auto ({a})") + + for elem in chain(file.data_xml.findall(".token[@str='KOKKOS_CLASS_LAMBDA']"), file.data_xml.findall(".token[@str='KOKKOS_LAMBDA']")): + start_idx = file.data.index(elem.attrib)+3 + end_args_idx = next(j for j,a in enumerate(file.data[start_idx:], start_idx) if a['str'] == ')') + lambda_args = ' '.join(a['str'] for a in file.data[start_idx:end_args_idx]).split(',') + for a in lambda_args: + if 'auto' in a.split(): + report_error(ERROR, file, file.data[start]['linenr'], f"Please use explicit types instead of auto ({a})") + def search_for_bad_create_mirror(file): """ Test for instances where create_mirror functions are called incorrectly. @@ -406,6 +429,52 @@ def check_directives(file): elif not possible_matches and include_str == '"': report_error(STYLE, file, linenr, f'Angle brackets should be used to include files from external libraries ({include_str}-><{include_file}>)') +def check_kokkos_lambda_use(file): + """ + Check that KOKKOS_LAMBDA expressions are present and that they are correctly used. + + The checks are: + - Check that lambda functions passed to ddc::parallel_X functions use KOKKOS_LAMBDA or KOKKOS_CLASS_LAMBDA + - Check that class variables are not used in lambda functions passed to ddc::parallel_X functions. + """ + for p in parallel_functions: + for elem in file.data_xml.findall(f".token[@str='{p}']"): + idx = file.data.index(elem.attrib) + open_brackets = file.data[idx+1] + close_brackets_id = open_brackets['link'] + scope = open_brackets['scope'] + end_idx = file.data.index(file.data_xml.find(f".token[@id='{close_brackets_id}'][@scope='{scope}']").attrib) + arg_splitters = [idx+1] + i = idx+3 + while i < end_idx: + if file.data[i]['str'] == ',': + arg_splitters.append(i) + i+=1 + elif file.data[i]['str'] in ('(','{', '['): + close_brackets_id = file.data[i]['link'] + scope = file.data[i]['scope'] + i = file.data.index(file.data_xml.find(f".token[@id='{close_brackets_id}'][@scope='{scope}']").attrib)+1 + else: + i+=1 + + if file.data[arg_splitters[-1]+1]['str'] == '[': + last_arg = ' '.join(a['str'] for a in file.data[arg_splitters[-1]+1:end_idx]) + correct_last_arg = 'KOKKOS_LAMBDA '+last_arg[last_arg.find(']')+1:] + report_error(ERROR, file, elem.attrib['linenr'], f'The lambda function passed to the function ddc::{p} must be a KOKKOS_LAMBDA\n{last_arg}\nShould be:\n{correct_last_arg}') + + lambda_body_start = next(i for i,a in enumerate(file.data[arg_splitters[-1]+1:], arg_splitters[-1]+1) if a['str'] == '{') + lambda_body_end_id = file.data[lambda_body_start]['link'] + lambda_body_scope = file.data[lambda_body_start]['scope'] + end_body_idx = file.data.index(file.data_xml.find(f".token[@id='{lambda_body_end_id}'][@scope='{lambda_body_scope}']").attrib) + for a in file.data[lambda_body_start+1:end_body_idx]: + var_id = a.get('variable', None) + if var_id: + var = next(v for v in file.variables if v['id'] == var_id) + var_scope = var['scope'] + scope_info = file.root.find(f"./dump/scopes/scope[@id='{var_scope}']").attrib + if scope_info['type'] == 'Class': + report_error(ERROR, file, a['linenr'], f'Please create a local variable with the suffix "_proxy" to store the class variable that is used in a parallel loop ({a["str"]})') + def check_licence_presence(file): """ Test to ensure that all .hpp files have the license information in the first line. @@ -454,5 +523,6 @@ def check_licence_presence(file): search_for_bad_memory(myfile) check_directives(myfile) check_licence_presence(myfile) + check_kokkos_lambda_use(myfile) sys.exit(error_level)