Skip to content

Commit

Permalink
fix bug in transpose propagation
Browse files Browse the repository at this point in the history
  • Loading branch information
zjgarvey committed Oct 30, 2024
1 parent 4f0818f commit 90789b8
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 21 deletions.
41 changes: 28 additions & 13 deletions lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -505,24 +505,39 @@ class PropagateAtenTransposeIntPattern
}
backDP *= selfSizes[i];
}
// [i0, i1, i2, i3, i4] -> i0*D1D2D3D4 + i1*D2D3D4 + i2*D3*D4 + i3*D4 + i4
// i -> [i//(D1D2D3D4), i//(D2D3D4) % D1, i//(D3D4) % D2, i//D4 % D3, i %
// D4]
// -> [i//D1D2D3D4, i//D4 % D3, i//D3D4 % D2, i//D2D3D4 % D1, i %D4] ->
// -> (i/D1D2D3D4)*D1'D2'D3'D4' + (i//D4 % D3)*D2'D3'D4' + (i//D3D4
// %D2)*D3'D4' + (i//D2D3D4 % D1)*D4' + (i % D4)
// -> (i/D1D2D3D4)*D3D2D1D4 + (i//D4 % D3)*D2D1D4 + (i//D3D4 %D2)*D1D4 +
// (i//D2D3D4 % D1)*D4 + (i % D4)

int64_t D1234 = dim0L * midDP * dim1L * backDP;
int64_t fullDP = frontDP * D1234;
if (D1234 != (int64_t)elements.size())
if (fullDP != (int64_t)elements.size())
return failure();

