From a137e4dab8649b2bb9d0f157718c4ce67bab819a Mon Sep 17 00:00:00 2001 From: max Date: Thu, 20 Jul 2023 00:14:25 -0500 Subject: [PATCH] refactor func to be a class hierarchy --- mlir_utils/_configuration/configuration.py | 3 + mlir_utils/dialects/ext/func.py | 185 ++++++++++++--------- mlir_utils/dialects/ext/tensor.py | 3 +- tests/test_func.py | 63 ++++++- tests/test_regions.py | 2 +- tests/test_value_caster.py | 3 +- 6 files changed, 170 insertions(+), 89 deletions(-) diff --git a/mlir_utils/_configuration/configuration.py b/mlir_utils/_configuration/configuration.py index 80e94c6..97465a6 100644 --- a/mlir_utils/_configuration/configuration.py +++ b/mlir_utils/_configuration/configuration.py @@ -5,6 +5,7 @@ from base64 import urlsafe_b64encode from importlib.metadata import distribution, packages_distributions from importlib.resources import files +from importlib.resources.readers import MultiplexedPath from pathlib import Path from .module_alias_map import get_meta_path_insertion_index, AliasedModuleFinder @@ -12,6 +13,8 @@ __MLIR_PYTHON_PACKAGE_PREFIX__ = "__MLIR_PYTHON_PACKAGE_PREFIX__" PACKAGE = __package__.split(".")[0] PACKAGE_ROOT_PATH = files(PACKAGE) +if isinstance(PACKAGE_ROOT_PATH, MultiplexedPath): + PACKAGE_ROOT_PATH = PACKAGE_ROOT_PATH._paths[0] DIST = distribution(packages_distributions()[PACKAGE][0]) MLIR_PYTHON_PACKAGE_PREFIX_TOKEN_PATH = ( Path(__file__).parent / __MLIR_PYTHON_PACKAGE_PREFIX__ diff --git a/mlir_utils/dialects/ext/func.py b/mlir_utils/dialects/ext/func.py index 9fa1083..5a1a001 100644 --- a/mlir_utils/dialects/ext/func.py +++ b/mlir_utils/dialects/ext/func.py @@ -18,95 +18,114 @@ ) -def func_base( - FuncOp, - ReturnOp, - CallOp, - sym_visibility=None, - arg_attrs=None, - res_attrs=None, - loc=None, - ip=None, -): - ip = ip or InsertionPoint.current - - # if this is set to true then wrapper below won't emit a call op - # it is set below by a def emit fn that is attached to the body_builder - # wrapper; thus you can call wrapped_fn.emit() (i.e., without an operands) - # and the func will be emitted. - _emit = False - - def builder_wrapper(body_builder): - @wraps(body_builder) - def wrapper(*call_args): - # TODO(max): implement constexpr ie enable passing constants that skip being - # part of the signature - sig = inspect.signature(body_builder) - implicit_return = sig.return_annotation is inspect._empty - input_types = [p.annotation for p in sig.parameters.values()] - if not ( - len(input_types) == len(sig.parameters) - and all(isinstance(t, Type) for t in input_types) - ): - input_types = [a.type for a in call_args] - function_type = TypeAttr.get( - FunctionType.get( - inputs=input_types, - results=[] if implicit_return else sig.return_annotation, - ) +class FuncOpMeta(type): + def __call__(cls, *args, **kwargs): + cls_obj = cls.__new__(cls) + if len(args) == 1 and len(kwargs) == 0 and inspect.isfunction(args[0]): + return cls.__init__(cls_obj, args[0]) + else: + + def init_wrapper(f): + cls.__init__(cls_obj, f, *args, **kwargs) + return cls_obj + + return lambda f: init_wrapper(f) + + +class FuncBase(metaclass=FuncOpMeta): + def __init__( + self, + body_builder, + func_op_ctor, + return_op_ctor, + call_op_ctor, + sym_visibility=None, + arg_attrs=None, + res_attrs=None, + loc=None, + ip=None, + ): + assert inspect.isfunction(body_builder), body_builder + assert inspect.isclass(func_op_ctor), func_op_ctor + assert inspect.isclass(return_op_ctor), return_op_ctor + assert inspect.isclass(call_op_ctor), call_op_ctor + + self.body_builder = body_builder + self.func_name = self.body_builder.__name__ + + self.func_op_ctor = func_op_ctor + self.return_op_ctor = return_op_ctor + self.call_op_ctor = call_op_ctor + self.sym_visibility = ( + StringAttr.get(str(sym_visibility)) if sym_visibility is not None else None + ) + self.arg_attrs = arg_attrs + self.res_attrs = res_attrs + self.loc = loc + self.ip = ip or InsertionPoint.current + self.emitted = False + + def __str__(self): + return str(f"{self.__class__} {self.__dict__}") + + def body_builder_wrapper(self, *call_args): + sig = inspect.signature(self.body_builder) + implicit_return = sig.return_annotation is inspect._empty + input_types = [p.annotation for p in sig.parameters.values()] + if not ( + len(input_types) == len(sig.parameters) + and all(isinstance(t, Type) for t in input_types) + ): + input_types = [a.type for a in call_args] + function_type = TypeAttr.get( + FunctionType.get( + inputs=input_types, + results=[] if implicit_return else sig.return_annotation, ) - # FuncOp is extended but we do really want the base - func_name = body_builder.__name__ - func_op = FuncOp( - func_name, - function_type, - sym_visibility=StringAttr.get(str(sym_visibility)) - if sym_visibility is not None - else None, - arg_attrs=arg_attrs, - res_attrs=res_attrs, - loc=loc, - ip=ip, + ) + func_op = self.func_op_ctor( + self.func_name, + function_type, + sym_visibility=self.sym_visibility, + arg_attrs=self.arg_attrs, + res_attrs=self.res_attrs, + loc=self.loc, + ip=self.ip, + ) + func_op.regions[0].blocks.append(*input_types) + with InsertionPoint(func_op.regions[0].blocks[0]): + results = get_result_or_results( + self.body_builder(*func_op.regions[0].blocks[0].arguments) ) - func_op.regions[0].blocks.append(*input_types) - with InsertionPoint(func_op.regions[0].blocks[0]): - results = get_result_or_results( - body_builder(*func_op.regions[0].blocks[0].arguments) - ) - if results is not None: - if isinstance(results, (tuple, list)): - results = list(results) - else: - results = [results] + if results is not None: + if isinstance(results, (tuple, list)): + results = list(results) else: - results = [] - ReturnOp(results) - # Recompute the function type. - return_types = [v.type for v in results] - function_type = FunctionType.get(inputs=input_types, results=return_types) - func_op.attributes["function_type"] = TypeAttr.get(function_type) - - if _emit: - return maybe_cast(get_result_or_results(func_op)) + results = [results] else: - call_op = CallOp( - [r.type for r in results], - FlatSymbolRefAttr.get(func_name), - call_args, - ) - return maybe_cast(get_result_or_results(call_op)) + results = [] + self.return_op_ctor(results) - def emit(): - nonlocal _emit - _emit = True - wrapper() + return results, input_types, func_op - wrapper.emit = emit - return wrapper + def emit(self): + self.results, input_types, func_op = self.body_builder_wrapper() + return_types = [v.type for v in self.results] + function_type = FunctionType.get(inputs=input_types, results=return_types) + func_op.attributes["function_type"] = TypeAttr.get(function_type) + self.emitted = True + # this is the func op itself (funcs never have a resulting ssa value) + return maybe_cast(get_result_or_results(func_op)) - return builder_wrapper + def __call__(self, *call_args): + if not self.emitted: + self.emit() + call_op = CallOp( + [r.type for r in self.results], + FlatSymbolRefAttr.get(self.func_name), + call_args, + ) + return maybe_cast(get_result_or_results(call_op)) -func = make_maybe_no_args_decorator( - partial(func_base, FuncOp=FuncOp.__base__, ReturnOp=ReturnOp, CallOp=CallOp) -) +func = FuncBase(FuncOp.__base__, ReturnOp, CallOp.__base__) diff --git a/mlir_utils/dialects/ext/tensor.py b/mlir_utils/dialects/ext/tensor.py index 46138ff..5e9d99f 100644 --- a/mlir_utils/dialects/ext/tensor.py +++ b/mlir_utils/dialects/ext/tensor.py @@ -2,7 +2,7 @@ from typing import Union, Tuple, Sequence import numpy as np -from mlir.dialects.tensor import EmptyOp +from mlir.dialects.tensor import EmptyOp, GenerateOp from mlir.ir import Type, Value, RankedTensorType, DenseElementsAttr, ShapedType from mlir_utils.dialects.ext.arith import ArithValue @@ -62,7 +62,6 @@ def empty( shape: Union[list[Union[int, Value]], tuple[Union[int, Value], ...]], el_type: Type, ) -> "Tensor": - return cls(EmptyOp(shape, el_type).result) diff --git a/tests/test_func.py b/tests/test_func.py index a900fa5..cf4b4cc 100644 --- a/tests/test_func.py +++ b/tests/test_func.py @@ -2,7 +2,6 @@ from textwrap import dedent import pytest - from mlir_utils.dialects.ext.arith import constant from mlir_utils.dialects.ext.func import func @@ -20,7 +19,7 @@ def demo_fun1(): return one assert hasattr(demo_fun1, "emit") - assert inspect.isfunction(demo_fun1.emit) + assert inspect.ismethod(demo_fun1.emit) demo_fun1.emit() correct = dedent( """\ @@ -33,3 +32,63 @@ def demo_fun1(): """ ) filecheck(correct, ctx.module) + + +def test_func_base_meta(ctx: MLIRContext): + print() + + @func + def foo1(): + one = constant(1) + return one + + # print("wrapped foo", foo1) + foo1.emit() + correct = dedent( + """\ + module { + func.func @foo1() -> i64 { + %c1_i64 = arith.constant 1 : i64 + return %c1_i64 : i64 + } + } + """ + ) + filecheck(correct, ctx.module) + + foo1() + correct = dedent( + """\ + module { + func.func @foo1() -> i64 { + %c1_i64 = arith.constant 1 : i64 + return %c1_i64 : i64 + } + %0 = func.call @foo1() : () -> i64 + } + """ + ) + filecheck(correct, ctx.module) + + +def test_func_base_meta2(ctx: MLIRContext): + print() + + @func + def foo1(): + one = constant(1) + return one + + foo1() + correct = dedent( + """\ + module { + func.func @foo1() -> i64 { + %c1_i64 = arith.constant 1 : i64 + return %c1_i64 : i64 + } + %0 = func.call @foo1() : () -> i64 + } + """ + ) + filecheck(correct, ctx.module) diff --git a/tests/test_regions.py b/tests/test_regions.py index 2f76376..017f245 100644 --- a/tests/test_regions.py +++ b/tests/test_regions.py @@ -7,7 +7,7 @@ from mlir_utils.dialects.ext.arith import constant from mlir_utils.dialects.ext.func import func -from mlir_utils.dialects.ext.tensor import Tensor, S, rank +from mlir_utils.dialects.ext.tensor import S, rank # noinspection PyUnresolvedReferences from mlir_utils.testing import mlir_ctx as ctx, filecheck, MLIRContext diff --git a/tests/test_value_caster.py b/tests/test_value_caster.py index 698ed8f..d80f9ba 100644 --- a/tests/test_value_caster.py +++ b/tests/test_value_caster.py @@ -1,5 +1,4 @@ import pytest -from mlir.ir import OpResult from mlir_utils.dialects.ext.tensor import S, empty from mlir_utils.dialects.ext.arith import constant @@ -9,6 +8,8 @@ from mlir_utils.testing import mlir_ctx as ctx, filecheck, MLIRContext from mlir_utils.types import f64_t, RankedTensorType +from mlir.ir import OpResult + # needed since the fix isn't defined here nor conftest.py pytest.mark.usefixtures("ctx")