diff --git a/core/shark_turbine/importers/fx_importer.py b/core/shark_turbine/importers/fx_importer.py index bcab246c5..adf0baaa0 100644 --- a/core/shark_turbine/importers/fx_importer.py +++ b/core/shark_turbine/importers/fx_importer.py @@ -41,6 +41,7 @@ Attribute, Block, Context, + DenseElementsAttr, DenseResourceElementsAttr, FloatAttr, BF16Type, @@ -573,9 +574,11 @@ def _import_symbolic_torch_op( # operations on symbolic arguments as regular python expressions rather than as torch ops if is_builtin_function_or_method(target): arg_types = [ - arg.meta["val"].node.pytype - if isinstance(arg, torch.fx.Node) - else type(arg) + ( + arg.meta["val"].node.pytype + if isinstance(arg, torch.fx.Node) + else type(arg) + ) for arg in node.args ] is_int = [item == int for item in arg_types] @@ -905,7 +908,7 @@ def create_mlir_tensor_type(tensor: torch.Tensor) -> IrType: tensor_type = RankedTensorType.get(tuple(tensor.size()), element_type) return tensor_type except KeyError: - raise TypeError(f"Could not map Torch dtype {dtype} to an IREE type") + raise TypeError(f"Could not map Torch dtype {dtype} to an MLIR type") def _make_vtensor_literal_op( @@ -925,15 +928,28 @@ def _make_vtensor_literal_op( # buffer is via the indirection: Tensor -> list -> numpy array. This allows us to create a vtensor literal as # desired, but also limits which data types we can support in this function (see TORCH_DTYPE_TO_NPY_TYPE above) np_tensor = np.array(tensor.tolist()).astype(npy_dtype) - bytes_view = memoryview(np_tensor) - tensor_type = create_mlir_tensor_type(tensor) - shape_desc = "_".join([str(d) for d in tensor.shape]) - blob_name = f"torch_tensor_{shape_desc}_{str(tensor.dtype)}" - elements_attr = DenseResourceElementsAttr.get_from_buffer( - bytes_view, - blob_name, - tensor_type, - ) + # One element constants are more optimizable as splat DenseElementsAttr. DenseResourceElementsAttr does not + # support splats, so don't use it for that case. In addition, at the time of writing, it has bugs with handling + # 0d tensors. + if np_tensor.size == 1: + try: + dtype = tensor.dtype + element_type = TORCH_DTYPE_TO_MLIR_TYPE[dtype]() + except KeyError: + raise TypeError(f"Could not map Torch dtype {dtype} to an MLIR type") + elements_attr = DenseElementsAttr.get( + type=element_type, array=np_tensor, shape=np_tensor.shape + ) + else: + bytes_view = memoryview(np_tensor) + tensor_type = create_mlir_tensor_type(tensor) + shape_desc = "_".join([str(d) for d in tensor.shape]) + blob_name = f"torch_tensor_{shape_desc}_{str(tensor.dtype)}" + elements_attr = DenseResourceElementsAttr.get_from_buffer( + bytes_view, + blob_name, + tensor_type, + ) mapping.value = elements_attr else: elements_attr = mapping.value diff --git a/core/tests/dynamo/importer_basic_test.py b/core/tests/dynamo/importer_basic_test.py index 45ac5c8e0..07ccb7d9b 100644 --- a/core/tests/dynamo/importer_basic_test.py +++ b/core/tests/dynamo/importer_basic_test.py @@ -96,6 +96,49 @@ def foo(x): opt_foo = torch.compile(foo, backend=create_backend()) opt_foo(torch.randn(4, 4, 4, 4)) + def testScalarLiteralConversion(self): + """ + Test whether scalar tensors are appropriately converted to literals + """ + + def foo(): + a = torch.tensor(0, dtype=torch.int32) + b = torch.tensor(0, dtype=torch.int64) + c = torch.tensor(0, dtype=torch.float32) + d = torch.tensor(0, dtype=torch.float64) + e = torch.tensor(0, dtype=torch.complex64) + f = torch.tensor(0, dtype=torch.complex128) + g = torch.tensor(0, dtype=torch.bool) + h = torch.tensor(0, dtype=torch.uint8) + i = torch.tensor(0, dtype=torch.int8) + j = torch.tensor(0, dtype=torch.int16) + return a, b, c, d, e, f, g, h, i, j + + opt_foo = torch.compile(foo, backend=create_backend()) + opt_foo() + print(opt_foo()) + + def testSingleElementTensor(self): + """ + Test whether single element tensors are properly converted to scalars + """ + + def foo(): + a = torch.tensor([0], dtype=torch.int32) + b = torch.tensor([0], dtype=torch.int64) + c = torch.tensor([0], dtype=torch.float32) + d = torch.tensor([0], dtype=torch.float64) + e = torch.tensor([0], dtype=torch.complex64) + f = torch.tensor([0], dtype=torch.complex128) + g = torch.tensor([0], dtype=torch.bool) + h = torch.tensor([0], dtype=torch.uint8) + i = torch.tensor([0], dtype=torch.int8) + j = torch.tensor([0], dtype=torch.int16) + return a[0], b[0], c[0], d[0], e[0], f[0], g[0], h[0], i[0], j[0] + + opt_foo = torch.compile(foo, backend=create_backend()) + opt_foo() + def testPromoteScalarTensor(self): """ Test whether scalar arguments are properly promoted to 0-rank Tensors for torch ops with no Scalar equivalent