Skip to content
This repository has been archived by the owner on Nov 1, 2024. It is now read-only.

Commit

Permalink
remove changes to code needed for testing with world size of 1
Browse files Browse the repository at this point in the history
  • Loading branch information
dianaml0 committed Dec 23, 2022
1 parent 59a933e commit 931791c
Showing 1 changed file with 7 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -330,9 +330,7 @@ def backward(ctx, grad_output):
actv_out = gelu(fc1_out) if activation_fn_name == "gelu" else relu(fc1_out)

# Now wait for reduce scatter
world_size = get_tensor_model_parallel_world_size()
if world_size != 1:
handle.wait()
handle.wait()

ffn_layer_norm_output, handle = _gather_along_first_dim(
ffn_layer_norm_output, async_op=True, cached_buffer_name="mpu"
Expand All @@ -341,8 +339,7 @@ def backward(ctx, grad_output):
grad_fc2_input = grad_output.matmul(fc2_weight)

if ctx.recompute_fc1:
if world_size != 1:
handle.wait()
handle.wait()
assert fc1_out is None
fc1_out = torch.matmul(ffn_layer_norm_output, fc1_weight.t())
actv_out = gelu(fc1_out) if activation_fn_name == "gelu" else relu(fc1_out)
Expand All @@ -362,8 +359,7 @@ def backward(ctx, grad_output):
grad_fc2_weight = grad_output.t().matmul(actv_out)

grad_fc1_input = grad_actv_input.matmul(fc1_weight)
if world_size != 1:
handle.wait()
handle.wait()

grad_actv_input = SequeuceParallelTransformerBlock._collapse_first_dimensions(
grad_actv_input
Expand All @@ -380,8 +376,7 @@ def backward(ctx, grad_output):

grad_fc1_weight = grad_actv_input.t().matmul(ffn_layer_norm_output)

if world_size != 1:
handle.wait()
handle.wait()

grad_attention_output = fused_layer_norm_cuda.backward(
grad_fc1_input.contiguous(),
Expand Down Expand Up @@ -425,8 +420,7 @@ def backward(ctx, grad_output):
q, k, v, bsz, seq_len, head_dim, embed_dim_per_partition, dtype
)

if world_size != 1:
handle.wait()
handle.wait()

grad_out_proj_input = grad_attention_output.matmul(out_proj_weight)
grad_attention_output = (
Expand Down Expand Up @@ -477,8 +471,7 @@ def backward(ctx, grad_output):
cached_buffer_name="mpu",
)
grad_input = grad_kvq_proj_output.matmul(kvq_proj_weight)
if world_size != 1:
handle.wait()
handle.wait()

grad_input, handle = _reduce_scatter_along_first_dim(grad_input, async_op=True)
mha_layer_norm_output = (
Expand All @@ -492,8 +485,7 @@ def backward(ctx, grad_output):
)
)
grad_kvq_weight = grad_kvq_proj_output.t().matmul(mha_layer_norm_output)
if world_size != 1:
handle.wait()
handle.wait()

grad_input = fused_layer_norm_cuda.backward(
grad_input.contiguous(),
Expand Down

0 comments on commit 931791c

Please sign in to comment.