Skip to content

Commit

Permalink
Merge pull request #157 from ecmwf-ifs/nabr_fix_scc_cuf
Browse files Browse the repository at this point in the history
Fix cudafor import in driver for SCC CUF
  • Loading branch information
reuterbal authored Sep 25, 2023
2 parents 2322c1b + 7c7a75d commit edc54f8
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 9 deletions.
3 changes: 3 additions & 0 deletions transformations/tests/test_scc_cuf.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ def fixture_config():


def check_subroutine_driver(routine, blocking, disable=()):
# use of "use cudafor"
imports = [_import.module.lower() for _import in FindNodes(Import).visit(routine.spec)]
assert "cudafor" in imports
# device arrays
# device arrays: declaration
arrays = [var for var in routine.variables if isinstance(var, sym.Array)]
Expand Down
10 changes: 1 addition & 9 deletions transformations/transformations/scc_cuf.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,13 +671,6 @@ def __init__(self, horizontal, vertical, block_dim, transformation_type='paramet
self.derived_types = [_.upper() for _ in derived_types]
self.derived_type_variables = ()

def transform_module(self, module, **kwargs):

role = kwargs.get('role')

if role == 'driver':
module.spec.prepend(ir.Import(module="cudafor"))

def transform_subroutine(self, routine, **kwargs):

item = kwargs.get('item', None)
Expand All @@ -696,8 +689,7 @@ def transform_subroutine(self, routine, **kwargs):
single_variable_declaration(routine=routine, group_by_shape=True)
device_subroutine_prefix(routine, depth)

if depth > 0:
routine.spec.prepend(ir.Import(module="cudafor"))
routine.spec.prepend(ir.Import(module="cudafor"))

if role == 'driver':
self.process_routine_driver(routine, targets=targets)
Expand Down

0 comments on commit edc54f8

Please sign in to comment.