Skip to content

Commit

Permalink
Merge pull request #332 from Xilinx/bump_to_fe9db781
Browse files Browse the repository at this point in the history
[AutoBump] Merge with fe9db78 (Jul 14) (9)
  • Loading branch information
mgehre-amd authored Sep 18, 2024
2 parents 2263bfb + 3dae02a commit b8c5a81
Show file tree
Hide file tree
Showing 10 changed files with 222 additions and 42 deletions.
31 changes: 25 additions & 6 deletions lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -373,16 +373,19 @@ static FailureOr<SmallVector<Value>> createTMTensorTopkOp(
}

namespace {
class ConvertAtenScatterSrcOp : public OpConversionPattern<AtenScatterSrcOp> {
template <typename AtenOpT>
class ConvertAtenScatterOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern::OpConversionPattern;
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult
matchAndRewrite(AtenScatterSrcOp op, OpAdaptor adaptor,
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op.getLoc();
const TypeConverter *typeConverter = getTypeConverter();
const TypeConverter *typeConverter =
OpConversionPattern<AtenOpT>::getTypeConverter();
Value self = adaptor.getSelf();
Value index = adaptor.getIndex();
Value src = adaptor.getSrc();
Expand Down Expand Up @@ -410,7 +413,19 @@ class ConvertAtenScatterSrcOp : public OpConversionPattern<AtenScatterSrcOp> {
/*dimensionsMap=*/createDefaultDimMap(indices), /*uniqueIndices=*/false,
[&](OpBuilder &b, Location loc, Value updatesElement,
Value inputElement) {
b.create<TMTensor::YieldOp>(loc, updatesElement);
if (isa<AtenScatterSrcOp>(op)) {
b.create<TMTensor::YieldOp>(loc, updatesElement);
} else if (isa<AtenScatterAddOp>(op)) {
if (isa<mlir::IntegerType>(selfType.getElementType())) {
Value add =
b.create<arith::AddIOp>(loc, inputElement, updatesElement);
b.create<TMTensor::YieldOp>(loc, add);
} else if (isa<mlir::FloatType>(selfType.getElementType())) {
Value add =
b.create<arith::AddFOp>(loc, inputElement, updatesElement);
b.create<TMTensor::YieldOp>(loc, add);
}
}
});

auto resultType = cast<RankedTensorType>(
Expand Down Expand Up @@ -2172,7 +2187,11 @@ class ConvertTorchToTMTensor
context);

target.addIllegalOp<AtenScatterSrcOp>();
patterns.add<ConvertAtenScatterSrcOp>(typeConverter, context);
patterns.add<ConvertAtenScatterOp<AtenScatterSrcOp>>(typeConverter,
context);
target.addIllegalOp<AtenScatterAddOp>();
patterns.add<ConvertAtenScatterOp<AtenScatterAddOp>>(typeConverter,
context);
target.addIllegalOp<AtenKthvalueOp>();
patterns.add<ConvertAtenKthvalueOp>(typeConverter, context);

Expand Down
7 changes: 7 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9808,6 +9808,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" func.func @\"__torch_mlir_shape_fn.aten.scatter.value\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.list<int>, %arg3: !torch.float) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.scatter_add\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.list<int>, %arg3: !torch.list<int>) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.index_select\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.index_select(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.int, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down Expand Up @@ -11605,6 +11608,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.scatter_add\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.tuple<int, int>, %arg3: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.masked_scatter\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
Expand Down
1 change: 1 addition & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2874,6 +2874,7 @@
"ScatterReduceIntMaxModuleIncludeSelf",
"ScatterReduceIntMinModuleIncludeSelf",
"ScatterValueFloatModule_basic",
"ScatterAddStaticModule_basic",
# Failure - onnx_lowering: onnx.ScatterND
"IndexPut1DFloatAccumulateModule_basic",
"IndexPut1DIntAccumulateModule_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1817,6 +1817,9 @@ def aten〇scatter〇src〡shape(self: List[int], dim: int, index: List[int], sr
def aten〇scatter〇value〡shape(self: List[int], dim: int, index: List[int], value: float) -> List[int]:
return self

def aten〇scatter_add〡shape(self: List[int], dim: int, index: List[int], src: List[int]) -> List[int]:
return self

def aten〇index_select〡shape(self: List[int], dim: int, index: List[int]) -> List[int]:
return upstream_shape_functions.index_select(self, dim, index)

Expand Down Expand Up @@ -3140,6 +3143,12 @@ def aten〇scatter〇value〡dtype(self_rank_dtype: Tuple[int, int], dim: int, i
self_rank, self_dtype = self_rank_dtype
return self_dtype

@check_dtype_function(
[Invocation(TensorOfShape(3, dtype=dtype), 0, TensorOfShape(3, dtype=torch.int64), TensorOfShape(3, dtype=dtype)) for dtype in _SORTED_TORCH_TYPES])
def aten〇scatter_add〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index_rank_dtype: Tuple[int, int], src_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype

@check_dtype_function(
[Invocation(TensorOfShape(3, dtype=dtype), TensorOfShape(3, dtype=torch.bool), TensorOfShape(3, dtype=dtype)) for dtype in _SORTED_TORCH_TYPES])
def aten〇masked_scatter〡dtype(self_rank_dtype: Tuple[int, int], mask_rank_dtype: Tuple[int, int], source_rank_dtype: Tuple[int, int]) -> int:
Expand Down
25 changes: 25 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1053,6 +1053,31 @@ def ScatterValueIntModule_basic(module, tu: TestUtils):
# ==============================================================================


class ScatterAddStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([10, 8, 6], torch.float32, True),
([2, 4, 3], torch.int64, True),
([5, 8, 6], torch.float32, True),
]
)
def forward(self, input, index, src):
return torch.ops.aten.scatter_add(input, 0, index, src)


@register_test_case(module_factory=lambda: ScatterAddStaticModule())
def ScatterAddStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4), tu.rand(5, 8, 6))


