Skip to content

Commit

Permalink
[fx_importer] Add support for 0D tensors (#401)
Browse files Browse the repository at this point in the history
Adds an escape hatch from creating a DenseResourceElementsAttr for
single value tensors into DenseElementsAttr.

Addresses #398
  • Loading branch information
dan-garvey authored Feb 6, 2024
1 parent e7f0f94 commit 66f79ab
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 13 deletions.
42 changes: 29 additions & 13 deletions core/shark_turbine/importers/fx_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
Attribute,
Block,
Context,
DenseElementsAttr,
DenseResourceElementsAttr,
FloatAttr,
BF16Type,
Expand Down Expand Up @@ -573,9 +574,11 @@ def _import_symbolic_torch_op(
# operations on symbolic arguments as regular python expressions rather than as torch ops
if is_builtin_function_or_method(target):
arg_types = [
arg.meta["val"].node.pytype
if isinstance(arg, torch.fx.Node)
else type(arg)
(
arg.meta["val"].node.pytype
if isinstance(arg, torch.fx.Node)
else type(arg)
)
for arg in node.args
]
is_int = [item == int for item in arg_types]
Expand Down Expand Up @@ -905,7 +908,7 @@ def create_mlir_tensor_type(tensor: torch.Tensor) -> IrType:
tensor_type = RankedTensorType.get(tuple(tensor.size()), element_type)
return tensor_type
except KeyError:
raise TypeError(f"Could not map Torch dtype {dtype} to an IREE type")
raise TypeError(f"Could not map Torch dtype {dtype} to an MLIR type")


def _make_vtensor_literal_op(
Expand All @@ -925,15 +928,28 @@ def _make_vtensor_literal_op(
# buffer is via the indirection: Tensor -> list -> numpy array. This allows us to create a vtensor literal as
# desired, but also limits which data types we can support in this function (see TORCH_DTYPE_TO_NPY_TYPE above)
np_tensor = np.array(tensor.tolist()).astype(npy_dtype)
bytes_view = memoryview(np_tensor)
tensor_type = create_mlir_tensor_type(tensor)
shape_desc = "_".join([str(d) for d in tensor.shape])
blob_name = f"torch_tensor_{shape_desc}_{str(tensor.dtype)}"
elements_attr = DenseResourceElementsAttr.get_from_buffer(
bytes_view,
blob_name,
tensor_type,
)
# One element constants are more optimizable as splat DenseElementsAttr. DenseResourceElementsAttr does not
# support splats, so don't use it for that case. In addition, at the time of writing, it has bugs with handling
# 0d tensors.
if np_tensor.size == 1:
try:
dtype = tensor.dtype
element_type = TORCH_DTYPE_TO_MLIR_TYPE[dtype]()
except KeyError:
raise TypeError(f"Could not map Torch dtype {dtype} to an MLIR type")
elements_attr = DenseElementsAttr.get(
type=element_type, array=np_tensor, shape=np_tensor.shape
)
else:
bytes_view = memoryview(np_tensor)
tensor_type = create_mlir_tensor_type(tensor)
shape_desc = "_".join([str(d) for d in tensor.shape])
blob_name = f"torch_tensor_{shape_desc}_{str(tensor.dtype)}"
elements_attr = DenseResourceElementsAttr.get_from_buffer(
bytes_view,
blob_name,
tensor_type,
)
mapping.value = elements_attr
else:
elements_attr = mapping.value
Expand Down
43 changes: 43 additions & 0 deletions core/tests/dynamo/importer_basic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,49 @@ def foo(x):
opt_foo = torch.compile(foo, backend=create_backend())
opt_foo(torch.randn(4, 4, 4, 4))

def testScalarLiteralConversion(self):
"""
Test whether scalar tensors are appropriately converted to literals
"""

def foo():
a = torch.tensor(0, dtype=torch.int32)
b = torch.tensor(0, dtype=torch.int64)
c = torch.tensor(0, dtype=torch.float32)
d = torch.tensor(0, dtype=torch.float64)
e = torch.tensor(0, dtype=torch.complex64)
f = torch.tensor(0, dtype=torch.complex128)
g = torch.tensor(0, dtype=torch.bool)
h = torch.tensor(0, dtype=torch.uint8)
i = torch.tensor(0, dtype=torch.int8)
j = torch.tensor(0, dtype=torch.int16)
return a, b, c, d, e, f, g, h, i, j

opt_foo = torch.compile(foo, backend=create_backend())
opt_foo()
print(opt_foo())

def testSingleElementTensor(self):
"""
Test whether single element tensors are properly converted to scalars
"""

def foo():
a = torch.tensor([0], dtype=torch.int32)
b = torch.tensor([0], dtype=torch.int64)
c = torch.tensor([0], dtype=torch.float32)
d = torch.tensor([0], dtype=torch.float64)
e = torch.tensor([0], dtype=torch.complex64)
f = torch.tensor([0], dtype=torch.complex128)
g = torch.tensor([0], dtype=torch.bool)
h = torch.tensor([0], dtype=torch.uint8)
i = torch.tensor([0], dtype=torch.int8)
j = torch.tensor([0], dtype=torch.int16)
return a[0], b[0], c[0], d[0], e[0], f[0], g[0], h[0], i[0], j[0]

opt_foo = torch.compile(foo, backend=create_backend())
opt_foo()

def testPromoteScalarTensor(self):
"""
Test whether scalar arguments are properly promoted to 0-rank Tensors for torch ops with no Scalar equivalent
Expand Down

0 comments on commit 66f79ab

Please sign in to comment.