diff --git a/src/spox/_adapt.py b/src/spox/_adapt.py index 8877e40..42aad5b 100644 --- a/src/spox/_adapt.py +++ b/src/spox/_adapt.py @@ -1,8 +1,9 @@ # Copyright (c) QuantCo 2023-2024 # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + import warnings -from typing import Optional import onnx import onnx.version_converter @@ -22,7 +23,7 @@ def adapt_node( source_version: int, target_version: int, var_names: dict[_VarInfo, str], -) -> Optional[list[onnx.NodeProto]]: +) -> list[onnx.NodeProto] | None: if source_version == target_version: return None @@ -93,7 +94,7 @@ def adapt_best_effort( opsets: dict[str, int], var_names: dict[_VarInfo, str], node_names: dict[Node, str], -) -> Optional[list[onnx.NodeProto]]: +) -> list[onnx.NodeProto] | None: if isinstance(node, _Inline): return adapt_inline( node, diff --git a/src/spox/_fields.py b/src/spox/_fields.py index 07c09ed..5cb77be 100644 --- a/src/spox/_fields.py +++ b/src/spox/_fields.py @@ -10,9 +10,9 @@ from dataclasses import Field, dataclass from typing import Optional, get_type_hints +from . import _type_system from ._attributes import Attr from ._exceptions import InferenceWarning -from ._type_system import Optional as tOptional from ._value_prop import PropDict, PropValue from ._var import Var, _VarInfo @@ -160,7 +160,10 @@ def _create_var(key: str, var_info: _VarInfo) -> Var: if var_info.type is None or key not in prop_values: return ret - if not isinstance(var_info.type, tOptional) and prop_values[key] is None: + if ( + not isinstance(var_info.type, _type_system.Optional) + and prop_values[key] is None + ): return ret prop = PropValue(var_info.type, prop_values[key]) diff --git a/src/spox/_public.py b/src/spox/_public.py index e12240e..369a67f 100644 --- a/src/spox/_public.py +++ b/src/spox/_public.py @@ -3,10 +3,12 @@ """Module implementing the main public interface functions in Spox.""" +from __future__ import annotations + import contextlib import itertools from collections.abc import Iterator -from typing import Optional, Protocol +from typing import Protocol import numpy as np import onnx @@ -51,8 +53,8 @@ def _temporary_renames(**kwargs: Var) -> Iterator[None]: # The build code can't really special-case variable names that are # not just ``Var._name``. So we set names here and reset them # afterwards. - name: Optional[str] - pre: dict[Var, Optional[str]] = {} + name: str | None + pre: dict[Var, str | None] = {} try: for name, arg in kwargs.items(): pre[arg] = arg._var_info._name diff --git a/src/spox/_schemas.py b/src/spox/_schemas.py index 7486dcc..4eb51c2 100644 --- a/src/spox/_schemas.py +++ b/src/spox/_schemas.py @@ -3,9 +3,11 @@ """Exposes information related to reference ONNX operator schemas, used by StandardOpNode.""" +from __future__ import annotations + import itertools from collections.abc import Iterable -from typing import Any, Callable, Optional, Protocol, TypeVar +from typing import Any, Callable, Protocol, TypeVar from onnx.defs import OpSchema, get_all_schemas_with_history @@ -30,8 +32,8 @@ def _key_groups( def _current_schema( - schemas: Iterable[OpSchema], version: Optional[int] = None -) -> Optional[OpSchema]: + schemas: Iterable[OpSchema], version: int | None = None +) -> OpSchema | None: """ Find the schema for the current ``version`` from the list (the latest existing version). If ``version`` is None (or left to default), the newest of the schemas is returned. diff --git a/src/spox/_traverse.py b/src/spox/_traverse.py index 6b64189..4a32775 100644 --- a/src/spox/_traverse.py +++ b/src/spox/_traverse.py @@ -1,8 +1,10 @@ # Copyright (c) QuantCo 2023-2024 # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + from collections.abc import Iterable, Iterator -from typing import Callable, Optional, TypeVar +from typing import Callable, TypeVar V = TypeVar("V") @@ -10,7 +12,7 @@ def iterative_dfs( sources: Iterable[V], adj: Callable[[V], Iterable[V]], - post_callback: Optional[Callable[[V], None]] = None, + post_callback: Callable[[V], None] | None = None, raise_on_cycle: bool = True, ) -> list[V]: """ diff --git a/src/spox/_utils.py b/src/spox/_utils.py index af6ed1a..7b00de4 100644 --- a/src/spox/_utils.py +++ b/src/spox/_utils.py @@ -1,7 +1,7 @@ # Copyright (c) QuantCo 2023-2024 # SPDX-License-Identifier: BSD-3-Clause -from typing import Optional +from __future__ import annotations import numpy as np import numpy.typing as npt @@ -42,7 +42,7 @@ def dtype_to_tensor_type(dtype_like: npt.DTypeLike) -> int: raise TypeError(err_msg) -def from_array(arr: np.ndarray, name: Optional[str] = None) -> TensorProto: +def from_array(arr: np.ndarray, name: str | None = None) -> TensorProto: """Convert the given ``numpy.array`` into an ``onnx.TensorProto``. As it may be useful to name the TensorProto (e.g. in