Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP [torch-mlir][sparse] replace temporary metadata/export with upstream PyTorch #3016

Closed
wants to merge 37 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
bc3261d
[torch-mlir][sparse] replace temporary metadata/export with upstream …
aartbik Mar 12, 2024
398d23d
edit
aartbik Mar 12, 2024
c911384
Merge branch 'llvm:main' into bik
aartbik Mar 13, 2024
a941f0c
Merge branch 'main' into bik
aartbik Mar 18, 2024
f5198a9
Merge branch 'llvm:main' into bik
aartbik Mar 19, 2024
1cfc4f0
Merge branch 'llvm:main' into bik
aartbik Mar 19, 2024
c3e45c5
resolve conflict
aartbik Mar 19, 2024
216a3df
Merge branch 'main' into bik
aartbik Mar 19, 2024
fba7407
Merge branch 'llvm:main' into bik
aartbik Mar 25, 2024
c734d2c
Adjusted meta data extraction with latest pending PyTorch PR
aartbik Mar 25, 2024
cb15604
edit
aartbik Mar 25, 2024
924b855
Merge branch 'llvm:main' into bik
aartbik Mar 27, 2024
e020159
prepare for MLIR bump
aartbik Apr 1, 2024
a7749ab
Merge branch 'main' into bik
aartbik Apr 1, 2024
fb01f72
Merge branch 'llvm:main' into bik
aartbik Apr 1, 2024
975050a
rebased and resolved conflicts of pending WIP PR with main
aartbik Apr 1, 2024
e98b85e
Merge branch 'main' into bik
aartbik Apr 8, 2024
fdac9c7
Merge branch 'main' into bik
aartbik Apr 8, 2024
9058470
rebased changes with mainline
aartbik Apr 8, 2024
f644d59
edit
aartbik Apr 9, 2024
37043d4
edit
aartbik Apr 9, 2024
6f36ed7
Merge branch 'llvm:main' into bik
aartbik Apr 9, 2024
a721656
Merge branch 'llvm:main' into bik
aartbik Apr 9, 2024
46d0573
Merge branch 'llvm:main' into bik
aartbik Apr 12, 2024
e9d5441
Merge branch 'llvm:main' into bik
aartbik Apr 12, 2024
890091a
edit
aartbik Apr 9, 2024
520064a
merge with mainline
aartbik Apr 9, 2024
e0715f7
Merge branch 'llvm:main' into bik
aartbik Apr 16, 2024
a8441f2
Merge branch 'llvm:main' into bik
aartbik Apr 17, 2024
cda0a12
Merge branch 'llvm:main' into bik
aartbik Apr 17, 2024
4dd03cf
rebased
aartbik Apr 18, 2024
988fe77
Merge branch 'llvm:main' into bik
aartbik Apr 24, 2024
db59b9a
Merge branch 'llvm:main' into bik
aartbik Apr 25, 2024
f18e8a2
Merge branch 'llvm:main' into bik
aartbik Apr 25, 2024
956faeb
Merge branch 'main' into bik
aartbik Apr 30, 2024
2d55a78
Merge branch 'llvm:main' into bik
aartbik May 3, 2024
757eeca
Merge branch 'llvm:main' into bik
aartbik May 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 21 additions & 42 deletions python/torch_mlir/extras/fx_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,57 +263,38 @@
}


@dataclass(frozen=True)
class SparsityMeta:
"""
Class for keeping track of sparsity meta data.

NOTE: this will be fully replaced by
torch.fx.passes.shape_prop.SparseTensorMetadata
"""

layout: torch.layout
batch_dim: int
sparse_dim: int
dense_dim: int
blocksize: Optional[tuple[int, int]]
pos_dtype: torch.dtype
crd_dtype: torch.dtype


def sparsity_encoding(shape: torch.Size, sparsity: SparsityMeta) -> str:
"""Returns sparse tensor encoding for the given sparse layout as string."""
assert sparsity is not None
def sparsity_encoding(tm: TensorMetadata) -> str:
"""Returns sparse tensor encoding for the given sparse tensor as string."""

# Sparse tensors have the form
# [ <batch_dimensions> , <sparse_dimensions>, <dense_dimensions> ]
# which map directly to MLIR types.
batch_dim, sparse_dim, dense_dim = (
sparsity.batch_dim,
sparsity.sparse_dim,
sparsity.dense_dim,
tm.batch_dim,
tm.sparse_dim,
tm.dense_dim,
)
dim = batch_dim + sparse_dim + dense_dim
assert dim == len(shape)
blocksize = sparsity.blocksize
assert dim == len(tm.shape)
blocksize = tm.blocksize

