Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ProgramUnit.resolve_typebound_var: raise error if top-level parent is not declared #325

Merged
merged 6 commits into from
Jun 14, 2024
Merged
2 changes: 2 additions & 0 deletions loki/expression/symbols.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,8 @@ def get_derived_type_member(self, name_str):
Resolve type-bound variables of arbitrary nested depth.
"""
name_parts = name_str.split('%', maxsplit=1)
if self.type.dtype is not BasicType.DEFERRED and self.type.dtype.typedef is not BasicType.DEFERRED:
assert self.type.dtype.typedef.variable_map[name_parts[0]]
declared_var = Variable(name=f'{self.name}%{name_parts[0]}', scope=self.scope, parent=self)
if len(name_parts) > 1:
return declared_var.get_derived_type_member(name_parts[1]) # pylint:disable=no-member
Expand Down
4 changes: 3 additions & 1 deletion loki/frontend/fparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2120,7 +2120,9 @@ def visit_Module_Stmt(self, o, **kwargs):
if module_type and module_type.dtype.module != BasicType.DEFERRED:
return module_type.dtype.module

return Module(name=name, parent=kwargs['scope'])
module = Module(name=name, parent=kwargs['scope'])
self.definitions[name] = module
return module

visit_Module_Name = visit_Name

Expand Down
4 changes: 3 additions & 1 deletion loki/frontend/ofp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1253,7 +1253,9 @@ def _create_Module_object(self, o, scope):
if module_type and module_type.dtype.module != BasicType.DEFERRED:
return module_type.dtype.module

return Module(name=name, parent=scope)
module = Module(name=name, parent=scope)
self.definitions[name] = module
return module

def visit_module(self, o, **kwargs):
# Extract known sections
Expand Down
5 changes: 3 additions & 2 deletions loki/frontend/omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,8 +465,9 @@ def _create_Module_object(self, o, scope):
if module_type and module_type.dtype.module != BasicType.DEFERRED:
return module_type.dtype.module

return Module(name=name, parent=scope)

module = Module(name=name, parent=scope)
self.definitions[name] = module
return module

def visit_FmoduleDefinition(self, o, **kwargs):
# Update the symbol map with local entries
Expand Down
3 changes: 2 additions & 1 deletion loki/program_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,6 +822,7 @@ def resolve_typebound_var(self, name, variable_map=None):
_variable_map = self.variable_map

name_parts = name.split('%', maxsplit=1)
if (var := _variable_map.get(name_parts[0], None)) and len(name_parts) > 1:
var = _variable_map[name_parts[0]]
if len(name_parts) > 1:
var = var.get_derived_type_member(name_parts[1])
return var
50 changes: 50 additions & 0 deletions loki/tests/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)
from loki.build import jit_compile, clean_test
from loki.frontend import available_frontends, OFP, OMNI
from loki.sourcefile import Sourcefile


@pytest.fixture(scope='module', name='here')
Expand Down Expand Up @@ -1320,3 +1321,52 @@ def test_module_all_imports(frontend):
assert routine.symbol_map['b'].type.module is header_a
assert routine_mod.symbol_map['b_b'].type.module is header_b
assert routine_mod.symbol_map['b_b'].type.use_name == 'b'


@pytest.mark.parametrize('frontend', available_frontends())
def test_module_enrichment_within_file(frontend, tmp_path):
fcode = """
module foo
implicit none
integer, parameter :: j = 16

contains
integer function SUM(v)
implicit none
integer, intent(in) :: v
SUM = v + 1
end function SUM
end module foo