# ==============================================================================


class ScatterReduceFloatModule(torch.nn.Module):
include_self: bool
reduce_type: str
Expand Down
15 changes: 8 additions & 7 deletions python/torch_mlir/compiler_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,12 @@ def run_pipeline_with_repro_report(

class OutputType(Enum):

# Output torch dialect. When converting from FX, this will be immediately
# after the import from FX to MLIR. When converting from torchscript,
# this will come after some cleanup passes which attempt to de-alias,
# decompose and infer shapes. These should be roughly the same level of
# abstraction since those steps are done within PyTorch itself
# when coming directly from Dynamo/FX.
# Output torch dialect in backend form. When converting from TorchDynamo,
# this comes after some decomposition and reduce op variants passes are
# applied to the raw torch dialect. When converting from TorchScript, this
# comes after some cleanup passes which attempt to de-alias, decompose and infer shapes.
# These should be roughly the same level of abstraction since those
# steps are done within PyTorch itself when coming directly from Dynamo/FX.
TORCH = "torch"

# The output type contains a mix of `linalg`-on-tensors ops, `scf`, and
Expand All @@ -108,7 +108,8 @@ class OutputType(Enum):
# as taking the `TORCH` output type and lowering it to StableHLO.
STABLEHLO = "stablehlo"

# Raw output of the JIT IR importer. This is not expected to be useful
# Raw output of the JIT IR importer in the TorchScript frontend or that of
# the FX IR importer in the TorchDynamo frontend. This is not expected to be useful
# for end-users, but can be convenient for development or reporting bugs.
RAW = "raw"

Expand Down
82 changes: 59 additions & 23 deletions python/torch_mlir/extras/fx_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1099,6 +1099,10 @@ def value_info_to_type(
return self.get_vtensor_type(
val.size(), val.dtype, sparsity=sparsity, mutable=mutable
)
elif isinstance(val, list) and all(
isinstance(x, TorchFakeTensor) for x in val
):
return IrType.parse("!torch.list<vtensor>", context=self._c)

# Note that None is a valid scalar here, so it is important that this
# is always checked as the last fallback.
Expand Down Expand Up @@ -1227,6 +1231,7 @@ class GraphNodeImporter:
"_v",
"_symbol_to_value",
"_multi_result_nodes",
"_unpack_list_values",
"fx_importer",
]

Expand All @@ -1251,6 +1256,10 @@ def __init__(
# Statically multi-result nodes which we have de-tupled are noted here.
# They will have their getitem calls short-circuited.
self._multi_result_nodes: Set[torch_fx.Node] = set()
# If a OP returns a list, then it needs to be unpacked entirely using
# prim.ListUnpack. Cache the result of these nodes so that it only
# unpacks once instead of every time that getitem is used
self._unpack_list_values: Dict[torch_fx.Node, Tuple[Value]] = {}

def bind_node_value(
self,
Expand Down Expand Up @@ -1420,29 +1429,7 @@ def import_nodes(
elif op == "call_function":
target = node.target
if target == operator.getitem:
# Special case handling of getitem for when it is resolving
# against a function call that we know has returned multiple
# results. We short-circuit this case because we have modeled
# function calls to natively return multiple results vs tupling.
getitem_ref, getitem_index = node.args
if getitem_ref in self._multi_result_nodes:
try:
self.bind_node_value(
node,
self.resolve_node_value(getitem_ref, getitem_index),
)
except IndexError:
raise RuntimeError(
f"getitem de-aliasing failed. This likely "
f"indicates a programmer error that usually "
f"would have happened at runtime. Please "
f"notify developers if this case happens "
f"(at {loc})."
)
else:
raise NotImplementedError(
f"General getitem access to non-multi-result ops"
)
self._import_getitem(loc, node)
elif target in SYMBOLIC_TORCH_OPS or (
is_symbolic(node.meta.get("val"))
and is_builtin_function_or_method(target)
Expand Down Expand Up @@ -2007,6 +1994,51 @@ def _import_default_value(self, loc: Location, arg, expected_jit_type) -> Value:
with loc:
return cvt(arg, self, self._cc)

def _import_getitem(self, loc: Location, node: torch.fx.Node):
ref_node, index = node.args
if ref_node in self._multi_result_nodes:
# Special case handling of getitem for when it is resolving
# against a function call that we know has returned multiple
# results. We short-circuit this case because we have modeled
# function calls to natively return multiple results vs tupling.
try:
self.bind_node_value(
node,
self.resolve_node_value(ref_node, index),
)
except IndexError:
raise RuntimeError(
f"getitem de-aliasing failed. This likely "
f"indicates a programmer error that usually "
f"would have happened at runtime. Please "
f"notify developers if this case happens "
f"(at {loc})."
)
else:
# handle nodes that return a torch.list<...> at the MLIR level
# NOTE: the length of the list must be knowable at compile time.
if ref_node not in self._unpack_list_values:
node_result = self.resolve_node_value(ref_node, 0)
if str(node_result.type) in TORCH_LIST_TYPES:
result_types = [
self._cc.value_info_to_type(v) for v in ref_node.meta["val"]
]
operation = Operation.create(
"torch.prim.ListUnpack",
results=result_types,
operands=[node_result],
loc=loc,
)
self._unpack_list_values[ref_node] = tuple(operation.results)

try:
self.bind_node_value(node, self._unpack_list_values[ref_node][index])
except IndexError:
raise RuntimeError(
f"getitem failed. "
f"getitem only supports lists of known length. (at {loc})"
)

def _unpack_node_result_types(
self, node: torch.fx.Node, schema: FunctionSchema
) -> List[IrType]:
Expand Down Expand Up @@ -2337,6 +2369,10 @@ def _ref_finalizer(self, ref_id: int):
"vtensor": "!torch.list<optional<vtensor>>",
}

TORCH_LIST_TYPES = set(PY_TYPE_TO_TORCH_LIST_TYPE.values()) | set(
PY_TYPE_TO_TORCH_OPTIONAL_LIST_TYPE.values()
)

SCALAR_TYPE_TO_TORCH_MLIR_TYPE = {
torch.SymInt: "!torch.int",
torch.SymFloat: "!torch.float",
Expand Down
6 changes: 3 additions & 3 deletions python/torch_mlir/fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def _module_lowering(
extra_library_file_name=None,
):

if output_type == OutputType.TORCH:
if output_type == OutputType.RAW:
if verbose:
print(torch_mod)
return torch_mod
Expand All @@ -50,7 +50,7 @@ def _module_lowering(
def export_and_import(
f: Union[nn.Module, ExportedProgram],
*args,
output_type: Union[str, OutputType] = OutputType.TORCH,
output_type: Union[str, OutputType] = OutputType.RAW,
fx_importer: Optional[FxImporter] = None,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
experimental_support_mutation: bool = False,
Expand Down Expand Up @@ -99,7 +99,7 @@ def export_and_import(

def stateless_fx_import(
gm: torch.fx.GraphModule,
output_type: Union[str, OutputType] = OutputType.TORCH,
output_type: Union[str, OutputType] = OutputType.RAW,
fx_importer: Optional[FxImporter] = None,
hooks: Optional[FxImporterHooks] = None,
model_name: str = "main",
Expand Down
47 changes: 47 additions & 0 deletions test/python/fx_importer/custom_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,50 @@ def forward(self, x, y, z):
import_symbolic_shape_expressions=True,
)
print(m)


@run
# CHECK-LABEL: test_custom_op_array_output
# CHECK: func.func @main(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?,3],f32>)
# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = 10} : !torch.int
# CHECK: %[[int:.+]] = torch.constant.int 4
# CHECK: %[[V0:.+]] = torch.operator "torch.my_custom_library.array_output_op"(%[[int]], %[[ARG0]]) : (!torch.int, !torch.vtensor<[?,3],f32>) -> !torch.list<vtensor>
# CHECK: %[[V1:.+]]:4 = torch.prim.ListUnpack %[[V0]] : !torch.list<vtensor> -> !torch.vtensor<[?,3],f32>, !torch.vtensor<[?,3],f32>, !torch.vtensor<[?,3],f32>, !torch.vtensor<[?,3],f32>
# CHECK: torch.bind_symbolic_shape %[[V1]]#0, [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32>
# CHECK: torch.bind_symbolic_shape %[[V1]]#1, [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32>
# CHECK: torch.bind_symbolic_shape %[[V1]]#2, [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32>
# CHECK: torch.bind_symbolic_shape %[[V1]]#3, [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32>
# CHECK: return %[[V1]]#0, %[[V1]]#1, %[[V1]]#2, %[[V1]]#3 : !torch.vtensor<[?,3],f32>, !torch.vtensor<[?,3],f32>, !torch.vtensor<[?,3],f32>, !torch.vtensor<[?,3],f32>
def test_custom_op_array_output():
m = Library("my_custom_library", "DEF")
m.define("array_output_op(int num_outs, Tensor a) -> Tensor[]")

@impl(m, "array_output_op", "CompositeExplicitAutograd")
def custom_op(num_outs, a):
return [a] * num_outs

@impl_abstract("my_custom_library::array_output_op")
def custom_op_meta(num_outs, a):
result = custom_op(num_outs, a)
return [torch.empty_like(t) for t in result]

class ArrayOutputCustomOp(nn.Module):
def __init__(self):
super().__init__()

def forward(self, a):
return torch.ops.my_custom_library.array_output_op(4, a)

dim = Dim("n", max=10)
dynamic_shapes = {
"a": {0: dim},
}

a = torch.rand(2, 3)
m = fx.export_and_import(
ArrayOutputCustomOp(),
a,
import_symbolic_shape_expressions=True,
dynamic_shapes=dynamic_shapes,
)
print(m)
Loading

0 comments on commit b8c5a81

Please sign in to comment.