Skip to content

Commit

Permalink
[Codegen][VectorExt] Fix VectorExt ops for 0-d vectors (iree-org#18915)
Browse files Browse the repository at this point in the history
Upstream "AnyVector" does not actually allow 0d vectors. Instead, the
upstream macro, AnyVectorOfAnyRank allows them instead.
  • Loading branch information
Groverkss authored Nov 4, 2024
1 parent 9c85e30 commit ec7528c
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -256,25 +256,24 @@ def NestedLayoutAttr : IREEVectorExt_Attr<"NestedLayout",
}];

let parameters = (ins
ArrayRefParameter<"int64_t", "subgroup_tile">:$subgroupTile,
ArrayRefParameter<"int64_t", "batch_tile">:$batchTile,
ArrayRefParameter<"int64_t", "outer_tile">:$outerTile,
ArrayRefParameter<"int64_t", "thread_tile">:$threadTile,
ArrayRefParameter<"int64_t", "element_tile">:$elementTile,

ArrayRefParameter<"int64_t", "subgroup_strides">:$subgroupStrides,
ArrayRefParameter<"int64_t", "thread_strides">:$threadStrides
OptionalArrayRefParameter<"int64_t", "subgroup_tile">:$subgroupTile,
OptionalArrayRefParameter<"int64_t", "batch_tile">:$batchTile,
OptionalArrayRefParameter<"int64_t", "outer_tile">:$outerTile,
OptionalArrayRefParameter<"int64_t", "thread_tile">:$threadTile,
OptionalArrayRefParameter<"int64_t", "element_tile">:$elementTile,

OptionalArrayRefParameter<"int64_t", "subgroup_strides">:$subgroupStrides,
OptionalArrayRefParameter<"int64_t", "thread_strides">:$threadStrides
);

let assemblyFormat = [{
`<` `subgroup_tile` `=` `[` $subgroupTile `]` `,`
`batch_tile` `=` `[` $batchTile `]` `,`
`outer_tile` `=` `[` $outerTile `]` `,`
`thread_tile` `=` `[` $threadTile `]` `,`
`element_tile` `=` `[` $elementTile `]` `,`

`subgroup_strides` `=` `[` $subgroupStrides `]` `,`
`thread_strides` `=` `[` $threadStrides `]`
`<` `subgroup_tile` `=` `[` (`]`) : ($subgroupTile^ `]`)? `,`
`batch_tile` `=` `[` (`]`) : ($batchTile^ `]`)? `,`
`outer_tile` `=` `[` (`]`) : ($outerTile^ `]`)? `,`
`thread_tile` `=` `[` (`]`) : ($threadTile^ `]`)? `,`
`element_tile` `=` `[` (`]`) : ($elementTile^ `]`)? `,`
`subgroup_strides` `=` `[` (`]`) : ($subgroupStrides^ `]`)? `,`
`thread_strides` `=` `[` (`]`) : ($threadStrides^ `]`)?
`>`
}];

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,10 @@ def IREEVectorExt_ToSIMDOp : IREEVectorExt_PureOp<"to_simd",
distributed vectors.
}];
let arguments = (ins
AnyVector:$input
AnyVectorOfAnyRank:$input
);
let results = (outs
AnyVector:$output
AnyVectorOfAnyRank:$output
);
let extraClassDeclaration = [{}];
let assemblyFormat = "$input attr-dict `:` type($input) `->` type($output)";
Expand All @@ -103,10 +103,10 @@ def IREEVectorExt_ToSIMTOp : IREEVectorExt_PureOp<"to_simt",
distributed vectors.
}];
let arguments = (ins
AnyVector:$input
AnyVectorOfAnyRank:$input
);
let results = (outs
AnyVector:$output
AnyVectorOfAnyRank:$output
);
let extraClassDeclaration = [{}];
let assemblyFormat = "$input attr-dict `:` type($input) `->` type($output)";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,37 @@ func.func @specify_nested(%lhs: memref<32x32xf16>) -> vector<32x32xf16> {

// -----

#nested_0 = #iree_vector_ext.nested_layout<
subgroup_tile = [],
batch_tile = [],
outer_tile = [],
thread_tile = [],
element_tile = [],

subgroup_strides = [],
thread_strides = []
>

func.func @specify_nested_0d(%lhs: vector<f16>) -> vector<f16> {
%result = iree_vector_ext.to_layout %lhs to layout(#nested_0) : vector<f16>
func.return %result : vector<f16>
}

// CHECK: #[[$LAYOUT0:.+]] = #iree_vector_ext.nested_layout<
// CHECK-SAME: subgroup_tile = [],
// CHECK-SAME: batch_tile = [],
// CHECK-SAME: outer_tile = [],
// CHECK-SAME: thread_tile = [],
// CHECK-SAME: element_tile = [],
// CHECK-SAME: subgroup_strides = [],
// CHECK-SAME: thread_strides = []>

// CHECK-LABEL: func.func @specify_nested_0d
// CHECK: to_layout
// CHECK-SAME: layout(#[[$LAYOUT0]])

// -----

func.func @to_simd_op(%simt: vector<4x4x4xf16>) -> vector<64x64xf16> {
%simd = iree_vector_ext.to_simd %simt : vector<4x4x4xf16> -> vector<64x64xf16>
func.return %simd : vector<64x64xf16>
Expand All @@ -103,3 +134,21 @@ func.func @to_simt_op(%simd: vector<64x64xf32>) -> vector<4x4x4xf32> {
}
// CHECK-LABEL: func.func @to_simt_op
// CHECK: iree_vector_ext.to_simd

// -----

func.func @to_simd_op_0d(%simt: vector<f16>) -> vector<f16> {
%simd = iree_vector_ext.to_simd %simt : vector<f16> -> vector<f16>
func.return %simd : vector<f16>
}
// CHECK-LABEL: func.func @to_simd_op
// CHECK: iree_vector_ext.to_simd

// -----

func.func @to_simt_op_0d(%simd: vector<f32>) -> vector<f32> {
%simt = iree_vector_ext.to_simd %simd : vector<f32> -> vector<f32>
func.return %simt : vector<f32>
}
// CHECK-LABEL: func.func @to_simt_op
// CHECK: iree_vector_ext.to_simd

0 comments on commit ec7528c

Please sign in to comment.