Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix wrong classification as StatementFunction in translation to Loki IR #327

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions loki/expression/symbols.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,12 @@ def __getinitargs__(self):
def init_arg_names(self):
return self.symbol.init_arg_names

def _lookup_type(self, scope):
"""
Helper method to look-up type information in any :data:`scope`
"""
return self.symbol._lookup_type(scope)

def clone(self, **kwargs):
"""
Replicate the object with the provided overrides.
Expand Down
71 changes: 71 additions & 0 deletions loki/expression/tests/test_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -2197,3 +2197,74 @@ def static_func(a):
test_str = 'foo%bar%barbar%barbarbar%val_barbarbar + 1'
parsed = parse_expr(convert_to_case(f'{test_str}', mode=case), evaluate=True, context=context)
assert parsed == 6


@pytest.mark.parametrize('frontend', available_frontends(
skip={OMNI: "OMNI fails on missing module"}
))
def test_stmt_func_heuristic(frontend, tmp_path):
"""
Our Fparser/OFP translation has a heuristic to detect statement function declarations,
but that falsely misinterpreted some assignments as statement functions due to
missing shape information (reported in #326)
"""
fcode = """
SUBROUTINE SOME_ROUTINE(YDFIELDS,YDMODEL,YDCST)
USE FIELDS_MOD , ONLY : FIELDS
USE TYPE_MODEL , ONLY : MODEL
USE VAR_MOD , ONLY : ARR, FNAME
IMPLICIT NONE
TYPE(FIELDS) ,INTENT(INOUT) :: YDFIELDS
TYPE(MODEL) ,INTENT(IN) :: YDMODEL
TYPE(TOMCST) ,INTENT(IN) :: YDCST
CHARACTER(LEN=20) :: CLFILE
REAL :: ZALFA
REAL :: ZALFAG(3)
REAL :: FOEALFA
REAL :: PTARE
FOEALFA(PTARE) = MIN(1.0, PTARE)
#include "fcttre.func.h"

ASSOCIATE(YDSURF=>YDFIELDS%YRSURF,RTT=>YDCST%RTT)
ASSOCIATE(SD_VN=>YDSURF%SD_VN,YSD_VN=>YDSURF%YSD_VN, &
& LEGBRAD=>YDMODEL%YRML_PHY_EC%YREPHY%LEGBRAD)
IF(LEGBRAD)SD_VN(:,YSD_VN%YACCPR5%MP,:)=SD_VN(:,YSD_VN%YACCPR%MP,:)
IF(LEGBRAD)ARR(:,YSD_VN%YACCPR5%MP,:)=SD_VN(:,YSD_VN%YACCPR%MP,:)
CLFILE(1:20)=FNAME
ZALFA=FOEDELTA(RTT)
ZALFAG(1)=FOEDELTA(RTT-1.)
ZALFAG(2)=FOEALFA(RTT)
ZALFAG(3)=FOEALFA(RTT-1.)
END ASSOCIATE
END ASSOCIATE
END SUBROUTINE SOME_ROUTINE
""".strip()
source = Sourcefile.from_source(fcode, frontend=frontend, xmods=[tmp_path])
routine = source['some_routine']

assignments = FindNodes(ir.Assignment).visit(routine.body)

assert [
ass.lhs.name.lower() for ass in assignments
] == [
'sd_vn', 'arr', 'clfile', 'zalfa', 'zalfag', 'zalfag', 'zalfag'
]

sd_vn = assignments[0].lhs
assert isinstance(sd_vn, sym.Array)

arr = assignments[1].lhs
assert isinstance(arr, sym.Array)
assert arr.type.imported

# FOEDELTA cannot be identified as a statement function due to the declarations
# hidden in the external header
assert isinstance(assignments[3].rhs, sym.Array)
assert isinstance(assignments[4].rhs, sym.Array)

# FOEALFA should have been identified as a statement function
stmt_funcs = FindNodes(ir.StatementFunction).visit(routine.ir)
assert len(stmt_funcs) == 1
assert stmt_funcs[0].name.lower() == 'foealfa'
assert isinstance(assignments[5].rhs, sym.InlineCall)
assert isinstance(assignments[6].rhs, sym.InlineCall)
39 changes: 24 additions & 15 deletions loki/frontend/fparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2983,22 +2983,31 @@ def visit_Assignment_Stmt(self, o, **kwargs):

# Special-case: Identify statement functions using our internal symbol table
symbol_attrs = kwargs['scope'].symbol_attrs
if isinstance(lhs, sym.Array) and not lhs.parent and lhs.name in symbol_attrs:

def _create_stmt_func_type(stmt_func):
name = str(stmt_func.variable)
procedure = LazyNodeLookup(
anchor=kwargs['scope'],
query=lambda x: [
f for f in FindNodes(ir.StatementFunction).visit(x.spec) if f.variable == name
][0]
)
proc_type = ProcedureType(is_function=True, procedure=procedure, name=name)
return SymbolAttributes(dtype=proc_type, is_stmt_func=True)
if isinstance(lhs, sym.Array) and symbol_attrs.lookup(lhs.name) is not None:
# If this looks like an array but we have an explicit scalar declaration then
# this might in fact be a statement function.
# To avoid the costly lookup for declarations on each array assignment, we run through
# some sanity checks instead that allow us to bail out early in most cases
lhs_type = lhs.type
could_be_a_statement_func = not (
lhs_type.shape or lhs_type.length # Declaration with length or dimensions
or lhs.parent # Derived type member (we might lack information from enrichment)
or lhs_type.intent or lhs_type.imported # Dummy argument or imported from module
or isinstance(lhs.scope, ir.Associate) # Symbol stems from an associate
)

if could_be_a_statement_func:
def _create_stmt_func_type(stmt_func):
name = str(stmt_func.variable)
procedure = LazyNodeLookup(
anchor=kwargs['scope'],
query=lambda x: [
f for f in FindNodes(ir.StatementFunction).visit(x.spec) if f.variable == name
][0]
)
proc_type = ProcedureType(is_function=True, procedure=procedure, name=name)
return SymbolAttributes(dtype=proc_type, is_stmt_func=True)

if not lhs.type.shape and not lhs.type.intent:
# If the LHS array access is actually declared as a scalar,
# we are actually dealing with a statement function!
f_symbol = sym.ProcedureSymbol(name=lhs.name, scope=kwargs['scope'])
stmt_func = ir.StatementFunction(
variable=f_symbol, arguments=lhs.dimensions,
Expand Down
82 changes: 69 additions & 13 deletions loki/frontend/ofp.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from loki import ir
from loki.ir import (
GenericVisitor, attach_pragmas, process_dimension_pragmas,
detach_pragmas, pragmas_attached
detach_pragmas, pragmas_attached, FindNodes
)
import loki.expression.symbols as sym
from loki.expression.operations import (
Expand All @@ -37,6 +37,7 @@
from loki.expression import ExpressionDimensionsMapper, AttachScopesMapper
from loki.tools import (
as_tuple, disk_cached, flatten, gettempdir, filehash, CaseInsensitiveDict,
LazyNodeLookup
)
from loki.logging import debug, info, warning, error
from loki.types import BasicType, DerivedType, ProcedureType, SymbolAttributes
Expand Down Expand Up @@ -457,6 +458,45 @@ def visit_where_stmt(self, o, **kwargs):
def visit_assignment(self, o, **kwargs):
lhs = self.visit(o.find('target'), **kwargs)
rhs = self.visit(o.find('value'), **kwargs)

# Special-case: Identify statement functions using our internal symbol table
symbol_attrs = kwargs['scope'].symbol_attrs
if isinstance(lhs, sym.Array) and symbol_attrs.lookup(lhs.name) is not None:
# If this looks like an array but we have an explicit scalar declaration then
# this might in fact be a statement function.
# To avoid the costly lookup for declarations on each array assignment, we run through
# some sanity checks instead that allow us to bail out early in most cases
lhs_type = lhs.type
could_be_a_statement_func = not (
lhs_type.shape or lhs_type.length # Declaration with length or dimensions
or lhs.parent # Derived type member (we might lack information from enrichment)
or lhs_type.intent or lhs_type.imported # Dummy argument or imported from module
or isinstance(lhs.scope, ir.Associate) # Symbol stems from an associate
)

if could_be_a_statement_func:
def _create_stmt_func_type(stmt_func):
name = str(stmt_func.variable)
procedure = LazyNodeLookup(
anchor=kwargs['scope'],
query=lambda x: [
f for f in FindNodes(ir.StatementFunction).visit(x.spec) if f.variable == name
][0]
)
proc_type = ProcedureType(is_function=True, procedure=procedure, name=name)
return SymbolAttributes(dtype=proc_type, is_stmt_func=True)

f_symbol = sym.ProcedureSymbol(name=lhs.name, scope=kwargs['scope'])
stmt_func = ir.StatementFunction(
variable=f_symbol, arguments=lhs.dimensions,
rhs=rhs, return_type=symbol_attrs[lhs.name],
label=kwargs.get('label'), source=kwargs.get('source')
)

# Update the type in the local scope and return stmt func node
symbol_attrs[str(stmt_func.variable)] = _create_stmt_func_type(stmt_func)
return stmt_func

return ir.Assignment(lhs=lhs, rhs=rhs, label=kwargs['label'], source=kwargs['source'])

def visit_pointer_assignment(self, o, **kwargs):
Expand Down Expand Up @@ -943,19 +983,19 @@ def visit_declaration(self, o, **kwargs):

length = None
kind = None
if tk1 in ('', 'len'):
if tk1.lower() in ('', 'len'):
# The first child _should_ be the length selector
length = self.visit(o[0], **kwargs)

if tk2 == 'kind' or selector_idx > 2:
if tk2.lower() == 'kind' or selector_idx > 2:
# There is another value, presumably the kind specifier, which
# should be right before the char-selector
kind = self.visit(o[selector_idx-1], **kwargs)
elif tk1 == 'kind':
elif tk1.lower() == 'kind':
# The first child _should_ be the kind selector
kind = self.visit(o[0], **kwargs)

if tk2 == 'len':
if tk2.lower() == 'len':
# The second child should then be the length selector
assert selector_idx > 2
length = self.visit(o[1], **kwargs)
Expand Down Expand Up @@ -1125,7 +1165,7 @@ def visit_defined_operator(self, o, **kwargs):
name = f'OPERATOR({o.attrib["definedOp"]})'
return sym.Variable(name=name)

def _create_Subroutine_object(self, o, scope):
def _create_Subroutine_object(self, o, scope, prefix=None):
"""Helper method to instantiate a Subroutine object"""
from loki.subroutine import Subroutine # pylint: disable=import-outside-toplevel,cyclic-import
assert o.tag in ('subroutine', 'function')
Expand Down Expand Up @@ -1165,7 +1205,12 @@ def _create_Subroutine_object(self, o, scope):
if suffix.attrib['result'] == 'result':
result_name = header_ast.find('name').attrib['name']

prefix = [a.attrib['spec'].upper() for a in header_ast.findall('t-prefix-spec')] or None
if prefix:
prefix = [a.attrib['spec'].upper() for a in prefix if a.tag == 't-prefix-spec']
else:
prefix = []
if header_ast:
prefix += [a.attrib['spec'].upper() for a in header_ast.findall('t-prefix-spec')]

if routine is None:
routine = Subroutine(
Expand Down Expand Up @@ -1276,10 +1321,14 @@ def visit_module(self, o, **kwargs):
# subroutine objects using the weakref pointers stored in the symbol table.
# I know, it's not pretty but alternatively we could hand down this array as part of
# kwargs but that feels like carrying around a lot of bulk, too.
contains = [
self._create_Subroutine_object(member_ast, kwargs['scope'])
for member_ast in contains_ast if member_ast.tag in ('subroutine', 'function')
]
contains = []
prefix = []
for member_ast in contains_ast:
if member_ast.tag in ('subroutine', 'function'):
contains += [self._create_Subroutine_object(member_ast, kwargs['scope'], prefix)]
prefix = []
else:
prefix += [member_ast]

# Parse the spec
spec = self.visit(spec_ast, **kwargs)
Expand Down Expand Up @@ -1567,10 +1616,14 @@ def visit_names(self, o, **kwargs):
return tuple(self.visit(c, **kwargs) for c in o.findall('name'))

def visit_name(self, o, **kwargs):
scope = kwargs.get('scope', None)

if o.find('generic-name-list-part') is not None:
# From an external-stmt or use-stmt
return sym.Variable(name=o.attrib['id'])
name = o.attrib['id']
if scope:
scope = scope.get_symbol_scope(name)
return sym.Variable(name=name, scope=scope)

if o.find('generic_spec') is not None:
return self.visit(o.find('generic_spec'), **kwargs)
Expand All @@ -1582,6 +1635,10 @@ def visit_name(self, o, **kwargs):
name, parent = self.visit(part_ref, **kwargs), name
if parent:
name = name.clone(name=f'{parent.name}%{name.name}', parent=parent)
scope = parent.scope
if scope:
scope = scope.get_symbol_scope(name.name)
name = name.clone(scope=scope)

if part_ref.attrib['hasSectionSubscriptList'] == 'true':
if i < num_part_ref - 1 or o.attrib['type'] == 'variable':
Expand Down Expand Up @@ -1776,7 +1833,6 @@ def create_typedef_procedure_declaration(self, comps, iface=None, attrs=None, sc
symbols = tuple(s.rescope(scope=scope) for s in symbols)
return ir.ProcedureDeclaration(symbols=symbols, interface=iface, source=source)


def create_typedef_variable_declaration(self, t, comps, attr=None, scope=None, source=None):
"""
Utility method to create individual declarations from a group of AST nodes.
Expand Down
4 changes: 1 addition & 3 deletions loki/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,7 @@ def test_symbol_attributes_compare():
assert not someint.compare(somereal)


@pytest.mark.parametrize('frontend', available_frontends(xfail=[
(OFP, 'OFP needs preprocessing to support contiguous keyword'
)]))
@pytest.mark.parametrize('frontend', available_frontends())
def test_type_declaration_attributes(frontend):
"""
Test recognition of different declaration attributes.
Expand Down
6 changes: 2 additions & 4 deletions loki/transformations/build_system/tests/test_dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,8 +514,7 @@ def test_dependency_transformation_inline_call(frontend):
assert 'kernel_test' in [str(s) for s in imports[0].symbols]


@pytest.mark.parametrize('frontend', available_frontends(
xfail=[(OFP, 'OFP does not correctly handle result variable declaration.')]))
@pytest.mark.parametrize('frontend', available_frontends())
def test_dependency_transformation_inline_call_result_var(frontend):
"""
Test injection of suffixed kernel, accessed through inline function call.
Expand Down Expand Up @@ -675,8 +674,7 @@ def test_dependency_transformation_contained_member(frontend, use_scheduler, tmp
assert calls[0].name == 'get_b'


@pytest.mark.parametrize('frontend', available_frontends(
xfail=[(OFP, 'OFP does not correctly handle result variable declaration.')]))
@pytest.mark.parametrize('frontend', available_frontends())
def test_dependency_transformation_item_filter(frontend, tmp_path, config):
"""
Test that injection is not applied to modules that have no procedures
Expand Down
3 changes: 3 additions & 0 deletions loki/transformations/tests/test_inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ def test_transform_inline_elemental_functions(here, builder, frontend):
routine = Subroutine.from_source(fcode, definitions=module, frontend=frontend)
inline_elemental_functions(routine)

# Make sure there are no more inline calls in the routine body
assert not FindInlineCalls().visit(routine.body)

# Verify correct scope of inlined elements
assert all(v.scope is routine for v in FindVariables().visit(routine.body))

Expand Down
2 changes: 1 addition & 1 deletion loki/transformations/tests/test_transform_derived_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1048,7 +1048,7 @@ def test_transform_derived_type_arguments_optional_named_arg(frontend):
assert calls[2].kwarguments == (('opt1', '1'), ('val', '1'), ('t_arr', 't%arr'), ('opt2', '2'))


@pytest.mark.parametrize('frontend', available_frontends(xfail=[(OFP, 'No support for recursive prefix')]))
@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_derived_type_arguments_recursive(frontend):
fcode = """
module some_mod
Expand Down
Loading