Skip to content

Commit

Permalink
Subroutine: Only clone symbol when inferring from allocatable
Browse files Browse the repository at this point in the history
  • Loading branch information
mlange05 committed Mar 12, 2024
1 parent 2dcde7e commit 16762e5
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 25 deletions.
3 changes: 1 addition & 2 deletions loki/frontend/fparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1819,10 +1819,9 @@ def visit_Subroutine_Subprogram(self, o, **kwargs):
decl._update(symbols=tuple(s.clone() if routine.symbol_attrs[s.name].is_stmt_func else s
for s in decl.symbols))

# Big, but necessary hack:
# For deferred array dimensions on allocatables, we infer the conceptual
# dimension by finding any `allocate(var(<dims>))` statements.
routine.spec, routine.body = routine._infer_allocatable_shapes(routine.spec, routine.body)
routine._infer_allocatable_shapes()

# Update array shapes with Loki dimension pragmas
with pragmas_attached(routine, ir.VariableDeclaration):
Expand Down
3 changes: 1 addition & 2 deletions loki/frontend/ofp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1221,10 +1221,9 @@ def visit_subroutine(self, o, **kwargs):
rescope_symbols=True, source=kwargs['source'], incomplete=False
)

# Big, but necessary hack:
# For deferred array dimensions on allocatables, we infer the conceptual
# dimension by finding any `allocate(var(<dims>))` statements.
routine.spec, routine.body = routine._infer_allocatable_shapes(routine.spec, routine.body)
routine._infer_allocatable_shapes()

# Update array shapes with Loki dimension pragmas
with pragmas_attached(routine, ir.VariableDeclaration):
Expand Down
3 changes: 1 addition & 2 deletions loki/frontend/omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,10 +426,9 @@ def visit_FfunctionDefinition(self, o, **kwargs):
source=routine.source, incomplete=False
)

# Big, but necessary hack:
# For deferred array dimensions on allocatables, we infer the conceptual
# dimension by finding any `allocate(var(<dims>))` statements.
routine.spec, routine.body = routine._infer_allocatable_shapes(routine.spec, routine.body)
routine._infer_allocatable_shapes()

# Update array shapes with Loki dimension pragmas
with pragmas_attached(routine, ir.VariableDeclaration):
Expand Down
28 changes: 9 additions & 19 deletions loki/subroutine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# nor does it submit to any jurisdiction.

from loki import ir
from loki.expression import FindVariables, SubstituteExpressions, symbols as sym
from loki.expression import symbols as sym
from loki.frontend import (
parse_omni_ast, parse_ofp_ast, parse_fparser_ast, get_fparser_node,
parse_regex_source
Expand Down Expand Up @@ -131,31 +131,21 @@ def __setstate__(self, s):
# Ensure that we are attaching all symbols to the newly create ``self``.
self.rescope_symbols()

@staticmethod
def _infer_allocatable_shapes(spec, body):

def _infer_allocatable_shapes(self):
"""
Infer variable symbol shapes from allocations of ``allocatable`` arrays.
"""
alloc_map = {}
for alloc in FindNodes(ir.Allocation).visit(body):
for alloc in FindNodes(ir.Allocation).visit(self.body):
for v in alloc.variables:
if isinstance(v, sym.Array):
if alloc.data_source:
alloc_map[v.name.lower()] = alloc.data_source.type.shape
new_shape = alloc.data_source.type.shape
else:
alloc_map[v.name.lower()] = v.dimensions
vmap = {}
for v in FindVariables().visit(body):
if v.name.lower() in alloc_map:
vtype = v.type.clone(shape=alloc_map[v.name.lower()])
vmap[v] = v.clone(type=vtype)
smap = {}
for v in FindVariables().visit(spec):
if v.name.lower() in alloc_map:
vtype = v.type.clone(shape=alloc_map[v.name.lower()])
smap[v] = v.clone(type=vtype)
return (SubstituteExpressions(smap, invalidate_source=False).visit(spec),
SubstituteExpressions(vmap, invalidate_source=False).visit(body))
new_shape = v.dimensions

# Update the type to inject shape info into symbol table
v.type = v.type.clone(shape=new_shape)

@classmethod
def from_omni(cls, ast, raw_source, definitions=None, parent=None, type_map=None):
Expand Down

0 comments on commit 16762e5

Please sign in to comment.