Skip to content

Commit

Permalink
Merge branch 'ebourne_static_analysis_improvements' into 'main'
Browse files Browse the repository at this point in the history
Static analysis improvements

See merge request gysela-developpers/gyselalibxx!674
  • Loading branch information
EmilyBourne committed Aug 27, 2024
1 parent 6077354 commit d9eb82c
Showing 1 changed file with 72 additions and 2 deletions.
74 changes: 72 additions & 2 deletions ci_tools/gyselalib_static_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@

mirror_functions = {'create_mirror', 'create_mirror_and_copy', 'create_mirror_view', 'create_mirror_view_and_copy'}

parallel_functions = ['parallel_for', 'parallel_for_each', 'parallel_transform_reduce']

HOME_DIR = Path(__file__).parent.parent.absolute()
global_folders = [HOME_DIR / f for f in ('src', 'simulations', 'tests')]

Expand Down Expand Up @@ -224,8 +226,8 @@ def search_for_unnecessary_auto(file):
for v,t in zip(variables, variable_types) if t is not None and t.attrib['str'] == 'auto']

# Find the name and location of the first use of the auto variables
var_data_idx = [file.data.index(v.attrib) for v in auto_variables if v]
var_names = [v.attrib['str'] for v in auto_variables if v]
var_data_idx = [file.data.index(v.attrib) for v in auto_variables if v is not None]
var_names = [v.attrib['str'] for v in auto_variables if v is not None]
for idx, var_name in enumerate(var_names):
start = var_data_idx[idx]
end = next(i for i,v in enumerate(file.data[start:], start) if v['str'] == ';')
Expand All @@ -240,6 +242,27 @@ def search_for_unnecessary_auto(file):
if not any(v['str'] in chain(mirror_functions, auto_functions) for v in file.data[start:end]):
report_error(ERROR, file, file.data[start]['linenr'], f"Please use explicit types instead of auto ({var_name})")

# Find auto in arguments of lambda functions
for elem in file.data_xml.findall(".token[@str='[']"):
start_idx = file.data.index(elem.attrib)+1
end_idx = next(j for j,a in enumerate(file.data[start_idx:], start_idx) if a['str'] == ']')
keys = ''.join(a['str'] for a in file.data[start_idx:end_idx]).split(',')
if not all(k[0] in ('&', '=') for k in keys):
continue
end_args_idx = next(j for j,a in enumerate(file.data[end_idx:], end_idx) if a['str'] == ')')
lambda_args = ' '.join(a['str'] for a in file.data[end_idx+2:end_args_idx]).split(',')
for a in lambda_args:
if 'auto' in a:
report_error(ERROR, file, file.data[start]['linenr'], f"Please use explicit types instead of auto ({a})")

for elem in chain(file.data_xml.findall(".token[@str='KOKKOS_CLASS_LAMBDA']"), file.data_xml.findall(".token[@str='KOKKOS_LAMBDA']")):
start_idx = file.data.index(elem.attrib)+3
end_args_idx = next(j for j,a in enumerate(file.data[start_idx:], start_idx) if a['str'] == ')')
lambda_args = ' '.join(a['str'] for a in file.data[start_idx:end_args_idx]).split(',')
for a in lambda_args:
if 'auto' in a.split():
report_error(ERROR, file, file.data[start]['linenr'], f"Please use explicit types instead of auto ({a})")

def search_for_bad_create_mirror(file):
"""
Test for instances where create_mirror functions are called incorrectly.
Expand Down Expand Up @@ -406,6 +429,52 @@ def check_directives(file):
elif not possible_matches and include_str == '"':
report_error(STYLE, file, linenr, f'Angle brackets should be used to include files from external libraries ({include_str}-><{include_file}>)')

def check_kokkos_lambda_use(file):
"""
Check that KOKKOS_LAMBDA expressions are present and that they are correctly used.
The checks are:
- Check that lambda functions passed to ddc::parallel_X functions use KOKKOS_LAMBDA or KOKKOS_CLASS_LAMBDA
- Check that class variables are not used in lambda functions passed to ddc::parallel_X functions.
"""
for p in parallel_functions:
for elem in file.data_xml.findall(f".token[@str='{p}']"):
idx = file.data.index(elem.attrib)
open_brackets = file.data[idx+1]
close_brackets_id = open_brackets['link']
scope = open_brackets['scope']
end_idx = file.data.index(file.data_xml.find(f".token[@id='{close_brackets_id}'][@scope='{scope}']").attrib)
arg_splitters = [idx+1]
i = idx+3
while i < end_idx:
if file.data[i]['str'] == ',':
arg_splitters.append(i)
i+=1
elif file.data[i]['str'] in ('(','{', '['):
close_brackets_id = file.data[i]['link']
scope = file.data[i]['scope']
i = file.data.index(file.data_xml.find(f".token[@id='{close_brackets_id}'][@scope='{scope}']").attrib)+1
else:
i+=1

if file.data[arg_splitters[-1]+1]['str'] == '[':
last_arg = ' '.join(a['str'] for a in file.data[arg_splitters[-1]+1:end_idx])
correct_last_arg = 'KOKKOS_LAMBDA '+last_arg[last_arg.find(']')+1:]
report_error(ERROR, file, elem.attrib['linenr'], f'The lambda function passed to the function ddc::{p} must be a KOKKOS_LAMBDA\n{last_arg}\nShould be:\n{correct_last_arg}')

lambda_body_start = next(i for i,a in enumerate(file.data[arg_splitters[-1]+1:], arg_splitters[-1]+1) if a['str'] == '{')
lambda_body_end_id = file.data[lambda_body_start]['link']
lambda_body_scope = file.data[lambda_body_start]['scope']
end_body_idx = file.data.index(file.data_xml.find(f".token[@id='{lambda_body_end_id}'][@scope='{lambda_body_scope}']").attrib)
for a in file.data[lambda_body_start+1:end_body_idx]:
var_id = a.get('variable', None)
if var_id:
var = next(v for v in file.variables if v['id'] == var_id)
var_scope = var['scope']
scope_info = file.root.find(f"./dump/scopes/scope[@id='{var_scope}']").attrib
if scope_info['type'] == 'Class':
report_error(ERROR, file, a['linenr'], f'Please create a local variable with the suffix "_proxy" to store the class variable that is used in a parallel loop ({a["str"]})')

def check_licence_presence(file):
"""
Test to ensure that all .hpp files have the license information in the first line.
Expand Down Expand Up @@ -454,5 +523,6 @@ def check_licence_presence(file):
search_for_bad_memory(myfile)
check_directives(myfile)
check_licence_presence(myfile)
check_kokkos_lambda_use(myfile)

sys.exit(error_level)

0 comments on commit d9eb82c

Please sign in to comment.