Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IR: Update to Pydantic >2.0 compatibility #349

Merged
merged 4 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion loki/backend/tests/test_cufgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def test_cufgen(frontend, tmp_path):
call_map = {}
for call in FindNodes(ir.CallStatement).visit(driver.body):
if "kernel" in str(call.name):
with pytest.raises(AssertionError):
with pytest.raises(ValidationError):
_ = call.clone(chevron=(sym.IntLiteral(1), sym.IntLiteral(1), sym.IntLiteral(1), sym.IntLiteral(1),
sym.IntLiteral(1)))
with pytest.raises(ValidationError):
Expand Down
4 changes: 3 additions & 1 deletion loki/frontend/fparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,9 @@ def visit_Intrinsic_Stmt(self, o, **kwargs):
"""
Universal routine to capture nodes as plain string in the IR
"""
return ir.Intrinsic(text=o.tostr(), label=kwargs.get('label'), source=kwargs.get('source'))
label = kwargs.get('label')
label = str(label) if label else label # Ensure srting labels
return ir.Intrinsic(text=o.tostr(), label=label, source=kwargs.get('source'))

#
# Base blocks
Expand Down
95 changes: 47 additions & 48 deletions loki/ir/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from dataclasses import dataclass
from functools import partial
from itertools import chain
from typing import Any, Tuple, Union
from typing import Any, Tuple, Union, Optional

from pymbolic.primitives import Expression

Expand Down Expand Up @@ -44,7 +44,6 @@

# Configuration for validation mechanism via pydantic
dataclass_validation_config = {
'validate_assignment': True,
'arbitrary_types_allowed': True,
}

Expand Down Expand Up @@ -79,8 +78,8 @@ class Node:

