Skip to content

Commit

Permalink
Merge pull request #325 from ecmwf-ifs/307-resolve_typebound_var-warning
Browse files Browse the repository at this point in the history
ProgramUnit.resolve_typebound_var: raise error if top-level parent is not declared
  • Loading branch information
mlange05 authored Jun 14, 2024
2 parents 181f6b2 + e9a2443 commit 3bc2f14
Show file tree
Hide file tree
Showing 9 changed files with 182 additions and 16 deletions.
2 changes: 2 additions & 0 deletions loki/expression/symbols.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,8 @@ def get_derived_type_member(self, name_str):
Resolve type-bound variables of arbitrary nested depth.
"""
name_parts = name_str.split('%', maxsplit=1)
if self.type.dtype is not BasicType.DEFERRED and self.type.dtype.typedef is not BasicType.DEFERRED:
assert self.type.dtype.typedef.variable_map[name_parts[0]]
declared_var = Variable(name=f'{self.name}%{name_parts[0]}', scope=self.scope, parent=self)
if len(name_parts) > 1:
return declared_var.get_derived_type_member(name_parts[1]) # pylint:disable=no-member
Expand Down
4 changes: 3 additions & 1 deletion loki/frontend/fparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2120,7 +2120,9 @@ def visit_Module_Stmt(self, o, **kwargs):
if module_type and module_type.dtype.module != BasicType.DEFERRED:
return module_type.dtype.module

return Module(name=name, parent=kwargs['scope'])
module = Module(name=name, parent=kwargs['scope'])
self.definitions[name] = module
return module

visit_Module_Name = visit_Name

Expand Down
4 changes: 3 additions & 1 deletion loki/frontend/ofp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1298,7 +1298,9 @@ def _create_Module_object(self, o, scope):
if module_type and module_type.dtype.module != BasicType.DEFERRED:
return module_type.dtype.module

return Module(name=name, parent=scope)
module = Module(name=name, parent=scope)
self.definitions[name] = module
return module

def visit_module(self, o, **kwargs):
# Extract known sections
Expand Down
5 changes: 3 additions & 2 deletions loki/frontend/omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,8 +465,9 @@ def _create_Module_object(self, o, scope):
if module_type and module_type.dtype.module != BasicType.DEFERRED:
return module_type.dtype.module

return Module(name=name, parent=scope)

module = Module(name=name, parent=scope)
self.definitions[name] = module
return module

def visit_FmoduleDefinition(self, o, **kwargs):
# Update the symbol map with local entries
Expand Down
3 changes: 2 additions & 1 deletion loki/program_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,6 +822,7 @@ def resolve_typebound_var(self, name, variable_map=None):
_variable_map = self.variable_map

name_parts = name.split('%', maxsplit=1)
if (var := _variable_map.get(name_parts[0], None)) and len(name_parts) > 1:
var = _variable_map[name_parts[0]]
if len(name_parts) > 1:
var = var.get_derived_type_member(name_parts[1])
return var
50 changes: 50 additions & 0 deletions loki/tests/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)
from loki.build import jit_compile, clean_test
from loki.frontend import available_frontends, OFP, OMNI
from loki.sourcefile import Sourcefile


@pytest.fixture(scope='module', name='here')
Expand Down Expand Up @@ -1320,3 +1321,52 @@ def test_module_all_imports(frontend):
assert routine.symbol_map['b'].type.module is header_a
assert routine_mod.symbol_map['b_b'].type.module is header_b
assert routine_mod.symbol_map['b_b'].type.use_name == 'b'


@pytest.mark.parametrize('frontend', available_frontends())
def test_module_enrichment_within_file(frontend, tmp_path):
fcode = """
module foo
implicit none
integer, parameter :: j = 16
contains
integer function SUM(v)
implicit none
integer, intent(in) :: v
SUM = v + 1
end function SUM
end module foo
module test
use foo
implicit none
integer, parameter :: rk = selected_real_kind(12)
integer, parameter :: ik = selected_int_kind(9)
contains
subroutine calc (n, res)
integer, intent(in) :: n
real(kind=rk), intent(inout) :: res
integer(kind=ik) :: i
do i = 1, n
res = res + SUM(j)
end do
end subroutine calc
end module test
"""

source = Sourcefile.from_source(fcode, frontend=frontend, xmods=[tmp_path])
routine = source['calc']
calls = list(FindInlineCalls().visit(routine.body))
assert len(calls) == 1
assert calls[0].function == 'sum'
assert calls[0].function.type.imported
assert calls[0].function.type.module is source['foo']
assert calls[0].function.type.dtype.procedure is source['sum']
if frontend != OMNI:
# OMNI inlines parameters
assert calls[0].arguments[0].type.dtype == BasicType.INTEGER
assert calls[0].arguments[0].type.imported
assert calls[0].arguments[0].type.parameter
assert calls[0].arguments[0].type.initial == 16
assert calls[0].arguments[0].type.module is source['foo']
100 changes: 100 additions & 0 deletions loki/tests/test_subroutine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2168,6 +2168,7 @@ def test_subroutine_deep_clone(frontend):
assert len(FindNodes(ir.Assignment).visit(routine.body)) == 3
assert len(FindNodes(ir.Assignment).visit(new_routine.body)) == 0


@pytest.mark.parametrize('frontend', available_frontends())
def test_call_args_kwargs_conversion(frontend):

Expand Down Expand Up @@ -2238,3 +2239,102 @@ def test_call_args_kwargs_conversion(frontend):
for call in FindNodes(ir.CallStatement).visit(driver.body):
assert tuple(arg.name for arg in call.arguments) == call_args
assert call.kwarguments == ()


@pytest.mark.parametrize('frontend', available_frontends())
def test_resolve_typebound_var(frontend, tmp_path):
"""
Test correct behaviour of :any:`Scope.resolve_typebound_var` utility
"""
fcode = """
module header_mod
implicit none
type some_type
integer :: ival
end type some_type
type other_type
type(some_type) :: other
end type other_type
type third_type
type(other_type) :: some
end type third_type
end module header_mod
subroutine some_routine
use header_mod, only: third_type
implicit none
type(third_type) :: tt
end subroutine
""".strip()

source = Sourcefile.from_source(fcode, frontend=frontend, xmods=[tmp_path])
routine = source['some_routine']

tt_some = routine.resolve_typebound_var('tt%some')
assert tt_some == 'tt%some'
assert tt_some.type.dtype.name == 'other_type'
assert tt_some.type.dtype.typedef is source['header_mod']['other_type']

tt_some_other_ival = routine.resolve_typebound_var('tt%some%other%ival')
assert tt_some_other_ival == 'tt%some%other%ival'
assert tt_some_other_ival.type.dtype == BasicType.INTEGER
assert tt_some_other_ival.parent.type.dtype.name == 'some_type'
assert tt_some_other_ival.parent.type.dtype.typedef is source['header_mod']['some_type']

tt = routine.resolve_typebound_var('tt')
assert tt == 'tt'
assert tt.type.dtype.name == 'third_type'
assert tt.type.dtype.typedef is source['header_mod']['third_type']

# This throws an error as the type definition is available and therefore
# the invalid member can be deduced
with pytest.raises(KeyError):
routine.resolve_typebound_var('tt%invalid%val')

with pytest.raises(KeyError):
routine.resolve_typebound_var('tt%some%invalid')

# This throws errors as resolving derived type members for
# non-declared derived types should not be possible
with pytest.raises(KeyError):
routine.resolve_typebound_var('not_tt%invalid')

with pytest.raises(KeyError):
routine.resolve_typebound_var('not_a_var')

# Instead, we can creatae a deferred type variable in the scope and
# resolve members relative to it
not_tt = Variable(name='not_tt', scope=routine)
assert not_tt.type.dtype == BasicType.DEFERRED
not_tt_invalid = not_tt.get_derived_type_member('invalid')
assert not_tt_invalid == 'not_tt%invalid'
assert not_tt_invalid.type.dtype == BasicType.DEFERRED


@pytest.mark.parametrize('frontend', available_frontends(
xfail=[(OMNI, 'Parsing fails with no header information available')]
))
def test_resolve_typebound_var_missing_definition(frontend, tmp_path):
"""
Test correct behaviour of :any:`Scope.resolve_typebound_var` utility
in the absence of type information
"""
fcode = """
subroutine some_routine
use header_mod, only: third_type
implicit none
type(third_type) :: tt
end subroutine
""".strip()

source = Sourcefile.from_source(fcode, frontend=frontend, xmods=[tmp_path])
routine = source['some_routine']

# This does not throw an error as the use-case of incomplete type definitions
# may well require working with incomplete type definitions
tt_invalid_val = routine.resolve_typebound_var('tt%invalid%val')
assert tt_invalid_val == 'tt%invalid%val'
assert tt_invalid_val.type.dtype == BasicType.DEFERRED
assert tt_invalid_val.parent.type.dtype == BasicType.DEFERRED
12 changes: 10 additions & 2 deletions loki/transformations/single_column/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
symbols as sym, FindExpressions, SubstituteExpressions
)
from loki.ir import nodes as ir, FindNodes, Transformer
from loki.logging import debug
from loki.tools import as_tuple
from loki.types import SymbolAttributes, BasicType

Expand Down Expand Up @@ -180,8 +181,15 @@ def resolve_vector_dimension(cls, routine, loop_variable, bounds):
bounds_str = f'{bounds[0]}:{bounds[1]}'

variable_map = routine.variable_map
bounds_v = (routine.resolve_typebound_var(bounds[0], variable_map),
routine.resolve_typebound_var(bounds[1], variable_map))
try:
bounds_v = (routine.resolve_typebound_var(bounds[0], variable_map),
routine.resolve_typebound_var(bounds[1], variable_map))
except KeyError:
debug(
'SCCBaseTransformation.resolve_vector_dimension: '
f'Dimension bound {bounds[0]} or {bounds[1]} not found in {routine.name}.'
)
return

mapper = {}
for stmt in FindNodes(ir.Assignment).visit(routine.body):
Expand Down
18 changes: 9 additions & 9 deletions loki/transformations/tests/test_scc_cuf.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

@pytest.fixture(scope='module', name='horizontal')
def fixture_horizontal():
return Dimension(name='horizontal', size='nlon', index='jl', bounds=('start', 'end'))
return Dimension(name='horizontal', size='nlon', index='jl', bounds=('start', 'iend'))


@pytest.fixture(scope='module', name='vertical')
Expand Down Expand Up @@ -176,11 +176,11 @@ def test_scc_cuf_simple(frontend, horizontal, vertical, blocking):
REAL, INTENT(INOUT) :: t(nlon,nz,nb)
REAL, INTENT(INOUT) :: q(nlon,nz,nb)
REAL, INTENT(INOUT) :: z(nlon,nz+1,nb)
INTEGER :: b, start, end, ibl, icend
INTEGER :: b, start, iend, ibl, icend
start = 1
end = tot
do b=1,end,nlon
iend = tot
do b=1,iend,nlon
ibl = (b-1)/nlon+1
icend = MIN(nlon,tot-b+1)
call kernel(start, icend, nlon, nz, q(:,:,b), t(:,:,b), z(:,:,b))
Expand All @@ -189,8 +189,8 @@ def test_scc_cuf_simple(frontend, horizontal, vertical, blocking):
"""

fcode_kernel = """
SUBROUTINE kernel(start, end, nlon, nz, q, t, z)
INTEGER, INTENT(IN) :: start, end ! Iteration indices
SUBROUTINE kernel(start, iend, nlon, nz, q, t, z)
INTEGER, INTENT(IN) :: start, iend ! Iteration indices
INTEGER, INTENT(IN) :: nlon, nz ! Size of the horizontal and vertical
REAL, INTENT(INOUT) :: t(nlon,nz)
REAL, INTENT(INOUT) :: q(nlon,nz)
Expand All @@ -200,19 +200,19 @@ def test_scc_cuf_simple(frontend, horizontal, vertical, blocking):
c = 5.345
DO jk = 2, nz
DO jl = start, end
DO jl = start, iend
t(jl, jk) = c * jk
q(jl, jk) = q(jl, jk-1) + t(jl, jk) * c
END DO
END DO
DO jk = 2, nz
DO jl = start, end
DO jl = start, iend
z(jl, jk) = 0.0
END DO
END DO
! DO JL = START, END
! DO JL = START, IEND
! Q(JL, NZ) = Q(JL, NZ) * C
! END DO
END SUBROUTINE kernel
Expand Down

0 comments on commit 3bc2f14

Please sign in to comment.