diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 6ed9d369e8e5..2897ff3423e9 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -193,6 +193,9 @@ class ConvertAtenViewOp : public OpConversionPattern { ArrayRef yDims, SmallVector &xIndices, SmallVector &yIndices) { + if (xDims.empty() || yDims.empty()) + return failure(); + auto isValidReduction = [](int64_t expectedReductionProduct, ArrayRef arrayToReduce) -> bool { if (llvm::count(arrayToReduce, kUnknownSize) > 0 || @@ -262,6 +265,8 @@ class ConvertAtenViewOp : public OpConversionPattern { // all the dimensions in `outputShape`. static void calculateSingleDynamicSize(MutableArrayRef inputShape, MutableArrayRef outputShape) { + if (inputShape.empty() || outputShape.empty()) + return; int64_t inputDynamicDimCount = llvm::count(inputShape, kUnknownSize); int64_t outputDynamicDimCount = llvm::count(outputShape, kUnknownSize); if (inputDynamicDimCount + outputDynamicDimCount != 1) @@ -488,12 +493,29 @@ class ConvertAtenViewOp : public OpConversionPattern { outputDim = outputAssociations.back().back() + 1; } - // Append the associations for the dims matching `aten.size.int` - if (nextUnchangedInput != inputRank && - nextUnchangedOutput != resultRank) { + // Handle any leading or trailing size-1 dimensions and append the + // associations for the dims matching `aten.size.int`. + if (nextUnchangedInput != inputRank) { + assert(nextUnchangedOutput != resultRank && + "`nextUnchangedInput` and `nextUnchangedOutput` should equal " + "the respective input and output rank at the same time"); inputAssociations.emplace_back(); outputAssociations.emplace_back(); + } + while (inputDim <= nextUnchangedInput && inputDim < inputRank) { + if (inputDim != nextUnchangedInput && inputShape[inputDim] != 1) { + return rewriter.notifyMatchFailure( + op, "unimplemented: only collapsing of static size-1 into " + "unchanged dim supported"); + } inputAssociations.back().push_back(inputDim++); + } + while (outputDim <= nextUnchangedOutput && outputDim < resultRank) { + if (outputDim != nextUnchangedOutput && outputShape[outputDim] != 1) { + return rewriter.notifyMatchFailure( + op, "unimplemented: only expanding of static size-1 out of " + "unchanged dim supported"); + } outputAssociations.back().push_back(outputDim++); } } diff --git a/python/torch_mlir_e2e_test/test_suite/reshape_like.py b/python/torch_mlir_e2e_test/test_suite/reshape_like.py index 1c2d810c13de..304d3025eb8d 100644 --- a/python/torch_mlir_e2e_test/test_suite/reshape_like.py +++ b/python/torch_mlir_e2e_test/test_suite/reshape_like.py @@ -672,6 +672,108 @@ def forward(self, a): def ViewNegativeStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(1, 128)) +class ViewSizeDimFollowedByExpandedOnesModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.float32, True), + ]) + + def forward(self, a): + return a.view(a.size(0), 1, 1, 1) + +@register_test_case(module_factory=lambda: ViewSizeDimFollowedByExpandedOnesModule()) +def ViewSizeDimFollowedByExpandedOnesModule_basic(module, tu: TestUtils): + module.forward(tu.rand(128)) + +class ViewSizeDimFollowedByCollapsedOnesModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, 1, 1, 1], torch.float32, True), + ]) + + def forward(self, a): + return a.view(a.size(0)) + +@register_test_case(module_factory=lambda: ViewSizeDimFollowedByCollapsedOnesModule()) +def ViewSizeDimFollowedByCollapsedOnesModule_basic(module, tu: TestUtils): + module.forward(tu.rand(128, 1, 1, 1)) + +class ViewSizeDimLedByExpandedOnesModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.float32, True), + ]) + + def forward(self, a): + return a.view(1, 1, 1, a.size(0)) + +@register_test_case(module_factory=lambda: ViewSizeDimLedByExpandedOnesModule()) +def ViewSizeDimLedByExpandedOnesModule_basic(module, tu: TestUtils): + module.forward(tu.rand(128)) + +class ViewSizeDimLedByCollapsedOnesModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 1, 1, -1], torch.float32, True), + ]) + + def forward(self, a): + return a.view(a.size(3)) + +@register_test_case(module_factory=lambda: ViewSizeDimLedByCollapsedOnesModule()) +def ViewSizeDimLedByCollapsedOnesModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 1, 128)) + +class ViewSizeDimLedAndFollowedByExpandedOnesModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.float32, True), + ]) + + def forward(self, a): + return a.view(1, 1, 1, a.size(0), 1, 1, 1) + +@register_test_case(module_factory=lambda: ViewSizeDimLedAndFollowedByExpandedOnesModule()) +def ViewSizeDimLedAndFollowedByExpandedOnesModule_basic(module, tu: TestUtils): + module.forward(tu.rand(128)) + +class ViewSizeDimLedAndFollowedByCollapsedOnesModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 1, 1, -1, 1, 1, 1], torch.float32, True), + ]) + + def forward(self, a): + return a.view(a.size(3)) + +@register_test_case(module_factory=lambda: ViewSizeDimLedAndFollowedByCollapsedOnesModule()) +def ViewSizeDimLedAndFollowedByCollapsedOnesModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 1, 128, 1, 1, 1)) + # ============================================================================== class ReshapeAliasExpandModule(torch.nn.Module): @@ -710,4 +812,4 @@ def forward(self, a): @register_test_case(module_factory=lambda: ReshapeAliasCollapseModule()) def ReshapeAliasCollapseModule_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 4)) \ No newline at end of file + module.forward(tu.rand(2, 4))