dims = ",".join(f"d{d}" for d in range(0, dim))

if sparsity.layout is torch.sparse_coo:
if tm.layout is torch.sparse_coo:
assert sparse_dim == 2 and blocksize is None # TODO: deeper sparse dims
lvls = f"d{batch_dim}:compressed(nonunique),d{batch_dim+1}:singleton(soa)"
elif sparsity.layout is torch.sparse_csr:
elif tm.layout is torch.sparse_csr:
assert sparse_dim == 2 and blocksize is None
lvls = f"d{batch_dim}:dense,d{batch_dim+1}:compressed"
elif sparsity.layout is torch.sparse_csc:
elif tm.layout is torch.sparse_csc:
assert sparse_dim == 2 and blocksize is None
lvls = f"d{batch_dim+1}:dense,d{batch_dim}:compressed"
else:
assert sparse_dim == 2 and blocksize is not None
if sparsity.layout is torch.sparse_bsr:
if tm.layout is torch.sparse_bsr:
i, j = batch_dim, batch_dim + 1
else:
assert sparsity.layout is torch.sparse_bsc
assert tm.layout is torch.sparse_bsc
j, i = batch_dim, batch_dim + 1
m, n = blocksize
lvls = (
Expand All @@ -329,8 +310,7 @@ def sparsity_encoding(shape: torch.Size, sparsity: SparsityMeta) -> str:
dense = ",".join(f"d{d}:dense" for d in range(batch_dim + sparse_dim, dim))
lvls = f"{lvls},{dense}"

posw = torch.iinfo(sparsity.pos_dtype).bits
crdw = torch.iinfo(sparsity.crd_dtype).bits
posw = crdw = torch.iinfo(tm.idx_dtype).bits
return f"#sparse_tensor.encoding<{{map=({dims})->({lvls}),posWidth={posw},crdWidth={crdw}}}>"


Expand Down Expand Up @@ -865,15 +845,15 @@ def get_vtensor_type(
shape: torch.Size,
dtype: torch.dtype,
*,
sparsity: Optional[SparsityMeta] = None,
tm: Optional[TensorMetadata] = None,
mutable: bool = False,
):
"""Return IrType for !torch.vtensor with the given shape and dtype"""
stem = "torch.tensor" if mutable else "torch.vtensor"
shape_asm = self.format_asm_shape(shape)
mlir_dtype = str(self.dtype_to_type(dtype))
if sparsity is not None:
encoding = sparsity_encoding(shape, sparsity)
if tm is not None and tm.sparse_dim is not None:
encoding = sparsity_encoding(tm)
assert encoding is not None
return IrType.parse(
f"!{stem}<[{shape_asm}],{str(mlir_dtype)},{encoding}>",
Expand All @@ -887,7 +867,6 @@ def node_val_to_type(self, node: torch_fx.Node, *, mutable: bool = False) -> IrT
try:
tensor_meta = node.meta.get("tensor_meta")
val = node.meta.get("val")
sparsity = node.meta.get("sparsity", None)
if tensor_meta is not None:
assert isinstance(tensor_meta, TensorMetadata)
# Quantized tensor meta data is not preserved in our lowering,
Expand All @@ -898,14 +877,14 @@ def node_val_to_type(self, node: torch_fx.Node, *, mutable: bool = False) -> IrT
)
else:
return self.tensor_metadata_to_type(
tensor_meta, sparsity=sparsity, mutable=mutable
tensor_meta, mutable=mutable
)
elif val is not None:
# some nodes with symbolic inputs pass a 'val' attribute rather than
# tensor_meta
if isinstance(val, TorchFakeTensor):
return self.get_vtensor_type(
val.size(), val.dtype, sparsity=sparsity, mutable=mutable
val.size(), val.dtype, mutable=mutable
)

t = SCALAR_TYPE_TO_TORCH_MLIR_TYPE.get(type(val))
Expand All @@ -925,18 +904,18 @@ def tensor_metadata_to_type(
self,
tm: TensorMetadata,
*,
sparsity: Optional[SparsityMeta] = None,
mutable: bool = False,
) -> IrType:
tm_shape = tuple(
item.node if is_symbolic(item) else item for item in list(tm.shape)
)

key = (tm_shape, tm.dtype, sparsity, mutable)
sparse_key = (tm.layout, tm.sparse_dim, tm.dense_dim, tm.blocksize, tm.idx_dtype)
key = (tm_shape, tm.dtype, sparse_key, mutable)
t = self._tensor_metadata_cache.get(key)
if t is None:
t = self.get_vtensor_type(
tm.shape, tm.dtype, sparsity=sparsity, mutable=mutable
tm.shape, tm.dtype, tm=tm, mutable=mutable
)
self._tensor_metadata_cache[key] = t
return t
Expand Down
106 changes: 1 addition & 105 deletions test/python/fx_importer/sparse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,9 @@
from typing import Any, Callable, Optional

import torch
import torch.export
import torch.nn as nn

from torch_mlir.extras.fx_importer import FxImporter
from torch_mlir.extras.fx_importer import SparsityMeta
from torch_mlir import ir
from torch_mlir.dialects import torch as torch_d
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
Expand All @@ -21,114 +19,12 @@
)


