Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Additional fixes for IR #55

Merged
merged 10 commits into from
May 4, 2023
31 changes: 31 additions & 0 deletions loki/expression/expr_visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,8 +297,25 @@ def __init__(self, expr_map, invalidate_source=True, **kwargs):
invalidate_source=invalidate_source)

def visit_Expression(self, o, **kwargs):
"""
call :any:`SubstituteExpressionsMapper` for the given expression node
"""
if kwargs.get('recurse_to_declaration_attributes'):
return self.expr_mapper(o, recurse_to_declaration_attributes=True)
return self.expr_mapper(o)

def visit_Import(self, o, **kwargs):
"""
For :any:`Import` (as well as :any:`VariableDeclaration` and :any:`ProcedureDeclaration`)
we set ``recurse_to_declaration_attributes=True`` to make sure properties in the symbol
table are updated during dispatch to the expression mapper
"""
kwargs['recurse_to_declaration_attributes'] = True
return super().visit_Node(o, **kwargs)

visit_VariableDeclaration = visit_Import
visit_ProcedureDeclaration = visit_Import


class AttachScopes(Visitor):
"""
Expand Down Expand Up @@ -344,6 +361,8 @@ def visit_Expression(self, o, **kwargs):
"""
Dispatch :any:`AttachScopesMapper` for :any:`Expression` tree nodes
"""
if kwargs.get('recurse_to_declaration_attributes'):
return self.expr_mapper(o, scope=kwargs['scope'], recurse_to_declaration_attributes=True)
return self.expr_mapper(o, scope=kwargs['scope'])

def visit_list(self, o, **kwargs):
Expand All @@ -363,6 +382,18 @@ def visit_Node(self, o, **kwargs):
children = tuple(self.visit(i, **kwargs) for i in o.children)
return self._update(o, children)

def visit_Import(self, o, **kwargs):
"""
For :any:`Import` (as well as :any:`VariableDeclaration` and :any:`ProcedureDeclaration`)
we set ``recurse_to_declaration_attributes=True`` to make sure properties in the symbol
table are updated during dispatch to the expression mapper
"""
kwargs['recurse_to_declaration_attributes'] = True
return self.visit_Node(o, **kwargs)

visit_VariableDeclaration = visit_Import
visit_ProcedureDeclaration = visit_Import

def visit_Scope(self, o, **kwargs):
"""
Generic handler for :any:`Scope` objects
Expand Down
92 changes: 62 additions & 30 deletions loki/expression/mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,7 @@ def __init__(self, invalidate_source=True):
def __call__(self, expr, *args, **kwargs):
if expr is None:
return None
kwargs.setdefault('recurse_to_declaration_attributes', False)
new_expr = super().__call__(expr, *args, **kwargs)
if getattr(expr, 'source', None):
if isinstance(new_expr, tuple):
Expand Down Expand Up @@ -551,38 +552,63 @@ def map_int_literal(self, expr, *args, **kwargs):
map_float_literal = map_int_literal

def map_variable_symbol(self, expr, *args, **kwargs):
kind = self.rec(expr.type.kind, *args, **kwargs)
if kind is not expr.type.kind and expr.scope:
# Update symbol table entry for kind directly because with a scope attached
# it does not affect the outcome of expr.clone
expr.scope.symbol_attrs[expr.name] = expr.type.clone(kind=kind)

if expr.scope and expr.type.initial and expr.name == expr.type.initial:
# FIXME: This is a hack to work around situations where a constant
# symbol (from a parent scope) with the same name as the declared
# variable is used as initializer. This hands down the correct scope
# (in case this traversal is part of ``AttachScopesMapper``) and thus
# interrupts an otherwise infinite recursion (see LOKI-52).
_kwargs = kwargs.copy()
_kwargs['scope'] = expr.scope.parent
initial = self.rec(expr.type.initial, *args, **_kwargs)
else:
initial = self.rec(expr.type.initial, *args, **kwargs)
if initial is not expr.type.initial and expr.scope:
# Update symbol table entry for initial directly because with a scope attached
# it does not affect the outcome of expr.clone
expr.scope.symbol_attrs[expr.name] = expr.type.clone(initial=initial)

