Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG FIX] Fix bugs in stream manager. #172

Merged
merged 7 commits into from
Sep 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions cuda/balancing.cu
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ std::vector<torch::Tensor> _swipe_once(
}
long *d_lec = _h2d(lec, n_worker), *d_gec = _cudamalloc<long>(n_worker);
fmoe_cuda_expert_exchange_impl(d_lec, d_gec, 1, n_worker, smgr);
smgr->syncTorch();
long *gec = _d2h(d_gec, n_worker);

/* Limit number of incoming samples */
Expand All @@ -123,17 +124,20 @@ std::vector<torch::Tensor> _swipe_once(
/* Send limit information back */
_h2d(gec, d_gec, n_worker);
fmoe_cuda_expert_exchange_impl(d_gec, d_lec, 1, n_worker, smgr);
smgr->syncTorch();
_d2h(d_lec, lec, n_worker);

auto d_dropcount = _h2d(drop_count, n_worker);
ncclAllReduce(d_dropcount, d_dropcount, n_worker, ncclInt64, ncclSum,
smgr->ncclcomm, smgr->stream());
smgr->ncclcomm, smgr->torchStream());
smgr->syncTorch();
_d2h(d_dropcount, drop_count, n_worker);

auto d_gcap = _cudamalloc<long>(n_worker);
_h2d(&cap, d_gcap + rank, 1);
ncclAllGather(d_gcap + rank, d_gcap, 1, ncclInt64,
smgr->ncclcomm, smgr->stream());
smgr->ncclcomm, smgr->torchStream());
smgr->syncTorch();
auto gcap = _d2h(d_gcap, n_worker);

/* Re-assign and update counters */
Expand Down
6 changes: 2 additions & 4 deletions cuda/balancing.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,8 @@ void fmoe_cuda_limit_by_capacity_impl(const long* ec, int* cap,
CudaStreamManager* smgr) {
dim3 grid_dim(CEIL(n_worker, 1024), n_expert);
dim3 block_dim(1024);
limit_by_capacity_kernel<<<grid_dim, block_dim, 0, smgr->stream(0)>>>(
limit_by_capacity_kernel<<<grid_dim, block_dim, 0, smgr->torchStream()>>>(
ec, cap, eca, n_expert, n_worker);
smgr->sync(1);
}

__global__
Expand All @@ -51,8 +50,7 @@ void fmoe_cuda_prune_gate_by_capacity_impl(long* gate_idx, long* new_gate_idx,
CudaStreamManager* smgr) {
dim3 grid_dim(CEIL(batch_size, 1024));
dim3 block_dim(1024);
prune_gate_by_capacity_kernel<<<grid_dim, block_dim, 0, smgr->stream(0)>>>(
prune_gate_by_capacity_kernel<<<grid_dim, block_dim, 0, smgr->torchStream()>>>(
gate_idx, new_gate_idx, ec, batch_size, n_expert, n_worker
);
smgr->sync(1);
}
3 changes: 1 addition & 2 deletions cuda/fastermoe/smart_schedule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,9 @@ void _reduce_grad(
long expert_size) {
auto smgr = getCudaStreamManager(t.device().index());

auto torch_stream = c10::cuda::getCurrentCUDAStream().stream();
cudaEvent_t evt_stash;
cudaEventCreate(&evt_stash);
cudaEventRecord(evt_stash, torch_stream);
cudaEventRecord(evt_stash, smgr->torchStream());
FMOE_SWE(smgr->stream(0), evt_stash);
cudaEventDestroy(evt_stash);

Expand Down
62 changes: 39 additions & 23 deletions cuda/fastermoe/smart_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ void fmoe_cuda_fused_forward_impl(
long d_model,
long num_expert, long rank, long world_size, long expert_size,
long pipeline_gran, CudaStreamManager* smgr) {
auto torch_stream = c10::cuda::getCurrentCUDAStream().stream();
smgr->syncTorch();

int *local_ptr = new int[num_expert * world_size + 1];
int *global_ptr = new int[num_expert * world_size + 1];
Expand All @@ -139,9 +139,11 @@ void fmoe_cuda_fused_forward_impl(

cudaEvent_t *input_ready = new cudaEvent_t[n_groups];
cudaEvent_t *output_ready = new cudaEvent_t[n_groups];
cudaEvent_t *output_torch_ready = new cudaEvent_t[n_groups];
for (long i = 0; i < n_groups; ++i) {
cudaEventCreate(input_ready + i);
cudaEventCreate(output_ready + i);
cudaEventCreate(output_torch_ready + i);
}

// S_0 ... S_n
Expand All @@ -157,11 +159,11 @@ void fmoe_cuda_fused_forward_impl(
local_expert_count[idx_send] * !stored_models[idx_send], rank_send,
global_input_buf + global_ptr[gidx_recv] * d_model,
global_expert_count[idx_recv] * !stored_models[idx_self], rank_recv,
d_model, smgr->stream(0), smgr->ncclcomm);
d_model, smgr->stream(num_expert), smgr->ncclcomm);
}
NCCL_SAFE_CALL(ncclGroupEnd());
}
cudaEventRecord(input_ready[step], smgr->stream(0));
cudaEventRecord(input_ready[step], smgr->stream(num_expert));
}

// Broadcast shadowed experts
Expand All @@ -173,22 +175,23 @@ void fmoe_cuda_fused_forward_impl(
if (stored_models[i]) {
if (i / num_expert == rank) {
cudaEventCreate(&evt_get);
cudaEventRecord(evt_get, torch_stream);
FMOE_SWE(smgr->stream(0), evt_get);
cudaEventRecord(evt_get, smgr->stream(0));
FMOE_SWE(smgr->stream(num_expert), evt_get);
cudaEventDestroy(evt_get);
}
NCCL_SAFE_CALL(ncclBcast((void*)params[si].data_ptr<scalar_t>(),
expert_size * sizeof(scalar_t), ncclChar,
i / num_expert, smgr->ncclcomm, smgr->stream(0)));
i / num_expert, smgr->ncclcomm, smgr->stream(num_expert)));
cudaEventCreate(evt_shadow + si);
cudaEventRecord(evt_shadow[si], smgr->stream(0));
cudaEventRecord(evt_shadow[si], smgr->stream(num_expert));
++si;
}
}