"""

source: Union[Source, str] = None
label: str = None
source: Optional[Union[Source, str]] = None
label: Optional[str] = None

_traversable = []

Expand Down Expand Up @@ -437,11 +436,11 @@ class _LoopBase():
variable: Expression
bounds: Expression
body: Tuple[Node, ...]
pragma: Tuple[Node, ...] = None
pragma_post: Tuple[Node, ...] = None
loop_label: Any = None
name: str = None
has_end_do: bool = True
pragma: Optional[Tuple[Node, ...]] = None
pragma_post: Optional[Tuple[Node, ...]] = None
loop_label: Optional[Any] = None
name: Optional[str] = None
has_end_do: Optional[bool] = True


@dataclass_strict(frozen=True)
Expand Down Expand Up @@ -498,13 +497,13 @@ def __repr__(self):
class _WhileLoopBase():
""" Type definitions for :any:`WhileLoop` node type. """

condition: Union[Expression, None]
condition: Optional[Expression]
body: Tuple[Node, ...]
pragma: Node = None
pragma_post: Node = None
loop_label: Any = None
name: str = None
has_end_do: bool = True
pragma: Optional[Node] = None
pragma_post: Optional[Node] = None
loop_label: Optional[Any] = None
name: Optional[str] = None
has_end_do: Optional[bool] = True


@dataclass_strict(frozen=True)
Expand Down Expand Up @@ -561,10 +560,10 @@ class _ConditionalBase():

condition: Expression
body: Tuple[Node, ...]
else_body: Tuple[Node, ...] = None
else_body: Optional[Tuple[Node, ...]] = None
inline: bool = False
has_elseif: bool = False
name: str = None
name: Optional[str] = None


@dataclass_strict(frozen=True)
Expand Down Expand Up @@ -669,7 +668,7 @@ class _InterfaceBase():

body: Tuple[Any, ...]
abstract: bool = False
spec: Union[Expression, str] = None
spec: Optional[Union[Expression, str]] = None


@dataclass_strict(frozen=True)
Expand Down Expand Up @@ -738,7 +737,7 @@ class _AssignmentBase():
lhs: Expression
rhs: Expression
ptr: bool = False
comment: Node = None
comment: Optional[Node] = None


@dataclass_strict(frozen=True)
Expand Down Expand Up @@ -776,10 +775,10 @@ def __repr__(self):
class _ConditionalAssignmentBase():
""" Type definitions for :any:`ConditionalAssignment` node type. """

lhs: Expression = None
condition: Expression = None
rhs: Expression = None
else_rhs: Expression = None
lhs: Optional[Expression] = None
condition: Optional[Expression] = None
rhs: Optional[Expression] = None
else_rhs: Optional[Expression] = None


@dataclass_strict(frozen=True)
Expand Down Expand Up @@ -821,11 +820,11 @@ class _CallStatementBase():
""" Type definitions for :any:`CallStatement` node type. """

name: Expression
arguments: Tuple[Expression, ...] = None
kwarguments: Tuple[Tuple[str, Expression], ...] = None
pragma: Tuple[Node, ...] = None
not_active: bool = None
chevron: Tuple[Expression, ...] = None
arguments: Optional[Tuple[Expression, ...]] = None
kwarguments: Optional[Tuple[Tuple[str, Expression], ...]] = None
pragma: Optional[Tuple[Node, ...]] = None
not_active: Optional[bool] = None
chevron: Optional[Tuple[Expression, ...]] = None


@dataclass_strict(frozen=True)
Expand Down Expand Up @@ -982,8 +981,8 @@ class _AllocationBase():
""" Type definitions for :any:`Allocation` node type. """

variables: Tuple[Expression, ...]
data_source: Expression = None
status_var: Expression = None
data_source: Optional[Expression] = None
status_var: Optional[Expression] = None


@dataclass_strict(frozen=True)
Expand Down Expand Up @@ -1021,7 +1020,7 @@ class _DeallocationBase():
""" Type definitions for :any:`Deallocation` node type. """

variables: Tuple[Expression, ...]
status_var: Expression = None
status_var: Optional[Expression] = None


@dataclass_strict(frozen=True)
Expand Down Expand Up @@ -1150,7 +1149,7 @@ class _PragmaBase():
""" Type definitions for :any:`Pragma` node type. """

keyword: str
content: str = None
content: Optional[str] = None


@dataclass_strict(frozen=True)
Expand Down Expand Up @@ -1210,13 +1209,13 @@ def __repr__(self):
class _ImportBase():
""" Type definitions for :any:`Import` node type. """

module: Union[str, None]
module: Optional[str]
symbols: Tuple[Expression, ...] = ()
nature: str = None
nature: Optional[str] = None
c_import: bool = False
f_include: bool = False
f_import: bool = False
rename_list: Tuple[Any, ...] = None
rename_list: Optional[Tuple[Any, ...]] = None


@dataclass_strict(frozen=True)
Expand Down Expand Up @@ -1274,9 +1273,9 @@ class _VariableDeclarationBase():
""" Type definitions for :any:`VariableDeclaration` node type. """

symbols: Tuple[Expression, ...]
dimensions: Tuple[Expression, ...] = None
comment: Node = None
pragma: Node = None
dimensions: Optional[Tuple[Expression, ...]] = None
comment: Optional[Node] = None
pragma: Optional[Node] = None


@dataclass_strict(frozen=True)
Expand Down Expand Up @@ -1325,13 +1324,13 @@ class _ProcedureDeclarationBase():
""" Type definitions for :any:`ProcedureDeclaration` node type. """

symbols: Tuple[Expression, ...]
interface: Union[Expression, DataType] = None
interface: Optional[Union[Expression, DataType]] = None
external: bool = False
module: bool = False
generic: bool = False
final: bool = False
comment: Node = None
pragma: Tuple[Node, ...] = None
comment: Optional[Node] = None
pragma: Optional[Tuple[Node, ...]] = None


@dataclass_strict(frozen=True)
Expand Down Expand Up @@ -1471,10 +1470,10 @@ def __repr__(self):
class _TypeDefBase():
""" Type definitions for :any:`TypeDef` node type. """

name: str = None
body: Tuple[Node, ...] = None
name: Optional[str] = None
body: Optional[Tuple[Node, ...]] = None
abstract: bool = False
extends: str = None
extends: Optional[str] = None
bind_c: bool = False
private: bool = False
public: bool = False
Expand Down Expand Up @@ -1624,7 +1623,7 @@ class _MultiConditionalBase():
values: Tuple[Any, ...]
bodies: Tuple[Any, ...]
else_body: Tuple[Node, ...]
name: str = None
name: Optional[str] = None


@dataclass_strict(frozen=True)
Expand Down Expand Up @@ -1670,8 +1669,8 @@ class _ForallBase():

named_bounds: Tuple[Tuple[Expression, Expression], ...]
body: Tuple[Node, ...]
mask: Expression = None
name: str = None
mask: Optional[Expression] = None
name: Optional[str] = None
inline: bool = False


Expand Down Expand Up @@ -1716,7 +1715,7 @@ class _MaskedStatementBase():

conditions: Tuple[Expression, ...]
bodies: Tuple[Tuple[Node, ...], ...]
default: Tuple[Node, ...] = None
default: Optional[Tuple[Node, ...]] = None
inline: bool = False


Expand Down
123 changes: 123 additions & 0 deletions loki/ir/tests/test_ir_nodes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import pytest

from dataclasses import FrozenInstanceError
from pymbolic.primitives import Expression
from pydantic import ValidationError

from loki.expression import symbols as sym, parse_expr
from loki.ir import nodes as ir
from loki.scope import Scope


@pytest.fixture(name='scope')
def fixture_scope():
return Scope()

@pytest.fixture(name='one')
def fixture_one():
return sym.Literal(1)

@pytest.fixture(name='i')
def fixture_i(scope):
return sym.Scalar('i', scope=scope)

@pytest.fixture(name='n')
def fixture_n(scope):
return sym.Scalar('n', scope=scope)

@pytest.fixture(name='a_i')
def fixture_a_i(scope, i):
return sym.Array('a', dimensions=(i,), scope=scope)


def test_assignment(scope, a_i):
"""
Test constructors of :any:`Assignment`.
"""
assign = ir.Assignment(lhs=a_i, rhs=sym.Literal(42.0))
assert isinstance(assign.lhs, Expression)
assert isinstance(assign.rhs, Expression)
assert assign.comment is None

# Ensure "frozen" status of node objects
with pytest.raises(FrozenInstanceError) as error:
assign.lhs = sym.Scalar('b', scope=scope)
with pytest.raises(FrozenInstanceError) as error:
assign.rhs = sym.Scalar('b', scope=scope)

# Test errors for wrong contructor usage
with pytest.raises(ValidationError) as error:
ir.Assignment(lhs='a', rhs=sym.Literal(42.0))
with pytest.raises(ValidationError) as error:
ir.Assignment(lhs=a_i, rhs='42.0 + 6.0')
with pytest.raises(ValidationError) as error:
ir.Assignment(lhs=a_i, rhs=sym.Literal(42.0), comment=a_i)


def test_loop(scope, one, i, n, a_i):
"""
Test constructors of :any:`Loop`.
"""
assign = ir.Assignment(lhs=a_i, rhs=sym.Literal(42.0))
bounds = sym.Range((one, n))

loop = ir.Loop(variable=i, bounds=bounds, body=(assign,))
assert isinstance(loop.variable, Expression)
assert isinstance(loop.bounds, Expression)
assert isinstance(loop.body, tuple)
assert all(isinstance(n, ir.Node) for n in loop.body)

# Ensure "frozen" status of node objects
with pytest.raises(FrozenInstanceError) as error:
loop.variable = sym.Scalar('j', scope=scope)
with pytest.raises(FrozenInstanceError) as error:
loop.bounds = sym.Range((n, sym.Scalar('k', scope=scope)))
with pytest.raises(FrozenInstanceError) as error:
loop.body = (assign, assign, assign)

# Test errors for wrong contructor usage
with pytest.raises(ValidationError) as error:
ir.Loop(variable=i, bounds=bounds, body=assign)
with pytest.raises(ValidationError) as error:
ir.Loop(variable=None, bounds=bounds, body=(assign,))
with pytest.raises(ValidationError) as error:
ir.Loop(variable=i, bounds=None, body=(assign,))

# TODO: Test pragmas, names and labels


def test_conditional(scope, one, i, n, a_i):
"""
Test constructors of :any:`Conditional`.
"""
assign = ir.Assignment(lhs=a_i, rhs=sym.Literal(42.0))
condition = parse_expr('i >= 2', scope=scope)

cond = ir.Conditional(
condition=condition, body=(assign,assign,), else_body=(assign,)
)
assert isinstance(cond.condition, Expression)
assert isinstance(cond.body, tuple) and len(cond.body) == 2
assert all(isinstance(n, ir.Node) for n in cond.body)
assert isinstance(cond.else_body, tuple) and len(cond.else_body) == 1
assert all(isinstance(n, ir.Node) for n in cond.else_body)

with pytest.raises(FrozenInstanceError) as error:
cond.condition = parse_expr('k == 0', scope=scope)
with pytest.raises(FrozenInstanceError) as error:
cond.body = (assign, assign, assign)
with pytest.raises(FrozenInstanceError) as error:
cond.else_body = (assign, assign, assign)

# Test errors for wrong contructor usage
with pytest.raises(ValidationError) as error:
ir.Conditional(condition=condition, body=assign)

# TODO: Test inline, name, has_elseif
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ dependencies = [
"coloredlogs", # optional for loki-build utility
"junit_xml", # optional for JunitXML output in loki-lint
"codetiming", # essential for scheduler and sourcefile timings
"pydantic<2.0", # type checking for IR nodes
"pydantic", # type checking for IR nodes
]

[project.optional-dependencies]
Expand Down
Loading