From ca8c557839b7d07f68e86182c94ec7e7712f3974 Mon Sep 17 00:00:00 2001 From: Balthasar Reuter Date: Wed, 2 Aug 2023 14:32:23 +0100 Subject: [PATCH] Test and fix constructor arg conflict in FindExpressionRoot --- loki/expression/expr_visitors.py | 5 ++++- tests/test_visitor.py | 24 ++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/loki/expression/expr_visitors.py b/loki/expression/expr_visitors.py index fc3fcfde0..c19ff3ba2 100644 --- a/loki/expression/expr_visitors.py +++ b/loki/expression/expr_visitors.py @@ -256,7 +256,10 @@ class FindExpressionRoot(ExpressionFinder): """ def __init__(self, expr, recurse_to_parent=True, **kwargs): self._retriever = ExpressionRetriever(lambda e: e is expr, recurse_to_parent=recurse_to_parent) - super().__init__(unique=False, retrieve=lambda e: e if self._retriever.retrieve(e) else (), **kwargs) + if kwargs.get('unique'): + raise ValueError('FindExpressionRoot requires unique=False') + kwargs['unique'] = False + super().__init__(retrieve=lambda e: e if self._retriever.retrieve(e) else (), **kwargs) class SubstituteExpressions(Transformer): diff --git a/tests/test_visitor.py b/tests/test_visitor.py index 5165bd799..b8e90c6ec 100644 --- a/tests/test_visitor.py +++ b/tests/test_visitor.py @@ -391,6 +391,30 @@ def test_find_expression_root(frontend): assert cast_root[0].source.lines == (13, 13) +@pytest.mark.parametrize('frontend', available_frontends()) +def test_find_expression_root_constructor_args(frontend): + """ + Test correct handling for various constructor arguments + """ + fcode = """ +subroutine my_routine + implicit none + integer :: i + i = 1 + 1 +end subroutine my_routine + """.strip() + + routine = Subroutine.from_source(fcode, frontend=frontend) + exprs = FindExpressions().visit(routine.body) + some_expr = [expr for expr in exprs if isinstance(expr, IntLiteral)][0] + + with pytest.raises(ValueError): + FindExpressionRoot(some_expr, unique=True).visit(routine.body) + + expr_root = FindExpressionRoot(some_expr, unique=False).visit(routine.body) + assert expr_root == (routine.body.body[0].rhs,) + + @pytest.mark.parametrize('frontend', available_frontends()) def test_find_variables_associates(frontend): """