Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IR: Symbol management on scoped nodes #375

Merged
merged 4 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 77 additions & 5 deletions loki/ir/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
:ref:`internal_representation:Control flow tree`
"""

from abc import abstractmethod
from collections import OrderedDict
from dataclasses import dataclass
from functools import partial
Expand All @@ -22,10 +23,11 @@
from pydantic.dataclasses import dataclass as dataclass_validated
from pydantic import model_validator

from loki.expression import Variable, parse_expr
from loki.frontend.source import Source
from loki.scope import Scope
from loki.tools import flatten, as_tuple, is_iterable, truncate_string, CaseInsensitiveDict
from loki.types import DataType, BasicType, DerivedType, SymbolAttributes
from loki.frontend.source import Source


__all__ = [
Expand Down Expand Up @@ -329,6 +331,80 @@ def __setstate__(self, s):
symbol_attrs = s.pop('symbol_attrs', None)
self._update(**s, symbol_attrs=symbol_attrs, rescope_symbols=True)

@property
@abstractmethod
def variables(self):
"""
Return the variables defined in this :any:`ScopedNode`.
"""

@property
def variable_map(self):
"""
Map of variable names to :any:`Variable` objects
"""
return CaseInsensitiveDict((v.name, v) for v in self.variables)

def get_symbol(self, name):
"""
Returns the symbol for a given name as defined in its declaration.

The returned symbol might include dimension symbols if it was
declared as an array.

Parameters
----------
name : str
Base name of the symbol to be retrieved
"""
return self.get_symbol_scope(name).variable_map.get(name)

def Variable(self, **kwargs):
"""
Factory method for :any:`TypedSymbol` or :any:`MetaSymbol` classes.

This invokes the :any:`Variable` with this node as the scope.

Parameters
----------
name : str
The name of the variable.
type : optional
The type of that symbol. Defaults to :any:`BasicType.DEFERRED`.
parent : :any:`Scalar` or :any:`Array`, optional
The derived type variable this variable belongs to.
dimensions : :any:`ArraySubscript`, optional
The array subscript expression.
"""
kwargs['scope'] = self
return Variable(**kwargs)

def parse_expr(self, expr_str, strict=False, evaluate=False, context=None):
"""
Uses :meth:`parse_expr` to convert expression(s) represented
in a string to Loki expression(s)/IR.

Parameters
----------
expr_str : str
The expression as a string
strict : bool, optional
Whether to raise exception for unknown variables/symbols when
evaluating an expression (default: `False`)
evaluate : bool, optional
Whether to evaluate the expression or not (default: `False`)
context : dict, optional
Symbol context, defining variables/symbols/procedures to help/support
evaluating an expression

Returns
-------
:any:`Expression`
The expression tree corresponding to the expression
"""
return parse_expr(expr_str, scope=self, strict=strict, evaluate=evaluate, context=context)


# Intermediate node types


Expand Down Expand Up @@ -1579,10 +1655,6 @@ def comments(self):
def variables(self):
return tuple(flatten([decl.symbols for decl in self.declarations]))

@property
def variable_map(self):
return CaseInsensitiveDict((s.name, s) for s in self.variables)

@property
def imported_symbols(self):
"""
Expand Down
198 changes: 198 additions & 0 deletions loki/ir/tests/test_scoped_nodes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

# A set of tests for the symbol accessort and management API built into `ScopedNode`.

import pytest

from loki import Module
from loki.frontend import available_frontends
from loki.ir import nodes as ir, FindNodes
from loki.expression import symbols as sym
from loki.types import BasicType


@pytest.mark.parametrize('frontend', available_frontends())
def test_scoped_node_get_symbols(frontend, tmp_path):
""" Test :method:`get_symbol` functionality on scoped nodes. """
fcode = """
module test_scoped_node_symbols_mod
implicit none
integer, parameter :: jprb = 8

contains
subroutine test_scoped_node_symbols(n, a, b, c)
integer, intent(in) :: n
real(kind=jprb), intent(inout) :: a(n), b(n), c
integer :: i

a(1) = 42.0_jprb

associate(d => a)
do i=1, n
b(i) = a(i) + c
end do
end associate
end subroutine test_scoped_node_symbols
end module test_scoped_node_symbols_mod
"""
module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
routine = module['test_scoped_node_symbols']
associate = FindNodes(ir.Associate).visit(routine.body)[0]

