Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FParser: Perform in-place scope attachment during parse #242

Merged
merged 11 commits into from
Mar 12, 2024
Merged
89 changes: 70 additions & 19 deletions loki/frontend/fparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@
)
from loki.expression import ExpressionDimensionsMapper, AttachScopes, AttachScopesMapper
from loki.logging import debug, perf, info, warning, error
from loki.tools import as_tuple, flatten, CaseInsensitiveDict, LazyNodeLookup
from loki.tools import (
as_tuple, flatten, CaseInsensitiveDict, LazyNodeLookup, dict_override
)
from loki.pragma_utils import (
attach_pragmas, process_dimension_pragmas, detach_pragmas, pragmas_attached
)
Expand Down Expand Up @@ -346,7 +348,14 @@ def visit_Name(self, o, **kwargs):

:class:`fparser.two.Fortran2003.Name` has no children.
"""
return sym.Variable(name=o.tostr(), parent=kwargs.get('parent'))
name = o.tostr()
scope = kwargs.get('scope', None)
parent = kwargs.get('parent')
if parent:
scope = parent.scope
if scope:
scope = scope.get_symbol_scope(name)
Comment on lines +351 to +357
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm assuming this is the discussed handling of derived type members, that ensures

  1. that we create symbol table entries in the scope where the derived type instance has been declared (and not the typedef scope).
  2. we attach symbols to the scope where they are declared (and not the closest one)

For future sanity: Would it be worth adding a comment explaining why it's done like this?

return sym.Variable(name=name, parent=parent, scope=scope)

def visit_Type_Name(self, o, **kwargs):
"""
Expand All @@ -368,7 +377,9 @@ def visit_Part_Ref(self, o, **kwargs):
subscript (or `None`)
"""
name = self.visit(o.children[0], **kwargs)
dimensions = self.visit(o.children[1], **kwargs)
with dict_override(kwargs, {'parent': None}):
# Don't pass any parent on to dimension symbols
dimensions = self.visit(o.children[1], **kwargs)
if dimensions:
name = name.clone(dimensions=dimensions)

Expand All @@ -394,14 +405,20 @@ def visit_Data_Ref(self, o, **kwargs):
var = self.visit(o.children[0], **kwargs)
for c in o.children[1:]:
parent = var
kwargs['parent'] = parent
var = self.visit(c, **kwargs)
if isinstance(var, sym.InlineCall):
# This is a function call with a type-bound procedure, so we need to
# update the name slightly different
function = var.function.clone(name=f'{parent.name}%{var.function.name}', parent=parent)
var = var.clone(function=function)
else:
var = var.clone(name=f'{parent.name}%{var.name}', parent=parent)
# Hack: Need to force re-evaluation of the type from parent here via `type=None`
# We know there's a parent, but we cannot trust the auto-generation of the type,
# since the type lookup via parents can create mismatched DeferredTypeSymbols.
var = var.clone(
name=f'{parent.name}%{var.name}', parent=parent, scope=parent.scope, type=None
)
return var

#
Expand Down Expand Up @@ -761,7 +778,12 @@ def visit_Entity_Decl(self, o, **kwargs):
* char length (:class:`fparser.two.Fortran2003.Char_Length`)
* init (:class:`fparser.two.Fortran2003.Initialization`)
"""
var = self.visit(o.children[0], **kwargs)

# Do not pass scope down, as it might alias with previously
# created symbols. Instead, let the rescope in the Declaration
# assign the right scope, always!
with dict_override(kwargs, {'scope': None}):
var = self.visit(o.children[0], **kwargs)

if o.children[1]:
dimensions = as_tuple(self.visit(o.children[1], **kwargs))
Expand Down Expand Up @@ -840,7 +862,7 @@ def visit_External_Stmt(self, o, **kwargs):
scope = kwargs['scope']
for var in symbols:
_type = scope.symbol_attrs.lookup(var.name)
if _type is None:
if _type is None or _type.dtype == BasicType.DEFERRED:
dtype = ProcedureType(var.name, is_function=False)
else:
dtype = _type.dtype
Expand Down Expand Up @@ -1329,22 +1351,31 @@ def visit_Specific_Binding(self, o, **kwargs):
interface = None
if o.children[0]:
# Procedure interface provided
# (we pass the parent scope down for this)
kwargs['scope'] = scope.parent
interface = self.visit(o.children[0], **kwargs)
interface = AttachScopesMapper()(interface, scope=scope)
bind_names = as_tuple(interface)
func_names = [interface.name] * len(symbols)
assert o.children[4] is None
kwargs['scope'] = scope
elif o.children[4]:
# we pass the parent scope down for this
kwargs['scope'] = scope.parent
bind_names = as_tuple(self.visit(o.children[4], **kwargs))
bind_names = AttachScopesMapper()(bind_names, scope=scope)
assert len(bind_names) == len(symbols)
func_names = [i.name for i in bind_names]
kwargs['scope'] = scope
else:
bind_names = None
func_names = [s.name for s in symbols]

