Skip to content

Commit

Permalink
Merge pull request #55 from ecmwf-ifs/nabr-linter-fixes
Browse files Browse the repository at this point in the history
Additional fixes for IR
  • Loading branch information
mlange05 authored May 4, 2023
2 parents c1dd882 + 9cef8f4 commit 14c6c82
Show file tree
Hide file tree
Showing 13 changed files with 281 additions and 46 deletions.
31 changes: 31 additions & 0 deletions loki/expression/expr_visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,8 +297,25 @@ def __init__(self, expr_map, invalidate_source=True, **kwargs):
invalidate_source=invalidate_source)

def visit_Expression(self, o, **kwargs):
"""
call :any:`SubstituteExpressionsMapper` for the given expression node
"""
if kwargs.get('recurse_to_declaration_attributes'):
return self.expr_mapper(o, recurse_to_declaration_attributes=True)
return self.expr_mapper(o)

def visit_Import(self, o, **kwargs):
"""
For :any:`Import` (as well as :any:`VariableDeclaration` and :any:`ProcedureDeclaration`)
we set ``recurse_to_declaration_attributes=True`` to make sure properties in the symbol
table are updated during dispatch to the expression mapper
"""
kwargs['recurse_to_declaration_attributes'] = True
return super().visit_Node(o, **kwargs)

visit_VariableDeclaration = visit_Import
visit_ProcedureDeclaration = visit_Import


class AttachScopes(Visitor):
"""
Expand Down Expand Up @@ -344,6 +361,8 @@ def visit_Expression(self, o, **kwargs):
"""
Dispatch :any:`AttachScopesMapper` for :any:`Expression` tree nodes
"""
if kwargs.get('recurse_to_declaration_attributes'):
return self.expr_mapper(o, scope=kwargs['scope'], recurse_to_declaration_attributes=True)
return self.expr_mapper(o, scope=kwargs['scope'])

def visit_list(self, o, **kwargs):
Expand All @@ -363,6 +382,18 @@ def visit_Node(self, o, **kwargs):
children = tuple(self.visit(i, **kwargs) for i in o.children)
return self._update(o, children)

def visit_Import(self, o, **kwargs):
"""
For :any:`Import` (as well as :any:`VariableDeclaration` and :any:`ProcedureDeclaration`)
we set ``recurse_to_declaration_attributes=True`` to make sure properties in the symbol
table are updated during dispatch to the expression mapper
"""
kwargs['recurse_to_declaration_attributes'] = True
return self.visit_Node(o, **kwargs)

visit_VariableDeclaration = visit_Import
visit_ProcedureDeclaration = visit_Import

def visit_Scope(self, o, **kwargs):
"""
Generic handler for :any:`Scope` objects
Expand Down
92 changes: 62 additions & 30 deletions loki/expression/mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,7 @@ def __init__(self, invalidate_source=True):
def __call__(self, expr, *args, **kwargs):
if expr is None:
return None
kwargs.setdefault('recurse_to_declaration_attributes', False)
new_expr = super().__call__(expr, *args, **kwargs)
if getattr(expr, 'source', None):
if isinstance(new_expr, tuple):
Expand Down Expand Up @@ -551,38 +552,63 @@ def map_int_literal(self, expr, *args, **kwargs):
map_float_literal = map_int_literal

def map_variable_symbol(self, expr, *args, **kwargs):
kind = self.rec(expr.type.kind, *args, **kwargs)
if kind is not expr.type.kind and expr.scope:
# Update symbol table entry for kind directly because with a scope attached
# it does not affect the outcome of expr.clone
expr.scope.symbol_attrs[expr.name] = expr.type.clone(kind=kind)

if expr.scope and expr.type.initial and expr.name == expr.type.initial:
# FIXME: This is a hack to work around situations where a constant
# symbol (from a parent scope) with the same name as the declared
# variable is used as initializer. This hands down the correct scope
# (in case this traversal is part of ``AttachScopesMapper``) and thus
# interrupts an otherwise infinite recursion (see LOKI-52).
_kwargs = kwargs.copy()
_kwargs['scope'] = expr.scope.parent
initial = self.rec(expr.type.initial, *args, **_kwargs)
else:
initial = self.rec(expr.type.initial, *args, **kwargs)
if initial is not expr.type.initial and expr.scope:
# Update symbol table entry for initial directly because with a scope attached
# it does not affect the outcome of expr.clone
expr.scope.symbol_attrs[expr.name] = expr.type.clone(initial=initial)

