Skip to content

Commit

Permalink
Transform: Strip only the removed procedure symbols from Imports
Browse files Browse the repository at this point in the history
  • Loading branch information
mlange05 committed Apr 8, 2024
1 parent 6152cc4 commit a9e4770
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 32 deletions.
48 changes: 25 additions & 23 deletions loki/transform/transform_remove_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,12 @@ class RemoveCodeTransformation(Transformation):
call_names : list of str
List of subroutine names against which to match
:any:`CallStatement` nodes during :method:`remove_calls`.
import_names : list of str
List of module names against which to match :any:`Import`
nodes during :method:`remove_calls`.
intrinsic_names : list of str
List of module names against which to match :any:`Intrinsic`
nodes during :method:`remove_calls`.
remove_imports : boolean
Flag indicating whether to remove symbols from :any:`Import`
objects during :method:`remove_calls`; default: ``True``
kernel_only : boolean
Only apply the configured removal to subroutines marked as
"kernel"; default: ``False``
Expand All @@ -70,8 +70,8 @@ class RemoveCodeTransformation(Transformation):
def __init__(
self, remove_marked_regions=True, mark_with_comment=True,
remove_dead_code=False, use_simplify=True,
call_names=None, import_names=None,
intrinsic_names=None, kernel_only=False
call_names=None, intrinsic_names=None,
remove_imports=True, kernel_only=False
):
self.remove_marked_regions = remove_marked_regions
self.mark_with_comment = mark_with_comment
Expand All @@ -80,8 +80,8 @@ def __init__(
self.use_simplify = use_simplify

self.call_names = as_tuple(call_names)
self.import_names = as_tuple(import_names)
self.intrinsic_names = as_tuple(intrinsic_names)
self.remove_imports = remove_imports

self.kernel_only = kernel_only

Expand All @@ -94,8 +94,8 @@ def transform_subroutine(self, routine, **kwargs):
if self.call_names or self.intrinsic_names:
do_remove_calls(
routine, call_names=self.call_names,
import_names=self.import_names,
intrinsic_names=self.intrinsic_names
intrinsic_names=self.intrinsic_names,
remove_imports=self.remove_imports
)

# Apply marked region removal
Expand Down Expand Up @@ -222,7 +222,7 @@ def visit_PragmaRegion(self, o, **kwargs):


def do_remove_calls(
routine, call_names=None, import_names=None, intrinsic_names=None,
routine, call_names=None, intrinsic_names=None, remove_imports=True
):
"""
Utility routine to remove all :any:`CallStatement` nodes
Expand All @@ -235,17 +235,17 @@ def do_remove_calls(
call_names : list of str
List of subroutine names against which to match
:any:`CallStatement` nodes.
import_names : list of str
List of module names against which to match :any:`Import`
nodes.
intrinsic_names : list of str
List of module names against which to match :any:`Intrinsic`
nodes.
remove_imports : boolean
Flag indicating whether to remove the respective procedure
symbols from :any:`Import` objects; default: ``True``.
"""

transformer = RemoveCallsTransformer(
call_names=call_names, intrinsic_names=intrinsic_names,
import_names=import_names
remove_imports=remove_imports
)
routine.spec = transformer.visit(routine.spec)
routine.body = transformer.visit(routine.body)
Expand Down Expand Up @@ -274,23 +274,23 @@ class RemoveCallsTransformer(Transformer):
call_names : list of str
List of subroutine names against which to match
:any:`CallStatement` nodes.
import_names : list of str
List of module names against which to match :any:`Import`
nodes.
intrinsic_names : list of str
List of module names against which to match :any:`Intrinsic`
nodes.
remove_imports : boolean
Flag indicating whether to remove the respective procedure
symbols from :any:`Import` objects; default: ``True``.
"""

def __init__(
self, call_names=None, import_names=None,
intrinsic_names=None, **kwargs
self, call_names=None, intrinsic_names=None,
remove_imports=True, **kwargs
):
super().__init__(**kwargs)

self.call_names = as_tuple(call_names)
self.intrinsic_names = as_tuple(intrinsic_names)
self.import_names = as_tuple(import_names)
self.remove_imports = remove_imports

def visit_CallStatement(self, o, **kwargs):
""" Match and remove :any:`CallStatement` nodes against name patterns """
Expand Down Expand Up @@ -322,10 +322,12 @@ def visit_Intrinsic(self, o, **kwargs):
return self._rebuild(o, rebuilt)

def visit_Import(self, o, **kwargs):
""" Match and remove :any:`Import` nodes against name patterns """
if self.import_names:
if any(str(c).lower() in o.module.lower() for c in self.import_names):
return None
""" Remove the symbol of any named calls from Import nodes """

symbols_found = any(s in self.call_names for s in o.symbols)
if self.remove_imports and symbols_found:
new_symbols = tuple(s for s in o.symbols if s not in self.call_names)
return o.clone(symbols=new_symbols) if new_symbols else None

rebuilt = tuple(self.visit(i, **kwargs) for i in o.children)
return self._rebuild(o, rebuilt)
3 changes: 1 addition & 2 deletions scripts/loki_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,7 @@ def convert(
if not remove_code_trafo:
remove_code_trafo = RemoveCodeTransformation(
remove_marked_regions=True, remove_dead_code=False,
call_names=('ABOR1', 'DR_HOOK'), import_names=('yomhook'),
intrinsic_names=('WRITE(NULOUT',)
call_names=('ABOR1', 'DR_HOOK'), intrinsic_names=('WRITE(NULOUT',)
)
scheduler.process(transformation=remove_code_trafo)

Expand Down
21 changes: 14 additions & 7 deletions tests/test_transform_remove_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,8 @@ def test_transform_remove_code_pragma_region(frontend, mark_with_comment):


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_remove_calls(frontend):
@pytest.mark.parametrize('remove_imports', [True, False])
def test_transform_remove_calls(frontend, remove_imports):
"""
Test removal of utility calls and intrinsics with custom patterns.
"""
Expand All @@ -307,7 +308,6 @@ def test_transform_remove_calls(frontend):
fcode_abor1 = """
module abor1_mod
implicit none
integer(kind=8) :: NULOUT
contains
subroutine abor1(msg)
character(len=*), intent(in) :: msg
Expand All @@ -319,9 +319,10 @@ def test_transform_remove_calls(frontend):
fcode = """
subroutine never_gonna_give(dave)
use yomhook, only : lhook, dr_hook
use abor1_mod, only : abor1, NULOUT
use abor1_mod, only : abor1
implicit none
integer(kind=8), parameter :: NULOUT = 6
integer, parameter :: jprb = 8
logical, intent(in) :: dave
real(kind=jprb) :: zhook_handle
Expand Down Expand Up @@ -355,7 +356,7 @@ def test_transform_remove_calls(frontend):
do_remove_calls(
routine, call_names=('ABOR1', 'DR_HOOK'),
intrinsic_names=('WRITE(NULOUT', 'write(unit=nulout'),
import_names=('yomhook',)
remove_imports=remove_imports
)

# Check that all but one specific call have been removed
Expand All @@ -376,8 +377,14 @@ def test_transform_remove_calls(frontend):

# Check that the repsective imports have also been stripped
imports = FindNodes(ir.Import).visit(routine.spec)
assert len(imports) == 1
assert imports[0].module == 'abor1_mod'
assert len(imports) == 1 if remove_imports else 2
assert imports[0].module == 'yomhook'
if remove_imports:
assert imports[0].symbols == ('lhook',)
else:
assert imports[0].symbols == ('lhook', 'dr_hook')
assert imports[1].module == 'abor1_mod'
assert imports[1].symbols == ('abor1',)


@pytest.mark.parametrize('frontend', available_frontends(
Expand Down Expand Up @@ -405,7 +412,7 @@ def test_remove_code_transformation(frontend, source, include_intrinsics, kernel

# Apply the transformation to the call tree
transformation = RemoveCodeTransformation(
call_names=('ABOR1', 'DR_HOOK'), import_names=('yomhook'),
call_names=('ABOR1', 'DR_HOOK'),
intrinsic_names=('WRITE(NULOUT',) if include_intrinsics else (),
kernel_only=kernel_only
)
Expand Down

0 comments on commit a9e4770

Please sign in to comment.