# Look up the type of the procedure
types = [scope.symbol_attrs.lookup(name) or SymbolAttributes(dtype=ProcedureType(name)) for name in func_names]
types = [scope.symbol_attrs.lookup(name) for name in func_names]
types = [
SymbolAttributes(dtype=ProcedureType(name))
if not t or t.dtype == BasicType.DEFERRED else t
for t, name in zip(types, func_names)
]

# Any declared attributes
attrs = self.visit(o.children[1], **kwargs) if o.children[1] else ()
Expand Down Expand Up @@ -1560,8 +1591,11 @@ def visit_Interface_Block(self, o, **kwargs):
elif spec is not None:
# This has a generic specification (and we might need to update symbol table)
scope = kwargs['scope']
if spec.name not in scope.symbol_attrs:
scope.symbol_attrs[spec.name] = SymbolAttributes(ProcedureType(name=spec.name, is_generic=True))
spec_type = scope.symbol_attrs.lookup(spec.name)
if not spec_type or spec_type.dtype == BasicType.DEFERRED:
scope.symbol_attrs[spec.name] = SymbolAttributes(
ProcedureType(name=spec.name, is_generic=True)
)
spec = spec.rescope(scope=scope)

# Traverse the body and build the object
Expand Down Expand Up @@ -1741,6 +1775,10 @@ def visit_Subroutine_Subprogram(self, o, **kwargs):
spec = ir.Section(body=as_tuple(spec_parts))
spec = sanitize_ir(spec, FP, pp_registry=sanitize_registry[FP], pp_info=self.pp_info)

# As variables may be defined out of sequence, we need to re-generate
# symbols in the spec part to make them coherent with the symbol table
spec = AttachScopes().visit(spec, scope=routine, recurse_to_declaration_attributes=True)

# Now all declarations are well-defined and we can parse the member routines
if contains_ast is not None:
contains = self.visit(contains_ast, **kwargs)
Expand Down Expand Up @@ -1809,7 +1847,7 @@ def visit_Subroutine_Subprogram(self, o, **kwargs):
name=routine.name, args=routine._dummies, docstring=docs, spec=spec,
body=body, contains=contains, ast=o, prefix=routine.prefix, bind=routine.bind,
result_name=routine.result_name, is_function=routine.is_function,
rescope_symbols=True, source=source, incomplete=False
rescope_symbols=False, source=source, incomplete=False
)

# Once statement functions are in place, we need to update the original declaration so that it
Expand Down Expand Up @@ -1869,7 +1907,8 @@ def visit_Subroutine_Stmt(self, o, **kwargs):
routine = None
if kwargs['scope'] is not None and name in kwargs['scope'].symbol_attrs:
proc_type = kwargs['scope'].symbol_attrs[name] # Look-up only in current scope!
if proc_type and proc_type.dtype.procedure != BasicType.DEFERRED:
if proc_type and proc_type.dtype != BasicType.DEFERRED and \
proc_type.dtype.procedure != BasicType.DEFERRED:
routine = proc_type.dtype.procedure
if not routine._incomplete:
# We return the existing object right away, unless it exists from a
Expand Down Expand Up @@ -2036,6 +2075,10 @@ def visit_Module(self, o, **kwargs):
docs = []
spec = None

# As variables may be defined out of sequence, we need to re-generate
# symbols in the spec part to make them coherent with the symbol table
spec = AttachScopes().visit(spec, scope=module, recurse_to_declaration_attributes=True)

# Now that all declarations are well-defined we can parse the member routines
if contains_ast is not None:
contains = self.visit(contains_ast, **kwargs)
Expand All @@ -2048,7 +2091,7 @@ def visit_Module(self, o, **kwargs):
module.__initialize__(
name=module.name, docstring=docs, spec=spec, contains=contains,
default_access_spec=module.default_access_spec, public_access_spec=module.public_access_spec,
private_access_spec=module.private_access_spec, ast=o, rescope_symbols=True, source=source,
private_access_spec=module.private_access_spec, ast=o, rescope_symbols=False, source=source,
incomplete=False
)

