From 9d3f5eb9ed4628de4b48bf11862df25774018363 Mon Sep 17 00:00:00 2001 From: Daniel Smith Date: Tue, 9 Jul 2024 22:55:52 -0400 Subject: [PATCH] MyPy fixes --- opt_einsum/backends/dispatch.py | 2 +- opt_einsum/backends/jax.py | 4 ++-- opt_einsum/backends/object_arrays.py | 2 +- opt_einsum/backends/tensorflow.py | 2 +- opt_einsum/backends/theano.py | 2 +- opt_einsum/backends/torch.py | 2 +- opt_einsum/parser.py | 5 ++--- opt_einsum/tests/test_backends.py | 4 ++-- opt_einsum/tests/test_input.py | 4 ++-- opt_einsum/tests/test_sharing.py | 2 +- pyproject.toml | 2 +- 11 files changed, 15 insertions(+), 16 deletions(-) diff --git a/opt_einsum/backends/dispatch.py b/opt_einsum/backends/dispatch.py index e2a1b42..dd71642 100644 --- a/opt_einsum/backends/dispatch.py +++ b/opt_einsum/backends/dispatch.py @@ -59,7 +59,7 @@ def _import_func(func: str, backend: str, default: Any = None) -> Any: } try: - import numpy as np + import numpy as np # type: ignore _cached_funcs[("tensordot", "numpy")] = np.tensordot _cached_funcs[("transpose", "numpy")] = np.transpose diff --git a/opt_einsum/backends/jax.py b/opt_einsum/backends/jax.py index c76c592..45aa4b2 100644 --- a/opt_einsum/backends/jax.py +++ b/opt_einsum/backends/jax.py @@ -10,7 +10,7 @@ def _get_jax_and_to_jax(): global _JAX if _JAX is None: - import jax + import jax # type: ignore @to_backend_cache_wrap @jax.jit @@ -29,7 +29,7 @@ def build_expression(_, expr): # pragma: no cover jax_expr = jax.jit(expr._contract) def jax_contract(*arrays): - import numpy as np + import numpy as np # type: ignore return np.asarray(jax_expr(arrays)) diff --git a/opt_einsum/backends/object_arrays.py b/opt_einsum/backends/object_arrays.py index f34a78e..8954a18 100644 --- a/opt_einsum/backends/object_arrays.py +++ b/opt_einsum/backends/object_arrays.py @@ -27,7 +27,7 @@ def object_einsum(eq: str, *arrays: ArrayType) -> ArrayType: out : numpy.ndarray The output tensor, with ``dtype=object``. """ - import numpy as np + import numpy as np # type: ignore # when called by ``opt_einsum`` we will always be given a full eq lhs, output = eq.split("->") diff --git a/opt_einsum/backends/tensorflow.py b/opt_einsum/backends/tensorflow.py index 3a86dad..a6544f8 100644 --- a/opt_einsum/backends/tensorflow.py +++ b/opt_einsum/backends/tensorflow.py @@ -12,7 +12,7 @@ def _get_tensorflow_and_device(): global _CACHED_TF_DEVICE if _CACHED_TF_DEVICE is None: - import tensorflow as tf + import tensorflow as tf # type: ignore try: eager = tf.executing_eagerly() diff --git a/opt_einsum/backends/theano.py b/opt_einsum/backends/theano.py index 86a3391..5b54aab 100644 --- a/opt_einsum/backends/theano.py +++ b/opt_einsum/backends/theano.py @@ -9,7 +9,7 @@ @to_backend_cache_wrap(constants=True) def to_theano(array, constant=False): """Convert a numpy array to ``theano.tensor.TensorType`` instance.""" - import theano + import theano # type: ignore if has_array_interface(array): if constant: diff --git a/opt_einsum/backends/torch.py b/opt_einsum/backends/torch.py index b5c641d..f7f01f2 100644 --- a/opt_einsum/backends/torch.py +++ b/opt_einsum/backends/torch.py @@ -24,7 +24,7 @@ def _get_torch_and_device(): global _TORCH_HAS_TENSORDOT if _TORCH_DEVICE is None: - import torch + import torch # type: ignore device = "cuda" if torch.cuda.is_available() else "cpu" _TORCH_DEVICE = torch, device diff --git a/opt_einsum/parser.py b/opt_einsum/parser.py index bbab323..316a421 100644 --- a/opt_einsum/parser.py +++ b/opt_einsum/parser.py @@ -1,8 +1,7 @@ """A functionally equivalent parser of the numpy.einsum input parser.""" import itertools -from collections.abc import Sequence -from typing import Any, Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Sequence, Tuple from opt_einsum.typing import ArrayType, TensorShapeType @@ -219,7 +218,7 @@ def possibly_convert_to_numpy(x: Any) -> Any: """ if not hasattr(x, "shape"): try: - import numpy as np + import numpy as np # type: ignore except ModuleNotFoundError: raise ModuleNotFoundError( "numpy is required to convert non-array objects to arrays. This function will be deprecated in the future." diff --git a/opt_einsum/tests/test_backends.py b/opt_einsum/tests/test_backends.py index 9088335..8145fb1 100644 --- a/opt_einsum/tests/test_backends.py +++ b/opt_einsum/tests/test_backends.py @@ -9,10 +9,10 @@ try: # needed so tensorflow doesn't allocate all gpu mem try: - from tensorflow import ConfigProto + from tensorflow import ConfigProto # type: ignore from tensorflow import Session as TFSession except ImportError: - from tensorflow.compat.v1 import ConfigProto + from tensorflow.compat.v1 import ConfigProto # type: ignore from tensorflow.compat.v1 import Session as TFSession _TF_CONFIG = ConfigProto() _TF_CONFIG.gpu_options.allow_growth = True diff --git a/opt_einsum/tests/test_input.py b/opt_einsum/tests/test_input.py index e0fb21c..8ce9b0a 100644 --- a/opt_einsum/tests/test_input.py +++ b/opt_einsum/tests/test_input.py @@ -2,7 +2,7 @@ Tests the input parsing for opt_einsum. Duplicates the np.einsum input tests. """ -from typing import Any +from typing import Any, List import pytest @@ -12,7 +12,7 @@ np = pytest.importorskip("numpy") -def build_views(string: str) -> list[ArrayType]: +def build_views(string: str) -> List[ArrayType]: """Builds random numpy arrays for testing by using a fixed size dictionary and an input string.""" chars = "abcdefghij" diff --git a/opt_einsum/tests/test_sharing.py b/opt_einsum/tests/test_sharing.py index b90d4d8..8b0f1c0 100644 --- a/opt_einsum/tests/test_sharing.py +++ b/opt_einsum/tests/test_sharing.py @@ -30,7 +30,7 @@ cupy_if_found = pytest.param("cupy", marks=[pytest.mark.skip(reason="CuPy not installed.")]) # type: ignore try: - import torch # noqa + import torch # type: ignore # noqa torch_if_found = "torch" except ImportError: diff --git a/pyproject.toml b/pyproject.toml index 18819ca..0d185dd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,5 +88,5 @@ branch = true relative_files = true [[tool.mypy.overrides]] -module = "cupy.*, jax.*, theano.*, tensorflow.*, torch.*" +module = "cupy.*, jax.*, numpy.*, theano.*, tensorflow.*, torch.*" ignore_missing_imports = true \ No newline at end of file