// C_0 ... C_n
for (long step = 0; step < n_groups; ++step) {
FMOE_SWE(torch_stream, input_ready[step]);
FMOE_SWE(smgr->stream(0), input_ready[step]);
FMOE_SWE(smgr->torchStream(), input_ready[step]);
for (int ei = 0; ei < num_expert; ++ei) {
GEN_BASE(step);
long offset = global_ptr[ei * world_size + from_base];
Expand All @@ -198,13 +201,15 @@ void fmoe_cuda_fused_forward_impl(
global_input_buf, global_output_buf,
(long) ei, step * num_expert + ei, offset, micro_batch_size, d_model, smgr);
}
cudaEventRecord(output_ready[step], torch_stream);
cudaEventRecord(output_ready[step], smgr->stream(0));
cudaEventRecord(output_torch_ready[step], smgr->torchStream());
}

// Compute over shadowed experts
for (long i = 0, si = 0; i < world_size * num_expert; ++i) {
if (stored_models[i]) {
FMOE_SWE(torch_stream, evt_shadow[si]);
FMOE_SWE(smgr->stream(0), evt_shadow[si]);
FMOE_SWE(smgr->torchStream(), evt_shadow[si]);
stash_fn(params[si], si, 0); // always put shadowed expert at first, so expert_idx = 0
long offset = local_ptr[i];
long micro_batch_size = local_expert_count[i];
Expand All @@ -218,7 +223,8 @@ void fmoe_cuda_fused_forward_impl(

// R_0 ... R_n
for (long step = 0; step < n_groups; ++step) {
FMOE_SWE(smgr->stream(0), output_ready[step]);
FMOE_SWE(smgr->stream(num_expert), output_ready[step]);
FMOE_SWE(smgr->stream(num_expert), output_torch_ready[step]);
for (int ei = 0; ei < num_expert; ++ei) {
GEN_BASE(step);
NCCL_SAFE_CALL(ncclGroupStart());
Expand All @@ -230,12 +236,12 @@ void fmoe_cuda_fused_forward_impl(
global_expert_count[idx_send] * !stored_models[idx_self], rank_send,
output_buf + local_ptr[idx_recv] * d_model,
local_expert_count[idx_recv] * !stored_models[idx_recv], rank_recv,
d_model, smgr->stream(0), smgr->ncclcomm);
d_model, smgr->stream(num_expert), smgr->ncclcomm);
}
NCCL_SAFE_CALL(ncclGroupEnd());
}
}
smgr->sync(1);
smgr->sync(num_expert + 1);

delete [] local_ptr;
delete [] global_ptr;
Expand All @@ -244,12 +250,14 @@ void fmoe_cuda_fused_forward_impl(
for (long i = 0; i < n_groups; ++i) {
cudaEventDestroy(input_ready[i]);
cudaEventDestroy(output_ready[i]);
cudaEventDestroy(output_torch_ready[i]);
}
for (unsigned i = 0; i < params.size(); ++i) {
cudaEventDestroy(evt_shadow[i]);
}
delete [] input_ready;
delete [] output_ready;
delete [] output_torch_ready;
}


Expand All @@ -273,7 +281,7 @@ void fmoe_cuda_fused_backward_impl(
long d_model,
long num_expert, long rank, long world_size,
long pipeline_gran, CudaStreamManager* smgr) {
auto torch_stream = c10::cuda::getCurrentCUDAStream().stream();
smgr->syncTorch();

int *local_ptr = new int[num_expert * world_size + 1];
int *global_ptr = new int[num_expert * world_size + 1];
Expand All @@ -290,9 +298,11 @@ void fmoe_cuda_fused_backward_impl(

cudaEvent_t *input_ready = new cudaEvent_t[n_groups];
cudaEvent_t *output_ready = new cudaEvent_t[n_groups];
cudaEvent_t *output_torch_ready = new cudaEvent_t[n_groups];
for (long i = 0; i < n_groups; ++i) {
cudaEventCreate(input_ready + i);
cudaEventCreate(output_ready + i);
cudaEventCreate(output_torch_ready + i);
}

// S_0 ... S_n
Expand All @@ -308,11 +318,11 @@ void fmoe_cuda_fused_backward_impl(
local_expert_count[idx_send] * !stored_models[idx_send], rank_send,
global_grad_out + global_ptr[gidx_recv] * d_model,
global_expert_count[idx_recv] * !stored_models[idx_self], rank_recv,
d_model, smgr->stream(0), smgr->ncclcomm);
d_model, smgr->stream(num_expert), smgr->ncclcomm);
}
NCCL_SAFE_CALL(ncclGroupEnd());
}
cudaEventRecord(input_ready[step], smgr->stream(0));
cudaEventRecord(input_ready[step], smgr->stream(num_expert));
}