Expand Down Expand Up @@ -2451,9 +2494,15 @@ def visit_Procedure_Designator(self, o, **kwargs):
* procedure name :class:`fparser.two.Fortran2003.Binding_Name`
"""
assert o.children[1] == '%'
scope = kwargs.get('scope', None)
parent = self.visit(o.children[0], **kwargs)
if parent:
scope = parent.scope
name = self.visit(o.children[2], **kwargs)
name = name.clone(name=f'{parent.name}%{name.name}', parent=parent)
# Hack: Need to force re-evaluation of the type from parent here via `type=None`
# To fix this, we should stop creating symbols in the enclosing scope
# when determining the type of drieved type members from their parent.
name = name.clone(name=f'{parent.name}%{name.name}', parent=parent, scope=scope, type=None)
Comment on lines +2497 to +2505
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just looking at the visit_Name, not having actually tested this: Wouldn't passing down parent in the recursion to obtain name work around the scope issue here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alas, no (I tried!). I'm using this hack in two places, each associated with a specific failure in the test base (both in test_scheduler.py). Neither was fixed by passing the parent down (even though I can see what you mean).

I still believe (but not 100% on this), that this is due to us (unintentionally) unrolling all potential local subcomponent variables of a derived-type variable in the local scope. I tried dropping this, but it opened another can of worms, so dropped it. Maybe worth capturing this in an issue instead?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Captured in #243

return name

visit_Actual_Arg_Spec_List = visit_List
Expand Down Expand Up @@ -2529,10 +2578,11 @@ def visit_Structure_Constructor(self, o, **kwargs):
# https://github.com/stfc/fparser/issues/201 for some details
name = self.visit(o.children[0], **kwargs)
assert isinstance(name, DerivedType)
scope = kwargs.get('scope', None)

# `name` is a DerivedType but we represent a constructor call as InlineCall for
# which we need ProcedureSymbol
name = sym.Variable(name=name.name)
name = sym.Variable(name=name.name, scope=scope)

if o.children[1] is not None:
arguments = self.visit(o.children[1], **kwargs)
Expand Down Expand Up @@ -2929,7 +2979,7 @@ def visit_Assignment_Stmt(self, o, **kwargs):

# Special-case: Identify statement functions using our internal symbol table
symbol_attrs = kwargs['scope'].symbol_attrs
if isinstance(lhs, sym.Array) and lhs.name in symbol_attrs:
if isinstance(lhs, sym.Array) and not lhs.parent and lhs.name in symbol_attrs:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha, I think I've hit that one end of last week, too...


def _create_stmt_func_type(stmt_func):
name = str(stmt_func.variable)
Expand All @@ -2942,11 +2992,12 @@ def _create_stmt_func_type(stmt_func):
proc_type = ProcedureType(is_function=True, procedure=procedure, name=name)
return SymbolAttributes(dtype=proc_type, is_stmt_func=True)

if not symbol_attrs[lhs.name].shape and not symbol_attrs[lhs.name].intent:
if not lhs.type.shape and not lhs.type.intent:
# If the LHS array access is actually declared as a scalar,
# we are actually dealing with a statement function!
f_symbol = sym.ProcedureSymbol(name=lhs.name, scope=kwargs['scope'])
stmt_func = ir.StatementFunction(
variable=lhs.clone(dimensions=None), arguments=lhs.dimensions,
variable=f_symbol, arguments=lhs.dimensions,
rhs=rhs, return_type=symbol_attrs[lhs.name],
label=kwargs.get('label'), source=kwargs.get('source')
)
Expand Down
25 changes: 24 additions & 1 deletion loki/tools/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
'execute', 'CaseInsensitiveDict', 'strip_inline_comments',
'binary_insertion_sort', 'cached_func', 'optional', 'LazyNodeLookup',
'yaml_include_constructor', 'auto_post_mortem_debugger', 'set_excepthook',
'timeout', 'WeakrefProperty', 'group_by_class', 'replace_windowed'
'timeout', 'WeakrefProperty', 'group_by_class', 'replace_windowed', 'dict_override'
]


Expand Down Expand Up @@ -669,3 +669,26 @@ def replace_windowed(iterable, group, subs):
iterable, pred=lambda *args: args == group,
substitutes=as_tuple(subs), window_size=len(group)
))


@contextmanager
def dict_override(base, override):
"""
Contextmanager to temporarily override a set of dictionary values.