bind_names = self.rec(expr.type.bind_names, *args, **kwargs)
if not (bind_names is None or all(new is old for new, old in zip_longest(bind_names, expr.type.bind_names))):
# Update symbol table entry for bind_names directly because with a scope attached
# it does not affect the outcome of expr.clone
expr.scope.symbol_attrs[expr.name] = expr.type.clone(bind_names=as_tuple(bind_names))
# When updating declaration attributes, which are stored in the symbol table,
# we need to disable `recurse_to_declaration_attributes` to avoid infinite
# recursion because of the various ways that Fortran allows to use the declared
# symbol also inside the declaration expression
recurse_to_declaration_attributes = kwargs['recurse_to_declaration_attributes'] or expr.scope is None
kwargs['recurse_to_declaration_attributes'] = False

if recurse_to_declaration_attributes:
old_type = expr.type
kind = self.rec(old_type.kind, *args, **kwargs)

if expr.scope and expr.name == old_type.initial:
# FIXME: This is a hack to work around situations where a constant
# symbol (from a parent scope) with the same name as the declared
# variable is used as initializer. This hands down the correct scope
# (in case this traversal is part of ``AttachScopesMapper``) and thus
# interrupts an otherwise infinite recursion (see LOKI-52).
_kwargs = kwargs.copy()
_kwargs['scope'] = expr.scope.parent
initial = self.rec(old_type.initial, *args, **_kwargs)
else:
initial = self.rec(old_type.initial, *args, **kwargs)

if old_type.bind_names:
bind_names = ()
for bind_name in old_type.bind_names:
if bind_name == expr.name:
# FIXME: This is a hack to work around situations where an
# explicit interface is used with the same name as the
# type bound procedure. This hands down the correct scope.
_kwargs = kwargs.copy()
_kwargs['scope'] = expr.scope.parent
bind_names += (self.rec(bind_name, *args, **_kwargs),)
else:
bind_names += (self.rec(bind_name, *args, **kwargs),)
else:
bind_names = None

is_type_changed = (
kind is not old_type.kind or initial is not old_type.initial or
any(new is not old for new, old in zip_longest(as_tuple(bind_names), as_tuple(old_type.bind_names)))
)
if is_type_changed:
new_type = old_type.clone(kind=kind, initial=initial, bind_names=bind_names)
if expr.scope:
# Update symbol table entry
expr.scope.symbol_attrs[expr.name] = new_type

parent = self.rec(expr.parent, *args, **kwargs)
if parent is expr.parent and (kind is expr.type.kind or expr.scope):
if expr.scope is None:
if parent is expr.parent and not is_type_changed:
return expr
return expr.clone(parent=parent, type=new_type)

if parent is expr.parent:
return expr
return expr.clone(parent=parent, type=expr.type.clone(kind=kind))
return expr.clone(parent=parent)

map_deferred_type_symbol = map_variable_symbol
map_procedure_symbol = map_variable_symbol
Expand Down Expand Up @@ -611,7 +637,13 @@ def map_array(self, expr, *args, **kwargs):
# and make sure we don't loose the call parameters (aka dimensions)
return InlineCall(function=symbol.clone(parent=parent), parameters=dimensions)

shape = self.rec(symbol.type.shape, *args, **kwargs)
if kwargs['recurse_to_declaration_attributes']:
_kwargs = kwargs.copy()
_kwargs['recurse_to_declaration_attributes'] = False
shape = self.rec(symbol.type.shape, *args, **_kwargs)
else:
shape = symbol.type.shape

if (getattr(symbol, 'symbol', symbol) is expr.symbol and
all(d is orig_d for d, orig_d in zip_longest(dimensions or (), expr.dimensions or ())) and
all(d is orig_d for d, orig_d in zip_longest(shape or (), symbol.type.shape or ()))):
Expand Down
16 changes: 13 additions & 3 deletions loki/frontend/fparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1782,7 +1782,8 @@ def visit_Subroutine_Subprogram(self, o, **kwargs):
if return_type is not None:
routine.symbol_attrs[routine.name] = return_type
return_var = sym.Variable(name=routine.name, scope=routine)
return_var_decl = ir.VariableDeclaration(symbols=(return_var,))
decl_source = self.get_source(subroutine_stmt, source=None)
return_var_decl = ir.VariableDeclaration(symbols=(return_var,), source=decl_source)

