From bc3261dc98adee44a9829d2cccc987fd5bfb64a6 Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Tue, 12 Mar 2024 13:05:42 -0700 Subject: [PATCH 01/13] [torch-mlir][sparse] replace temporary metadata/export with upstream PyTorch This is still WIP but shows all cleanup that will happen once https://github.com/pytorch/pytorch/pull/117907 and follow-up friends are submitted --- python/torch_mlir/extras/fx_importer.py | 64 +++++--------- test/python/fx_importer/sparse_test.py | 106 +----------------------- 2 files changed, 23 insertions(+), 147 deletions(-) diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 952b638c1988..27c83134d338 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -263,57 +263,38 @@ } -@dataclass(frozen=True) -class SparsityMeta: - """ - Class for keeping track of sparsity meta data. - - NOTE: this will be fully replaced by - torch.fx.passes.shape_prop.SparseTensorMetadata - """ - - layout: torch.layout - batch_dim: int - sparse_dim: int - dense_dim: int - blocksize: Optional[tuple[int, int]] - pos_dtype: torch.dtype - crd_dtype: torch.dtype - - -def sparsity_encoding(shape: torch.Size, sparsity: SparsityMeta) -> str: - """Returns sparse tensor encoding for the given sparse layout as string.""" - assert sparsity is not None +def sparsity_encoding(tm: TensorMetadata) -> str: + """Returns sparse tensor encoding for the given sparse tensor as string.""" # Sparse tensors have the form # [ , , ] # which map directly to MLIR types. batch_dim, sparse_dim, dense_dim = ( - sparsity.batch_dim, - sparsity.sparse_dim, - sparsity.dense_dim, + tm.batch_dim, + tm.sparse_dim, + tm.dense_dim, ) dim = batch_dim + sparse_dim + dense_dim - assert dim == len(shape) - blocksize = sparsity.blocksize + assert dim == len(tm.shape) + blocksize = tm.blocksize dims = ",".join(f"d{d}" for d in range(0, dim)) - if sparsity.layout is torch.sparse_coo: + if tm.layout is torch.sparse_coo: assert sparse_dim == 2 and blocksize is None # TODO: deeper sparse dims lvls = f"d{batch_dim}:compressed(nonunique),d{batch_dim+1}:singleton(soa)" - elif sparsity.layout is torch.sparse_csr: + elif tm.layout is torch.sparse_csr: assert sparse_dim == 2 and blocksize is None lvls = f"d{batch_dim}:dense,d{batch_dim+1}:compressed" - elif sparsity.layout is torch.sparse_csc: + elif tm.layout is torch.sparse_csc: assert sparse_dim == 2 and blocksize is None lvls = f"d{batch_dim+1}:dense,d{batch_dim}:compressed" else: assert sparse_dim == 2 and blocksize is not None - if sparsity.layout is torch.sparse_bsr: + if tm.layout is torch.sparse_bsr: i, j = batch_dim, batch_dim + 1 else: - assert sparsity.layout is torch.sparse_bsc + assert tm.layout is torch.sparse_bsc j, i = batch_dim, batch_dim + 1 m, n = blocksize lvls = ( @@ -329,8 +310,7 @@ def sparsity_encoding(shape: torch.Size, sparsity: SparsityMeta) -> str: dense = ",".join(f"d{d}:dense" for d in range(batch_dim + sparse_dim, dim)) lvls = f"{lvls},{dense}" - posw = torch.iinfo(sparsity.pos_dtype).bits - crdw = torch.iinfo(sparsity.crd_dtype).bits + posw = crdw = torch.iinfo(tm.idx_dtype).bits return f"#sparse_tensor.encoding<{{map=({dims})->({lvls}),posWidth={posw},crdWidth={crdw}}}>" @@ -865,15 +845,15 @@ def get_vtensor_type( shape: torch.Size, dtype: torch.dtype, *, - sparsity: Optional[SparsityMeta] = None, + tm: Optional[TensorMetadata] = None, mutable: bool = False, ): """Return IrType for !torch.vtensor with the given shape and dtype""" stem = "torch.tensor" if mutable else "torch.vtensor" shape_asm = self.format_asm_shape(shape) mlir_dtype = str(self.dtype_to_type(dtype)) - if sparsity is not None: - encoding = sparsity_encoding(shape, sparsity) + if tm is not None and tm.sparse_dim is not None: + encoding = sparsity_encoding(tm) assert encoding is not None return IrType.parse( f"!{stem}<[{shape_asm}],{str(mlir_dtype)},{encoding}>", @@ -887,7 +867,7 @@ def node_val_to_type(self, node: torch_fx.Node, *, mutable: bool = False) -> IrT try: tensor_meta = node.meta.get("tensor_meta") val = node.meta.get("val") - sparsity = node.meta.get("sparsity", None) + print('BIK SEES', tensor_meta) if tensor_meta is not None: assert isinstance(tensor_meta, TensorMetadata) # Quantized tensor meta data is not preserved in our lowering, @@ -898,14 +878,14 @@ def node_val_to_type(self, node: torch_fx.Node, *, mutable: bool = False) -> IrT ) else: return self.tensor_metadata_to_type( - tensor_meta, sparsity=sparsity, mutable=mutable + tensor_meta, mutable=mutable ) elif val is not None: # some nodes with symbolic inputs pass a 'val' attribute rather than # tensor_meta if isinstance(val, TorchFakeTensor): return self.get_vtensor_type( - val.size(), val.dtype, sparsity=sparsity, mutable=mutable + val.size(), val.dtype, mutable=mutable ) t = SCALAR_TYPE_TO_TORCH_MLIR_TYPE.get(type(val)) @@ -925,18 +905,18 @@ def tensor_metadata_to_type( self, tm: TensorMetadata, *, - sparsity: Optional[SparsityMeta] = None, mutable: bool = False, ) -> IrType: tm_shape = tuple( item.node if is_symbolic(item) else item for item in list(tm.shape) ) - key = (tm_shape, tm.dtype, sparsity, mutable) + sparse_key = (tm.layout, tm.sparse_dim, tm.dense_dim, tm.blocksize, tm.idx_dtype) + key = (tm_shape, tm.dtype, sparse_key, mutable) t = self._tensor_metadata_cache.get(key) if t is None: t = self.get_vtensor_type( - tm.shape, tm.dtype, sparsity=sparsity, mutable=mutable + tm.shape, tm.dtype, tm=tm, mutable=mutable ) self._tensor_metadata_cache[key] = t return t diff --git a/test/python/fx_importer/sparse_test.py b/test/python/fx_importer/sparse_test.py index 6260a5bbaab3..37567236320f 100644 --- a/test/python/fx_importer/sparse_test.py +++ b/test/python/fx_importer/sparse_test.py @@ -8,11 +8,9 @@ from typing import Any, Callable, Optional import torch -import torch.export import torch.nn as nn from torch_mlir.extras.fx_importer import FxImporter -from torch_mlir.extras.fx_importer import SparsityMeta from torch_mlir import ir from torch_mlir.dialects import torch as torch_d from torch_mlir.compiler_utils import run_pipeline_with_repro_report @@ -21,114 +19,12 @@ ) -# All sparse layouts currently supported in torch.sparse. -SPARSE_LAYOUTS = [ - torch.sparse_coo, - torch.sparse_csr, - torch.sparse_csc, - torch.sparse_bsr, - torch.sparse_bsc, -] - - -def sparse_metadata(a: torch.Tensor) -> SparsityMeta: - """ - Returns a meta data tuple for the given sparse tensor. - - NOTE: this will be fully replaced by fx graph SparseTensorMetadata - """ - sparse_dim = a.sparse_dim() - dense_dim = a.dense_dim() - batch_dim = a.ndim - dense_dim - sparse_dim - blocksize = None - if a.layout is torch.sparse_coo: - return SparsityMeta( - a.layout, - batch_dim, - sparse_dim, - dense_dim, - blocksize, - a.indices().dtype, - a.indices().dtype, - ) - elif a.layout is torch.sparse_csr or a.layout is torch.sparse_bsr: - if a.layout is torch.sparse_bsr: - blocksize = a.values().shape[batch_dim + 1 : batch_dim + 3] - return SparsityMeta( - a.layout, - batch_dim, - sparse_dim, - dense_dim, - blocksize, - a.crow_indices().dtype, - a.col_indices().dtype, - ) - elif a.layout is torch.sparse_csc or a.layout is torch.sparse_bsc: - if a.layout is torch.sparse_bsc: - blocksize = a.values().shape[batch_dim + 1 : batch_dim + 3] - return SparsityMeta( - a.layout, - batch_dim, - sparse_dim, - dense_dim, - blocksize, - a.ccol_indices().dtype, - a.row_indices().dtype, - ) - else: - raise RuntimeError(f"Unsupported sparse layout for {a}") - - -def sparse_export( - f: Callable, args: tuple[Any, ...], kwargs: Optional[dict[str, Any]] = None -) -> torch.export.ExportedProgram: - """ - This is a ***temporary*** wrapper around `torch.export.export` - that eventually should be removed and simply replaced by the - standard API for exporting traced graphs. - - But until issue - - https://github.com/pytorch/pytorch/pull/117907 - - is addressed, this wrapper provides support for the sparse - tensor types by first converting all operands to dense tensors, - building the traced graph as for the dense case, and then - annotation sparse parameters with their actual sparse layout - attributes. This temporary solution accelerates testing - torch-mlir with PyTorch sparse tensors until the issue is - resolved. - """ - # Convert all arguments to dense. - dargs = tuple(a.to_dense() if a.layout in SPARSE_LAYOUTS else a for a in args) - mask = [a.layout in SPARSE_LAYOUTS for a in args] - # Build the regular FX traced graph with only dense arguments - # (the current version would crash otherwise, see issue above). - prog = torch.export.export(f, dargs, kwargs) - # Annotate sparse arguments in the graph. Note that we currently - # only account for sparsity defined by the user inputs to the model. - # TODO: support sparsity in model parameters (weights, biases) - # TODO: propagate sparsity into the layers - specs = prog.graph_signature.input_specs - alen = len(specs) - k = 0 - for i, node in enumerate(prog.graph.nodes): - if i >= alen: - break - spec = specs[i] - if spec.kind is torch.export.graph_signature.InputKind.USER_INPUT: - if mask[k]: - node.meta["sparsity"] = sparse_metadata(args[k]) - k = k + 1 - return prog - - def export_and_import(f, *args, **kwargs): """This method implements Stella's importer, stripped down to essentials.""" context = ir.Context() torch_d.register_dialect(context) fx_importer = FxImporter(context=context) - prog = sparse_export(f, args, kwargs) + prog = torch.export.export(f, args, kwargs) fx_importer.import_frozen_program(prog) return fx_importer.module From 398d23d347346a1be9028ed9412666373680ffd7 Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Tue, 12 Mar 2024 13:12:17 -0700 Subject: [PATCH 02/13] edit --- python/torch_mlir/extras/fx_importer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 27c83134d338..ec15998ee825 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -867,7 +867,6 @@ def node_val_to_type(self, node: torch_fx.Node, *, mutable: bool = False) -> IrT try: tensor_meta = node.meta.get("tensor_meta") val = node.meta.get("val") - print('BIK SEES', tensor_meta) if tensor_meta is not None: assert isinstance(tensor_meta, TensorMetadata) # Quantized tensor meta data is not preserved in our lowering, From c3e45c5f5f17c9a838f4f6bd1abf2fffca1c4046 Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Tue, 19 Mar 2024 16:02:06 -0700 Subject: [PATCH 03/13] resolve conflict --- python/torch_mlir/extras/fx_importer.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index a81d7f9930b5..a8c72a67cbc1 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -278,11 +278,14 @@ def sparsity_encoding(tm: TensorMetadata) -> str: assert dim == len(tm.shape) blocksize = tm.blocksize - dims = ",".join(f"d{d}" for d in range(0, dim)) + dims = ",".join(f"d{d}" for d in range(dim)) if tm.layout is torch.sparse_coo: - assert sparse_dim == 2 and blocksize is None # TODO: deeper sparse dims - lvls = f"d{batch_dim}:compressed(nonunique),d{batch_dim+1}:singleton(soa)" + assert sparse_dim >= 2 and blocksize is None + trail_dim = batch_dim + sparse_dim - 1 + coords = ",".join(f"d{d}:singleton(nonunique,soa)" for d in range(batch_dim+1, trail_dim)) + sep = "," if sparse_dim > 2 else "" + lvls = f"d{batch_dim}:compressed(nonunique),{coords}{sep}d{trail_dim}:singleton(soa)" elif tm.layout is torch.sparse_csr: assert sparse_dim == 2 and blocksize is None lvls = f"d{batch_dim}:dense,d{batch_dim+1}:compressed" @@ -303,7 +306,7 @@ def sparsity_encoding(tm: TensorMetadata) -> str: ) if batch_dim > 0: - batch = ",".join(f"d{d}:dense" for d in range(0, batch_dim)) + batch = ",".join(f"d{d}:dense" for d in range(batch_dim)) lvls = f"{batch},{lvls}" if dense_dim > 0: From c734d2ca11d9274307626084a076d129e635260f Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Mon, 25 Mar 2024 16:18:25 -0700 Subject: [PATCH 04/13] Adjusted meta data extraction with latest pending PyTorch PR --- python/torch_mlir/extras/fx_importer.py | 67 ++++++++++++++----------- 1 file changed, 39 insertions(+), 28 deletions(-) diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 52a2a673a64f..ce33114ccdea 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -262,43 +262,54 @@ torch.ops.aten.sym_numel.default: torch.ops.aten.numel.default, } +SPARSE_LAYOUTS = [ + torch.sparse_coo, + torch.sparse_csr, + torch.sparse_csc, + torch.sparse_bsr, + torch.sparse_bsc, +] + -def sparsity_encoding(tm: TensorMetadata) -> str: +def sparsity_encoding(t: torch.Tensor) -> str: """Returns sparse tensor encoding for the given sparse tensor as string.""" # Sparse tensors have the form # [ , , ] # which map directly to MLIR types. - batch_dim, sparse_dim, dense_dim = ( - tm.batch_dim, - tm.sparse_dim, - tm.dense_dim, + dim, batch_dim, sparse_dim, dense_dim = ( + t.ndim, + t.ndim - t.sparse_dim() - t.dense_dim(), + t.sparse_dim(), + t.dense_dim(), ) - dim = batch_dim + sparse_dim + dense_dim - assert dim == len(tm.shape) - blocksize = tm.blocksize - dims = ",".join(f"d{d}" for d in range(dim)) - if tm.layout is torch.sparse_coo: - assert sparse_dim >= 2 and blocksize is None + if t.layout is torch.sparse_coo: + assert sparse_dim >= 2 trail_dim = batch_dim + sparse_dim - 1 coords = ",".join(f"d{d}:singleton(nonunique,soa)" for d in range(batch_dim+1, trail_dim)) sep = "," if sparse_dim > 2 else "" lvls = f"d{batch_dim}:compressed(nonunique),{coords}{sep}d{trail_dim}:singleton(soa)" - elif tm.layout is torch.sparse_csr: - assert sparse_dim == 2 and blocksize is None + idx_dtype = t._indices().dtype # supports uncoalesced COO tensors + elif t.layout is torch.sparse_csr: + assert sparse_dim == 2 lvls = f"d{batch_dim}:dense,d{batch_dim+1}:compressed" - elif tm.layout is torch.sparse_csc: - assert sparse_dim == 2 and blocksize is None + idx_dtype = t.col_indices().dtype + elif t.layout is torch.sparse_csc: + assert sparse_dim == 2 lvls = f"d{batch_dim+1}:dense,d{batch_dim}:compressed" + idx_dtype = t.row_indices().dtype else: - assert sparse_dim == 2 and blocksize is not None - if tm.layout is torch.sparse_bsr: + assert sparse_dim == 2 + blocksize = t.values().shape[batch_dim + 1 : batch_dim + 3] + if t.layout is torch.sparse_bsr: i, j = batch_dim, batch_dim + 1 + idx_dtype = t.col_indices().dtype else: - assert tm.layout is torch.sparse_bsc + assert t.layout is torch.sparse_bsc j, i = batch_dim, batch_dim + 1 + idx_dtype = t.row_indices().dtype m, n = blocksize lvls = ( f"d{i} floordiv {m}:dense,d{j} floordiv {n}:compressed," @@ -313,7 +324,7 @@ def sparsity_encoding(tm: TensorMetadata) -> str: dense = ",".join(f"d{d}:dense" for d in range(batch_dim + sparse_dim, dim)) lvls = f"{lvls},{dense}" - posw = crdw = torch.iinfo(tm.idx_dtype).bits + posw = crdw = torch.iinfo(idx_dtype).bits return f"#sparse_tensor.encoding<{{map=({dims})->({lvls}),posWidth={posw},crdWidth={crdw}}}>" @@ -848,19 +859,18 @@ def get_vtensor_type( shape: torch.Size, dtype: torch.dtype, *, - tm: Optional[TensorMetadata] = None, + val: Optional[torch.Tensor] = None, mutable: bool = False, ): """Return IrType for !torch.vtensor with the given shape and dtype""" stem = "torch.tensor" if mutable else "torch.vtensor" shape_asm = self.format_asm_shape(shape) mlir_dtype = str(self.dtype_to_type(dtype)) - if tm is not None and tm.sparse_dim is not None: - encoding = sparsity_encoding(tm) - assert encoding is not None + if val is not None and val.layout in SPARSE_LAYOUTS: + encoding = sparsity_encoding(val) return IrType.parse( f"!{stem}<[{shape_asm}],{str(mlir_dtype)},{encoding}>", - context=self._c, + context=self._c ) return IrType.parse( f"!{stem}<[{shape_asm}],{str(mlir_dtype)}>", context=self._c @@ -880,14 +890,14 @@ def node_val_to_type(self, node: torch_fx.Node, *, mutable: bool = False) -> IrT ) else: return self.tensor_metadata_to_type( - tensor_meta, mutable=mutable + tensor_meta, val=val, mutable=mutable ) elif val is not None: # some nodes with symbolic inputs pass a 'val' attribute rather than # tensor_meta if isinstance(val, TorchFakeTensor): return self.get_vtensor_type( - val.size(), val.dtype, mutable=mutable + val.size(), val.dtype, val=val, mutable=mutable ) t = SCALAR_TYPE_TO_TORCH_MLIR_TYPE.get(type(val)) @@ -907,18 +917,19 @@ def tensor_metadata_to_type( self, tm: TensorMetadata, *, + val: Optional[torch.Tensor] = None, mutable: bool = False, ) -> IrType: tm_shape = tuple( item.node if is_symbolic(item) else item for item in list(tm.shape) ) - sparse_key = (tm.layout, tm.sparse_dim, tm.dense_dim, tm.blocksize, tm.idx_dtype) + sparse_key = None #(tm.layout, tm.sparse_dim, tm.dense_dim, tm.blocksize, tm.idx_dtype) key = (tm_shape, tm.dtype, sparse_key, mutable) t = self._tensor_metadata_cache.get(key) if t is None: t = self.get_vtensor_type( - tm.shape, tm.dtype, tm=tm, mutable=mutable + tm.shape, tm.dtype, val=val, mutable=mutable ) self._tensor_metadata_cache[key] = t return t From cb156045a4a5359078f43124ff79264f296338bf Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Mon, 25 Mar 2024 16:25:37 -0700 Subject: [PATCH 05/13] edit --- python/torch_mlir/extras/fx_importer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index ce33114ccdea..c32207497ac7 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -870,7 +870,7 @@ def get_vtensor_type( encoding = sparsity_encoding(val) return IrType.parse( f"!{stem}<[{shape_asm}],{str(mlir_dtype)},{encoding}>", - context=self._c + context=self._c, ) return IrType.parse( f"!{stem}<[{shape_asm}],{str(mlir_dtype)}>", context=self._c @@ -924,8 +924,7 @@ def tensor_metadata_to_type( item.node if is_symbolic(item) else item for item in list(tm.shape) ) - sparse_key = None #(tm.layout, tm.sparse_dim, tm.dense_dim, tm.blocksize, tm.idx_dtype) - key = (tm_shape, tm.dtype, sparse_key, mutable) + key = (tm_shape, tm.dtype, val, mutable) t = self._tensor_metadata_cache.get(key) if t is None: t = self.get_vtensor_type( From e0201598b1702c432623b52d8dbe0498df940f17 Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Mon, 1 Apr 2024 09:54:44 -0700 Subject: [PATCH 06/13] prepare for MLIR bump --- python/torch_mlir/extras/fx_importer.py | 2 +- test/python/fx_importer/sparse_test.py | 84 +++++++++++++++++++------ 2 files changed, 67 insertions(+), 19 deletions(-) diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 9d2a90802fae..2d9f6d8a1f66 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -320,7 +320,7 @@ def sparsity_encoding(t: torch.Tensor) -> str: ) if batch_dim > 0: - batch = ",".join(f"d{d}:dense" for d in range(batch_dim)) + batch = ",".join(f"d{d}:batch" for d in range(batch_dim)) lvls = f"{batch},{lvls}" if dense_dim > 0: diff --git a/test/python/fx_importer/sparse_test.py b/test/python/fx_importer/sparse_test.py index 85b1a9339761..0fcbcfe1d9f1 100644 --- a/test/python/fx_importer/sparse_test.py +++ b/test/python/fx_importer/sparse_test.py @@ -56,7 +56,6 @@ def sparse_jit(f, *args, **kwargs): xargs = [] for a in args: if a.layout is torch.sparse_coo: - xargs.append(a.values().numpy()) # Construct the additional position array required by MLIR with data # array([0, nnz]). xargs.append(torch.tensor([0, a._nnz()], dtype=a.indices().dtype).numpy()) @@ -64,14 +63,15 @@ def sparse_jit(f, *args, **kwargs): # MLIR SoA COO representation. for idx in a.indices(): xargs.append(idx.numpy()) - elif a.layout is torch.sparse_csr or a.layout is torch.sparse_bsr: xargs.append(a.values().numpy()) + elif a.layout is torch.sparse_csr or a.layout is torch.sparse_bsr: xargs.append(a.crow_indices().numpy()) xargs.append(a.col_indices().numpy()) - elif a.layout is torch.sparse_csc or a.layout is torch.sparse_bsc: xargs.append(a.values().numpy()) + elif a.layout is torch.sparse_csc or a.layout is torch.sparse_bsc: xargs.append(a.ccol_indices().numpy()) xargs.append(a.row_indices().numpy()) + xargs.append(a.values().numpy()) else: xargs.append(a.numpy()) # Invoke. @@ -211,21 +211,20 @@ def forward(self, x, y): # CHECK: return %[[R]] : !torch.vtensor<[8,4,2],f32> # CHECK: } # -# CHECK: torch.sparse -# CHECK: tensor(crow_indices=tensor([ 0, 4, 8, 12, 16, 20, 24, 28, 32]), -# CHECK: col_indices=tensor([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, -# CHECK: 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]), -# CHECK: values=tensor({{\[}}[ -1., -2.], -# CHECK: [ -3., -4.], -# ... -# CHECK: [-63., -64.]{{\]}}), size=(8, 4, 2), nnz=32, -# CHECK: layout=torch.sparse_csr) -# CHECK: torch.mlir -# CHECK: {{\[\[}}[ -1. -2.] -# CHECK: [ -3. -4.] -# ... -# CHECK: [-61. -62.] -# CHECK: [-63. -64.]{{\]\]}} +# CHECK: torch.sparse +# CHECK: tensor(crow_indices=tensor([ 0, 4, 8, 12, 16, 20, 24, 28, 32]), +# CHECK: col_indices=tensor([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, +# CHECK: 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]), +# CHECK: values=tensor({{\[}}[ -1., -2.], +# ... +# CHECK: [-63., -64.]{{\]}}), size=(8, 4, 2), nnz=32, +# CHECK: layout=torch.sparse_csr) +# CHECK: torch.mlir +# CHECK: {{\[\[}}[ -1. -2.] +# CHECK: [ -3. -4.] +# ... +# CHECK: [-61. -62.] +# CHECK: [-63. -64.]{{\]\]}} # def test_sparse_eltwise(): class EltNet(torch.nn.Module): @@ -262,3 +261,52 @@ def forward(self, x): print(res1) print("torch.mlir") print(res2) + + +@run +# CHECK-LABEL: test_sparse_coo3 +# CHECK: #[[$COO3:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed(nonunique), d1 : singleton(nonunique, soa), d2 : singleton(soa)), posWidth = 64, crdWidth = 64 }> +# CHECK: func.func @main( +# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[10,20,30],f64,#sparse>) -> !torch.vtensor<[10,20,30],f64,#sparse> { +# CHECK: %[[R:.*]] = torch.aten.relu %[[A]] : !torch.vtensor<[10,20,30],f64,#sparse> -> !torch.vtensor<[10,20,30],f64,#sparse> +# CHECK: return %[[R]] : !torch.vtensor<[10,20,30],f64,#sparse> +# CHECK: } +# +# CHECK: torch.sparse +# CHECK: tensor(indices=tensor([[ 0, 1, 1, 4, 9, 9], +# CHECK: [ 0, 1, 1, 5, 19, 19], +# CHECK: [ 0, 1, 3, 6, 28, 29]]), +# CHECK: values=tensor([ 0., 0., 1., 2., 3., 1000.]), +# CHECK: size=(10, 20, 30), nnz=6, dtype=torch.float64, layout=torch.sparse_coo) +# CHECK: torch.mlir +# CHECK: tensor(indices=tensor([[ 0, 1, 1, 4, 9, 9], +# CHECK: [ 0, 1, 1, 5, 19, 19], +# CHECK: [ 0, 1, 3, 6, 28, 29]]), +# CHECK: values=tensor([ 0., 0., 1., 2., 3., 1000.]), +# +def test_sparse_coo3(): + class COO3Net(torch.nn.Module): + def __init__(self): + super(COO3Net, self).__init__() + self.relu = nn.ReLU() + + def forward(self, x): + return self.relu(x) + + net = COO3Net() + + # Direct 3-dim COO construction. + idx = torch.tensor([[0, 1, 1, 4, 9, 9], [0, 1, 1, 5, 19, 19], [0, 1, 3, 6, 28, 29]]) + val = torch.tensor([-1000.0, -1.0, 1.0, 2.0, 3.0, 1000.0], dtype=torch.float64) + sparse_input = torch.sparse_coo_tensor(idx, val, size=[10, 20, 30]) + + m = export_and_import(net, sparse_input) + print(m) + + # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. + res1 = net(sparse_input) + res2 = sparse_jit(net, sparse_input) + print("torch.sparse") + print(res1) + print("torch.mlir") + print(res2) From 975050ad75b41ea6e076aa1787e8a57d205f1dad Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Mon, 1 Apr 2024 16:25:17 -0700 Subject: [PATCH 07/13] rebased and resolved conflicts of pending WIP PR with main --- python/torch_mlir/extras/fx_importer.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 629cef5dac78..b6309856e8ee 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -898,7 +898,7 @@ def node_val_to_type(self, node: torch_fx.Node, *, mutable: bool = False) -> IrT val = node.meta.get("val") sparsity = node.meta.get("sparsity", None) return self.value_info_to_type( - val, tensor_meta=tensor_meta, sparsity=sparsity, mutable=mutable + val, tensor_meta=tensor_meta, mutable=mutable ) except KeyError as e: raise RuntimeError( @@ -910,7 +910,6 @@ def value_info_to_type( val, *, tensor_meta: Optional[TensorMetadata] = None, - sparsity=None, mutable: bool = False, ): if tensor_meta is not None: @@ -923,14 +922,14 @@ def value_info_to_type( ) else: return self.tensor_metadata_to_type( - tensor_meta, sparsity=sparsity, mutable=mutable + tensor_meta, val=val, mutable=mutable ) elif val is not None: # some nodes with symbolic inputs pass a 'val' attribute rather than # tensor_meta if isinstance(val, TorchFakeTensor): return self.get_vtensor_type( - val.size(), val.dtype, sparsity=sparsity, mutable=mutable + val.size(), val.dtype, val=val, mutable=mutable ) else: t = SCALAR_TYPE_TO_TORCH_MLIR_TYPE.get(type(val)) From 9058470cbe8e75fe2e61ccbb241469c9d277d125 Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Mon, 8 Apr 2024 16:59:27 -0700 Subject: [PATCH 08/13] rebased changes with mainline --- test/python/fx_importer/sparse_test.py | 27 ++++++++++++++++++++++---- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/test/python/fx_importer/sparse_test.py b/test/python/fx_importer/sparse_test.py index 2a0c3fefcd3f..d632e92285b4 100644 --- a/test/python/fx_importer/sparse_test.py +++ b/test/python/fx_importer/sparse_test.py @@ -274,12 +274,23 @@ def forward(self, x): # CHECK-LABEL: test_sparse_coo3 # CHECK: #[[$COO3:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed(nonunique), d1 : singleton(nonunique, soa), d2 : singleton(soa)), posWidth = 64, crdWidth = 64 }> # CHECK: func.func @main( -# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[10,20,30],f64,#sparse>) -> !torch.vtensor<[10,20,30],f64> { -# CHECK: %[[R:.*]] = torch.aten.relu %[[A]] : !torch.vtensor<[10,20,30],f64,#sparse> -> !torch.vtensor<[10,20,30],f64> -# CHECK: return %[[R]] : !torch.vtensor<[10,20,30],f64> +# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[10,20,30],f64,#sparse>) -> !torch.vtensor<[10,20,30],f64,#sparse> { +# CHECK: %[[R:.*]] = torch.aten.relu %[[A]] : !torch.vtensor<[10,20,30],f64,#sparse> -> !torch.vtensor<[10,20,30],f64,#sparse> +# CHECK: return %[[R]] : !torch.vtensor<[10,20,30],f64,#sparse> # CHECK: } # -# TODO: make sure sparsity propagates through relu into the output and test actual JIT output +# CHECK: torch.sparse +# CHECK: tensor(indices=tensor([[ 0, 1, 1, 4, 9, 9], +# CHECK: [ 0, 1, 1, 5, 19, 19], +# CHECK: [ 0, 1, 3, 6, 28, 29]]), +# CHECK: values=tensor([ 0., 0., 1., 2., 3., 1000.]), +# CHECK: size=(10, 20, 30), nnz=6, dtype=torch.float64, layout=torch.sparse_coo) +# CHECK: torch.mlir +# CHECK: tensor(indices=tensor([[ 0, 1, 1, 4, 9, 9], +# CHECK: [ 0, 1, 1, 5, 19, 19], +# CHECK: [ 0, 1, 3, 6, 28, 29]]), +# CHECK: values=tensor([ 0., 0., 1., 2., 3., 1000.]), +# CHECK: size=(10, 20, 30), nnz=6, dtype=torch.float64, layout=torch.sparse_coo) # def test_sparse_coo3(): class COO3Net(torch.nn.Module): @@ -299,3 +310,11 @@ def forward(self, x): m = export_and_import(net, sparse_input) print(m) + + # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. + res1 = net(sparse_input) + res2 = sparse_jit(net, sparse_input) + print("torch.sparse") + print(res1) + print("torch.mlir") + print(res2) From f644d597c3a950081c08f202edeb7bc7b5cf1653 Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Mon, 8 Apr 2024 17:16:35 -0700 Subject: [PATCH 09/13] edit --- test/python/fx_importer/sparse_test.py | 41 +++++++++++++------------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/test/python/fx_importer/sparse_test.py b/test/python/fx_importer/sparse_test.py index d632e92285b4..928785707941 100644 --- a/test/python/fx_importer/sparse_test.py +++ b/test/python/fx_importer/sparse_test.py @@ -211,27 +211,26 @@ def forward(self, x, y): # CHECK: return %[[R]] : !torch.vtensor<[8,4,2],f32> # CHECK: } # -# CHECK: torch.sparse -# CHECK: tensor(crow_indices=tensor([ 0, 4, 8, 12, 16, 20, 24, 28, 32]), -# CHECK: col_indices=tensor([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, -# CHECK: 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]), -# CHECK: values=tensor({{\[}}[ -1., -2.], -# ... -# CHECK: [-63., -64.]{{\]}}), size=(8, 4, 2), nnz=32, -# CHECK: layout=torch.sparse_csr) -# CHECK: torch.mlir -# CHECK: {{\[\[}}[ -1. -2.] -# CHECK: [ -3. -4.] -# ... -# CHECK: [-61. -62.] -# CHECK: [-63. -64.]{{\]\]}} -# -# CHECK: torch.mlir.batch -# CHECK: {{\[\[}}[ -1. -2.] -# CHECK: [ -3. -4.] -# ... -# CHECK: [-61. -62.] -# CHECK: [-63. -64.]{{\]\]}} +# CHECK: torch.sparse +# CHECK: tensor(crow_indices=tensor([ 0, 4, 8, 12, 16, 20, 24, 28, 32]), +# CHECK: col_indices=tensor([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, +# CHECK: 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]), +# CHECK: values=tensor({{\[}}[ -1., -2.], +# ... +# CHECK: [-63., -64.]{{\]}}), size=(8, 4, 2), nnz=32, +# CHECK: layout=torch.sparse_csr) +# CHECK: torch.mlir +# CHECK: {{\[\[}}[ -1. -2.] +# CHECK: [ -3. -4.] +# ... +# CHECK: [-61. -62.] +# CHECK: [-63. -64.]{{\]\]}} +# CHECK: torch.mlir.batch +# CHECK: {{\[\[}}[ -1. -2.] +# CHECK: [ -3. -4.] +# ... +# CHECK: [-61. -62.] +# CHECK: [-63. -64.]{{\]\]}} def test_sparse_eltwise(): class EltNet(torch.nn.Module): def __init__(self): From 37043d4fdd658ac82975887961217338b003c18c Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Mon, 8 Apr 2024 17:19:48 -0700 Subject: [PATCH 10/13] edit --- test/python/fx_importer/sparse_test.py | 41 +++++++++++++------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/test/python/fx_importer/sparse_test.py b/test/python/fx_importer/sparse_test.py index 928785707941..a88a382ff94f 100644 --- a/test/python/fx_importer/sparse_test.py +++ b/test/python/fx_importer/sparse_test.py @@ -211,26 +211,27 @@ def forward(self, x, y): # CHECK: return %[[R]] : !torch.vtensor<[8,4,2],f32> # CHECK: } # -# CHECK: torch.sparse -# CHECK: tensor(crow_indices=tensor([ 0, 4, 8, 12, 16, 20, 24, 28, 32]), -# CHECK: col_indices=tensor([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, -# CHECK: 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]), -# CHECK: values=tensor({{\[}}[ -1., -2.], -# ... -# CHECK: [-63., -64.]{{\]}}), size=(8, 4, 2), nnz=32, -# CHECK: layout=torch.sparse_csr) -# CHECK: torch.mlir -# CHECK: {{\[\[}}[ -1. -2.] -# CHECK: [ -3. -4.] -# ... -# CHECK: [-61. -62.] -# CHECK: [-63. -64.]{{\]\]}} -# CHECK: torch.mlir.batch -# CHECK: {{\[\[}}[ -1. -2.] -# CHECK: [ -3. -4.] -# ... -# CHECK: [-61. -62.] -# CHECK: [-63. -64.]{{\]\]}} +# CHECK: torch.sparse +# CHECK: tensor(crow_indices=tensor([ 0, 4, 8, 12, 16, 20, 24, 28, 32]), +# CHECK: col_indices=tensor([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, +# CHECK: 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]), +# CHECK: values=tensor({{\[}}[ -1., -2.], +# ... +# CHECK: [-63., -64.]{{\]}}), size=(8, 4, 2), nnz=32, +# CHECK: layout=torch.sparse_csr) +# CHECK: torch.mlir +# CHECK: {{\[\[}}[ -1. -2.] +# CHECK: [ -3. -4.] +# ... +# CHECK: [-61. -62.] +# CHECK: [-63. -64.]{{\]\]}} +# CHECK: torch.mlir.batch +# CHECK: {{\[\[}}[ -1. -2.] +# CHECK: [ -3. -4.] +# ... +# CHECK: [-61. -62.] +# CHECK: [-63. -64.]{{\]\]}} +# def test_sparse_eltwise(): class EltNet(torch.nn.Module): def __init__(self): From 890091a9434e156ef0fc1cc95ee69071670d550c Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Mon, 8 Apr 2024 17:22:01 -0700 Subject: [PATCH 11/13] edit --- test/python/fx_importer/sparse_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/python/fx_importer/sparse_test.py b/test/python/fx_importer/sparse_test.py index ef8a76bdb51f..22ecc3c87708 100644 --- a/test/python/fx_importer/sparse_test.py +++ b/test/python/fx_importer/sparse_test.py @@ -258,7 +258,7 @@ def forward(self, x, y): # CHECK: values=tensor({{\[}}[ -1., -2.], # ... # CHECK: [-63., -64.]{{\]}}), size=(8, 4, 2), nnz=32, -# CHECK: layout=torch.sparse_csr) +# CHECK: layout=torch.sparse_csr) # CHECK: torch.mlir # CHECK: {{\[\[}}[ -1. -2.] # CHECK: [ -3. -4.] From 520064ac94880def83ba3d8a9a90448dff536f6d Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Tue, 9 Apr 2024 11:14:41 -0700 Subject: [PATCH 12/13] merge with mainline --- test/python/fx_importer/sparse_test.py | 40 ++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/test/python/fx_importer/sparse_test.py b/test/python/fx_importer/sparse_test.py index 22ecc3c87708..b1d441383ac8 100644 --- a/test/python/fx_importer/sparse_test.py +++ b/test/python/fx_importer/sparse_test.py @@ -236,6 +236,46 @@ def forward(self, x, y): print(res2) +@run +# CHECK-LABEL: test_sparse_id +# CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa)), posWidth = 64, crdWidth = 64 }> +# CHECK: func.func @main( +# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[10,20],f64,#[[$COO]]>) -> !torch.vtensor<[10,20],f64,#[[$COO]]> { +# CHECK: return %[[A]] : !torch.vtensor<[10,20],f64,#[[$COO]]> +# CHECK: } +# +# CHECK: torch.sparse +# CHECK: tensor(indices=tensor({{\[}}[ 0, 1, 2, 9], +# CHECK: [ 0, 1, 10, 19]{{\]}}), +# CHECK: values=tensor([-1000., -1., 1., 1000.]), +# CHECK: size=(10, 20), nnz=4, dtype=torch.float64, layout=torch.sparse_coo) +# CHECK: torch.mlir +# +def test_sparse_id(): + class IdNet(torch.nn.Module): + def __init__(self): + super(IdNet, self).__init__() + + def forward(self, x): + return x + + net = IdNet() + idx = torch.tensor([[0, 1, 2, 9], [0, 1, 10, 19]]) + val = torch.tensor([-1000.0, -1.0, 1.0, 1000.0], dtype=torch.float64) + sparse_input = torch.sparse_coo_tensor(idx, val, size=[10, 20]) + m = export_and_import(net, sparse_input) + print(m) + + # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. + # TODO: make output work + res1 = net(sparse_input) + # res2 = sparse_jit(net, sparse_input) + print("torch.sparse") + print(res1) + print("torch.mlir") + # print(res2) + + @run # CHECK-LABEL: test_sparse_eltwise # CHECK: #[[$BCSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : batch, d1 : dense, d2 : compressed), posWidth = 64, crdWidth = 64 }> From 4dd03cf65d1be90fc890d7a4fd189df849d4d96e Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Wed, 17 Apr 2024 17:00:48 -0700 Subject: [PATCH 13/13] rebased --- test/python/fx_importer/sparse_test.py | 41 -------------------------- 1 file changed, 41 deletions(-) diff --git a/test/python/fx_importer/sparse_test.py b/test/python/fx_importer/sparse_test.py index 7d0d525b2c15..73e5906265a2 100644 --- a/test/python/fx_importer/sparse_test.py +++ b/test/python/fx_importer/sparse_test.py @@ -242,47 +242,6 @@ def forward(self, x, y): print("torch.mlir") print(res2) - -@run -# CHECK-LABEL: test_sparse_id -# CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa)), posWidth = 64, crdWidth = 64 }> -# CHECK: func.func @main( -# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[10,20],f64,#[[$COO]]>) -> !torch.vtensor<[10,20],f64,#[[$COO]]> { -# CHECK: return %[[A]] : !torch.vtensor<[10,20],f64,#[[$COO]]> -# CHECK: } -# -# CHECK: torch.sparse -# CHECK: tensor(indices=tensor({{\[}}[ 0, 1, 2, 9], -# CHECK: [ 0, 1, 10, 19]{{\]}}), -# CHECK: values=tensor([-1000., -1., 1., 1000.]), -# CHECK: size=(10, 20), nnz=4, dtype=torch.float64, layout=torch.sparse_coo) -# CHECK: torch.mlir -# -def test_sparse_id(): - class IdNet(torch.nn.Module): - def __init__(self): - super(IdNet, self).__init__() - - def forward(self, x): - return x - - net = IdNet() - idx = torch.tensor([[0, 1, 2, 9], [0, 1, 10, 19]]) - val = torch.tensor([-1000.0, -1.0, 1.0, 1000.0], dtype=torch.float64) - sparse_input = torch.sparse_coo_tensor(idx, val, size=[10, 20]) - m = export_and_import(net, sparse_input) - print(m) - - # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. - # TODO: make output work - res1 = net(sparse_input) - # res2 = sparse_jit(net, sparse_input) - print("torch.sparse") - print(res1) - print("torch.mlir") - # print(res2) - - @run # CHECK-LABEL: test_sparse_eltwise # CHECK: #[[$BCSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : batch, d1 : dense, d2 : compressed), posWidth = 64, crdWidth = 64 }>