Parameters
----------
base : dict
The base dictionary in which to overide values
replace : dict
Replacement mapping to temporarily insert
"""
original_values = tuple((k, base[k]) for k in override.keys() if k in base)
added_keys = tuple(k for k in override.keys() if k not in base)
base.update(override)

yield base

base.update(original_values)
for k in added_keys:
del base[k]
14 changes: 13 additions & 1 deletion tests/test_derived_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
OMNI, OFP, Module, Subroutine, BasicType, DerivedType, TypeDef,
fgen, FindNodes, Intrinsic, ProcedureDeclaration, ProcedureType,
VariableDeclaration, Assignment, InlineCall, Builder, StringSubscript,
Conditional, CallStatement, ProcedureSymbol
Conditional, CallStatement, ProcedureSymbol, FindVariables
)


Expand Down Expand Up @@ -66,6 +66,18 @@ def test_simple_loops(here, frontend):
end module
"""
module = Module.from_source(fcode, frontend=frontend)
routine = module['simple_loops']

# Ensure type info is attached correctly
item_vars = [v for v in FindVariables(unique=False).visit(routine.body) if v.parent]
assert all(v.type.dtype == BasicType.REAL for v in item_vars)
assert item_vars[0].name == 'item%vector' and item_vars[0].shape == (3,)
assert item_vars[1].name == 'item%vector' and item_vars[1].shape == (3,)
assert item_vars[2].name == 'item%scalar' and item_vars[2].type.shape is None
assert item_vars[3].name == 'item%matrix' and item_vars[3].shape == (3, 3)
assert item_vars[4].name == 'item%matrix' and item_vars[4].shape == (3, 3)
assert item_vars[5].name == 'item%scalar' and item_vars[5].type.shape is None

filepath = here/(f'derived_types_simple_loops_{frontend}.f90')
mod = jit_compile(module, filepath=filepath, objname='derived_types_mod')

Expand Down
3 changes: 3 additions & 0 deletions tests/test_fgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ def test_multiline_inline_conditional(frontend):
"""
fcode = """
subroutine test_fgen(DIMS, ZSURF_LOCAL)
type(DIMENSION_TYPE), intent(in) :: DIMS
type(SURFACE_TYPE), intent(inout) :: ZSURF_LOCAL
type(STATE_TYPE) :: TENDENCY_LOC
contains
subroutine test_inline_multiline(KDIMS, LBUD23)

Expand Down
7 changes: 6 additions & 1 deletion tests/test_subroutine.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,9 +502,9 @@ def test_routine_variables_dim_shapes(frontend):
subroutine routine_dim_shapes(v1, v2, v3, v4, v5)
! Simple variable assignments with non-trivial sizes and indices
integer, parameter :: jprb = selected_real_kind(13,300)
integer, intent(in) :: v1, v2
real(kind=jprb), allocatable, intent(out) :: v3(:)
real(kind=jprb), intent(out) :: v4(v1,v2), v5(1:v1,v2-1)
integer, intent(in) :: v1, v2

allocate(v3(v1))
v3(v1-v2+1) = 1.
Expand All @@ -524,6 +524,11 @@ def test_routine_variables_dim_shapes(frontend):
assert shapes in (['(v1,)', '(v1, v2)', '(1:v1, v2 - 1)'],
['(v1,)', '(1:v1, 1:v2)', '(1:v1, 1:v2 - 1)'])

# Ensure that all spec variables (including dimension symbols) are scoped correctly
spec_vars = FindVariables(unique=False).visit(routine.spec)
assert all(v.scope == routine for v in spec_vars)
assert all(isinstance(v, (Scalar, Array)) for v in spec_vars)

# Ensure shapes of body variables are ok
b_shapes = [fexprgen(v.shape) for v in FindVariables(unique=False).visit(routine.body)
if isinstance(v, Array)]
Expand Down
14 changes: 13 additions & 1 deletion tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from conftest import stdchannel_is_captured, stdchannel_redirected
from loki.tools import (
JoinableStringList, truncate_string, binary_insertion_sort, is_subset,
optional, yaml_include_constructor, execute, timeout
optional, yaml_include_constructor, execute, timeout, dict_override
)


Expand Down Expand Up @@ -341,3 +341,15 @@ def test_timeout():
stop = perf_counter()
assert .9 < stop - start < 1.1
assert "My message" in str(exc.value)


def test_dict_override():
kwargs = {'rick' : 42, 'dave' : 'yeah'}
with dict_override(kwargs, {'dave' : 'nope', 'joe' : 'huh?'}):
assert kwargs['dave'] == 'nope'
assert kwargs['rick'] == 42
assert kwargs['joe'] == 'huh?'
assert kwargs['dave'] == 'yeah'
assert kwargs['rick'] == 42
assert 'joe' not in kwargs
assert len(kwargs) == 2
Loading