Skip to content

Commit

Permalink
Merge pull request #388 from ecmwf-ifs/naml-resolve-assoc-merge
Browse files Browse the repository at this point in the history
Utilities to merge associate blocks and restrict depth of associate resolution
  • Loading branch information
reuterbal authored Oct 18, 2024
2 parents 1492d43 + 63b7f9c commit 56ab8d7
Show file tree
Hide file tree
Showing 6 changed files with 332 additions and 47 deletions.
30 changes: 3 additions & 27 deletions loki/frontend/fparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from loki.expression.operations import (
StringConcat, ParenthesisedAdd, ParenthesisedMul, ParenthesisedDiv, ParenthesisedPow
)
from loki.expression import ExpressionDimensionsMapper, AttachScopesMapper
from loki.expression import AttachScopesMapper
from loki.logging import debug, detail, info, warning, error
from loki.tools import (
as_tuple, flatten, CaseInsensitiveDict, LazyNodeLookup, dict_override
Expand Down Expand Up @@ -1507,35 +1507,11 @@ def visit_Associate_Construct(self, o, **kwargs):
kwargs['scope'] = associate

# Put associate expressions into the right scope and determine type of new symbols
rescoped_associations = []
for expr, name in associations:
# Put symbols in associated expression into the right scope
expr = AttachScopesMapper()(expr, scope=parent_scope)

# Determine type of new names
if isinstance(expr, (sym.TypedSymbol, sym.MetaSymbol)):
# Use the type of the associated variable
_type = expr.type.clone(parent=None)
if isinstance(expr, sym.Array) and expr.dimensions is not None:
shape = ExpressionDimensionsMapper()(expr)
if shape == (sym.IntLiteral(1),):
# For a scalar expression, we remove the shape
shape = None
_type = _type.clone(shape=shape)
else:
# TODO: Handle data type and shape of complex expressions
shape = ExpressionDimensionsMapper()(expr)
if shape == (sym.IntLiteral(1),):
# For a scalar expression, we remove the shape
shape = None
_type = SymbolAttributes(BasicType.DEFERRED, shape=shape)
name = name.clone(scope=associate, type=_type)
rescoped_associations += [(expr, name)]
associations = as_tuple(rescoped_associations)
associate._derive_local_symbol_types(parent_scope=parent_scope)

# The body
body = as_tuple(flatten(self.visit(c, **kwargs) for c in o.children[assoc_stmt_index+1:end_assoc_stmt_index]))
associate._update(associations=associations, body=body)
associate._update(body=body)

# Everything past the END ASSOCIATE (should be empty)
assert not o.children[end_assoc_stmt_index+1:]
Expand Down
39 changes: 36 additions & 3 deletions loki/ir/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
from pydantic.dataclasses import dataclass as dataclass_validated
from pydantic import model_validator

from loki.expression import Variable, parse_expr
from loki.expression import (
symbols as sym, Variable, parse_expr, AttachScopesMapper,
ExpressionDimensionsMapper
)
from loki.frontend.source import Source
from loki.scope import Scope
from loki.tools import flatten, as_tuple, is_iterable, truncate_string, CaseInsensitiveDict
Expand Down Expand Up @@ -501,19 +504,49 @@ def association_map(self):
"""
An :any:`collections.OrderedDict` of associated expressions.
"""
return CaseInsensitiveDict((str(k), v) for k, v in self.associations)
return CaseInsensitiveDict((k, v) for k, v in self.associations)

@property
def inverse_map(self):
"""
An :any:`collections.OrderedDict` of associated expressions.
"""
return CaseInsensitiveDict((str(v), k) for k, v in self.associations)
return CaseInsensitiveDict((v, k) for k, v in self.associations)

@property
def variables(self):
return tuple(v for _, v in self.associations)

def _derive_local_symbol_types(self, parent_scope):
""" Derive the types of locally defined symbols from their associations. """

rescoped_associations = ()
for expr, name in self.associations:
# Put symbols in associated expression into the right scope
expr = AttachScopesMapper()(expr, scope=parent_scope)

# Determine type of new names
if isinstance(expr, (sym.TypedSymbol, sym.MetaSymbol)):
# Use the type of the associated variable
_type = expr.type.clone(parent=None)
if isinstance(expr, sym.Array) and expr.dimensions is not None:
shape = ExpressionDimensionsMapper()(expr)
if shape == (sym.IntLiteral(1),):
# For a scalar expression, we remove the shape
shape = None
_type = _type.clone(shape=shape)
else:
# TODO: Handle data type and shape of complex expressions
shape = ExpressionDimensionsMapper()(expr)
if shape == (sym.IntLiteral(1),):
# For a scalar expression, we remove the shape
shape = None
_type = SymbolAttributes(BasicType.DEFERRED, shape=shape)
name = name.clone(scope=self, type=_type)
rescoped_associations += ((expr, name),)

self._update(associations=rescoped_associations)

def __repr__(self):
if self.associations:
associations = ', '.join(f'{str(var)}={str(expr)}'
Expand Down
32 changes: 32 additions & 0 deletions loki/ir/tests/test_ir_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,3 +292,35 @@ def test_callstatement(scope, one, i, n, a_i):
)

# TODO: Test pragmas, active and chevron


def test_associate(scope, a_i):
"""
Test constructors and scoping bahviour of :any:`Associate`.
"""
b = sym.Scalar(name='b', scope=scope)
b_a = sym.Array(name='a', parent=b, scope=scope)
a = sym.Array(name='a', scope=scope)
assign = ir.Assignment(lhs=a_i, rhs=sym.Literal(42.0))
assign2 = ir.Assignment(lhs=a_i.clone(parent=b), rhs=sym.Literal(66.6))

assoc = ir.Associate(associations=((b_a, a),), body=(assign, assign2), parent=scope)
assert isinstance(assoc.associations, tuple)
assert all(isinstance(n, tuple) and len(n) == 2 for n in assoc.associations)
assert isinstance(assoc.body, tuple)
assert all(isinstance(n, ir.Node) for n in assoc.body)

# TODO: Check constructor failures, auto-casting and frozen status

# Check provided symbol maps
assert 'B%a' in assoc.association_map and assoc.association_map['B%a'] is a
assert b_a in assoc.association_map and assoc.association_map[b_a] is a
assert 'a' in assoc.inverse_map and assoc.inverse_map['a'] is b_a
assert a in assoc.inverse_map and assoc.inverse_map[a] is b_a

# Check rescoping facility
assert assign.lhs.scope is scope
assert assign2.lhs.scope is scope
assoc.rescope_symbols()
assert assign.lhs.scope is assoc
assert assign2.lhs.scope is scope
24 changes: 16 additions & 8 deletions loki/tools/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,33 +239,41 @@ class CaseInsensitiveDict(OrderedDict):
https://stackoverflow.com/questions/2082152/case-insensitive-dictionary
"""
def __setitem__(self, key, value):
super().__setitem__(key.lower(), value)
key = key.lower() if isinstance(key, str) else key
super().__setitem__(key, value)

def __getitem__(self, key):
return super().__getitem__(key.lower())
key = key.lower() if isinstance(key, str) else key
return super().__getitem__(key)

def get(self, key, default=None):
return super().get(key.lower(), default)
key = key.lower() if isinstance(key, str) else key
return super().get(key, default)

def __contains__(self, key):
return super().__contains__(key.lower())
key = key.lower() if isinstance(key, str) else key
return super().__contains__(key)


class CaseInsensitiveDefaultDict(defaultdict):
"""
Variant of :any:`collections.defaultdict` that ignores the casing of string keys.
"""
def __setitem__(self, key, value):
super().__setitem__(key.lower(), value)
key = key.lower() if isinstance(key, str) else key
super().__setitem__(key, value)

def __getitem__(self, key):
return super().__getitem__(key.lower())
key = key.lower() if isinstance(key, str) else key
return super().__getitem__(key)

def get(self, key, default=None):
return super().get(key.lower(), default)
key = key.lower() if isinstance(key, str) else key
return super().get(key, default)

def __contains__(self, key):
return super().__contains__(key.lower())
key = key.lower() if isinstance(key, str) else key
return super().__contains__(key)


def strip_inline_comments(source, comment_char='!', str_delim='"\''):
Expand Down
Loading

0 comments on commit 56ab8d7

Please sign in to comment.