# All sparse layouts currently supported in torch.sparse.
SPARSE_LAYOUTS = [
torch.sparse_coo,
torch.sparse_csr,
torch.sparse_csc,
torch.sparse_bsr,
torch.sparse_bsc,
]


def sparse_metadata(a: torch.Tensor) -> SparsityMeta:
"""
Returns a meta data tuple for the given sparse tensor.

NOTE: this will be fully replaced by fx graph SparseTensorMetadata
"""
sparse_dim = a.sparse_dim()
dense_dim = a.dense_dim()
batch_dim = a.ndim - dense_dim - sparse_dim
blocksize = None
if a.layout is torch.sparse_coo:
return SparsityMeta(
a.layout,
batch_dim,
sparse_dim,
dense_dim,
blocksize,
a.indices().dtype,
a.indices().dtype,
)
elif a.layout is torch.sparse_csr or a.layout is torch.sparse_bsr:
if a.layout is torch.sparse_bsr:
blocksize = a.values().shape[batch_dim + 1 : batch_dim + 3]
return SparsityMeta(
a.layout,
batch_dim,
sparse_dim,
dense_dim,
blocksize,
a.crow_indices().dtype,
a.col_indices().dtype,
)
elif a.layout is torch.sparse_csc or a.layout is torch.sparse_bsc:
if a.layout is torch.sparse_bsc:
blocksize = a.values().shape[batch_dim + 1 : batch_dim + 3]
return SparsityMeta(
a.layout,
batch_dim,
sparse_dim,
dense_dim,
blocksize,
a.ccol_indices().dtype,
a.row_indices().dtype,
)
else:
raise RuntimeError(f"Unsupported sparse layout for {a}")


def sparse_export(
f: Callable, args: tuple[Any, ...], kwargs: Optional[dict[str, Any]] = None
) -> torch.export.ExportedProgram:
"""
This is a ***temporary*** wrapper around `torch.export.export`
that eventually should be removed and simply replaced by the
standard API for exporting traced graphs.

But until issue

https://github.com/pytorch/pytorch/pull/117907

is addressed, this wrapper provides support for the sparse
tensor types by first converting all operands to dense tensors,
building the traced graph as for the dense case, and then
annotation sparse parameters with their actual sparse layout
attributes. This temporary solution accelerates testing
torch-mlir with PyTorch sparse tensors until the issue is
resolved.
"""
# Convert all arguments to dense.
dargs = tuple(a.to_dense() if a.layout in SPARSE_LAYOUTS else a for a in args)
mask = [a.layout in SPARSE_LAYOUTS for a in args]
# Build the regular FX traced graph with only dense arguments
# (the current version would crash otherwise, see issue above).
prog = torch.export.export(f, dargs, kwargs)
# Annotate sparse arguments in the graph. Note that we currently
# only account for sparsity defined by the user inputs to the model.
# TODO: support sparsity in model parameters (weights, biases)
# TODO: propagate sparsity into the layers
specs = prog.graph_signature.input_specs
alen = len(specs)
k = 0
for i, node in enumerate(prog.graph.nodes):
if i >= alen:
break
spec = specs[i]
if spec.kind is torch.export.graph_signature.InputKind.USER_INPUT:
if mask[k]:
node.meta["sparsity"] = sparse_metadata(args[k])
k = k + 1
return prog


def export_and_import(f, *args, **kwargs):
"""This method implements Stella's importer, stripped down to essentials."""
context = ir.Context()
torch_d.register_dialect(context)
fx_importer = FxImporter(context=context)
prog = sparse_export(f, args, kwargs)
prog = torch.export.export(f, args, kwargs)
fx_importer.import_frozen_program(prog)
return fx_importer.module

Expand Down
Loading