Skip to content

Commit

Permalink
STMT FUNC: allow other variables to be declared alongside statement f…
Browse files Browse the repository at this point in the history
…unction return var
  • Loading branch information
awnawab committed Nov 21, 2023
1 parent fc0d872 commit c1881b8
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 10 deletions.
15 changes: 10 additions & 5 deletions loki/backend/fgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,20 +328,25 @@ def visit_VariableDeclaration(self, o, **kwargs):
# the symbol has a known derived type
ignore = ['shape', 'dimensions', 'variables', 'source', 'initial']

if isinstance(types[0].dtype, ProcedureType):
# Statement functions can share declarations with scalars, so we collect the variable types here
_var_types = [t.dtype.return_type.dtype if isinstance(t.dtype, ProcedureType) else t.dtype for t in types]
_procedure_types = [t for t in types if isinstance(t.dtype, ProcedureType)]

if len(_procedure_types) > 0:
# Statement functions are the only symbol with ProcedureType that should appear
# in a VariableDeclaration as all other forms of procedure declarations (bindings,
# pointers, EXTERNAL statements) are handled by ProcedureDeclaration.
# However, the fact that statement function declarations can appear mixed with actual
# variable declarations forbids this in this case.
assert types[0].is_stmt_func
assert _procedure_types[0].is_stmt_func
# TODO: We can't fully compare statement functions, yet but we can make at least sure
# other declared attributes are compatible and that all have the same return type
ignore += ['dtype']
assert all(t.dtype.return_type == types[0].dtype.return_type or
t.dtype.return_type.compare(types[0].dtype.return_type, ignore=ignore) for t in types)
assert all(t.dtype.return_type == _procedure_types[0].dtype.return_type or
t.dtype.return_type.compare(_procedure_types[0].dtype.return_type, ignore=ignore)
for t in _procedure_types)

assert all(t.compare(types[0], ignore=ignore) for t in types)
assert all((t == _var_types[0]) for t in _var_types)

is_function = isinstance(types[0].dtype, ProcedureType) and types[0].dtype.is_function
if is_function:
Expand Down
7 changes: 4 additions & 3 deletions loki/frontend/fparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1812,11 +1812,12 @@ def visit_Subroutine_Subprogram(self, o, **kwargs):
rescope_symbols=True, source=source, incomplete=False
)

# Once statement functions are in place, we need to update the original declaration symbol
# Once statement functions are in place, we need to update the original declaration so that it
# contains ProcedureSymbols rather than Scalars
for decl in FindNodes(ir.VariableDeclaration).visit(spec):
if any(routine.symbol_attrs[s.name].is_stmt_func for s in decl.symbols):
assert all(routine.symbol_attrs[s.name].is_stmt_func for s in decl.symbols)
decl._update(symbols=tuple(s.clone() for s in decl.symbols))
decl._update(symbols=tuple(s.clone() if routine.symbol_attrs[s.name].is_stmt_func else s
for s in decl.symbols))

# Big, but necessary hack:
# For deferred array dimensions on allocatables, we infer the conceptual
Expand Down
3 changes: 1 addition & 2 deletions tests/test_subroutine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1493,8 +1493,7 @@ def test_subroutine_stmt_func(here, frontend):
integer, intent(in) :: a
integer, intent(out) :: b
integer :: array(a)
integer :: i, j
integer :: plus, minus
integer :: i, j, plus, minus
plus(i, j) = i + j
minus(i, j) = i - j
integer :: mult
Expand Down

0 comments on commit c1881b8

Please sign in to comment.