Skip to content

Commit

Permalink
[linalg] Add handling for leadin and trailing size-1 dims in ViewOp
Browse files Browse the repository at this point in the history
This commit adds to the lowering of `aten.view` handling for the
following cases:

- `(..., a.size(i))` -> `(..., a.size(i), 1, ..., 1)`
- `(..., a.size(i), 1, ..., 1)` -> `(..., a.size(i))`
- `(a.size(i), ...)` -> `(1, ..., 1, a.size(i), ...)`
- `(1, ..., 1, a.size(i), ...)` -> `(a.size(i), ...)`
  • Loading branch information
ramiro050 committed Oct 3, 2023
1 parent 1c508af commit 2e5d650
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 4 deletions.
28 changes: 25 additions & 3 deletions lib/Conversion/TorchToLinalg/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,9 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
ArrayRef<int64_t> yDims,
SmallVector<int64_t> &xIndices,
SmallVector<int64_t> &yIndices) {
if (xDims.empty() || yDims.empty())
return failure();

auto isValidReduction = [](int64_t expectedReductionProduct,
ArrayRef<int64_t> arrayToReduce) -> bool {
if (llvm::count(arrayToReduce, kUnknownSize) > 0 ||
Expand Down Expand Up @@ -262,6 +265,8 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
// all the dimensions in `outputShape`.
static void calculateSingleDynamicSize(MutableArrayRef<int64_t> inputShape,
MutableArrayRef<int64_t> outputShape) {
if (inputShape.empty() || outputShape.empty())
return;
int64_t inputDynamicDimCount = llvm::count(inputShape, kUnknownSize);
int64_t outputDynamicDimCount = llvm::count(outputShape, kUnknownSize);
if (inputDynamicDimCount + outputDynamicDimCount != 1)
Expand Down Expand Up @@ -488,12 +493,29 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
outputDim = outputAssociations.back().back() + 1;
}

// Append the associations for the dims matching `aten.size.int`
if (nextUnchangedInput != inputRank &&
nextUnchangedOutput != resultRank) {
// Handle any leading or trailing size-1 dimensions and append the
// associations for the dims matching `aten.size.int`.
if (nextUnchangedInput != inputRank) {
assert(nextUnchangedOutput != resultRank &&
"`nextUnchangedInput` and `nextUnchangedOutput` should equal "
"the respective input and output rank at the same time");
inputAssociations.emplace_back();
outputAssociations.emplace_back();
}
while (inputDim <= nextUnchangedInput && inputDim < inputRank) {
if (inputDim != nextUnchangedInput && inputShape[inputDim] != 1) {
return rewriter.notifyMatchFailure(
op, "unimplemented: only collapsing of static size-1 into "
"unchanged dim supported");
}
inputAssociations.back().push_back(inputDim++);
}
while (outputDim <= nextUnchangedOutput && outputDim < resultRank) {
if (outputDim != nextUnchangedOutput && outputShape[outputDim] != 1) {
return rewriter.notifyMatchFailure(
op, "unimplemented: only expanding of static size-1 out of "
"unchanged dim supported");
}
outputAssociations.back().push_back(outputDim++);
}
}
Expand Down
104 changes: 103 additions & 1 deletion python/torch_mlir_e2e_test/test_suite/reshape_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,108 @@ def forward(self, a):
def ViewNegativeStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 128))

class ViewSizeDimFollowedByExpandedOnesModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1], torch.float32, True),
])

def forward(self, a):
return a.view(a.size(0), 1, 1, 1)

@register_test_case(module_factory=lambda: ViewSizeDimFollowedByExpandedOnesModule())
def ViewSizeDimFollowedByExpandedOnesModule_basic(module, tu: TestUtils):
module.forward(tu.rand(128))

class ViewSizeDimFollowedByCollapsedOnesModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, 1, 1, 1], torch.float32, True),
])

def forward(self, a):
return a.view(a.size(0))

@register_test_case(module_factory=lambda: ViewSizeDimFollowedByCollapsedOnesModule())
def ViewSizeDimFollowedByCollapsedOnesModule_basic(module, tu: TestUtils):
module.forward(tu.rand(128, 1, 1, 1))

class ViewSizeDimLedByExpandedOnesModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1], torch.float32, True),
])

def forward(self, a):
return a.view(1, 1, 1, a.size(0))

@register_test_case(module_factory=lambda: ViewSizeDimLedByExpandedOnesModule())
def ViewSizeDimLedByExpandedOnesModule_basic(module, tu: TestUtils):
module.forward(tu.rand(128))

class ViewSizeDimLedByCollapsedOnesModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([1, 1, 1, -1], torch.float32, True),
])

def forward(self, a):
return a.view(a.size(3))

@register_test_case(module_factory=lambda: ViewSizeDimLedByCollapsedOnesModule())
def ViewSizeDimLedByCollapsedOnesModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 1, 1, 128))

class ViewSizeDimLedAndFollowedByExpandedOnesModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1], torch.float32, True),
])

def forward(self, a):
return a.view(1, 1, 1, a.size(0), 1, 1, 1)

@register_test_case(module_factory=lambda: ViewSizeDimLedAndFollowedByExpandedOnesModule())
def ViewSizeDimLedAndFollowedByExpandedOnesModule_basic(module, tu: TestUtils):
module.forward(tu.rand(128))

class ViewSizeDimLedAndFollowedByCollapsedOnesModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([1, 1, 1, -1, 1, 1, 1], torch.float32, True),
])

def forward(self, a):
return a.view(a.size(3))

@register_test_case(module_factory=lambda: ViewSizeDimLedAndFollowedByCollapsedOnesModule())
def ViewSizeDimLedAndFollowedByCollapsedOnesModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 1, 1, 128, 1, 1, 1))

# ==============================================================================

class ReshapeAliasExpandModule(torch.nn.Module):
Expand Down Expand Up @@ -710,4 +812,4 @@ def forward(self, a):

@register_test_case(module_factory=lambda: ReshapeAliasCollapseModule())
def ReshapeAliasCollapseModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 4))
module.forward(tu.rand(2, 4))

0 comments on commit 2e5d650

Please sign in to comment.