From 7c7a75dfb4146e2e2a1cb37b5076e56edef23f47 Mon Sep 17 00:00:00 2001 From: Balthasar Reuter Date: Fri, 22 Sep 2023 17:25:14 +0100 Subject: [PATCH] Fix cudafor import in driver for SCC CUF --- transformations/tests/test_scc_cuf.py | 3 +++ transformations/transformations/scc_cuf.py | 10 +--------- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/transformations/tests/test_scc_cuf.py b/transformations/tests/test_scc_cuf.py index 240c93cb7..0df127e2d 100644 --- a/transformations/tests/test_scc_cuf.py +++ b/transformations/tests/test_scc_cuf.py @@ -60,6 +60,9 @@ def fixture_config(): def check_subroutine_driver(routine, blocking, disable=()): + # use of "use cudafor" + imports = [_import.module.lower() for _import in FindNodes(Import).visit(routine.spec)] + assert "cudafor" in imports # device arrays # device arrays: declaration arrays = [var for var in routine.variables if isinstance(var, sym.Array)] diff --git a/transformations/transformations/scc_cuf.py b/transformations/transformations/scc_cuf.py index 6ffc779cb..9235b167c 100644 --- a/transformations/transformations/scc_cuf.py +++ b/transformations/transformations/scc_cuf.py @@ -671,13 +671,6 @@ def __init__(self, horizontal, vertical, block_dim, transformation_type='paramet self.derived_types = [_.upper() for _ in derived_types] self.derived_type_variables = () - def transform_module(self, module, **kwargs): - - role = kwargs.get('role') - - if role == 'driver': - module.spec.prepend(ir.Import(module="cudafor")) - def transform_subroutine(self, routine, **kwargs): item = kwargs.get('item', None) @@ -696,8 +689,7 @@ def transform_subroutine(self, routine, **kwargs): single_variable_declaration(routine=routine, group_by_shape=True) device_subroutine_prefix(routine, depth) - if depth > 0: - routine.spec.prepend(ir.Import(module="cudafor")) + routine.spec.prepend(ir.Import(module="cudafor")) if role == 'driver': self.process_routine_driver(routine, targets=targets)