module test
use foo
implicit none
integer, parameter :: rk = selected_real_kind(12)
integer, parameter :: ik = selected_int_kind(9)
contains
subroutine calc (n, res)
integer, intent(in) :: n
real(kind=rk), intent(inout) :: res
integer(kind=ik) :: i
do i = 1, n
res = res + SUM(j)
end do
end subroutine calc
end module test
"""

source = Sourcefile.from_source(fcode, frontend=frontend, xmods=[tmp_path])
routine = source['calc']
calls = list(FindInlineCalls().visit(routine.body))
assert len(calls) == 1
assert calls[0].function == 'sum'
assert calls[0].function.type.imported
assert calls[0].function.type.module is source['foo']
assert calls[0].function.type.dtype.procedure is source['sum']
if frontend != OMNI:
# OMNI inlines parameters
assert calls[0].arguments[0].type.dtype == BasicType.INTEGER
assert calls[0].arguments[0].type.imported
assert calls[0].arguments[0].type.parameter
assert calls[0].arguments[0].type.initial == 16
assert calls[0].arguments[0].type.module is source['foo']
100 changes: 100 additions & 0 deletions loki/tests/test_subroutine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2168,6 +2168,7 @@ def test_subroutine_deep_clone(frontend):
assert len(FindNodes(ir.Assignment).visit(routine.body)) == 3
assert len(FindNodes(ir.Assignment).visit(new_routine.body)) == 0


@pytest.mark.parametrize('frontend', available_frontends())
def test_call_args_kwargs_conversion(frontend):

Expand Down Expand Up @@ -2238,3 +2239,102 @@ def test_call_args_kwargs_conversion(frontend):
for call in FindNodes(ir.CallStatement).visit(driver.body):
assert tuple(arg.name for arg in call.arguments) == call_args
assert call.kwarguments == ()


@pytest.mark.parametrize('frontend', available_frontends())
def test_resolve_typebound_var(frontend, tmp_path):
"""
Test correct behaviour of :any:`Scope.resolve_typebound_var` utility
"""
fcode = """
module header_mod
implicit none
type some_type
integer :: ival
end type some_type

type other_type
type(some_type) :: other
end type other_type

type third_type
type(other_type) :: some
end type third_type
end module header_mod

