-
Notifications
You must be signed in to change notification settings - Fork 28
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Removed RPB and LSE flags from template arguments in favor of runtime args #164
base: main
Are you sure you want to change the base?
Removed RPB and LSE flags from template arguments in favor of runtime args #164
Conversation
assert( | ||
(!has_rpb || !kHasCausalDims) && "Causal NA does not support RPB yet."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you move this to host side instead?
I highly prefer moving assertions to outside the kernel as much as possible, I've had some bad experiences in the past with device side asserts.
We might already have a check for it on host, but if we don't, just add an if statement and raise an error in kernel_forward.h
. There should be some checks in the already that you can copy paste.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, please take a look if that's fine.
@@ -410,6 +410,25 @@ inline NA3dDim tuple_to_na_dim(std::tuple<int32_t, int32_t, int32_t> v) { | |||
return NA3dDim(std::get<0>(v), std::get<1>(v), std::get<2>(v)); | |||
} | |||
|
|||
template <typename BoolTupleType> | |||
bool bool_tuple_or(BoolTupleType tuple); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bool bool_tuple_or(BoolTupleType tuple); | |
bool tuple_or(BoolTupleType tuple); |
Nit: logical operators like or mostly make sense for boolean types, and in this case, only causal mask.
bool should_dump_lse = logsumexp_ptr != nullptr; | ||
|
||
bool has_causal_dim = bool_tuple_or(is_causal); | ||
assert((!has_rpb || !has_causal_dim) && "Causal NA does not support RPB yet.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use NATTEN_CHECK
like below so that the error is visible to the user, and more consistent with the rest of the code.
#include <cassert> | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove?
// removed in favor of runtime args | ||
// bool kSupportsRPB_ = false, | ||
// bool kStoresLSE_ = false |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just remove them.
|
||
// replaced in favor of runtime args | ||
// static constexpr bool kSupportsRPB = kSupportsRPB_; | ||
// static constexpr bool kStoresLSE = kStoresLSE_; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto
@@ -179,6 +183,10 @@ struct FusedNeighborhoodAttentionKernel { | |||
// [num_heads, num_queries_post_partitioning] - can be null | |||
lse_scalar_t* logsumexp_ptr = nullptr; | |||
|
|||
// StoresLSE/SupportsRPB flags |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove
@@ -272,12 +280,11 @@ struct FusedNeighborhoodAttentionKernel { | |||
output_ptr += (first_query * o_strideM).sum() + | |||
(dilation_idx * o_stride_dilation).sum() + head_id * o_strideH; | |||
|
|||
if constexpr (kSupportsRPB) { | |||
if (rpb_ptr != nullptr) { | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove extra line
@alihassanijr
As discussed, removed RPB and LSE flags from template args and added them to runtime args (under
struct Params
). To make reviewing easier, I have only committed the kernel definition files and not the autogen'd files (the diff is too large to review).Requesting review.