diff --git a/loki/transform/transform_remove_code.py b/loki/transform/transform_remove_code.py index f843151e1..dd6eb27a0 100644 --- a/loki/transform/transform_remove_code.py +++ b/loki/transform/transform_remove_code.py @@ -53,12 +53,12 @@ class RemoveCodeTransformation(Transformation): call_names : list of str List of subroutine names against which to match :any:`CallStatement` nodes during :method:`remove_calls`. - import_names : list of str - List of module names against which to match :any:`Import` - nodes during :method:`remove_calls`. intrinsic_names : list of str List of module names against which to match :any:`Intrinsic` nodes during :method:`remove_calls`. + remove_imports : boolean + Flag indicating whether to remove symbols from :any:`Import` + objects during :method:`remove_calls`; default: ``True`` kernel_only : boolean Only apply the configured removal to subroutines marked as "kernel"; default: ``False`` @@ -70,8 +70,8 @@ class RemoveCodeTransformation(Transformation): def __init__( self, remove_marked_regions=True, mark_with_comment=True, remove_dead_code=False, use_simplify=True, - call_names=None, import_names=None, - intrinsic_names=None, kernel_only=False + call_names=None, intrinsic_names=None, + remove_imports=True, kernel_only=False ): self.remove_marked_regions = remove_marked_regions self.mark_with_comment = mark_with_comment @@ -80,8 +80,8 @@ def __init__( self.use_simplify = use_simplify self.call_names = as_tuple(call_names) - self.import_names = as_tuple(import_names) self.intrinsic_names = as_tuple(intrinsic_names) + self.remove_imports = remove_imports self.kernel_only = kernel_only @@ -94,8 +94,8 @@ def transform_subroutine(self, routine, **kwargs): if self.call_names or self.intrinsic_names: do_remove_calls( routine, call_names=self.call_names, - import_names=self.import_names, - intrinsic_names=self.intrinsic_names + intrinsic_names=self.intrinsic_names, + remove_imports=self.remove_imports ) # Apply marked region removal @@ -222,7 +222,7 @@ def visit_PragmaRegion(self, o, **kwargs): def do_remove_calls( - routine, call_names=None, import_names=None, intrinsic_names=None, + routine, call_names=None, intrinsic_names=None, remove_imports=True ): """ Utility routine to remove all :any:`CallStatement` nodes @@ -235,17 +235,17 @@ def do_remove_calls( call_names : list of str List of subroutine names against which to match :any:`CallStatement` nodes. - import_names : list of str - List of module names against which to match :any:`Import` - nodes. intrinsic_names : list of str List of module names against which to match :any:`Intrinsic` nodes. + remove_imports : boolean + Flag indicating whether to remove the respective procedure + symbols from :any:`Import` objects; default: ``True``. """ transformer = RemoveCallsTransformer( call_names=call_names, intrinsic_names=intrinsic_names, - import_names=import_names + remove_imports=remove_imports ) routine.spec = transformer.visit(routine.spec) routine.body = transformer.visit(routine.body) @@ -274,23 +274,23 @@ class RemoveCallsTransformer(Transformer): call_names : list of str List of subroutine names against which to match :any:`CallStatement` nodes. - import_names : list of str - List of module names against which to match :any:`Import` - nodes. intrinsic_names : list of str List of module names against which to match :any:`Intrinsic` nodes. + remove_imports : boolean + Flag indicating whether to remove the respective procedure + symbols from :any:`Import` objects; default: ``True``. """ def __init__( - self, call_names=None, import_names=None, - intrinsic_names=None, **kwargs + self, call_names=None, intrinsic_names=None, + remove_imports=True, **kwargs ): super().__init__(**kwargs) self.call_names = as_tuple(call_names) self.intrinsic_names = as_tuple(intrinsic_names) - self.import_names = as_tuple(import_names) + self.remove_imports = remove_imports def visit_CallStatement(self, o, **kwargs): """ Match and remove :any:`CallStatement` nodes against name patterns """ @@ -322,10 +322,12 @@ def visit_Intrinsic(self, o, **kwargs): return self._rebuild(o, rebuilt) def visit_Import(self, o, **kwargs): - """ Match and remove :any:`Import` nodes against name patterns """ - if self.import_names: - if any(str(c).lower() in o.module.lower() for c in self.import_names): - return None + """ Remove the symbol of any named calls from Import nodes """ + + symbols_found = any(s in self.call_names for s in o.symbols) + if self.remove_imports and symbols_found: + new_symbols = tuple(s for s in o.symbols if s not in self.call_names) + return o.clone(symbols=new_symbols) if new_symbols else None rebuilt = tuple(self.visit(i, **kwargs) for i in o.children) return self._rebuild(o, rebuilt) diff --git a/scripts/loki_transform.py b/scripts/loki_transform.py index 60e1f223a..ea5e71153 100644 --- a/scripts/loki_transform.py +++ b/scripts/loki_transform.py @@ -194,8 +194,7 @@ def convert( if not remove_code_trafo: remove_code_trafo = RemoveCodeTransformation( remove_marked_regions=True, remove_dead_code=False, - call_names=('ABOR1', 'DR_HOOK'), import_names=('yomhook'), - intrinsic_names=('WRITE(NULOUT',) + call_names=('ABOR1', 'DR_HOOK'), intrinsic_names=('WRITE(NULOUT',) ) scheduler.process(transformation=remove_code_trafo) diff --git a/tests/test_transform_remove_code.py b/tests/test_transform_remove_code.py index e857fddba..2bdff7f89 100644 --- a/tests/test_transform_remove_code.py +++ b/tests/test_transform_remove_code.py @@ -288,7 +288,8 @@ def test_transform_remove_code_pragma_region(frontend, mark_with_comment): @pytest.mark.parametrize('frontend', available_frontends()) -def test_transform_remove_calls(frontend): +@pytest.mark.parametrize('remove_imports', [True, False]) +def test_transform_remove_calls(frontend, remove_imports): """ Test removal of utility calls and intrinsics with custom patterns. """ @@ -307,7 +308,6 @@ def test_transform_remove_calls(frontend): fcode_abor1 = """ module abor1_mod implicit none -integer(kind=8) :: NULOUT contains subroutine abor1(msg) character(len=*), intent(in) :: msg @@ -319,9 +319,10 @@ def test_transform_remove_calls(frontend): fcode = """ subroutine never_gonna_give(dave) use yomhook, only : lhook, dr_hook - use abor1_mod, only : abor1, NULOUT + use abor1_mod, only : abor1 implicit none + integer(kind=8), parameter :: NULOUT = 6 integer, parameter :: jprb = 8 logical, intent(in) :: dave real(kind=jprb) :: zhook_handle @@ -355,7 +356,7 @@ def test_transform_remove_calls(frontend): do_remove_calls( routine, call_names=('ABOR1', 'DR_HOOK'), intrinsic_names=('WRITE(NULOUT', 'write(unit=nulout'), - import_names=('yomhook',) + remove_imports=remove_imports ) # Check that all but one specific call have been removed @@ -376,8 +377,14 @@ def test_transform_remove_calls(frontend): # Check that the repsective imports have also been stripped imports = FindNodes(ir.Import).visit(routine.spec) - assert len(imports) == 1 - assert imports[0].module == 'abor1_mod' + assert len(imports) == 1 if remove_imports else 2 + assert imports[0].module == 'yomhook' + if remove_imports: + assert imports[0].symbols == ('lhook',) + else: + assert imports[0].symbols == ('lhook', 'dr_hook') + assert imports[1].module == 'abor1_mod' + assert imports[1].symbols == ('abor1',) @pytest.mark.parametrize('frontend', available_frontends( @@ -405,7 +412,7 @@ def test_remove_code_transformation(frontend, source, include_intrinsics, kernel # Apply the transformation to the call tree transformation = RemoveCodeTransformation( - call_names=('ABOR1', 'DR_HOOK'), import_names=('yomhook'), + call_names=('ABOR1', 'DR_HOOK'), intrinsic_names=('WRITE(NULOUT',) if include_intrinsics else (), kernel_only=kernel_only )