// Shadowed experts backward and reduce
Expand All @@ -328,7 +338,7 @@ void fmoe_cuda_fused_backward_impl(
collect_fn(si, i / num_expert, 0);
if (i / num_expert == rank) {
cudaEventCreate(evt_reduce + i % num_expert);
cudaEventRecord(evt_reduce[i % num_expert], smgr->stream(0));
cudaEventRecord(evt_reduce[i % num_expert], smgr->stream(num_expert));
}
++si;
}
Expand All @@ -337,7 +347,8 @@ void fmoe_cuda_fused_backward_impl(

// C_0 ... C_n
for (long step = 0; step < n_groups; ++step) {
FMOE_SWE(torch_stream, input_ready[step]);
FMOE_SWE(smgr->stream(0), input_ready[step]);
FMOE_SWE(smgr->torchStream(), input_ready[step]);
for (int ei = 0; ei < num_expert; ++ei) {
GEN_BASE(step);
long offset = global_ptr[ei * world_size + from_base];
Expand All @@ -348,14 +359,16 @@ void fmoe_cuda_fused_backward_impl(
global_grad_out, global_grad_in,
(long) ei, step * num_expert + ei, offset, micro_batch_size, d_model, smgr);
}
cudaEventRecord(output_ready[step], torch_stream);
cudaEventRecord(output_ready[step], smgr->stream(0));
cudaEventRecord(output_torch_ready[step], smgr->torchStream());
}

// Collect gradients for shadowed experts
for (long i = 0, si = 0; i < world_size * num_expert; ++i) {
if (stored_models[i]) {
if (i / num_expert == rank) {
FMOE_SWE(torch_stream, evt_reduce[i % num_expert]);
FMOE_SWE(smgr->stream(0), evt_reduce[i % num_expert]);
FMOE_SWE(smgr->torchStream(), evt_reduce[i % num_expert]);
set_grad_fn(si, i % num_expert);
}
++si;
Expand All @@ -364,7 +377,8 @@ void fmoe_cuda_fused_backward_impl(

// R_0 ... R_n
for (long step = 0; step < n_groups; ++step) {
FMOE_SWE(smgr->stream(0), output_ready[step]);
FMOE_SWE(smgr->stream(num_expert), output_ready[step]);
FMOE_SWE(smgr->stream(num_expert), output_torch_ready[step]);
for (int ei = 0; ei < num_expert; ++ei) {
GEN_BASE(step);
NCCL_SAFE_CALL(ncclGroupStart());
Expand All @@ -376,13 +390,13 @@ void fmoe_cuda_fused_backward_impl(
global_expert_count[idx_send] * !stored_models[idx_self], rank_send,
grad_in + local_ptr[idx_recv] * d_model,
local_expert_count[idx_recv] * !stored_models[idx_recv], rank_recv,
d_model, smgr->stream(0), smgr->ncclcomm);
d_model, smgr->stream(num_expert), smgr->ncclcomm);
}
NCCL_SAFE_CALL(ncclGroupEnd());
}
}

smgr->sync(1);
smgr->sync(num_expert + 1);
checkCudaErrors(cudaGetLastError());

delete [] local_ptr;
Expand All @@ -392,9 +406,11 @@ void fmoe_cuda_fused_backward_impl(
for (long i = 0; i < n_groups; ++i) {
cudaEventDestroy(input_ready[i]);
cudaEventDestroy(output_ready[i]);
cudaEventDestroy(output_torch_ready[i]);
}
delete [] input_ready;
delete [] output_ready;
delete [] output_torch_ready;
for (long i = 0; i < num_expert; ++i) {
if (stored_models[i + rank * num_expert]) {
cudaEventDestroy(evt_reduce[i]);
Expand Down
5 changes: 2 additions & 3 deletions cuda/global_exchange.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,16 @@ void fmoe_cuda_expert_exchange_impl(
ncclInt64,
i,
smgr->ncclcomm,
smgr->stream(0)));
smgr->torchStream()));
NCCL_SAFE_CALL(ncclRecv(
global_expert_count + n_expert * i,
n_expert,
ncclInt64,
i,
smgr->ncclcomm,
smgr->stream(0)));
smgr->torchStream()));
}
NCCL_SAFE_CALL(ncclGroupEnd());
smgr->sync(1);
}

