-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Lower distributed matmul to pipelined algorithm for fine-grained overlap: AG+GEMM layout #3695
Conversation
…lower_matmul_to_hostir
For CI testing only, please ignore. |
!test |
PR Reviewer Guide 🔍(o1-mini) (Review updated until commit 061955f)Here are some key observations to aid the review process:
|
PR Code Suggestions ✨(gpt-4o)
|
PR Code Suggestions ✨(o1-mini) No code suggestions found for the PR. |
PR Reviewer Guide 🔍(Qwen/Qwen2.5-Coder-32B-Instruct) Here are some key observations to aid the review process:
|
PR Code Suggestions ✨(Qwen/Qwen2.5-Coder-32B-Instruct)
|
@@ -235,6 +236,10 @@ void lowerToReduceScatter( | |||
std::vector<Expr*> HostIrLower::lower(Expr* c) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggestion: Proposed documentation
std::vector<Expr*> HostIrLower::lower(Expr* c) { | |
/** | |
* Lower a given expression to a series of communication operations. | |
* This function checks the type of the expression and calls the appropriate | |
* lowering function. | |
* | |
* @param c The expression to be lowered. | |
* @return A vector of expressions representing the lowered communication operations. | |
*/ | |
std::vector<Expr*> HostIrLower::lower(Expr* c) { |
@@ -302,16 +307,19 @@ | |||
return comms; | |||
} | |||
|
|||
bool HostIrLower::canLower(Expr* expr) { | |||
bool HostIrLower::canLower(Expr* expr, bool ignore_inner_resharding) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggestion: Proposed documentation
bool HostIrLower::canLower(Expr* expr, bool ignore_inner_resharding) { | |
/** | |
* Determines if a given expression can be lowered. | |
* This function checks if the expression is a resharding operation and if it involves | |
* tensor operations. It also checks specific conditions for reduction and load/store operations. | |
* | |
* @param expr The expression to be checked. | |
* @param ignore_inner_resharding Whether to ignore inner resharding checks. | |
* @return True if the expression can be lowered, false otherwise. | |
*/ | |
bool HostIrLower::canLower(Expr* expr, bool ignore_inner_resharding) { |
return false; | ||
} | ||
|
||
std::vector<Expr*> HostIrLower::lowerToCollectiveBasedPipelinedGemmComm( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggestion: Proposed documentation
std::vector<Expr*> HostIrLower::lowerToCollectiveBasedPipelinedGemmComm( | |
/** | |
* Lower a MatmulOp to a collective-based pipelined GEMM communication. | |
* This function handles the specific lowering of a matrix multiplication operation | |
* into a series of communication and computation operations suitable for pipelining. | |
* | |
* @param expr The MatmulOp expression to be lowered. | |
* @return A vector of expressions representing the lowered communication and computation operations. | |
*/ | |
std::vector<Expr*> HostIrLower::lowerToCollectiveBasedPipelinedGemmComm( |
@@ -16,14 +16,17 @@ namespace nvfuser { | |||
|
|||
class HostIrLower { | |||
public: | |||
static bool canLower(Expr* expr); | |||
static bool canLower(Expr* expr, bool ignore_inner_resharding = false); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggestion: Proposed documentation
static bool canLower(Expr* expr, bool ignore_inner_resharding = false); | |
/** | |
* Determines if a given expression can be lowered. | |
* This function checks if the expression is a resharding operation and if it involves | |
* tensor operations. It also checks specific conditions for reduction and load/store operations. | |
* | |
* @param expr The expression to be checked. | |
* @param ignore_inner_resharding Whether to ignore inner resharding checks. | |
* @return True if the expression can be lowered, false otherwise. | |
*/ | |
static bool canLower(Expr* expr, bool ignore_inner_resharding = false); |
|
||
// Lower a sharded Expr into a series of Communication. | ||
static std::vector<Expr*> lower(Expr* c); | ||
|
||
static std::unique_ptr<hir::HostIrContainer> lower( | ||
std::unique_ptr<Fusion> fusion, | ||
int64_t my_device_index); | ||
|
||
private: | ||
static std::vector<Expr*> lowerToCollectiveBasedPipelinedGemmComm(Expr* expr); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggestion: Proposed documentation
static std::vector<Expr*> lowerToCollectiveBasedPipelinedGemmComm(Expr* expr); | |
/** | |
* Lower a MatmulOp to a collective-based pipelined GEMM communication. | |
* This function handles the specific lowering of a matrix multiplication operation | |
* into a series of communication and computation operations suitable for pipelining. | |
* | |
* @param expr The MatmulOp expression to be lowered. | |
* @return A vector of expressions representing the lowered communication and computation operations. | |
*/ | |
static std::vector<Expr*> lowerToCollectiveBasedPipelinedGemmComm(Expr* expr); |
@@ -74,6 +74,9 @@ struct HostIrEvaluatorParams { | |||
// Experimental: whether to cache fusion executor. WAR: avoid recompilation | |||
// but implicitely assumes that the input shape don't change over iterations | |||
bool cache_fusion_executor = false; | |||
// number of additional cuda streams to use at runtime for comm+compute | |||
// pipelining | |||
int64_t number_of_streams = 4; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggestion: Proposed documentation
int64_t number_of_streams = 4; | |
/** | |
* Number of additional CUDA streams to use at runtime for communication and computation pipelining. | |
*/ | |
int64_t number_of_streams = 4; |
@@ -208,6 +208,8 @@ class Wait : public Expr { | |||
} | |||
}; | |||
|
|||
// Makes the current stream wait on the given stream. Non-blocking from the host | |||
// point of view. | |||
class Synchronize : public Expr { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggestion: Proposed documentation
class Synchronize : public Expr { | |
/** | |
* Makes the current stream wait on the given stream. This operation is non-blocking from the host point of view. | |
*/ | |
class Synchronize : public Expr { |
@@ -672,6 +672,7 @@ enum class ParallelType { | |||
TIDz, | |||
TIDy, | |||
TIDx, | |||
Stream, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggestion: Proposed documentation
Stream, | |
/** | |
* Stream parallel type. | |
*/ | |
Stream, |
PR Reviewer Guide 🔍(meta/llama-3.1-405b-instruct) Here are some key observations to aid the review process:
|
PR Code Suggestions ✨(meta/llama-3.1-405b-instruct)
|
Stacked on top of
GetCurrentStream
#3605What
Lower a MatmulOp sharded on the first inner axis into a pipelined AG+GEMM algorithm achieving fine grained overlap.
We introduce a new parallel type
Stream
to account for this scheduling.More precisely, this patch enables lowering the fusion:
to the Host Ir program (obtained from dump, using
NVFUSER_DUMP=host_ir
)The nsight profile shows that we do achieve overlap, in a way that is comparable to the Aten overlap experiments