diff --git a/loki/ir/nodes.py b/loki/ir/nodes.py index 6599df90b..56ae34e7b 100644 --- a/loki/ir/nodes.py +++ b/loki/ir/nodes.py @@ -879,6 +879,23 @@ class CallStatement(LeafNode, _CallStatementBase): _traversable = ['name', 'arguments', 'kwarguments'] + @model_validator(mode='before') + @classmethod + def pre_init(cls, values): + # Ensure non-nested tuples for arguments + if 'arguments' in values.kwargs: + values.kwargs['arguments'] = _sanitize_tuple(values.kwargs['arguments']) + else: + values.kwargs['arguments'] = () + # Ensure two-level nested tuples for kwarguments + if 'kwarguments' in values.kwargs: + kwarguments = as_tuple(values.kwargs['kwarguments']) + kwarguments = tuple(_sanitize_tuple(pair) for pair in kwarguments) + values.kwargs['kwarguments'] = kwarguments + else: + values.kwargs['kwarguments'] = () + return values + def __post_init__(self): super().__post_init__() assert isinstance(self.arguments, tuple) diff --git a/loki/ir/tests/test_ir_nodes.py b/loki/ir/tests/test_ir_nodes.py index cb17ea943..97a410c6b 100644 --- a/loki/ir/tests/test_ir_nodes.py +++ b/loki/ir/tests/test_ir_nodes.py @@ -196,3 +196,51 @@ def test_section(scope, one, i, n, a_n, a_i): assert sec.body == (assign, func, assign, assign) sec.insert(pos=3, node=func) assert sec.body == (assign, func, assign, func, assign) + + +def test_callstatement(scope, one, i, n, a_i): + """ Test constructor of :any:`CallStatement` nodes. """ + + cname = sym.ProcedureSymbol(name='test', scope=scope) + call = ir.CallStatement( + name=cname, arguments=(n, a_i), kwarguments=(('i', i), ('j', one)) + ) + assert isinstance(call.name, Expression) + assert isinstance(call.arguments, tuple) + assert all(isinstance(e, Expression) for e in call.arguments) + assert isinstance(call.kwarguments, tuple) + assert all(isinstance(e, tuple) for e in call.kwarguments) + assert all( + isinstance(k, str) and isinstance(v, Expression) + for k, v in call.kwarguments + ) + + # Ensure "frozen" status of node objects + with pytest.raises(FrozenInstanceError) as error: + call.name = sym.ProcedureSymbol('dave', scope=scope) + with pytest.raises(FrozenInstanceError) as error: + call.arguments = (a_i, n, one) + with pytest.raises(FrozenInstanceError) as error: + call.kwarguments = (('i', one), ('j', i)) + + # Test auto-casting of the body to tuple + call = ir.CallStatement(name=cname, arguments=[a_i, one]) + assert call.arguments == (a_i, one) and call.kwarguments == () + call = ir.CallStatement(name=cname, arguments=None) + assert call.arguments == () and call.kwarguments == () + call = ir.CallStatement(name=cname, kwarguments=[('i', i), ('j', one)]) + assert call.arguments == () and call.kwarguments == (('i', i), ('j', one)) + call = ir.CallStatement(name=cname, kwarguments=None) + assert call.arguments == () and call.kwarguments == () + + # Test errors for wrong contructor usage + with pytest.raises(ValidationError) as error: + ir.CallStatement(name='a', arguments=(sym.Literal(42.0),)) + with pytest.raises(ValidationError) as error: + ir.CallStatement(name=cname, arguments=('a',)) + with pytest.raises(ValidationError) as error: + ir.Assignment( + name=cname, arguments=(sym.Literal(42.0),), kwarguments=('i', 'i') + ) + + # TODO: Test pragmas, active and chevron