diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 8289ea96ae25..18abc0ca5d01 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -19,6 +19,7 @@ import functools import inspect from numbers import Integral +import sys from typing import Any, Callable, Dict, List, Optional, Tuple, Union # isort: off @@ -1764,14 +1765,31 @@ def f(): # pylint: disable=invalid-name -def _op_wrapper(func): - @functools.wraps(func) - def wrapped(*args, **kwargs): - if "dtype" in kwargs: - kwargs.pop("dtype") - return func(*args, **kwargs) +if sys.version_info >= (3, 10): + from typing import ParamSpec, TypeVar # pylint: disable=import-error - return wrapped + T = TypeVar("T") + P = ParamSpec("P") + + def _op_wrapper(func: Callable[P, T]) -> Callable[P, T]: + @functools.wraps(func) + def wrapped(*args, **kwargs) -> T: + if "dtype" in kwargs: + kwargs.pop("dtype") + return func(*args, **kwargs) + + return wrapped + +else: + + def _op_wrapper(func): + @functools.wraps(func) + def wrapped(*args, **kwargs): + if "dtype" in kwargs: + kwargs.pop("dtype") + return func(*args, **kwargs) + + return wrapped abs = _op_wrapper(_tir_op.abs) # pylint: disable=redefined-builtin