# Check symbol lookup from subroutine
assert routine.get_symbol('a') == 'a(n)'
assert routine.get_symbol('a').scope == routine
assert routine.get_symbol('b') == 'b(n)'
assert routine.get_symbol('b').scope == routine
assert routine.get_symbol('c') == 'c'
assert routine.get_symbol('c').scope == routine
assert routine.get_symbol('jprb') == 'jprb'
assert routine.get_symbol('jprb').scope == module
assert routine.get_symbol('jprb').initial == 8

# Check passthrough from the Associate (ScopedNode)
assert associate.get_symbol('a') == 'a(n)'
assert associate.get_symbol('a').scope == routine
assert associate.get_symbol('b') == 'b(n)'
assert associate.get_symbol('b').scope == routine
assert associate.get_symbol('c') == 'c'
assert associate.get_symbol('c').scope == routine
assert associate.get_symbol('d') == 'd'
assert associate.get_symbol('d').scope == associate
assert associate.get_symbol('jprb') == 'jprb'
assert associate.get_symbol('jprb').scope == module
assert associate.get_symbol('jprb').initial == 8


@pytest.mark.parametrize('frontend', available_frontends())
def test_scoped_node_variable_constructor(frontend, tmp_path):
""" Test :any:`Variable` constrcutore on scoped nodes. """
fcode = """
module test_scoped_nodes_mod
implicit none
integer, parameter :: jprb = 8

contains
subroutine test_scoped_nodes(n, a, b, c)
integer, intent(in) :: n
real(kind=jprb), intent(inout) :: a(n), b(n), c
integer :: i

a(1) = 42.0_jprb

associate(d => a)
do i=1, n
b(i) = a(i) + c
end do
end associate
end subroutine test_scoped_nodes
end module test_scoped_nodes_mod
"""
module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
routine = module['test_scoped_nodes']
associate = FindNodes(ir.Associate).visit(routine.body)[0]

# Build some symbiols and check their type
c = routine.Variable(name='c')
assert c.type.dtype == BasicType.REAL
assert c.type.kind in ('jprb', 8)
i = routine.Variable(name='i')
assert i.type.dtype == BasicType.INTEGER
jprb = routine.Variable(name='jprb')
assert jprb.type.dtype == BasicType.INTEGER
assert jprb.type.initial == 8

a_i = routine.Variable(name='a', dimensions=(i,))
assert a_i == 'a(i)'
assert isinstance(a_i, sym.Array)
assert a_i.dimensions == (i,)
assert a_i.type.dtype == BasicType.REAL
assert a_i.type.kind in ('jprb', 8)

# Build another, but from the associate node
b_i = associate.Variable(name='b', dimensions=(i,))
assert b_i == 'b(i)'
assert isinstance(b_i, sym.Array)
assert b_i.dimensions == (i,)
assert b_i.type.dtype == BasicType.REAL
assert b_i.type.kind in ('jprb', 8)


@pytest.mark.parametrize('frontend', available_frontends())
def test_scoped_node_parse_expr(frontend, tmp_path):
""" Test :any:`Variable` constrcutore on scoped nodes. """
fcode = """
module test_scoped_nodes_mod
implicit none
integer, parameter :: jprb = 8

contains
subroutine test_scoped_nodes(n, a, b, c)
integer, intent(in) :: n
real(kind=jprb), intent(inout) :: a(n), b(n), c
integer :: i

a(1) = 42.0_jprb

associate(d => a)
do i=1, n
b(i) = a(i) + c
end do
end associate
end subroutine test_scoped_nodes
end module test_scoped_nodes_mod
"""
module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
routine = module['test_scoped_nodes']
associate = FindNodes(ir.Associate).visit(routine.body)[0]

assigns = FindNodes(ir.Assignment).visit(routine.body)
assert assigns[0].lhs == 'a(1)'
assert assigns[1].rhs == 'a(i) + c'

