From c863374023960778e746e59e6a0e491da16e3ff0 Mon Sep 17 00:00:00 2001 From: zms1999 Date: Fri, 25 Aug 2023 13:11:04 +0800 Subject: [PATCH 1/6] [BUG FIX] make smart scheduling great again, fix bugs in streams management --- cuda/fastermoe/smart_schedule.h | 38 ++++++++++++++++++--------------- cuda/stream_manager.cpp | 6 +++++- 2 files changed, 26 insertions(+), 18 deletions(-) diff --git a/cuda/fastermoe/smart_schedule.h b/cuda/fastermoe/smart_schedule.h index 71ed050e..c5cb1ef8 100644 --- a/cuda/fastermoe/smart_schedule.h +++ b/cuda/fastermoe/smart_schedule.h @@ -157,11 +157,11 @@ void fmoe_cuda_fused_forward_impl( local_expert_count[idx_send] * !stored_models[idx_send], rank_send, global_input_buf + global_ptr[gidx_recv] * d_model, global_expert_count[idx_recv] * !stored_models[idx_self], rank_recv, - d_model, smgr->stream(0), smgr->ncclcomm); + d_model, smgr->stream(num_expert), smgr->ncclcomm); } NCCL_SAFE_CALL(ncclGroupEnd()); } - cudaEventRecord(input_ready[step], smgr->stream(0)); + cudaEventRecord(input_ready[step], smgr->stream(num_expert)); } // Broadcast shadowed experts @@ -173,21 +173,22 @@ void fmoe_cuda_fused_forward_impl( if (stored_models[i]) { if (i / num_expert == rank) { cudaEventCreate(&evt_get); - cudaEventRecord(evt_get, torch_stream); - FMOE_SWE(smgr->stream(0), evt_get); + cudaEventRecord(evt_get, smgr->stream(0)); + FMOE_SWE(smgr->stream(num_expert), evt_get); cudaEventDestroy(evt_get); } NCCL_SAFE_CALL(ncclBcast((void*)params[si].data_ptr(), expert_size * sizeof(scalar_t), ncclChar, - i / num_expert, smgr->ncclcomm, smgr->stream(0))); + i / num_expert, smgr->ncclcomm, smgr->stream(num_expert))); cudaEventCreate(evt_shadow + si); - cudaEventRecord(evt_shadow[si], smgr->stream(0)); + cudaEventRecord(evt_shadow[si], smgr->stream(num_expert)); ++si; } } // C_0 ... C_n for (long step = 0; step < n_groups; ++step) { + FMOE_SWE(smgr->stream(0), input_ready[step]); FMOE_SWE(torch_stream, input_ready[step]); for (int ei = 0; ei < num_expert; ++ei) { GEN_BASE(step); @@ -198,12 +199,13 @@ void fmoe_cuda_fused_forward_impl( global_input_buf, global_output_buf, (long) ei, step * num_expert + ei, offset, micro_batch_size, d_model, smgr); } - cudaEventRecord(output_ready[step], torch_stream); + cudaEventRecord(output_ready[step], smgr->stream(0)); } // Compute over shadowed experts for (long i = 0, si = 0; i < world_size * num_expert; ++i) { if (stored_models[i]) { + FMOE_SWE(smgr->stream(0), evt_shadow[si]); FMOE_SWE(torch_stream, evt_shadow[si]); stash_fn(params[si], si, 0); // always put shadowed expert at first, so expert_idx = 0 long offset = local_ptr[i]; @@ -218,7 +220,7 @@ void fmoe_cuda_fused_forward_impl( // R_0 ... R_n for (long step = 0; step < n_groups; ++step) { - FMOE_SWE(smgr->stream(0), output_ready[step]); + FMOE_SWE(smgr->stream(num_expert), output_ready[step]); for (int ei = 0; ei < num_expert; ++ei) { GEN_BASE(step); NCCL_SAFE_CALL(ncclGroupStart()); @@ -230,12 +232,12 @@ void fmoe_cuda_fused_forward_impl( global_expert_count[idx_send] * !stored_models[idx_self], rank_send, output_buf + local_ptr[idx_recv] * d_model, local_expert_count[idx_recv] * !stored_models[idx_recv], rank_recv, - d_model, smgr->stream(0), smgr->ncclcomm); + d_model, smgr->stream(num_expert), smgr->ncclcomm); } NCCL_SAFE_CALL(ncclGroupEnd()); } } - smgr->sync(1); + smgr->sync(num_expert + 1); delete [] local_ptr; delete [] global_ptr; @@ -308,11 +310,11 @@ void fmoe_cuda_fused_backward_impl( local_expert_count[idx_send] * !stored_models[idx_send], rank_send, global_grad_out + global_ptr[gidx_recv] * d_model, global_expert_count[idx_recv] * !stored_models[idx_self], rank_recv, - d_model, smgr->stream(0), smgr->ncclcomm); + d_model, smgr->stream(num_expert), smgr->ncclcomm); } NCCL_SAFE_CALL(ncclGroupEnd()); } - cudaEventRecord(input_ready[step], smgr->stream(0)); + cudaEventRecord(input_ready[step], smgr->stream(num_expert)); } // Shadowed experts backward and reduce @@ -328,7 +330,7 @@ void fmoe_cuda_fused_backward_impl( collect_fn(si, i / num_expert, 0); if (i / num_expert == rank) { cudaEventCreate(evt_reduce + i % num_expert); - cudaEventRecord(evt_reduce[i % num_expert], smgr->stream(0)); + cudaEventRecord(evt_reduce[i % num_expert], smgr->stream(num_expert)); } ++si; } @@ -337,6 +339,7 @@ void fmoe_cuda_fused_backward_impl( // C_0 ... C_n for (long step = 0; step < n_groups; ++step) { + FMOE_SWE(smgr->stream(0), input_ready[step]); FMOE_SWE(torch_stream, input_ready[step]); for (int ei = 0; ei < num_expert; ++ei) { GEN_BASE(step); @@ -348,13 +351,14 @@ void fmoe_cuda_fused_backward_impl( global_grad_out, global_grad_in, (long) ei, step * num_expert + ei, offset, micro_batch_size, d_model, smgr); } - cudaEventRecord(output_ready[step], torch_stream); + cudaEventRecord(output_ready[step], smgr->stream(0)); } // Collect gradients for shadowed experts for (long i = 0, si = 0; i < world_size * num_expert; ++i) { if (stored_models[i]) { if (i / num_expert == rank) { + FMOE_SWE(smgr->stream(0), evt_reduce[i % num_expert]); FMOE_SWE(torch_stream, evt_reduce[i % num_expert]); set_grad_fn(si, i % num_expert); } @@ -364,7 +368,7 @@ void fmoe_cuda_fused_backward_impl( // R_0 ... R_n for (long step = 0; step < n_groups; ++step) { - FMOE_SWE(smgr->stream(0), output_ready[step]); + FMOE_SWE(smgr->stream(num_expert), output_ready[step]); for (int ei = 0; ei < num_expert; ++ei) { GEN_BASE(step); NCCL_SAFE_CALL(ncclGroupStart()); @@ -376,13 +380,13 @@ void fmoe_cuda_fused_backward_impl( global_expert_count[idx_send] * !stored_models[idx_self], rank_send, grad_in + local_ptr[idx_recv] * d_model, local_expert_count[idx_recv] * !stored_models[idx_recv], rank_recv, - d_model, smgr->stream(0), smgr->ncclcomm); + d_model, smgr->stream(num_expert), smgr->ncclcomm); } NCCL_SAFE_CALL(ncclGroupEnd()); } } - smgr->sync(1); + smgr->sync(num_expert + 1); checkCudaErrors(cudaGetLastError()); delete [] local_ptr; diff --git a/cuda/stream_manager.cpp b/cuda/stream_manager.cpp index 4f52cdfc..459dec2f 100644 --- a/cuda/stream_manager.cpp +++ b/cuda/stream_manager.cpp @@ -45,7 +45,11 @@ void CudaStreamManager::setup(const int device) { streams = new cudaStream_t[SMGR_N_STREAMS]; handles = new cublasHandle_t[SMGR_N_STREAMS]; for (size_t i = 0; i < SMGR_N_STREAMS; ++i) { - checkCudaErrors(cudaStreamCreate(streams + i)); + // SHOULD NOT USE: cudaStreamCreate(...) + // more details in + // https://docs.nvidia.com/cuda/cuda-runtime-api/stream-sync-behavior.html + checkCudaErrors(cudaStreamCreateWithFlags(streams + i, + cudaStreamNonBlocking)); checkCudaErrors(cublasCreate(handles + i)); cublasSetStream(handles[i], streams[i]); } From ff28081c8d77c92363b37f70beae744f443f80a5 Mon Sep 17 00:00:00 2001 From: zms1999 Date: Fri, 25 Aug 2023 14:23:41 +0800 Subject: [PATCH 2/6] [BUG FIX] wait torch stream --- cuda/fastermoe/smart_schedule.h | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/cuda/fastermoe/smart_schedule.h b/cuda/fastermoe/smart_schedule.h index c5cb1ef8..022bf4d7 100644 --- a/cuda/fastermoe/smart_schedule.h +++ b/cuda/fastermoe/smart_schedule.h @@ -139,9 +139,11 @@ void fmoe_cuda_fused_forward_impl( cudaEvent_t *input_ready = new cudaEvent_t[n_groups]; cudaEvent_t *output_ready = new cudaEvent_t[n_groups]; + cudaEvent_t *output_torch_ready = new cudaEvent_t[n_groups]; for (long i = 0; i < n_groups; ++i) { cudaEventCreate(input_ready + i); cudaEventCreate(output_ready + i); + cudaEventCreate(output_torch_ready + i); } // S_0 ... S_n @@ -200,6 +202,7 @@ void fmoe_cuda_fused_forward_impl( (long) ei, step * num_expert + ei, offset, micro_batch_size, d_model, smgr); } cudaEventRecord(output_ready[step], smgr->stream(0)); + cudaEventRecord(output_torch_ready[step], torch_stream); } // Compute over shadowed experts @@ -221,6 +224,7 @@ void fmoe_cuda_fused_forward_impl( // R_0 ... R_n for (long step = 0; step < n_groups; ++step) { FMOE_SWE(smgr->stream(num_expert), output_ready[step]); + FMOE_SWE(smgr->stream(num_expert), output_torch_ready[step]); for (int ei = 0; ei < num_expert; ++ei) { GEN_BASE(step); NCCL_SAFE_CALL(ncclGroupStart()); @@ -246,12 +250,14 @@ void fmoe_cuda_fused_forward_impl( for (long i = 0; i < n_groups; ++i) { cudaEventDestroy(input_ready[i]); cudaEventDestroy(output_ready[i]); + cudaEventDestroy(output_torch_ready[i]); } for (unsigned i = 0; i < params.size(); ++i) { cudaEventDestroy(evt_shadow[i]); } delete [] input_ready; delete [] output_ready; + delete [] output_torch_ready; } @@ -292,9 +298,11 @@ void fmoe_cuda_fused_backward_impl( cudaEvent_t *input_ready = new cudaEvent_t[n_groups]; cudaEvent_t *output_ready = new cudaEvent_t[n_groups]; + cudaEvent_t *output_torch_ready = new cudaEvent_t[n_groups]; for (long i = 0; i < n_groups; ++i) { cudaEventCreate(input_ready + i); cudaEventCreate(output_ready + i); + cudaEventCreate(output_torch_ready + i); } // S_0 ... S_n @@ -352,6 +360,7 @@ void fmoe_cuda_fused_backward_impl( (long) ei, step * num_expert + ei, offset, micro_batch_size, d_model, smgr); } cudaEventRecord(output_ready[step], smgr->stream(0)); + cudaEventRecord(output_torch_ready[step], torch_stream); } // Collect gradients for shadowed experts @@ -369,6 +378,7 @@ void fmoe_cuda_fused_backward_impl( // R_0 ... R_n for (long step = 0; step < n_groups; ++step) { FMOE_SWE(smgr->stream(num_expert), output_ready[step]); + FMOE_SWE(smgr->stream(num_expert), output_torch_ready[step]); for (int ei = 0; ei < num_expert; ++ei) { GEN_BASE(step); NCCL_SAFE_CALL(ncclGroupStart()); @@ -396,9 +406,11 @@ void fmoe_cuda_fused_backward_impl( for (long i = 0; i < n_groups; ++i) { cudaEventDestroy(input_ready[i]); cudaEventDestroy(output_ready[i]); + cudaEventDestroy(output_torch_ready[i]); } delete [] input_ready; delete [] output_ready; + delete [] output_torch_ready; for (long i = 0; i < num_expert; ++i) { if (stored_models[i + rank * num_expert]) { cudaEventDestroy(evt_reduce[i]); From 2bd187cb3fb4f0d044d8ea636a75e27025ada96d Mon Sep 17 00:00:00 2001 From: zms1999 Date: Sat, 2 Sep 2023 19:00:13 +0800 Subject: [PATCH 3/6] [BUG FIX] sync torch stream before nccl send/recv --- cuda/fastermoe/smart_schedule.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cuda/fastermoe/smart_schedule.h b/cuda/fastermoe/smart_schedule.h index 022bf4d7..4d6a0ba9 100644 --- a/cuda/fastermoe/smart_schedule.h +++ b/cuda/fastermoe/smart_schedule.h @@ -123,6 +123,7 @@ void fmoe_cuda_fused_forward_impl( long num_expert, long rank, long world_size, long expert_size, long pipeline_gran, CudaStreamManager* smgr) { auto torch_stream = c10::cuda::getCurrentCUDAStream().stream(); + cudaStreamSynchronize(torch_stream); int *local_ptr = new int[num_expert * world_size + 1]; int *global_ptr = new int[num_expert * world_size + 1]; @@ -282,6 +283,7 @@ void fmoe_cuda_fused_backward_impl( long num_expert, long rank, long world_size, long pipeline_gran, CudaStreamManager* smgr) { auto torch_stream = c10::cuda::getCurrentCUDAStream().stream(); + cudaStreamSynchronize(torch_stream); int *local_ptr = new int[num_expert * world_size + 1]; int *global_ptr = new int[num_expert * world_size + 1]; From 4f9f77f86eb56b130d471fe39edbf66305df4a31 Mon Sep 17 00:00:00 2001 From: Rick Ho Date: Mon, 11 Sep 2023 15:26:03 +0800 Subject: [PATCH 4/6] use torchstream everywhere --- cuda/balancing.cuh | 6 ++---- cuda/fastermoe/smart_schedule.h | 18 ++++++++---------- cuda/global_exchange.cpp | 5 ++--- cuda/global_exchange.h | 10 ++++------ cuda/local_exchange.cuh | 6 ++---- cuda/parallel_linear.cuh | 2 ++ cuda/stream_manager.cpp | 8 ++++++++ cuda/stream_manager.h | 2 ++ 8 files changed, 30 insertions(+), 27 deletions(-) diff --git a/cuda/balancing.cuh b/cuda/balancing.cuh index e80d0b60..17100733 100644 --- a/cuda/balancing.cuh +++ b/cuda/balancing.cuh @@ -25,9 +25,8 @@ void fmoe_cuda_limit_by_capacity_impl(const long* ec, int* cap, CudaStreamManager* smgr) { dim3 grid_dim(CEIL(n_worker, 1024), n_expert); dim3 block_dim(1024); - limit_by_capacity_kernel<<stream(0)>>>( + limit_by_capacity_kernel<<torchStream()>>>( ec, cap, eca, n_expert, n_worker); - smgr->sync(1); } __global__ @@ -51,8 +50,7 @@ void fmoe_cuda_prune_gate_by_capacity_impl(long* gate_idx, long* new_gate_idx, CudaStreamManager* smgr) { dim3 grid_dim(CEIL(batch_size, 1024)); dim3 block_dim(1024); - prune_gate_by_capacity_kernel<<stream(0)>>>( + prune_gate_by_capacity_kernel<<torchStream()>>>( gate_idx, new_gate_idx, ec, batch_size, n_expert, n_worker ); - smgr->sync(1); } diff --git a/cuda/fastermoe/smart_schedule.h b/cuda/fastermoe/smart_schedule.h index 4d6a0ba9..dac8005b 100644 --- a/cuda/fastermoe/smart_schedule.h +++ b/cuda/fastermoe/smart_schedule.h @@ -122,8 +122,7 @@ void fmoe_cuda_fused_forward_impl( long d_model, long num_expert, long rank, long world_size, long expert_size, long pipeline_gran, CudaStreamManager* smgr) { - auto torch_stream = c10::cuda::getCurrentCUDAStream().stream(); - cudaStreamSynchronize(torch_stream); + smgr->syncTorch(); int *local_ptr = new int[num_expert * world_size + 1]; int *global_ptr = new int[num_expert * world_size + 1]; @@ -192,7 +191,7 @@ void fmoe_cuda_fused_forward_impl( // C_0 ... C_n for (long step = 0; step < n_groups; ++step) { FMOE_SWE(smgr->stream(0), input_ready[step]); - FMOE_SWE(torch_stream, input_ready[step]); + FMOE_SWE(smgr->torchStream(), input_ready[step]); for (int ei = 0; ei < num_expert; ++ei) { GEN_BASE(step); long offset = global_ptr[ei * world_size + from_base]; @@ -203,14 +202,14 @@ void fmoe_cuda_fused_forward_impl( (long) ei, step * num_expert + ei, offset, micro_batch_size, d_model, smgr); } cudaEventRecord(output_ready[step], smgr->stream(0)); - cudaEventRecord(output_torch_ready[step], torch_stream); + cudaEventRecord(output_torch_ready[step], smgr->torchStream()); } // Compute over shadowed experts for (long i = 0, si = 0; i < world_size * num_expert; ++i) { if (stored_models[i]) { FMOE_SWE(smgr->stream(0), evt_shadow[si]); - FMOE_SWE(torch_stream, evt_shadow[si]); + FMOE_SWE(smgr->torchStream(), evt_shadow[si]); stash_fn(params[si], si, 0); // always put shadowed expert at first, so expert_idx = 0 long offset = local_ptr[i]; long micro_batch_size = local_expert_count[i]; @@ -282,8 +281,7 @@ void fmoe_cuda_fused_backward_impl( long d_model, long num_expert, long rank, long world_size, long pipeline_gran, CudaStreamManager* smgr) { - auto torch_stream = c10::cuda::getCurrentCUDAStream().stream(); - cudaStreamSynchronize(torch_stream); + smgr->syncTorch(); int *local_ptr = new int[num_expert * world_size + 1]; int *global_ptr = new int[num_expert * world_size + 1]; @@ -350,7 +348,7 @@ void fmoe_cuda_fused_backward_impl( // C_0 ... C_n for (long step = 0; step < n_groups; ++step) { FMOE_SWE(smgr->stream(0), input_ready[step]); - FMOE_SWE(torch_stream, input_ready[step]); + FMOE_SWE(smgr->torchStream(), input_ready[step]); for (int ei = 0; ei < num_expert; ++ei) { GEN_BASE(step); long offset = global_ptr[ei * world_size + from_base]; @@ -362,7 +360,7 @@ void fmoe_cuda_fused_backward_impl( (long) ei, step * num_expert + ei, offset, micro_batch_size, d_model, smgr); } cudaEventRecord(output_ready[step], smgr->stream(0)); - cudaEventRecord(output_torch_ready[step], torch_stream); + cudaEventRecord(output_torch_ready[step], smgr->torchStream()); } // Collect gradients for shadowed experts @@ -370,7 +368,7 @@ void fmoe_cuda_fused_backward_impl( if (stored_models[i]) { if (i / num_expert == rank) { FMOE_SWE(smgr->stream(0), evt_reduce[i % num_expert]); - FMOE_SWE(torch_stream, evt_reduce[i % num_expert]); + FMOE_SWE(smgr->torchStream(), evt_reduce[i % num_expert]); set_grad_fn(si, i % num_expert); } ++si; diff --git a/cuda/global_exchange.cpp b/cuda/global_exchange.cpp index 2fbde1ff..c700b1a4 100644 --- a/cuda/global_exchange.cpp +++ b/cuda/global_exchange.cpp @@ -19,17 +19,16 @@ void fmoe_cuda_expert_exchange_impl( ncclInt64, i, smgr->ncclcomm, - smgr->stream(0))); + smgr->torchStream())); NCCL_SAFE_CALL(ncclRecv( global_expert_count + n_expert * i, n_expert, ncclInt64, i, smgr->ncclcomm, - smgr->stream(0))); + smgr->torchStream())); } NCCL_SAFE_CALL(ncclGroupEnd()); - smgr->sync(1); } torch::Tensor _expert_exchange( diff --git a/cuda/global_exchange.h b/cuda/global_exchange.h index a9f612fe..f3f2d02f 100644 --- a/cuda/global_exchange.h +++ b/cuda/global_exchange.h @@ -36,7 +36,7 @@ void fmoe_cuda_global_scatter_impl( ncclChar, j, smgr->ncclcomm, - smgr->stream(0))); + smgr->torchStream())); } if (global_expert_count[idx]) { NCCL_SAFE_CALL(ncclRecv( @@ -45,14 +45,13 @@ void fmoe_cuda_global_scatter_impl( ncclChar, j, smgr->ncclcomm, - smgr->stream(0))); + smgr->torchStream())); recv_ptr += global_expert_count[idx]; } } NCCL_SAFE_CALL(ncclGroupEnd()); } delete [] expert_ptr; - smgr->sync(1); } template @@ -82,7 +81,7 @@ void fmoe_cuda_global_gather_impl( ncclChar, j, smgr->ncclcomm, - smgr->stream(0))); + smgr->torchStream())); send_ptr += global_expert_count[idx]; } if (local_expert_count[idx]) { @@ -92,13 +91,12 @@ void fmoe_cuda_global_gather_impl( ncclChar, j, smgr->ncclcomm, - smgr->stream(0))); + smgr->torchStream())); } } NCCL_SAFE_CALL(ncclGroupEnd()); } delete [] expert_ptr; - smgr->sync(1); } diff --git a/cuda/local_exchange.cuh b/cuda/local_exchange.cuh index e9f0aafc..2a5d1fae 100644 --- a/cuda/local_exchange.cuh +++ b/cuda/local_exchange.cuh @@ -21,9 +21,8 @@ void fmoe_cuda_assign_pos_impl( CudaStreamManager* smgr) { size_t numel = batch_size * topk; assign_pos_kernel - <<stream(0)>>> + <<torchStream()>>> (cum_count, gate, pos, numel, topk); - smgr->sync(1); } #define PERTHREAD_EXPERTS 256 @@ -74,7 +73,6 @@ void fmoe_cuda_expert_count_impl( const size_t batch_size, const size_t n_expert, CudaStreamManager* smgr) { expert_count_kernel - <<stream(0)>>> + <<torchStream()>>> (gate_idx, expert_count, batch_size, n_expert); - smgr->sync(1); } diff --git a/cuda/parallel_linear.cuh b/cuda/parallel_linear.cuh index 47bd7175..a9db3de6 100644 --- a/cuda/parallel_linear.cuh +++ b/cuda/parallel_linear.cuh @@ -65,6 +65,7 @@ void fmoe_cuda_linear_forward_impl( CudaStreamManager* smgr) { scalar_t alpha = 1, beta = has_bias ? 1 : 0; + smgr->syncTorch(); for (int i = 0, ptr = 0; i < num_expert; ++i) { if (expert_count[i] == 0) { continue; @@ -102,6 +103,7 @@ void fmoe_cuda_linear_backward_impl( const size_t out_feat, const size_t num_expert, CudaStreamManager* smgr) { + smgr->syncTorch(); scalar_t alpha = 1, beta = 0; // bias diff --git a/cuda/stream_manager.cpp b/cuda/stream_manager.cpp index 459dec2f..70ba9ca5 100644 --- a/cuda/stream_manager.cpp +++ b/cuda/stream_manager.cpp @@ -19,6 +19,10 @@ cudaStream_t CudaStreamManager::stream(size_t idx) { return this->streams[idx % SMGR_N_STREAMS]; } +cudaStream_t CudaStreamManager::torchStream() { + return c10::cuda::getCurrentCUDAStream().stream(); +} + cublasHandle_t CudaStreamManager::handle(size_t idx) { if (this->use_default) { return at::cuda::getCurrentCUDABlasHandle(); @@ -27,6 +31,10 @@ cublasHandle_t CudaStreamManager::handle(size_t idx) { } +void CudaStreamManager::syncTorch() { + cudaStreamSynchronize(this->torchStream()); +} + void CudaStreamManager::sync(int idx) { if (this->use_default) { return; diff --git a/cuda/stream_manager.h b/cuda/stream_manager.h index 50856852..e187c015 100644 --- a/cuda/stream_manager.h +++ b/cuda/stream_manager.h @@ -34,8 +34,10 @@ class CudaStreamManager { void setup(int); void sync(int=0); + void syncTorch(); void destroy(); + cudaStream_t torchStream(); cudaStream_t stream(size_t=0); cublasHandle_t handle(size_t=0); From 226e077965a50272a1f988e110b140e88e83f240 Mon Sep 17 00:00:00 2001 From: Rick Ho Date: Mon, 11 Sep 2023 17:45:20 +0800 Subject: [PATCH 5/6] bug fix of swipe --- cuda/balancing.cu | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/cuda/balancing.cu b/cuda/balancing.cu index 599ad79a..7601cefd 100644 --- a/cuda/balancing.cu +++ b/cuda/balancing.cu @@ -104,6 +104,7 @@ std::vector _swipe_once( } long *d_lec = _h2d(lec, n_worker), *d_gec = _cudamalloc(n_worker); fmoe_cuda_expert_exchange_impl(d_lec, d_gec, 1, n_worker, smgr); + smgr->syncTorch(); long *gec = _d2h(d_gec, n_worker); /* Limit number of incoming samples */ @@ -123,17 +124,20 @@ std::vector _swipe_once( /* Send limit information back */ _h2d(gec, d_gec, n_worker); fmoe_cuda_expert_exchange_impl(d_gec, d_lec, 1, n_worker, smgr); + smgr->syncTorch(); _d2h(d_lec, lec, n_worker); auto d_dropcount = _h2d(drop_count, n_worker); ncclAllReduce(d_dropcount, d_dropcount, n_worker, ncclInt64, ncclSum, - smgr->ncclcomm, smgr->stream()); + smgr->ncclcomm, smgr->torchStream()); + smgr->syncTorch(); _d2h(d_dropcount, drop_count, n_worker); auto d_gcap = _cudamalloc(n_worker); _h2d(&cap, d_gcap + rank, 1); ncclAllGather(d_gcap + rank, d_gcap, 1, ncclInt64, - smgr->ncclcomm, smgr->stream()); + smgr->ncclcomm, smgr->torchStream()); + smgr->syncTorch(); auto gcap = _d2h(d_gcap, n_worker); /* Re-assign and update counters */ From 945004e7c5ef5a19d51c0d6e56a3881e0675b58c Mon Sep 17 00:00:00 2001 From: Rick Ho Date: Mon, 11 Sep 2023 18:53:17 +0800 Subject: [PATCH 6/6] fix shadow --- cuda/fastermoe/smart_schedule.cpp | 3 +-- fmoe/fastermoe/schedule.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/cuda/fastermoe/smart_schedule.cpp b/cuda/fastermoe/smart_schedule.cpp index ada1a2b9..de4d231b 100644 --- a/cuda/fastermoe/smart_schedule.cpp +++ b/cuda/fastermoe/smart_schedule.cpp @@ -44,10 +44,9 @@ void _reduce_grad( long expert_size) { auto smgr = getCudaStreamManager(t.device().index()); - auto torch_stream = c10::cuda::getCurrentCUDAStream().stream(); cudaEvent_t evt_stash; cudaEventCreate(&evt_stash); - cudaEventRecord(evt_stash, torch_stream); + cudaEventRecord(evt_stash, smgr->torchStream()); FMOE_SWE(smgr->stream(0), evt_stash); cudaEventDestroy(evt_stash); diff --git a/fmoe/fastermoe/schedule.py b/fmoe/fastermoe/schedule.py index 4b5c2d9f..fc84d68a 100644 --- a/fmoe/fastermoe/schedule.py +++ b/fmoe/fastermoe/schedule.py @@ -37,7 +37,7 @@ def _expert_forward(x, y, expert_idx, store_idx): try: # To skip torch autograd's version check. with torch.autograd.graph.saved_tensors_hooks(nothing, nothing): - y0 = expert_fn(x, torch.tensor([x.shape[0]], dtype=torch.int64)) + y0 = expert_fn(x, torch.tensor([x.shape[0]], dtype=torch.int64), expert_idx) except Exception as e: # Ignore the error and fall back for compatibility to older # versions of PyTorch