// A generic transpose will look like...
// [frontDimsFlat, dim0, midDimsFlat, dim1, backDimsFlat] -> .
// [frontDimsFlat, dim1, midDimsFlat, dim0, backDimsFlat] .
// If any of front, mid, or back don't actually exist (e.g. dim0 = 0, or
// dim1 = dim0 + 1), the reassociation of completely flattened indices will
// remain unaffected by the artificially unsqueezed dims.
// --------
// Setting some notation, let D0,D1,D2,D3,D4 be the respective dim sizes of
// "self". Let D'j be the transpose dim sizes, and Djk = Dj*Dk. Let fl_trans
// and fl_self be 1-D flattened tensors. Then:
// --------
// fl_trans[i] =
// = trans[i/D'1234, i/(D'234) % D'1, i/(D'34) % D'2, i/D'4 % D'3, i % D'4]
// . = trans[i/D1234, i/D214 % D3, i/D14 % D2, i/D4 % D1, i % D4] . =
// self[i/D1234, i/D4 % D1, i/D14 % D2, i/D214 % D3, i % D4] . =
// fl_self[dot.prod(indices, (D1234,D234,D34,D4,1))] .
// --------
// reassoc(i) = (i/(D1234)) * D1234 + .
// (i/D4 % D1) * D234 + .
// (i/(D14) % D2) * D34 + .
// (i/(D214) % D3) * D4 + .
// (i % D4) .
auto reassoc = [&](int64_t i) {
return (i / (D1234)) * D1234 +
((i / backDP) % dim1L) * midDP * dim0L * backDP +
((i / (dim1L * backDP)) % midDP) * dim0L * backDP +
((i / (midDP * dim1L * backDP)) % dim0L) * backDP + (i % backDP);
return (i / D1234) * D1234 +
((i / backDP) % dim0L) * midDP * dim1L * backDP +
((i / (dim0L * backDP)) % midDP) * dim1L * backDP +
((i / (midDP * dim0L * backDP)) % dim1L) * backDP + (i % backDP);
};
SmallVector<OpFoldResult> transposedFolds;
for (int64_t i = 0; i < fullDP; i++)
Expand Down
16 changes: 8 additions & 8 deletions test/Dialect/Torch/scalarize-shapes.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -470,17 +470,17 @@ func.func @pytorch_dynamic_pad_export_slice$prop(%arg0: !torch.vtensor<[?,144,?,
// CHECK-LABEL: @pytorch_dynamic_pad_export_transpose$prop(
func.func @pytorch_dynamic_pad_export_transpose$prop(%arg0: !torch.vtensor<[?,144,?,?],f32>) -> !torch.vtensor<[2,4],si64> {
// CHECK: %[[I0:.*]] = torch.constant.int 0
// CHECK: %[[x0:.*]] = torch.aten.size.int %arg0, %[[I0]] : !torch.vtensor<[?,144,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[DIM0:.*]] = torch.aten.size.int %arg0, %[[I0]] : !torch.vtensor<[?,144,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[I2:.*]] = torch.constant.int 2
// CHECK: %[[x1:.*]] = torch.aten.size.int %arg0, %[[I2]] : !torch.vtensor<[?,144,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[DIM2:.*]] = torch.aten.size.int %arg0, %[[I2]] : !torch.vtensor<[?,144,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[I3:.*]] = torch.constant.int 3
// CHECK: %[[x2:.*]] = torch.aten.size.int %arg0, %[[I3]] : !torch.vtensor<[?,144,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[DIM3:.*]] = torch.aten.size.int %arg0, %[[I3]] : !torch.vtensor<[?,144,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[I0_0:.*]] = torch.constant.int 0
// CHECK: %[[I0_1:.*]] = torch.constant.int 0
// CHECK: %[[I0_2:.*]] = torch.constant.int 0
// CHECK: %[[I0_3:.*]] = torch.constant.int 0
// CHECK: %[[I144:.*]] = torch.constant.int 144
// CHECK: %[[x3:.*]] = torch.prim.ListConstruct %[[I0_0]], %[[x1]], %[[I0_1]], %[[x2]], %[[I0_2]], %[[x0]], %[[I0_3]], %[[I144]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[DIM1:.*]] = torch.constant.int 144
// CHECK: %[[x3:.*]] = torch.prim.ListConstruct %[[I0_0]], %[[I0_1]], %[[DIM2]], %[[DIM0]], %[[I0_2]], %[[I0_3]], %[[DIM3]], %[[DIM1]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[none:.*]] = torch.constant.none
// CHECK: %[[false:.*]] = torch.constant.bool false
// CHECK: %[[x4:.*]] = torch.aten.tensor %[[x3]], %[[none]], %[[none]], %[[false]] : !torch.list<int>, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[2,4],si64>
Expand Down Expand Up @@ -515,12 +515,12 @@ func.func @pytorch_dynamic_pad_export_transpose$prop(%arg0: !torch.vtensor<[?,14
// CHECK-LABEL: @pytorch_dynamic_pad_export_full(
func.func @pytorch_dynamic_pad_export_full(%arg0: !torch.vtensor<[?,144,?,?],f32>) -> !torch.list<int> {
// CHECK: %[[I2:.*]] = torch.constant.int 2
// CHECK: %[[x0:.*]] = torch.aten.size.int %arg0, %[[I2]] : !torch.vtensor<[?,144,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[DIM2:.*]] = torch.aten.size.int %arg0, %[[I2]] : !torch.vtensor<[?,144,?,?],f32>, !torch.int -> !torch.int
// CHECK: %[[I0:.*]] = torch.constant.int 0
// CHECK: %[[x1:.*]] = torch.prim.ListConstruct %[[x0]], %[[I0]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[x1:.*]] = torch.prim.ListConstruct %[[DIM2]], %[[I0]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: return %[[x1]] : !torch.list<int>
%0 = torch.vtensor.literal(dense<0> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
%1 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%1 = torch.vtensor.literal(dense<2> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%2 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
%int-9223372036854775807 = torch.constant.int -9223372036854775807
%int-1 = torch.constant.int -1
Expand Down

0 comments on commit 90789b8

Please sign in to comment.