decls = FindNodes((ir.VariableDeclaration, ir.ProcedureDeclaration)).visit(spec)
if not decls:
Expand Down Expand Up @@ -2099,8 +2100,17 @@ def visit_If_Construct(self, o, **kwargs):
else_if_stmt_index, else_if_stmts = zip(*else_if_stmts)
else:
else_if_stmt_index = ()
else_stmt = get_child(o, Fortran2003.Else_Stmt)
else_stmt_index = o.children.index(else_stmt) if else_stmt else end_if_stmt_index

# Note: we need to use here the same method as for else-if because finding Else_Stmt
# directly and checking its position via o.children.index may give the wrong result.
# This is because Else_Stmt may erronously compare equal to other node types.
# See https://github.com/stfc/fparser/issues/400
else_stmt = tuple((i, c) for i, c in enumerate(o.children) if isinstance(c, Fortran2003.Else_Stmt))
if else_stmt:
assert len(else_stmt) == 1
else_stmt_index, else_stmt = else_stmt[0]
else:
else_stmt_index = end_if_stmt_index
conditions = as_tuple(self.visit(c, **kwargs) for c in (if_then_stmt,) + else_if_stmts)
bodies = tuple(
tuple(flatten(as_tuple(self.visit(c, **kwargs) for c in o.children[start+1:stop])))
Expand Down
2 changes: 1 addition & 1 deletion loki/frontend/ofp.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,7 @@ def visit_specific_binding(self, o, **kwargs):
type=SymbolAttributes(ProcedureType(interface))
)

_type = interface.type
_type = interface.type.clone(bind_names=(interface,))

elif o.attrib['procedureName']:
# Binding provided (<bindingName> => <procedureName>)
Expand Down
18 changes: 11 additions & 7 deletions loki/frontend/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def inject_statement_functions(routine):
def create_stmt_func(assignment):
arguments = assignment.lhs.dimensions
variable = assignment.lhs.clone(dimensions=None)
return StatementFunction(variable, arguments, assignment.rhs, variable.type)
return StatementFunction(variable, arguments, assignment.rhs, variable.type, source=assignment.source)

def create_type(stmt_func):
name = str(stmt_func.variable)
Expand Down Expand Up @@ -253,15 +253,19 @@ def create_type(stmt_func):
if variable.name.lower() in stmt_funcs:
if isinstance(variable, Array):
parameters = variable.dimensions
expr_map_spec[variable] = InlineCall(variable.clone(dimensions=None), parameters=parameters)
expr_map_spec[variable] = InlineCall(
variable.clone(dimensions=None), parameters=parameters, source=variable.source
)
elif not isinstance(variable, ProcedureSymbol):
expr_map_spec[variable] = variable.clone()
expr_map_body = {}
for variable in FindVariables().visit(routine.body):
if variable.name.lower() in stmt_funcs:
if isinstance(variable, Array):
parameters = variable.dimensions
expr_map_body[variable] = InlineCall(variable.clone(dimensions=None), parameters=parameters)
expr_map_body[variable] = InlineCall(
variable.clone(dimensions=None), parameters=parameters, source=variable.source
)
elif not isinstance(variable, ProcedureSymbol):
expr_map_body[variable] = variable.clone()

Expand All @@ -271,15 +275,15 @@ def create_type(stmt_func):

# Apply transformer with the built maps
if spec_map:
routine.spec = Transformer(spec_map).visit(routine.spec)
routine.spec = Transformer(spec_map, invalidate_source=False).visit(routine.spec)
if body_map:
routine.body = Transformer(body_map).visit(routine.body)
routine.body = Transformer(body_map, invalidate_source=False).visit(routine.body)
if spec_appendix:
routine.spec.append(spec_appendix)
if expr_map_spec:
routine.spec = SubstituteExpressions(expr_map_spec).visit(routine.spec)
routine.spec = SubstituteExpressions(expr_map_spec, invalidate_source=False).visit(routine.spec)
if expr_map_body:
routine.body = SubstituteExpressions(expr_map_body).visit(routine.body)
routine.body = SubstituteExpressions(expr_map_body, invalidate_source=False).visit(routine.body)

# And make sure all symbols have the right type
routine.rescope_symbols()
Expand Down
2 changes: 1 addition & 1 deletion loki/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
'Comment', 'CommentBlock', 'Pragma', 'PreprocessorDirective',
'Import', 'VariableDeclaration', 'ProcedureDeclaration', 'DataDeclaration',
'StatementFunction', 'TypeDef', 'MultiConditional', 'MaskedStatement',
'Intrinsic', 'Enumeration', 'RawSource'
'Intrinsic', 'Enumeration', 'RawSource',
]

