Skip to content

Commit

Permalink
[Unity][Parser] Check well-formedness in the parser (#16569)
Browse files Browse the repository at this point in the history
* Check well-formedness in the parser

* Correct packed funcs in NN frontend

* Support the check_well_formed optional argument to I.ir_module

* Also check well-formedness in TIR

* Enable normalization for individual Relax functions and PrimFuncs

* Use the error raised by the TIR well-formed checker for the message

* Fix tvmscript test failures

* Whitespace

* Fix errors in verify_well_formed test

* Include a more helpful error message

* Fix TIR test failures

* Address well-formed failures in test_tir_specialize

* Correct well-formedness error in test_tir_analysis_oob

* Correct further well-formedness failures

* Remove __tvm_meta__ from test case to avoid parsing error

* Avoid circular import in entryy.py

* Formatting fixes

* lint fix

* Add pylint exceptions

* Fix whitespace

* Fix more failed test cases

* Catch inappropriate use of decl_function instead of segfaulting

* Fix test_lower.py

* Mark purity in test_relax_2d_buffer_allocation.py

* Mark purity in test_dma_builtin.py

* Remove __tvm_meta___ from test_tir_usmp_analysis_extract_bufferinfo.py

* Suppress well-formed check in test_tir_transform_convert_blocks_to_opaque.py

* Remove __tvm_meta__ in test_tir_usmp_algo.py

* Remove __tvm_meta__ from more USMP tests

* Fix incorrect var in test_tir_transform_storage_flatten.py

* Remove all remaining instances of __tvm_meta__

* Fix purity error in test_dataflow_pattern.py

* Fix purity error in test_ast_printer

* Fix test_arith_domain_touched example

* Okay to set check_well_formed to True in test_tir_analysis_identify_mcmcpy

* Define variable in test_tir_analysis_oob

* Typo fix

* Add explanatory comment to test case

* Define the undefined vars in test_tir_transform_common_subexpr_elim

* Exception no longer necessary in test_tir_transform_inject_rolling_buffer

* Remove unnecessary check exemption in test_tir_transform_convert_ssa

* Avoid checking exemption in test_inject_ptx_ldg32

* Note special case in test_distributed_transform_propagate_sharding

* Exempt well-formed error in dlight/test_benchmark

* Exempt well-formedness errors in test_ethosu/, mostly uninitialized vars

* Whitespace

* Include non-CUDA GPUs in IsScheduledOnGPU

* Fix thread binding bug by changing thread binding var dtype

* Include overrides in test_runtime_builtin_paged_attention_kv_cache.py

* add exemptions in test_ethosu/test_replace_conv2d

* Add more ethosu exemptions

* More exemptions for ethosu tests

* Remove unused reference

* Indicate purity in test_transform_rewrite_cuda_graph

* Indicate purity in test_transform_normalize

* Reorder MergeSharedMemoryAllocations in GPU codegen

* Add target parameter for FP8StorageLegalize and FP8ComputeLegalize

* Don't re-import Target in tvm/tir/transform/transform.py
  • Loading branch information
slyubomirsky authored Mar 21, 2024
1 parent f9b38ab commit 6c701fe
Show file tree
Hide file tree
Showing 68 changed files with 603 additions and 389 deletions.
11 changes: 8 additions & 3 deletions python/tvm/relax/block_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,12 @@
class FunctionScope(object):
"""Auxiliary scope for function"""

def __init__(self, block_builder, name, params, attrs):
def __init__(self, block_builder, name, params, attrs, is_pure):
self._bb = block_builder
self._name = name
self._params = params
self._attrs = attrs
self._is_pure = is_pure

# Blocks that have been collected within the function
self._blocks = []
Expand Down Expand Up @@ -208,6 +209,7 @@ def function(
name: str,
params: Optional[Union[Var, Tuple, List[Var]]] = None,
attrs: Optional[Dict[str, Object]] = None,
pure: bool = True,
private: bool = False,
) -> FunctionScope:
"""Annotate a Relax function.
Expand All @@ -225,6 +227,9 @@ def function(
attrs : Dict[str, Object], optional
The function attrs
pure : bool, optional
Whether the function is annotated as pure.
private : bool, optional
Whether the function is annotated as private.
If the function is private, it will not have a global symbol attribute.
Expand Down Expand Up @@ -254,7 +259,7 @@ def function(
if not private:
attrs["global_symbol"] = name

return FunctionScope(self, name, params, attrs)
return FunctionScope(self, name, params, attrs, is_pure=pure)

def testing_scope(self, def_vars: List[tir.Var]) -> TestingScope:
"""Start a scope for unit-testing purposes.
Expand Down Expand Up @@ -640,7 +645,7 @@ def emit_func_output(

# do not specify ret_struct_info and let constructor deduce
# from seqe.struct_info
func = rx.Function(self._func._params, seqe)
func = rx.Function(self._func._params, seqe, is_pure=self._func._is_pure)
for key, value in self._func._attrs.items():
func = func.with_attr(key, value)
self.end_scope()
Expand Down
44 changes: 18 additions & 26 deletions python/tvm/relax/frontend/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing import List, Optional, Sequence, Union

from tvm import relax as rx
from tvm import tir, ir
from tvm import tir

from . import op
from .core import Effect, Module, ModuleList, Parameter, Tensor, get_default_dtype
Expand Down Expand Up @@ -599,15 +599,12 @@ def emit_init(self, name_hint: str, bb: rx.BlockBuilder): # pylint: disable=arg
init_shape = rx.ShapeExpr([self.init_seq_len] + self.unit_shape)
return [
bb.emit(
rx.Call(
ir.Op.get("relax.call_pure_packed"),
args=[
rx.extern("vm.builtin.attention_kv_cache_create"),
rx.op.zeros(init_shape, self.dtype),
init_shape,
rx.PrimValue(0),
],
sinfo_args=[rx.ObjectStructInfo()],
rx.op.call_pure_packed(
"vm.builtin.attention_kv_cache_create",
rx.op.zeros(init_shape, self.dtype),
init_shape,
rx.PrimValue(0),
sinfo_args=rx.ObjectStructInfo(),
),
name_hint=name_hint,
)
Expand Down Expand Up @@ -675,14 +672,11 @@ def view(self, seq_len: tir.Var) -> Tensor:
shape = rx.ShapeExpr([seq_len] + self.unit_shape)
return Tensor(
_expr=rx.BlockBuilder.current().emit(
rx.Call(
ir.Op.get("relax.call_pure_packed"),
args=[
rx.extern("vm.builtin.attention_kv_cache_view"),
self.cache,
shape,
],
sinfo_args=[rx.TensorStructInfo(shape, self.dtype)],
rx.op.call_pure_packed(
"vm.builtin.attention_kv_cache_view",
self.cache,
shape,
sinfo_args=rx.TensorStructInfo(shape, self.dtype),
)
)
)
Expand All @@ -702,14 +696,12 @@ def append(self, new_element: Tensor) -> None:
f'but got "{new_element.dtype}"'
)
self.cache = rx.BlockBuilder.current().emit(
rx.Call(
ir.Op.get("relax.call_pure_packed"),
args=[
rx.extern("vm.builtin.attention_kv_cache_append"),
self.cache,
new_element._expr,
],
sinfo_args=[rx.ObjectStructInfo()],
rx.op.call_inplace_packed(
"vm.builtin.attention_kv_cache_append",
self.cache,
new_element._expr,
inplace_indices=[0],
sinfo_args=rx.ObjectStructInfo(),
)
)

Expand Down
8 changes: 6 additions & 2 deletions python/tvm/script/ir_builder/ir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def decl_function(func_name: str, func_signature: BaseFunc) -> GlobalVar:
func_name : str
The function unique name.
func_signature: Optional[BaseFunc]
func_signature: BaseFunc
A Function w/o body, which used to specify the function signature
(i.e. func params and func return type/shape).
Expand All @@ -55,7 +55,11 @@ def decl_function(func_name: str, func_signature: BaseFunc) -> GlobalVar:
gv : GlobalVar
The corresponding GlobalVar.
"""

if not isinstance(func_signature, BaseFunc):
raise ValueError(
"decl_function expects an instance of BaseFunc, "
f"but {func_signature} is of type {type(func_signature)}"
)
return _ffi_api.DeclFunction( # type: ignore[attr-defined] # pylint: disable=no-member
func_name, func_signature
)
Expand Down
40 changes: 38 additions & 2 deletions python/tvm/script/parser/core/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,20 @@
import inspect
from typing import Any, Dict, Union

from ....ir.module import IRModule
from ...ir_builder import IRBuilder
from . import doc
from .diagnostics import Source
from .error import ParserError
from .parser import Parser

WELL_FORMED_ERROR_MESSAGE = (
"Program is not well-formed. If this is deliberate, consider "
"setting check_well_formed in the top-level decorator to False "
"(e.g., @I.ir_module(check_well_formed=False) or "
"@R.function(check_well_formed=False))."
)


def _default_globals() -> Dict[str, Any]:
import tvm # pylint: disable=import-outside-toplevel
Expand All @@ -43,7 +51,11 @@ def scan_macro(program: Union[Any, str], extra_vars: Dict[str, Any] = None) -> A
return source, closure_vars


def parse(program: Union[doc.AST, Any, str], extra_vars: Dict[str, Any] = None) -> Any:
def parse(
program: Union[doc.AST, Any, str],
extra_vars: Dict[str, Any] = None,
check_well_formed: bool = True,
) -> Any:
"""Register a method for a operand type, AST operator node and operand index.
Parameters
Expand All @@ -54,6 +66,9 @@ def parse(program: Union[doc.AST, Any, str], extra_vars: Dict[str, Any] = None)
extra_vars : Dict[str, Any]
The extra variable table for parsing.
check_well_formed : bool
Whether to check well-formedness after parsing.
Returns
-------
func : Any
Expand All @@ -77,4 +92,25 @@ def parse(program: Union[doc.AST, Any, str], extra_vars: Dict[str, Any] = None)
parser.parse(extra_vars=extra_vars)
except ParserError as err:
parser.report_error(err.node, err.args[0])
return builder.get()
ret = builder.get()
# check well-formedness in both Relax and TIR
if check_well_formed:
# (C0415 = import-outside-toplevel. It is necessary here to avoid a circular dependency,
# since importing Relax imports a dependency on the parser)
from ....relax.analysis import well_formed as relax_well_formed # pylint: disable=C0415
from ....tir.analysis import verify_well_formed as tir_well_formed # pylint: disable=C0415

check_ret = ret
if not isinstance(check_ret, IRModule):
check_ret = IRModule.from_expr(ret)
source_ast = source.as_ast()
if not relax_well_formed(check_ret):
parser.report_error(source_ast, err=WELL_FORMED_ERROR_MESSAGE)
try:
tir_well_formed(check_ret)
except Exception as err: # pylint: disable=broad-exception-caught
parser.report_error(
source_ast,
err=f"{WELL_FORMED_ERROR_MESSAGE}\n\nTraceback: {str(err)}",
)
return ret
30 changes: 23 additions & 7 deletions python/tvm/script/parser/ir/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,32 +17,48 @@
"""The entry point of TVM parser for ir module."""

import inspect
from typing import Type
from typing import Optional, Type

from tvm.ir import IRModule

from .._core import parse, utils


def ir_module(mod: Type) -> IRModule:
# this formulation allows us to support having @I.ir_module
# appear as a decorator by itself or to have optional arguments
# like @I.ir_module(check_well_formed=False)
def ir_module(mod: Optional[Type] = None, check_well_formed: bool = True) -> IRModule:
"""The parsing method for ir module, by using `@ir_module` as decorator.
Parameters
----------
mod : Type
The class to be parsed as ir module.
check_well_formed : bool
Whether to check well-formedness during parsing.
Returns
-------
ir_module : IRModule
The parsed ir module.
"""
if not inspect.isclass(mod):
raise TypeError(f"Expect a class, but got: {mod}")

m = parse(mod, utils.inspect_class_capture(mod))
setattr(m, "__name__", mod.__name__)
return m
def decorator_wrapper(mod):
if not inspect.isclass(mod):
raise TypeError(f"Expect a class, but got: {mod}")
m = parse(mod, utils.inspect_class_capture(mod), check_well_formed=check_well_formed)
setattr(m, "__name__", mod.__name__)
return m

if mod is not None:
# if there are no optional args given, this will directly invoke the wrapper
return decorator_wrapper(mod)
else:
# if there is a optional arg given, it returns the wrapper function
# as a new decorator and applies it
setattr(decorator_wrapper, "dispatch_token", "ir")
return decorator_wrapper


setattr(ir_module, "dispatch_token", "ir")
4 changes: 2 additions & 2 deletions python/tvm/script/parser/relax/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
# appear as a decorator by itself or to have optional arguments
# like @R.function(pure=False)
def function(
f: Optional[FType] = None, pure: bool = True, private: bool = False
f: Optional[FType] = None, pure: bool = True, private: bool = False, check_well_formed=True
) -> Union[Function, FType]:
# pylint: disable=unused-argument
# (pure and private aren't used here, but are used later in parsing)
Expand All @@ -66,7 +66,7 @@ def decorator_wrapper(f):
raise TypeError(f"Expect a function, but got: {f}")
if utils.is_defined_in_class(orig_stack, f):
return f
return parse(f, utils.inspect_function_capture(f))
return parse(f, utils.inspect_function_capture(f), check_well_formed=check_well_formed)

if f is not None:
# if there are no optional args given, this will directly invoke the wrapper
Expand Down
6 changes: 4 additions & 2 deletions python/tvm/script/parser/tir/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
from ..core.parser import Parser, ScriptMacro


def prim_func(func: Optional[Callable] = None, private: bool = False) -> Union[PrimFunc, Callable]:
def prim_func(
func: Optional[Callable] = None, private: bool = False, check_well_formed=True
) -> Union[PrimFunc, Callable]:
"""The parsing method for tir prim func, by using `@prim_func` as decorator.
Parameters
Expand Down Expand Up @@ -60,7 +62,7 @@ def decorator_wrapper(func):
raise TypeError(f"Expect a function, but got: {func}")
if utils.is_defined_in_class(outer_stack, func):
return func
f = parse(func, utils.inspect_function_capture(func))
f = parse(func, utils.inspect_function_capture(func), check_well_formed=check_well_formed)
setattr(f, "__name__", func.__name__)
return f

Expand Down
9 changes: 6 additions & 3 deletions python/tvm/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,6 @@ def enabled_targets():


class Feature:

"""A feature that may be required to run a test.
Parameters
Expand Down Expand Up @@ -1952,6 +1951,8 @@ def expected(A: T.Buffer(1, "int32")):
"""

check_well_formed: bool = True

def __init_subclass__(cls):
assert len([getattr(cls, name) for name in ["before", "Before"] if hasattr(cls, name)]) <= 1
assert (
Expand Down Expand Up @@ -1995,7 +1996,9 @@ def inner(self):
func_dict[name] = method.with_attr("global_symbol", name)
else:
source_code = "@T.prim_func\n" + textwrap.dedent(inspect.getsource(method))
prim_func = tvm.script.from_source(source_code)
prim_func = tvm.script.from_source(
source_code, check_well_formed=self.check_well_formed
)
func_dict[name] = prim_func.with_attr("global_symbol", name)
return tvm.IRModule(func_dict)

Expand All @@ -2004,7 +2007,7 @@ def inner(self):
def inner(self):
# pylint: disable=unused-argument
source_code = "@T.prim_func\n" + textwrap.dedent(inspect.getsource(func))
return tvm.script.from_source(source_code)
return tvm.script.from_source(source_code, check_well_formed=self.check_well_formed)

return pytest.fixture(inner)

Expand Down
Loading

0 comments on commit 6c701fe

Please sign in to comment.