torch::Tensor _expert_exchange(
Expand Down
10 changes: 4 additions & 6 deletions cuda/global_exchange.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ void fmoe_cuda_global_scatter_impl(
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(0)));
smgr->torchStream()));
}
if (global_expert_count[idx]) {
NCCL_SAFE_CALL(ncclRecv(
Expand All @@ -45,14 +45,13 @@ void fmoe_cuda_global_scatter_impl(
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(0)));
smgr->torchStream()));
recv_ptr += global_expert_count[idx];
}
}
NCCL_SAFE_CALL(ncclGroupEnd());
}
delete [] expert_ptr;
smgr->sync(1);
}

template<typename scalar_t>
Expand Down Expand Up @@ -82,7 +81,7 @@ void fmoe_cuda_global_gather_impl(
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(0)));
smgr->torchStream()));
send_ptr += global_expert_count[idx];
}
if (local_expert_count[idx]) {
Expand All @@ -92,13 +91,12 @@ void fmoe_cuda_global_gather_impl(
ncclChar,
j,
smgr->ncclcomm,
smgr->stream(0)));
smgr->torchStream()));
}
}
NCCL_SAFE_CALL(ncclGroupEnd());
}
delete [] expert_ptr;
smgr->sync(1);
}


Expand Down
6 changes: 2 additions & 4 deletions cuda/local_exchange.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@ void fmoe_cuda_assign_pos_impl(
CudaStreamManager* smgr) {
size_t numel = batch_size * topk;
assign_pos_kernel
<<<CEIL(numel, 256), 256, 0, smgr->stream(0)>>>
<<<CEIL(numel, 256), 256, 0, smgr->torchStream()>>>
(cum_count, gate, pos, numel, topk);
smgr->sync(1);
}

#define PERTHREAD_EXPERTS 256
Expand Down Expand Up @@ -74,7 +73,6 @@ void fmoe_cuda_expert_count_impl(
const size_t batch_size, const size_t n_expert,
CudaStreamManager* smgr) {
expert_count_kernel
<<<CEIL(n_expert, PERTHREAD_EXPERTS), 256, 0, smgr->stream(0)>>>
<<<CEIL(n_expert, PERTHREAD_EXPERTS), 256, 0, smgr->torchStream()>>>
(gate_idx, expert_count, batch_size, n_expert);
smgr->sync(1);
}
Loading