subroutine some_routine
use header_mod, only: third_type
implicit none
type(third_type) :: tt
end subroutine
""".strip()

source = Sourcefile.from_source(fcode, frontend=frontend, xmods=[tmp_path])
routine = source['some_routine']

tt_some = routine.resolve_typebound_var('tt%some')
assert tt_some == 'tt%some'
assert tt_some.type.dtype.name == 'other_type'
assert tt_some.type.dtype.typedef is source['header_mod']['other_type']

tt_some_other_ival = routine.resolve_typebound_var('tt%some%other%ival')
assert tt_some_other_ival == 'tt%some%other%ival'
assert tt_some_other_ival.type.dtype == BasicType.INTEGER
assert tt_some_other_ival.parent.type.dtype.name == 'some_type'
assert tt_some_other_ival.parent.type.dtype.typedef is source['header_mod']['some_type']

tt = routine.resolve_typebound_var('tt')
assert tt == 'tt'
assert tt.type.dtype.name == 'third_type'
assert tt.type.dtype.typedef is source['header_mod']['third_type']

# This throws an error as the type definition is available and therefore
# the invalid member can be deduced
with pytest.raises(KeyError):
routine.resolve_typebound_var('tt%invalid%val')

with pytest.raises(KeyError):
routine.resolve_typebound_var('tt%some%invalid')

# This throws errors as resolving derived type members for
# non-declared derived types should not be possible
with pytest.raises(KeyError):
routine.resolve_typebound_var('not_tt%invalid')

with pytest.raises(KeyError):
routine.resolve_typebound_var('not_a_var')

# Instead, we can creatae a deferred type variable in the scope and
# resolve members relative to it
not_tt = Variable(name='not_tt', scope=routine)
assert not_tt.type.dtype == BasicType.DEFERRED
not_tt_invalid = not_tt.get_derived_type_member('invalid')
assert not_tt_invalid == 'not_tt%invalid'
assert not_tt_invalid.type.dtype == BasicType.DEFERRED


@pytest.mark.parametrize('frontend', available_frontends(
xfail=[(OMNI, 'Parsing fails with no header information available')]
))
def test_resolve_typebound_var_missing_definition(frontend, tmp_path):
"""
Test correct behaviour of :any:`Scope.resolve_typebound_var` utility
in the absence of type information
"""
fcode = """
subroutine some_routine
use header_mod, only: third_type
implicit none
type(third_type) :: tt
end subroutine
""".strip()

source = Sourcefile.from_source(fcode, frontend=frontend, xmods=[tmp_path])
routine = source['some_routine']

# This does not throw an error as the use-case of incomplete type definitions
# may well require working with incomplete type definitions
tt_invalid_val = routine.resolve_typebound_var('tt%invalid%val')
assert tt_invalid_val == 'tt%invalid%val'
assert tt_invalid_val.type.dtype == BasicType.DEFERRED
assert tt_invalid_val.parent.type.dtype == BasicType.DEFERRED
12 changes: 10 additions & 2 deletions loki/transformations/single_column/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
symbols as sym, FindExpressions, SubstituteExpressions
)
from loki.ir import nodes as ir, FindNodes, Transformer
from loki.logging import debug
from loki.tools import as_tuple
from loki.types import SymbolAttributes, BasicType

Expand Down Expand Up @@ -180,8 +181,15 @@ def resolve_vector_dimension(cls, routine, loop_variable, bounds):
bounds_str = f'{bounds[0]}:{bounds[1]}'

variable_map = routine.variable_map
bounds_v = (routine.resolve_typebound_var(bounds[0], variable_map),
routine.resolve_typebound_var(bounds[1], variable_map))
try:
bounds_v = (routine.resolve_typebound_var(bounds[0], variable_map),
routine.resolve_typebound_var(bounds[1], variable_map))
except KeyError:
debug(
'SCCBaseTransformation.resolve_vector_dimension: '
f'Dimension bound {bounds[0]} or {bounds[1]} not found in {routine.name}.'
)
return

mapper = {}
for stmt in FindNodes(ir.Assignment).visit(routine.body):
Expand Down
18 changes: 9 additions & 9 deletions loki/transformations/tests/test_scc_cuf.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

@pytest.fixture(scope='module', name='horizontal')
def fixture_horizontal():
return Dimension(name='horizontal', size='nlon', index='jl', bounds=('start', 'end'))
return Dimension(name='horizontal', size='nlon', index='jl', bounds=('start', 'iend'))


@pytest.fixture(scope='module', name='vertical')
Expand Down Expand Up @@ -176,11 +176,11 @@ def test_scc_cuf_simple(frontend, horizontal, vertical, blocking):
REAL, INTENT(INOUT) :: t(nlon,nz,nb)
REAL, INTENT(INOUT) :: q(nlon,nz,nb)
REAL, INTENT(INOUT) :: z(nlon,nz+1,nb)
INTEGER :: b, start, end, ibl, icend
INTEGER :: b, start, iend, ibl, icend

start = 1
end = tot
do b=1,end,nlon
iend = tot
do b=1,iend,nlon
ibl = (b-1)/nlon+1
icend = MIN(nlon,tot-b+1)
call kernel(start, icend, nlon, nz, q(:,:,b), t(:,:,b), z(:,:,b))
Expand All @@ -189,8 +189,8 @@ def test_scc_cuf_simple(frontend, horizontal, vertical, blocking):
"""

fcode_kernel = """
SUBROUTINE kernel(start, end, nlon, nz, q, t, z)
INTEGER, INTENT(IN) :: start, end ! Iteration indices
SUBROUTINE kernel(start, iend, nlon, nz, q, t, z)
INTEGER, INTENT(IN) :: start, iend ! Iteration indices
INTEGER, INTENT(IN) :: nlon, nz ! Size of the horizontal and vertical
REAL, INTENT(INOUT) :: t(nlon,nz)
REAL, INTENT(INOUT) :: q(nlon,nz)
Expand All @@ -200,19 +200,19 @@ def test_scc_cuf_simple(frontend, horizontal, vertical, blocking):

c = 5.345
DO jk = 2, nz
DO jl = start, end
DO jl = start, iend
t(jl, jk) = c * jk
q(jl, jk) = q(jl, jk-1) + t(jl, jk) * c
END DO
END DO

DO jk = 2, nz
DO jl = start, end
DO jl = start, iend
z(jl, jk) = 0.0
END DO
END DO

! DO JL = START, END
! DO JL = START, IEND
! Q(JL, NZ) = Q(JL, NZ) * C
! END DO
END SUBROUTINE kernel
Expand Down
Loading