bind_names = self.rec(expr.type.bind_names, *args, **kwargs)
if not (bind_names is None or all(new is old for new, old in zip_longest(bind_names, expr.type.bind_names))):
# Update symbol table entry for bind_names directly because with a scope attached
# it does not affect the outcome of expr.clone
expr.scope.symbol_attrs[expr.name] = expr.type.clone(bind_names=as_tuple(bind_names))
# When updating declaration attributes, which are stored in the symbol table,
# we need to disable `recurse_to_declaration_attributes` to avoid infinite
# recursion because of the various ways that Fortran allows to use the declared
# symbol also inside the declaration expression
recurse_to_declaration_attributes = kwargs['recurse_to_declaration_attributes'] or expr.scope is None
kwargs['recurse_to_declaration_attributes'] = False

if recurse_to_declaration_attributes:
old_type = expr.type
kind = self.rec(old_type.kind, *args, **kwargs)

if expr.scope and expr.name == old_type.initial:
# FIXME: This is a hack to work around situations where a constant
# symbol (from a parent scope) with the same name as the declared
# variable is used as initializer. This hands down the correct scope
# (in case this traversal is part of ``AttachScopesMapper``) and thus
# interrupts an otherwise infinite recursion (see LOKI-52).
_kwargs = kwargs.copy()
_kwargs['scope'] = expr.scope.parent
initial = self.rec(old_type.initial, *args, **_kwargs)
else:
initial = self.rec(old_type.initial, *args, **kwargs)

if old_type.bind_names:
bind_names = ()
for bind_name in old_type.bind_names:
if bind_name == expr.name:
# FIXME: This is a hack to work around situations where an
# explicit interface is used with the same name as the
# type bound procedure. This hands down the correct scope.
_kwargs = kwargs.copy()
_kwargs['scope'] = expr.scope.parent
bind_names += (self.rec(bind_name, *args, **_kwargs),)
else:
bind_names += (self.rec(bind_name, *args, **kwargs),)
else:
bind_names = None

is_type_changed = (
kind is not old_type.kind or initial is not old_type.initial or
any(new is not old for new, old in zip_longest(as_tuple(bind_names), as_tuple(old_type.bind_names)))
)
if is_type_changed:
new_type = old_type.clone(kind=kind, initial=initial, bind_names=bind_names)
if expr.scope:
# Update symbol table entry
expr.scope.symbol_attrs[expr.name] = new_type

parent = self.rec(expr.parent, *args, **kwargs)
if parent is expr.parent and (kind is expr.type.kind or expr.scope):
if expr.scope is None:
if parent is expr.parent and not is_type_changed:
return expr
return expr.clone(parent=parent, type=new_type)

if parent is expr.parent:
return expr
return expr.clone(parent=parent, type=expr.type.clone(kind=kind))
return expr.clone(parent=parent)

map_deferred_type_symbol = map_variable_symbol
map_procedure_symbol = map_variable_symbol
Expand Down Expand Up @@ -611,7 +637,13 @@ def map_array(self, expr, *args, **kwargs):
# and make sure we don't loose the call parameters (aka dimensions)
return InlineCall(function=symbol.clone(parent=parent), parameters=dimensions)

shape = self.rec(symbol.type.shape, *args, **kwargs)
if kwargs['recurse_to_declaration_attributes']:
_kwargs = kwargs.copy()
_kwargs['recurse_to_declaration_attributes'] = False
shape = self.rec(symbol.type.shape, *args, **_kwargs)
else:
shape = symbol.type.shape

if (getattr(symbol, 'symbol', symbol) is expr.symbol and
all(d is orig_d for d, orig_d in zip_longest(dimensions or (), expr.dimensions or ())) and
all(d is orig_d for d, orig_d in zip_longest(shape or (), symbol.type.shape or ()))):
Expand Down
16 changes: 13 additions & 3 deletions loki/frontend/fparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1782,7 +1782,8 @@ def visit_Subroutine_Subprogram(self, o, **kwargs):
if return_type is not None:
routine.symbol_attrs[routine.name] = return_type
return_var = sym.Variable(name=routine.name, scope=routine)
return_var_decl = ir.VariableDeclaration(symbols=(return_var,))
decl_source = self.get_source(subroutine_stmt, source=None)
return_var_decl = ir.VariableDeclaration(symbols=(return_var,), source=decl_source)