# Configuration for validation mechanism via pydantic
Expand Down
2 changes: 1 addition & 1 deletion loki/transform/transform_associates.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class ResolveAssociatesTransformer(Transformer):
corresponding `selector` expression defined in ``associations``.
"""

def visit_Associate(self, o):
def visit_Associate(self, o, **kwargs): # pylint: disable=unused-argument
# First head-recurse, so that all associate blocks beneath are resolved
body = self.visit(o.body)

Expand Down
1 change: 0 additions & 1 deletion loki/visitors/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,6 @@ class NestedMaskedTransformer(MaskedTransformer):

# Handler for leaf nodes


def visit_object(self, o, **kwargs):
"""
Return the object unchanged.
Expand Down
44 changes: 43 additions & 1 deletion tests/test_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy as np

from conftest import jit_compile, clean_test, available_frontends
from loki import OMNI, Subroutine, FindNodes, Loop, Conditional, Node
from loki import OMNI, Subroutine, FindNodes, Loop, Conditional, Node, Intrinsic


@pytest.fixture(scope='module', name='here')
Expand Down Expand Up @@ -455,3 +455,45 @@ def test_conditional_bodies(frontend):
c.else_body and isinstance(c.else_body, tuple) and all(isinstance(n, Node) for n in c.else_body)
for c in conditionals
)


@pytest.mark.parametrize('frontend', available_frontends())
def test_conditional_else_body_return(frontend):
fcode = """
FUNCTION FUNC(PX,KN)
IMPLICIT NONE
INTEGER,INTENT(INOUT) :: KN
REAL,INTENT(IN) :: PX
REAL :: FUNC
INTEGER :: J
REAL :: Z0, Z1, Z2
Z0= 1.0
Z1= PX
IF (KN == 0) THEN
FUNC= Z0
RETURN
ELSEIF (KN == 1) THEN
FUNC= Z1
RETURN
ELSE
DO J=2,KN
Z2= Z0+Z1
Z0= Z1
Z1= Z2
ENDDO
FUNC= Z2
RETURN
ENDIF
END FUNCTION FUNC
""".strip()

routine = Subroutine.from_source(fcode, frontend=frontend)
conditionals = FindNodes(Conditional).visit(routine.body)
assert len(conditionals) == 2
assert isinstance(conditionals[0].body[-1], Intrinsic)
assert conditionals[0].body[-1].text.upper() == 'RETURN'
assert conditionals[0].else_body == (conditionals[1],)
assert isinstance(conditionals[1].body[-1], Intrinsic)
assert conditionals[1].body[-1].text.upper() == 'RETURN'
assert isinstance(conditionals[1].else_body[-1], Intrinsic)
assert conditionals[1].else_body[-1].text.upper() == 'RETURN'
39 changes: 39 additions & 0 deletions tests/test_derived_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1359,3 +1359,42 @@ def test_derived_types_nested_type(frontend):
assert assignment.rhs.parent.type.dtype.typedef is module['some_type']
assert assignment.rhs.parent.parent.type.dtype.name == 'other_type'
assert assignment.rhs.parent.parent.type.dtype.typedef is module['other_type']


@pytest.mark.parametrize('frontend', available_frontends())
def test_derived_types_abstract_deferred_procedure(frontend):
fcode = """
module some_mod
implicit none
type, abstract :: abstract_type
contains
procedure (some_proc), deferred :: some_proc
procedure (other_proc), deferred :: other_proc
end type abstract_type
abstract interface
subroutine some_proc(this)
import abstract_type
class(abstract_type), intent(in) :: this
end subroutine some_proc
end interface
abstract interface
subroutine other_proc(this)
import abstract_type
class(abstract_type), intent(inout) :: this
end subroutine other_proc
end interface
end module some_mod
""".strip()
module = Module.from_source(fcode, frontend=frontend)
typedef = module['abstract_type']
assert typedef.abstract is True
assert typedef.variables == ('some_proc', 'other_proc')
for symbol in typedef.variables:
assert isinstance(symbol, ProcedureSymbol)
assert isinstance(symbol.type.dtype, ProcedureType)
assert symbol.type.dtype.name.lower() == symbol.name.lower()
assert symbol.type.bind_names == (symbol,)
assert symbol.scope is typedef
assert symbol.type.bind_names[0].scope is module
Loading

0 comments on commit 14c6c82

Please sign in to comment.