Skip to content

Commit

Permalink
Test and fix constructor arg conflict in FindExpressionRoot
Browse files Browse the repository at this point in the history
  • Loading branch information
reuterbal committed Aug 2, 2023
1 parent f85297e commit ca8c557
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 1 deletion.
5 changes: 4 additions & 1 deletion loki/expression/expr_visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,10 @@ class FindExpressionRoot(ExpressionFinder):
"""
def __init__(self, expr, recurse_to_parent=True, **kwargs):
self._retriever = ExpressionRetriever(lambda e: e is expr, recurse_to_parent=recurse_to_parent)
super().__init__(unique=False, retrieve=lambda e: e if self._retriever.retrieve(e) else (), **kwargs)
if kwargs.get('unique'):
raise ValueError('FindExpressionRoot requires unique=False')
kwargs['unique'] = False
super().__init__(retrieve=lambda e: e if self._retriever.retrieve(e) else (), **kwargs)


class SubstituteExpressions(Transformer):
Expand Down
24 changes: 24 additions & 0 deletions tests/test_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,30 @@ def test_find_expression_root(frontend):
assert cast_root[0].source.lines == (13, 13)


@pytest.mark.parametrize('frontend', available_frontends())
def test_find_expression_root_constructor_args(frontend):
"""
Test correct handling for various constructor arguments
"""
fcode = """
subroutine my_routine
implicit none
integer :: i
i = 1 + 1
end subroutine my_routine
""".strip()

routine = Subroutine.from_source(fcode, frontend=frontend)
exprs = FindExpressions().visit(routine.body)
some_expr = [expr for expr in exprs if isinstance(expr, IntLiteral)][0]

with pytest.raises(ValueError):
FindExpressionRoot(some_expr, unique=True).visit(routine.body)

expr_root = FindExpressionRoot(some_expr, unique=False).visit(routine.body)
assert expr_root == (routine.body.body[0].rhs,)


@pytest.mark.parametrize('frontend', available_frontends())
def test_find_variables_associates(frontend):
"""
Expand Down

0 comments on commit ca8c557

Please sign in to comment.