Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
neNasko1 committed Dec 16, 2024
1 parent a9afad5 commit 2bbe13f
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 15 deletions.
7 changes: 4 additions & 3 deletions src/spox/_adapt.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# Copyright (c) QuantCo 2023-2024
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

import warnings
from typing import Optional

import onnx
import onnx.version_converter
Expand All @@ -22,7 +23,7 @@ def adapt_node(
source_version: int,
target_version: int,
var_names: dict[_VarInfo, str],
) -> Optional[list[onnx.NodeProto]]:
) -> list[onnx.NodeProto] | None:
if source_version == target_version:
return None

Expand Down Expand Up @@ -93,7 +94,7 @@ def adapt_best_effort(
opsets: dict[str, int],
var_names: dict[_VarInfo, str],
node_names: dict[Node, str],
) -> Optional[list[onnx.NodeProto]]:
) -> list[onnx.NodeProto] | None:
if isinstance(node, _Inline):
return adapt_inline(
node,
Expand Down
7 changes: 5 additions & 2 deletions src/spox/_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from dataclasses import Field, dataclass
from typing import Optional, get_type_hints

from . import _type_system
from ._attributes import Attr
from ._exceptions import InferenceWarning
from ._type_system import Optional as tOptional
from ._value_prop import PropDict, PropValue
from ._var import Var, _VarInfo

Expand Down Expand Up @@ -160,7 +160,10 @@ def _create_var(key: str, var_info: _VarInfo) -> Var:
if var_info.type is None or key not in prop_values:
return ret

if not isinstance(var_info.type, tOptional) and prop_values[key] is None:
if (
not isinstance(var_info.type, _type_system.Optional)
and prop_values[key] is None
):
return ret

prop = PropValue(var_info.type, prop_values[key])
Expand Down
8 changes: 5 additions & 3 deletions src/spox/_public.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@

"""Module implementing the main public interface functions in Spox."""

from __future__ import annotations

import contextlib
import itertools
from collections.abc import Iterator
from typing import Optional, Protocol
from typing import Protocol

import numpy as np
import onnx
Expand Down Expand Up @@ -51,8 +53,8 @@ def _temporary_renames(**kwargs: Var) -> Iterator[None]:
# The build code can't really special-case variable names that are
# not just ``Var._name``. So we set names here and reset them
# afterwards.
name: Optional[str]
pre: dict[Var, Optional[str]] = {}
name: str | None
pre: dict[Var, str | None] = {}
try:
for name, arg in kwargs.items():
pre[arg] = arg._var_info._name
Expand Down
8 changes: 5 additions & 3 deletions src/spox/_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@

"""Exposes information related to reference ONNX operator schemas, used by StandardOpNode."""

from __future__ import annotations

import itertools
from collections.abc import Iterable
from typing import Any, Callable, Optional, Protocol, TypeVar
from typing import Any, Callable, Protocol, TypeVar

from onnx.defs import OpSchema, get_all_schemas_with_history

Expand All @@ -30,8 +32,8 @@ def _key_groups(


def _current_schema(
schemas: Iterable[OpSchema], version: Optional[int] = None
) -> Optional[OpSchema]:
schemas: Iterable[OpSchema], version: int | None = None
) -> OpSchema | None:
"""
Find the schema for the current ``version`` from the list (the latest existing version).
If ``version`` is None (or left to default), the newest of the schemas is returned.
Expand Down
6 changes: 4 additions & 2 deletions src/spox/_traverse.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
# Copyright (c) QuantCo 2023-2024
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

from collections.abc import Iterable, Iterator
from typing import Callable, Optional, TypeVar
from typing import Callable, TypeVar

V = TypeVar("V")


def iterative_dfs(
sources: Iterable[V],
adj: Callable[[V], Iterable[V]],
post_callback: Optional[Callable[[V], None]] = None,
post_callback: Callable[[V], None] | None = None,
raise_on_cycle: bool = True,
) -> list[V]:
"""
Expand Down
4 changes: 2 additions & 2 deletions src/spox/_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) QuantCo 2023-2024
# SPDX-License-Identifier: BSD-3-Clause

from typing import Optional
from __future__ import annotations

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -42,7 +42,7 @@ def dtype_to_tensor_type(dtype_like: npt.DTypeLike) -> int:
raise TypeError(err_msg)


def from_array(arr: np.ndarray, name: Optional[str] = None) -> TensorProto:
def from_array(arr: np.ndarray, name: str | None = None) -> TensorProto:
"""Convert the given ``numpy.array`` into an ``onnx.TensorProto``.
As it may be useful to name the TensorProto (e.g. in
Expand Down

0 comments on commit 2bbe13f

Please sign in to comment.