diff --git a/loki/ir/nodes.py b/loki/ir/nodes.py index 37c54194f..78a24e01f 100644 --- a/loki/ir/nodes.py +++ b/loki/ir/nodes.py @@ -11,6 +11,7 @@ :ref:`internal_representation:Control flow tree` """ +from abc import abstractmethod from collections import OrderedDict from dataclasses import dataclass from functools import partial @@ -22,10 +23,11 @@ from pydantic.dataclasses import dataclass as dataclass_validated from pydantic import model_validator +from loki.expression import Variable, parse_expr +from loki.frontend.source import Source from loki.scope import Scope from loki.tools import flatten, as_tuple, is_iterable, truncate_string, CaseInsensitiveDict from loki.types import DataType, BasicType, DerivedType, SymbolAttributes -from loki.frontend.source import Source __all__ = [ @@ -329,6 +331,80 @@ def __setstate__(self, s): symbol_attrs = s.pop('symbol_attrs', None) self._update(**s, symbol_attrs=symbol_attrs, rescope_symbols=True) + @property + @abstractmethod + def variables(self): + """ + Return the variables defined in this :any:`ScopedNode`. + """ + + @property + def variable_map(self): + """ + Map of variable names to :any:`Variable` objects + """ + return CaseInsensitiveDict((v.name, v) for v in self.variables) + + def get_symbol(self, name): + """ + Returns the symbol for a given name as defined in its declaration. + + The returned symbol might include dimension symbols if it was + declared as an array. + + Parameters + ---------- + name : str + Base name of the symbol to be retrieved + """ + return self.get_symbol_scope(name).variable_map.get(name) + + def Variable(self, **kwargs): + """ + Factory method for :any:`TypedSymbol` or :any:`MetaSymbol` classes. + + This invokes the :any:`Variable` with this node as the scope. + + Parameters + ---------- + name : str + The name of the variable. + type : optional + The type of that symbol. Defaults to :any:`BasicType.DEFERRED`. + parent : :any:`Scalar` or :any:`Array`, optional + The derived type variable this variable belongs to. + dimensions : :any:`ArraySubscript`, optional + The array subscript expression. + """ + kwargs['scope'] = self + return Variable(**kwargs) + + def parse_expr(self, expr_str, strict=False, evaluate=False, context=None): + """ + Uses :meth:`parse_expr` to convert expression(s) represented + in a string to Loki expression(s)/IR. + + Parameters + ---------- + expr_str : str + The expression as a string + strict : bool, optional + Whether to raise exception for unknown variables/symbols when + evaluating an expression (default: `False`) + evaluate : bool, optional + Whether to evaluate the expression or not (default: `False`) + context : dict, optional + Symbol context, defining variables/symbols/procedures to help/support + evaluating an expression + + Returns + ------- + :any:`Expression` + The expression tree corresponding to the expression + """ + return parse_expr(expr_str, scope=self, strict=strict, evaluate=evaluate, context=context) + + # Intermediate node types @@ -1579,10 +1655,6 @@ def comments(self): def variables(self): return tuple(flatten([decl.symbols for decl in self.declarations])) - @property - def variable_map(self): - return CaseInsensitiveDict((s.name, s) for s in self.variables) - @property def imported_symbols(self): """ diff --git a/loki/ir/tests/test_scoped_nodes.py b/loki/ir/tests/test_scoped_nodes.py new file mode 100644 index 000000000..d3bca5a4e --- /dev/null +++ b/loki/ir/tests/test_scoped_nodes.py @@ -0,0 +1,198 @@ +# (C) Copyright 2018- ECMWF. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +# A set of tests for the symbol accessort and management API built into `ScopedNode`. + +import pytest + +from loki import Module +from loki.frontend import available_frontends +from loki.ir import nodes as ir, FindNodes +from loki.expression import symbols as sym +from loki.types import BasicType + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_scoped_node_get_symbols(frontend, tmp_path): + """ Test :method:`get_symbol` functionality on scoped nodes. """ + fcode = """ +module test_scoped_node_symbols_mod +implicit none +integer, parameter :: jprb = 8 + +contains + subroutine test_scoped_node_symbols(n, a, b, c) + integer, intent(in) :: n + real(kind=jprb), intent(inout) :: a(n), b(n), c + integer :: i + + a(1) = 42.0_jprb + + associate(d => a) + do i=1, n + b(i) = a(i) + c + end do + end associate + end subroutine test_scoped_node_symbols +end module test_scoped_node_symbols_mod +""" + module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path]) + routine = module['test_scoped_node_symbols'] + associate = FindNodes(ir.Associate).visit(routine.body)[0] + + # Check symbol lookup from subroutine + assert routine.get_symbol('a') == 'a(n)' + assert routine.get_symbol('a').scope == routine + assert routine.get_symbol('b') == 'b(n)' + assert routine.get_symbol('b').scope == routine + assert routine.get_symbol('c') == 'c' + assert routine.get_symbol('c').scope == routine + assert routine.get_symbol('jprb') == 'jprb' + assert routine.get_symbol('jprb').scope == module + assert routine.get_symbol('jprb').initial == 8 + + # Check passthrough from the Associate (ScopedNode) + assert associate.get_symbol('a') == 'a(n)' + assert associate.get_symbol('a').scope == routine + assert associate.get_symbol('b') == 'b(n)' + assert associate.get_symbol('b').scope == routine + assert associate.get_symbol('c') == 'c' + assert associate.get_symbol('c').scope == routine + assert associate.get_symbol('d') == 'd' + assert associate.get_symbol('d').scope == associate + assert associate.get_symbol('jprb') == 'jprb' + assert associate.get_symbol('jprb').scope == module + assert associate.get_symbol('jprb').initial == 8 + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_scoped_node_variable_constructor(frontend, tmp_path): + """ Test :any:`Variable` constrcutore on scoped nodes. """ + fcode = """ +module test_scoped_nodes_mod +implicit none +integer, parameter :: jprb = 8 + +contains + subroutine test_scoped_nodes(n, a, b, c) + integer, intent(in) :: n + real(kind=jprb), intent(inout) :: a(n), b(n), c + integer :: i + + a(1) = 42.0_jprb + + associate(d => a) + do i=1, n + b(i) = a(i) + c + end do + end associate + end subroutine test_scoped_nodes +end module test_scoped_nodes_mod +""" + module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path]) + routine = module['test_scoped_nodes'] + associate = FindNodes(ir.Associate).visit(routine.body)[0] + + # Build some symbiols and check their type + c = routine.Variable(name='c') + assert c.type.dtype == BasicType.REAL + assert c.type.kind in ('jprb', 8) + i = routine.Variable(name='i') + assert i.type.dtype == BasicType.INTEGER + jprb = routine.Variable(name='jprb') + assert jprb.type.dtype == BasicType.INTEGER + assert jprb.type.initial == 8 + + a_i = routine.Variable(name='a', dimensions=(i,)) + assert a_i == 'a(i)' + assert isinstance(a_i, sym.Array) + assert a_i.dimensions == (i,) + assert a_i.type.dtype == BasicType.REAL + assert a_i.type.kind in ('jprb', 8) + + # Build another, but from the associate node + b_i = associate.Variable(name='b', dimensions=(i,)) + assert b_i == 'b(i)' + assert isinstance(b_i, sym.Array) + assert b_i.dimensions == (i,) + assert b_i.type.dtype == BasicType.REAL + assert b_i.type.kind in ('jprb', 8) + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_scoped_node_parse_expr(frontend, tmp_path): + """ Test :any:`Variable` constrcutore on scoped nodes. """ + fcode = """ +module test_scoped_nodes_mod +implicit none +integer, parameter :: jprb = 8 + +contains + subroutine test_scoped_nodes(n, a, b, c) + integer, intent(in) :: n + real(kind=jprb), intent(inout) :: a(n), b(n), c + integer :: i + + a(1) = 42.0_jprb + + associate(d => a) + do i=1, n + b(i) = a(i) + c + end do + end associate + end subroutine test_scoped_nodes +end module test_scoped_nodes_mod +""" + module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path]) + routine = module['test_scoped_nodes'] + associate = FindNodes(ir.Associate).visit(routine.body)[0] + + assigns = FindNodes(ir.Assignment).visit(routine.body) + assert assigns[0].lhs == 'a(1)' + assert assigns[1].rhs == 'a(i) + c' + + # Check that all variables are identified + ai_c = routine.parse_expr('a(i) + c') + assert ai_c == assigns[1].rhs + assert isinstance(ai_c, sym.Sum) + assert isinstance(ai_c.children[0], sym.Array) + assert isinstance(ai_c.children[0].dimensions[0], sym.Scalar) + assert isinstance(ai_c.children[1], sym.Scalar) + assert ai_c.children[0].scope == routine + assert ai_c.children[0].dimensions[0].scope == routine + assert ai_c.children[1].scope == routine + + # Check that k is deferred + ai_k = routine.parse_expr('a(i) + k') + assert isinstance(ai_k, sym.Sum) + assert isinstance(ai_k.children[0], sym.Array) + assert isinstance(ai_k.children[0].dimensions[0], sym.Scalar) + assert isinstance(ai_k.children[1], sym.DeferredTypeSymbol) + assert ai_c.children[0].scope == routine + assert ai_c.children[0].dimensions[0].scope == routine + assert ai_c.children[1].scope == routine + + # Check that all variables are identified + ai_c = associate.parse_expr('a(i) + c') + assert ai_c == assigns[1].rhs + assert isinstance(ai_c, sym.Sum) + assert isinstance(ai_c.children[0], sym.Array) + assert isinstance(ai_c.children[0].dimensions[0], sym.Scalar) + assert isinstance(ai_c.children[1], sym.Scalar) + assert ai_c.children[0].scope == routine + assert ai_c.children[0].dimensions[0].scope == routine + assert ai_c.children[1].scope == routine + + # Check that k is deferred + ai_k = associate.parse_expr('a(i) + k') + assert isinstance(ai_k, sym.Sum) + assert isinstance(ai_k.children[0], sym.Array) + assert isinstance(ai_k.children[0].dimensions[0], sym.Scalar) + assert isinstance(ai_k.children[1], sym.DeferredTypeSymbol) + assert ai_c.children[0].scope == routine + assert ai_c.children[0].dimensions[0].scope == routine + assert ai_c.children[1].scope == routine diff --git a/loki/program_unit.py b/loki/program_unit.py index 31ad16f41..5c503811c 100644 --- a/loki/program_unit.py +++ b/loki/program_unit.py @@ -7,7 +7,7 @@ from abc import abstractmethod -from loki.expression import Variable +from loki.expression import Variable, parse_expr from loki.frontend import ( Frontend, parse_omni_source, parse_ofp_source, parse_fparser_source, RegexParserClass, preprocess_cpp, sanitize_input @@ -652,6 +652,65 @@ def symbol_map(self): (s.name, s) for s in self.symbols ) + def get_symbol(self, name): + """ + Returns the symbol for a given name as defined in its declaration. + + The returned symbol might include dimension symbols if it was + declared as an array. + + Parameters + ---------- + name : str + Base name of the symbol to be retrieved + """ + return self.get_symbol_scope(name).variable_map.get(name) + + def Variable(self, **kwargs): + """ + Factory method for :any:`TypedSymbol` or :any:`MetaSymbol` classes. + + This invokes the :any:`Variable` with this node as the scope. + + Parameters + ---------- + name : str + The name of the variable. + type : optional + The type of that symbol. Defaults to :any:`BasicType.DEFERRED`. + parent : :any:`Scalar` or :any:`Array`, optional + The derived type variable this variable belongs to. + dimensions : :any:`ArraySubscript`, optional + The array subscript expression. + """ + kwargs['scope'] = self + return Variable(**kwargs) + + def parse_expr(self, expr_str, strict=False, evaluate=False, context=None): + """ + Uses :meth:`parse_expr` to convert expression(s) represented + in a string to Loki expression(s)/IR. + + Parameters + ---------- + expr_str : str + The expression as a string + strict : bool, optional + Whether to raise exception for unknown variables/symbols when + evaluating an expression (default: `False`) + evaluate : bool, optional + Whether to evaluate the expression or not (default: `False`) + context : dict, optional + Symbol context, defining variables/symbols/procedures to help/support + evaluating an expression + + Returns + ------- + :any:`Expression` + The expression tree corresponding to the expression + """ + return parse_expr(expr_str, scope=self, strict=strict, evaluate=evaluate, context=context) + @property def subroutines(self): """