Skip to content

Commit

Permalink
Merge pull request #175 from ecmwf-ifs/naml-fix-subroutine-deep-clone
Browse files Browse the repository at this point in the history
Fix deep-cloning of subroutiens and modules (fix #174)
  • Loading branch information
reuterbal authored Oct 17, 2023
2 parents 03324ef + 31e5e43 commit af8ec9f
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 5 deletions.
7 changes: 4 additions & 3 deletions loki/program_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,12 +290,13 @@ def clone(self, **kwargs):
kwargs.setdefault('incomplete', self._incomplete)

# Rebuild IRs
rebuild = Transformer({}, rebuild_scopes=True)
if 'docstring' in kwargs:
kwargs['docstring'] = Transformer({}).visit(kwargs['docstring'])
kwargs['docstring'] = rebuild.visit(kwargs['docstring'])
if 'spec' in kwargs:
kwargs['spec'] = Transformer({}).visit(kwargs['spec'])
kwargs['spec'] = rebuild.visit(kwargs['spec'])
if 'contains' in kwargs:
kwargs['contains'] = Transformer({}).visit(kwargs['contains'])
kwargs['contains'] = rebuild.visit(kwargs['contains'])

# Rescope symbols if not explicitly disabled
kwargs.setdefault('rescope_symbols', True)
Expand Down
2 changes: 1 addition & 1 deletion loki/subroutine.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def clone(self, **kwargs):

# Rebuild body (other IR components are taken care of in super class)
if 'body' in kwargs:
kwargs['body'] = Transformer({}).visit(kwargs['body'])
kwargs['body'] = Transformer({}, rebuild_scopes=True).visit(kwargs['body'])

# Escalate to parent class
return super().clone(**kwargs)
Expand Down
51 changes: 50 additions & 1 deletion tests/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
OFP, OMNI, Module, Subroutine, VariableDeclaration, TypeDef, fexprgen,
BasicType, Assignment, FindNodes, FindInlineCalls, FindTypedSymbols,
Transformer, fgen, SymbolAttributes, Variable, Import, Section, Intrinsic,
Scalar, DeferredTypeSymbol
Scalar, DeferredTypeSymbol, FindVariables, SubstituteExpressions, Literal
)


Expand Down Expand Up @@ -529,6 +529,55 @@ def test_module_rescope_clone(frontend):
with pytest.raises(AttributeError):
fgen(other_module_copy)

@pytest.mark.parametrize('frontend', available_frontends(
xfail=[(OMNI, 'Parsing fails without dummy module provided')]
))
def test_module_deep_clone(frontend):
"""
Test the rescoping of variables in clone with nested scopes.
"""
fcode = """
module test_module_rescope_clone
use parkind1, only : jpim, jprb
implicit none
integer :: n
real :: array(n)
type my_type
real :: vector(n)
real :: matrix(n, n)
end type
end module test_module_rescope_clone
"""
module = Module.from_source(fcode, frontend=frontend)

# Deep-copy/clone the module
new_module = module.clone()

n = [v for v in FindVariables().visit(new_module.spec) if v.name == 'n'][0]
n_decl = FindNodes(VariableDeclaration).visit(new_module.spec)[0]

# Remove the declaration of `n` and replace it with `3`
new_module.spec = Transformer({n_decl: None}).visit(new_module.spec)
new_module.spec = SubstituteExpressions({n: Literal(3)}).visit(new_module.spec)

# Check the new module has been changed
assert len(FindNodes(VariableDeclaration).visit(new_module.spec)) == 1
new_type_decls = FindNodes(VariableDeclaration).visit(new_module['my_type'].body)
assert len(new_type_decls) == 2
assert new_type_decls[0].symbols[0] == 'vector(3)'
assert new_type_decls[1].symbols[0] == 'matrix(3, 3)'

# Check the old one has not changed
assert len(FindNodes(VariableDeclaration).visit(module.spec)) == 2
type_decls = FindNodes(VariableDeclaration).visit(module['my_type'].body)
assert len(type_decls) == 2
assert type_decls[0].symbols[0] == 'vector(n)'
assert type_decls[1].symbols[0] == 'matrix(n, n)'


@pytest.mark.parametrize('frontend', available_frontends())
def test_module_access_spec_none(frontend):
Expand Down
45 changes: 45 additions & 0 deletions tests/test_subroutine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2054,3 +2054,48 @@ def test_enrich_calls_explicit_interface(frontend):
# confirm that rescoping symbols has no effect
driver.rescope_symbols()
assert calls[0].routine is kernel


@pytest.mark.parametrize('frontend', available_frontends(
xfail=[(OMNI, 'OMNI cannot handle external type defs without source')]
))
def test_subroutine_deep_clone(frontend):
"""
Test that deep-cloning a subroutine actually ensures clean scope separation.
"""

fcode = """
subroutine myroutine(something)
use parkind1, only : jpim, jprb
implicit none
type(that_thing), intent(inout) :: something
real(kind=jprb) :: foo(something%n)
foo(:)=0.0_jprb
associate(thing=>something%else)
if (something%entirely%different) then
foo(:)=42.0_jprb
else
foo(:)=66.6_jprb
end if
end associate
end subroutine myroutine
"""
routine = Subroutine.from_source(fcode, frontend=frontend)

# Create a deep-copy of the routine
new_routine = routine.clone()

# Replace all assignments with dummy calls
map_nodes={}
for assign in FindNodes(Assignment).visit(new_routine.body):
map_nodes[assign] = CallStatement(
name=DeferredTypeSymbol(name='testcall'), arguments=(assign.lhs,), scope=new_routine
)
new_routine.body = Transformer(map_nodes).visit(new_routine.body)

# Ensure that the original copy of the routine remains unaffected
assert len(FindNodes(Assignment).visit(routine.body)) == 3
assert len(FindNodes(Assignment).visit(new_routine.body)) == 0

0 comments on commit af8ec9f

Please sign in to comment.