-
Notifications
You must be signed in to change notification settings - Fork 31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TKW] Propagate dim index in thread shape analysis #288
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Ivan Butygin <[email protected]>
Signed-off-by: Ivan Butygin <[email protected]>
Signed-off-by: Ivan Butygin <[email protected]>
Signed-off-by: Ivan Butygin <[email protected]>
Signed-off-by: Ivan Butygin <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm! just some requests for comments and also if you could check if we still need the propagation in index_sequence_analysis of elements_per_thread? thanks!
@@ -1349,6 +1349,10 @@ def num_reduction_dims(self) -> int: | |||
def reduction_dim(self) -> IndexSymbol: | |||
return self.dim | |||
|
|||
@property |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a comment explaining why this is necessary?
@@ -51,7 +57,7 @@ def set_index_size(custom: CustomOp, target_dim_sizes: list[DimSize]): | |||
# Anchor Indicies and Conflict resolution helpers | |||
################################################################# | |||
|
|||
anchorOpTypes = (Read, Write, MMA, ReduceOp, Reshape) | |||
anchorOpTypes = (Read, Write, MMA, ReduceOp, Reshape, Permute) | |||
noHandleTypes = (Placeholder, Output, ExtractSlice, Allocate) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also a comment explaining why Permute is added as an anchor op?
@@ -51,7 +57,7 @@ def set_index_size(custom: CustomOp, target_dim_sizes: list[DimSize]): | |||
# Anchor Indicies and Conflict resolution helpers | |||
################################################################# | |||
|
|||
anchorOpTypes = (Read, Write, MMA, ReduceOp, Reshape) | |||
anchorOpTypes = (Read, Write, MMA, ReduceOp, Reshape, Permute) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add the forward propagation of permute, just as safety measure to ensure we won't be generating "valid" but incorrect IRs. Speaking from experience, would be much better for program to crash than to debug why MLIR is wrong and where the wrong is coming from. 😄
Signed-off-by: Ivan Butygin <[email protected]>
Signed-off-by: Ivan Butygin <[email protected]>
Refactor
thread_shape_analysis
to take into account entire index instead of just elements per thread count.