decls = FindNodes((ir.VariableDeclaration, ir.ProcedureDeclaration)).visit(spec)
if not decls:
Expand Down Expand Up @@ -2099,8 +2100,17 @@ def visit_If_Construct(self, o, **kwargs):
else_if_stmt_index, else_if_stmts = zip(*else_if_stmts)
else:
else_if_stmt_index = ()
else_stmt = get_child(o, Fortran2003.Else_Stmt)
else_stmt_index = o.children.index(else_stmt) if else_stmt else end_if_stmt_index

# Note: we need to use here the same method as for else-if because finding Else_Stmt
# directly and checking its position via o.children.index may give the wrong result.
# This is because Else_Stmt may erronously compare equal to other node types.
# See https://github.com/stfc/fparser/issues/400
else_stmt = tuple((i, c) for i, c in enumerate(o.children) if isinstance(c, Fortran2003.Else_Stmt))
if else_stmt:
assert len(else_stmt) == 1
else_stmt_index, else_stmt = else_stmt[0]
else:
else_stmt_index = end_if_stmt_index
conditions = as_tuple(self.visit(c, **kwargs) for c in (if_then_stmt,) + else_if_stmts)
bodies = tuple(
tuple(flatten(as_tuple(self.visit(c, **kwargs) for c in o.children[start+1:stop])))
Expand Down
2 changes: 1 addition & 1 deletion loki/frontend/ofp.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,7 @@ def visit_specific_binding(self, o, **kwargs):
type=SymbolAttributes(ProcedureType(interface))
)

_type = interface.type
_type = interface.type.clone(bind_names=(interface,))

elif o.attrib['procedureName']:
# Binding provided (<bindingName> => <procedureName>)
Expand Down
18 changes: 11 additions & 7 deletions loki/frontend/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def inject_statement_functions(routine):
def create_stmt_func(assignment):
arguments = assignment.lhs.dimensions
variable = assignment.lhs.clone(dimensions=None)
return StatementFunction(variable, arguments, assignment.rhs, variable.type)
return StatementFunction(variable, arguments, assignment.rhs, variable.type, source=assignment.source)

def create_type(stmt_func):
name = str(stmt_func.variable)
Expand Down Expand Up @@ -253,15 +253,19 @@ def create_type(stmt_func):
if variable.name.lower() in stmt_funcs:
if isinstance(variable, Array):
parameters = variable.dimensions
expr_map_spec[variable] = InlineCall(variable.clone(dimensions=None), parameters=parameters)
expr_map_spec[variable] = InlineCall(
variable.clone(dimensions=None), parameters=parameters, source=variable.source
)
elif not isinstance(variable, ProcedureSymbol):
expr_map_spec[variable] = variable.clone()
expr_map_body = {}
for variable in FindVariables().visit(routine.body):
if variable.name.lower() in stmt_funcs:
if isinstance(variable, Array):
parameters = variable.dimensions
expr_map_body[variable] = InlineCall(variable.clone(dimensions=None), parameters=parameters)
expr_map_body[variable] = InlineCall(
variable.clone(dimensions=None), parameters=parameters, source=variable.source
)
elif not isinstance(variable, ProcedureSymbol):
expr_map_body[variable] = variable.clone()

Expand All @@ -271,15 +275,15 @@ def create_type(stmt_func):

# Apply transformer with the built maps
if spec_map:
routine.spec = Transformer(spec_map).visit(routine.spec)
routine.spec = Transformer(spec_map, invalidate_source=False).visit(routine.spec)
if body_map:
routine.body = Transformer(body_map).visit(routine.body)
routine.body = Transformer(body_map, invalidate_source=False).visit(routine.body)
if spec_appendix:
routine.spec.append(spec_appendix)
if expr_map_spec:
routine.spec = SubstituteExpressions(expr_map_spec).visit(routine.spec)
routine.spec = SubstituteExpressions(expr_map_spec, invalidate_source=False).visit(routine.spec)
if expr_map_body:
routine.body = SubstituteExpressions(expr_map_body).visit(routine.body)
routine.body = SubstituteExpressions(expr_map_body, invalidate_source=False).visit(routine.body)

# And make sure all symbols have the right type
routine.rescope_symbols()
Expand Down
2 changes: 1 addition & 1 deletion loki/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
'Comment', 'CommentBlock', 'Pragma', 'PreprocessorDirective',
'Import', 'VariableDeclaration', 'ProcedureDeclaration', 'DataDeclaration',
'StatementFunction', 'TypeDef', 'MultiConditional', 'MaskedStatement',
'Intrinsic', 'Enumeration', 'RawSource'
'Intrinsic', 'Enumeration', 'RawSource',
]

# Configuration for validation mechanism via pydantic
Expand Down
2 changes: 1 addition & 1 deletion loki/transform/transform_associates.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class ResolveAssociatesTransformer(Transformer):
corresponding `selector` expression defined in ``associations``.
"""

def visit_Associate(self, o):
def visit_Associate(self, o, **kwargs): # pylint: disable=unused-argument
# First head-recurse, so that all associate blocks beneath are resolved
body = self.visit(o.body)

Expand Down
1 change: 0 additions & 1 deletion loki/visitors/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,6 @@ class NestedMaskedTransformer(MaskedTransformer):

# Handler for leaf nodes


def visit_object(self, o, **kwargs):
"""
Return the object unchanged.
Expand Down
44 changes: 43 additions & 1 deletion tests/test_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy as np

from conftest import jit_compile, clean_test, available_frontends
from loki import OMNI, Subroutine, FindNodes, Loop, Conditional, Node
from loki import OMNI, Subroutine, FindNodes, Loop, Conditional, Node, Intrinsic


@pytest.fixture(scope='module', name='here')
Expand Down Expand Up @@ -455,3 +455,45 @@ def test_conditional_bodies(frontend):
c.else_body and isinstance(c.else_body, tuple) and all(isinstance(n, Node) for n in c.else_body)
for c in conditionals
)


@pytest.mark.parametrize('frontend', available_frontends())
def test_conditional_else_body_return(frontend):
fcode = """
FUNCTION FUNC(PX,KN)
IMPLICIT NONE
INTEGER,INTENT(INOUT) :: KN
REAL,INTENT(IN) :: PX
REAL :: FUNC
INTEGER :: J
REAL :: Z0, Z1, Z2
Z0= 1.0
Z1= PX
IF (KN == 0) THEN
FUNC= Z0
RETURN
ELSEIF (KN == 1) THEN
FUNC= Z1
RETURN
ELSE
DO J=2,KN
Z2= Z0+Z1
Z0= Z1
Z1= Z2
ENDDO
FUNC= Z2
RETURN
ENDIF
END FUNCTION FUNC
""".strip()

routine = Subroutine.from_source(fcode, frontend=frontend)
conditionals = FindNodes(Conditional).visit(routine.body)
assert len(conditionals) == 2
assert isinstance(conditionals[0].body[-1], Intrinsic)
assert conditionals[0].body[-1].text.upper() == 'RETURN'
assert conditionals[0].else_body == (conditionals[1],)
assert isinstance(conditionals[1].body[-1], Intrinsic)
assert conditionals[1].body[-1].text.upper() == 'RETURN'
assert isinstance(conditionals[1].else_body[-1], Intrinsic)
assert conditionals[1].else_body[-1].text.upper() == 'RETURN'
39 changes: 39 additions & 0 deletions tests/test_derived_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1359,3 +1359,42 @@ def test_derived_types_nested_type(frontend):
assert assignment.rhs.parent.type.dtype.typedef is module['some_type']
assert assignment.rhs.parent.parent.type.dtype.name == 'other_type'
assert assignment.rhs.parent.parent.type.dtype.typedef is module['other_type']


@pytest.mark.parametrize('frontend', available_frontends())
def test_derived_types_abstract_deferred_procedure(frontend):
fcode = """
module some_mod
implicit none
type, abstract :: abstract_type
contains
procedure (some_proc), deferred :: some_proc
procedure (other_proc), deferred :: other_proc
end type abstract_type

abstract interface
subroutine some_proc(this)
import abstract_type
class(abstract_type), intent(in) :: this
end subroutine some_proc
end interface

abstract interface
subroutine other_proc(this)
import abstract_type
class(abstract_type), intent(inout) :: this
end subroutine other_proc
end interface
end module some_mod
""".strip()
module = Module.from_source(fcode, frontend=frontend)
typedef = module['abstract_type']
assert typedef.abstract is True
assert typedef.variables == ('some_proc', 'other_proc')
for symbol in typedef.variables:
assert isinstance(symbol, ProcedureSymbol)
assert isinstance(symbol.type.dtype, ProcedureType)
assert symbol.type.dtype.name.lower() == symbol.name.lower()
assert symbol.type.bind_names == (symbol,)
assert symbol.scope is typedef
assert symbol.type.bind_names[0].scope is module
Loading