From 36b9535ff364c484d04b384555106731049f44cd Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Thu, 20 Jun 2024 20:35:38 +0800 Subject: [PATCH] [TVMScript] Better Type Annotation for TIR OP (#17107) Enable ParamType for TIR op, so that we can have better experience when writing TVMScript in Python with tools. However, ParamType is introduced in Python 3.10, so we only enable it when Python version is 3.10 or above. --- python/tvm/script/ir_builder/tir/ir.py | 32 ++++++++++++++++++++------ 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 8289ea96ae25..18abc0ca5d01 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -19,6 +19,7 @@ import functools import inspect from numbers import Integral +import sys from typing import Any, Callable, Dict, List, Optional, Tuple, Union # isort: off @@ -1764,14 +1765,31 @@ def f(): # pylint: disable=invalid-name -def _op_wrapper(func): - @functools.wraps(func) - def wrapped(*args, **kwargs): - if "dtype" in kwargs: - kwargs.pop("dtype") - return func(*args, **kwargs) +if sys.version_info >= (3, 10): + from typing import ParamSpec, TypeVar # pylint: disable=import-error - return wrapped + T = TypeVar("T") + P = ParamSpec("P") + + def _op_wrapper(func: Callable[P, T]) -> Callable[P, T]: + @functools.wraps(func) + def wrapped(*args, **kwargs) -> T: + if "dtype" in kwargs: + kwargs.pop("dtype") + return func(*args, **kwargs) + + return wrapped + +else: + + def _op_wrapper(func): + @functools.wraps(func) + def wrapped(*args, **kwargs): + if "dtype" in kwargs: + kwargs.pop("dtype") + return func(*args, **kwargs) + + return wrapped abs = _op_wrapper(_tir_op.abs) # pylint: disable=redefined-builtin