# Check that all variables are identified
ai_c = routine.parse_expr('a(i) + c')
assert ai_c == assigns[1].rhs
assert isinstance(ai_c, sym.Sum)
assert isinstance(ai_c.children[0], sym.Array)
assert isinstance(ai_c.children[0].dimensions[0], sym.Scalar)
assert isinstance(ai_c.children[1], sym.Scalar)
assert ai_c.children[0].scope == routine
assert ai_c.children[0].dimensions[0].scope == routine
assert ai_c.children[1].scope == routine

# Check that k is deferred
ai_k = routine.parse_expr('a(i) + k')
assert isinstance(ai_k, sym.Sum)
assert isinstance(ai_k.children[0], sym.Array)
assert isinstance(ai_k.children[0].dimensions[0], sym.Scalar)
assert isinstance(ai_k.children[1], sym.DeferredTypeSymbol)
assert ai_c.children[0].scope == routine
assert ai_c.children[0].dimensions[0].scope == routine
assert ai_c.children[1].scope == routine

# Check that all variables are identified
ai_c = associate.parse_expr('a(i) + c')
assert ai_c == assigns[1].rhs
assert isinstance(ai_c, sym.Sum)
assert isinstance(ai_c.children[0], sym.Array)
assert isinstance(ai_c.children[0].dimensions[0], sym.Scalar)
assert isinstance(ai_c.children[1], sym.Scalar)
assert ai_c.children[0].scope == routine
assert ai_c.children[0].dimensions[0].scope == routine
assert ai_c.children[1].scope == routine

# Check that k is deferred
ai_k = associate.parse_expr('a(i) + k')
assert isinstance(ai_k, sym.Sum)
assert isinstance(ai_k.children[0], sym.Array)
assert isinstance(ai_k.children[0].dimensions[0], sym.Scalar)
assert isinstance(ai_k.children[1], sym.DeferredTypeSymbol)
assert ai_c.children[0].scope == routine
assert ai_c.children[0].dimensions[0].scope == routine
assert ai_c.children[1].scope == routine
61 changes: 60 additions & 1 deletion loki/program_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from abc import abstractmethod

from loki.expression import Variable
from loki.expression import Variable, parse_expr
from loki.frontend import (
Frontend, parse_omni_source, parse_ofp_source, parse_fparser_source,
RegexParserClass, preprocess_cpp, sanitize_input
Expand Down Expand Up @@ -652,6 +652,65 @@ def symbol_map(self):
(s.name, s) for s in self.symbols
)

def get_symbol(self, name):
"""
Returns the symbol for a given name as defined in its declaration.

The returned symbol might include dimension symbols if it was
declared as an array.

Parameters
----------
name : str
Base name of the symbol to be retrieved
"""
return self.get_symbol_scope(name).variable_map.get(name)

def Variable(self, **kwargs):
"""
Factory method for :any:`TypedSymbol` or :any:`MetaSymbol` classes.

This invokes the :any:`Variable` with this node as the scope.

Parameters
----------
name : str
The name of the variable.
type : optional
The type of that symbol. Defaults to :any:`BasicType.DEFERRED`.
parent : :any:`Scalar` or :any:`Array`, optional
The derived type variable this variable belongs to.
dimensions : :any:`ArraySubscript`, optional
The array subscript expression.
"""
kwargs['scope'] = self
return Variable(**kwargs)

def parse_expr(self, expr_str, strict=False, evaluate=False, context=None):
"""
Uses :meth:`parse_expr` to convert expression(s) represented
in a string to Loki expression(s)/IR.

Parameters
----------
expr_str : str
The expression as a string
strict : bool, optional
Whether to raise exception for unknown variables/symbols when
evaluating an expression (default: `False`)
evaluate : bool, optional
Whether to evaluate the expression or not (default: `False`)
context : dict, optional
Symbol context, defining variables/symbols/procedures to help/support
evaluating an expression

Returns
-------
:any:`Expression`
The expression tree corresponding to the expression
"""
return parse_expr(expr_str, scope=self, strict=strict, evaluate=evaluate, context=context)

@property
def subroutines(self):
"""
Expand Down
Loading