Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix dist attn reshape error #5366

Open
wants to merge 5 commits into
base: master
Choose a base branch
from

Conversation

tkdcjf159
Copy link

By default, DeepSpeed's DistributedAttention is set with scatter_idx = 2 and gather_idx = 0. However, if I set gather_idx to 1 and have a batch size greater than 1, an error will occur during the output all to all operation, as illustrated below. To fix this, modify the seq_world_size to -1.

def single_all_to_all(input, scatter_idx, gather_idx, group):
    # Assume input shape [2, 1024, 8, 16], scatter_idx = 1, gather_idx=2, seq_world_size=8
    seq_world_size = dist.get_world_size(group)
    inp_shape = list(input.shape) # inp_shape = [2, 1024, 8, 16]
    inp_shape[scatter_idx] = inp_shape[scatter_idx] // seq_world_size # inp_shape = [2, 128, 8, 16]
    if scatter_idx < 2:
        # Reshaping from [2, 1024, 8, 16] to [8, 128, 8, 16]: ERROR! (2 * 1024 * 8 * 16) != (8 * 128 * 8 * 16)
        # Use -1 to fix issue
        input_t = input.reshape(
            [-1, inp_shape[scatter_idx]] + \
            # [seq_world_size, inp_shape[scatter_idx]] + \
            inp_shape[scatter_idx + 1:]
        ).contiguous()
    else:
        # Transpose groups of heads with the seq-len parallel dimension to scatter them
        input_t = input.reshape(
            [-1, seq_world_size, inp_shape[scatter_idx]] + \
            inp_shape[scatter_idx + 1:]
        ).transpose(0, 1).contiguous()

    output = torch.empty_like(input_t)
    dist.all_to_all_single(output, input_t, group=group)

    # If scattering the seq-dim, transpose the heads back to the original dimension
    if scatter_idx < 2:
        output = output.transpose(0, 1).contiguous()

    return output.reshape(
        inp_shape[: gather_idx] + \
        [inp_shape[gather_idx] * seq_world_size,] + \
        inp_shape[gather_idx + 1:]).contiguous()

@tkdcjf159
Copy link
Author

@microsoft-github-policy-service agree company="Upstage"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants