Skip to content

Commit

Permalink
Have a single _FusedSequenceParallel class handle all dtypes (fairint…
Browse files Browse the repository at this point in the history
…ernal/xformers#1144)

* have a single _FusedSequenceParallel class handle all dtypes

* Fix flake8 linter + refactor

* Added type annotation

* use dtype.itemsize

* use uint8 for opaque bytes

* remove the extra parentheses

* remove useless buffer_metadata: total_num_bytes computation is cheap

* remove paranthesis

* remove paranthesis

* use the same sequence number for each dtype

* using staging.view(dtype)

* simplified uint8 handling in linear_and_reducescatter

* using linters versions from requirements-test.txt

* removed useless pair of parentheses

* added scattered_inputs elements dtype consistency check

* fixed my_matmul multiline call formatting mishap

* black format fix for _ensure_staging_is_large_enough call

* add test showcasing handling multiple dtypes

* refactored fused and non fused output comparison

* put subbatch_dims line back

__original_commit__ = fairinternal/xformers@4b4d8e7
  • Loading branch information
lvaleriu authored and xFormers Bot committed Jul 18, 2024
1 parent 2456ea3 commit 71308aa
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 46 deletions.
97 changes: 82 additions & 15 deletions tests/test_sequence_parallel_fused_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,25 +38,15 @@
_xformers_seqpar_matmul_kernel.configs.pop()


def inner_sequence_parallel_fused(
seed: int,
kind: str,
def compare_fused_and_non_fused_ops(
my_rank: int,
world_size: int,
subgroup: torch.distributed.ProcessGroup,
step: str,
dims: Tuple[int, ...],
dtype: torch.dtype,
triton: bool,
):
my_rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
subgroup = torch.distributed.new_group()

triton = True
if kind == "fallback":
os.environ["DISABLE_FUSED_SEQUENCE_PARALLEL"] = "1"
elif kind == "pytorch":
triton = False

torch.random.manual_seed(seed)

batch_dims = dims[:-2]
subbatch_dims = (batch_dims[0] // world_size,) + batch_dims[1:]
outer_dim = dims[-2]
Expand Down Expand Up @@ -128,6 +118,36 @@ def inner_sequence_parallel_fused(
torch.testing.assert_close(output_reference, output_fused, atol=0, rtol=0)


def inner_sequence_parallel_fused(
seed: int,
kind: str,
step: str,
dims: Tuple[int, ...],
dtype: torch.dtype,
):
my_rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
subgroup = torch.distributed.new_group()

triton = True
if kind == "fallback":
os.environ["DISABLE_FUSED_SEQUENCE_PARALLEL"] = "1"
elif kind == "pytorch":
triton = False

torch.random.manual_seed(seed)

compare_fused_and_non_fused_ops(
my_rank=my_rank,
world_size=world_size,
subgroup=subgroup,
step=step,
dims=dims,
dtype=dtype,
triton=triton,
)


@cuda_sm70_only
@pytest.mark.parametrize(
"kind",
Expand Down Expand Up @@ -166,3 +186,50 @@ def test_sequence_parallel_fused(
dims=dims,
dtype=dtype,
)


def inner_sequence_parallel_fused_triton_handle_all_dtypes(
seed: int,
step: str,
dims: Tuple[int, ...],
):
my_rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
subgroup = torch.distributed.new_group()

torch.random.manual_seed(seed)

for dtype in [torch.bfloat16, torch.float16, torch.float32]:
compare_fused_and_non_fused_ops(
my_rank=my_rank,
world_size=world_size,
subgroup=subgroup,
step=step,
dims=dims,
dtype=dtype,
triton=True,
)


@cuda_sm70_only
@pytest.mark.parametrize("step", ["all-gather", "reduce-scatter"])
@pytest.mark.parametrize(
"dims",
[
pytest.param((2, 2, 512, 512, 256), id="nice-shapes"),
pytest.param((2, 1023, 511, 257), id="ugly-shapes"),
],
)
def test_sequence_parallel_fused_triton_handle_all_dtypes(
step: str,
dims: Tuple[int, ...],
):
world_size = 2
seed = random.getrandbits(32)
launch_subprocesses(
world_size,
inner_sequence_parallel_fused_triton_handle_all_dtypes,
seed=seed,
step=step,
dims=dims,
)
60 changes: 29 additions & 31 deletions xformers/ops/sequence_parallel_fused_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree.

import os
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union, overload
from typing import Any, Callable, Dict, List, Mapping, Optional, Union, overload

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -97,12 +97,10 @@ class _FusedSequenceParallel:
def __init__(
self,
device: torch.device,
dtype: torch.dtype,
group: dist.ProcessGroup,
num_stripes: int,
):
self.my_device = device
self.dtype = dtype
self.my_rank = group.rank()
self.world_size = group.size()
self.num_stripes = num_stripes
Expand Down Expand Up @@ -156,18 +154,22 @@ def __init__(

self.next_stream_idx = 0

def _ensure_staging_is_large_enough(self, num_elements: int, random_init: bool):
def _ensure_staging_is_large_enough(
self, num_elements: int, random_init: bool, dtype: torch.dtype
):
total_num_bytes = num_elements * dtype.itemsize

# Lazily size up the staging area as needed. (If it's the first call,
# this will always trigger, since staging starts empty). Once at steady
# state, staging will be of the right (max) size and never grow again.
if self.staging.numel() < self.world_size * num_elements:
if self.staging.numel() < self.world_size * total_num_bytes:
# When running with _memcpy=False (i.e., for benchmarks) we must
# ensure that the staging buffer doesn't contain all zeroes as that
# makes the matmuls go faster (better L2 compression or something).
self.staging = torch.empty(
(self.num_stripes, self.world_size, num_elements),
(self.num_stripes, self.world_size, total_num_bytes),
device=self.my_device,
dtype=self.dtype,
dtype=torch.uint8,
)
if random_init:
self.staging.normal_()
Expand Down Expand Up @@ -223,13 +225,14 @@ def allgather_and_linear(
):
"""Perform a fused all-gather followed by a linear layer"""

dtype = scattered_inputs[0].dtype
assert all(si.device == self.my_device for si in scattered_inputs)
assert all(si.dtype == self.dtype for si in scattered_inputs)
assert all(si.dtype == dtype for si in scattered_inputs)

scattered_input_numels = [si.numel() for si in scattered_inputs]
total_scattered_input_numel = sum(scattered_input_numels)
self._ensure_staging_is_large_enough(
total_scattered_input_numel, random_init=_memcpy is False
total_scattered_input_numel, random_init=_memcpy is False, dtype=dtype
)

stripe = self.next_stripe % self.num_stripes
Expand All @@ -242,7 +245,7 @@ def allgather_and_linear(
stagings = [
s.view((self.world_size,) + si.shape)
for s, si in zip(
self.staging[stripe, :, :total_scattered_input_numel].split(
self.staging.view(dtype)[stripe, :, :total_scattered_input_numel].split(
scattered_input_numels, dim=-1
),
scattered_inputs,
Expand All @@ -254,7 +257,7 @@ def allgather_and_linear(
else [
s.view(si.shape)
for s, si in zip(
bs[stripe, :total_scattered_input_numel].split(
bs.view(dtype)[stripe, :total_scattered_input_numel].split(
scattered_input_numels, dim=-1
),
scattered_inputs,
Expand Down Expand Up @@ -388,15 +391,16 @@ def linear_and_reducescatter(
):
"""Perform a fused linear layer followed by a reduce-scatter"""

dtype = gathered_outputs[0].dtype
assert all(go.device == self.my_device for go in gathered_outputs)
assert all(go.dtype == self.dtype for go in gathered_outputs)
assert all(go.dtype == dtype for go in gathered_outputs)
assert all(so.device == self.my_device for so in scattered_outputs)
assert all(so.dtype == self.dtype for so in scattered_outputs)
assert all(so.dtype == dtype for so in scattered_outputs)

scattered_output_numels = [so.numel() for so in scattered_outputs]
total_scattered_output_numel = sum(scattered_output_numels)
self._ensure_staging_is_large_enough(
total_scattered_output_numel, random_init=_memcpy is False
total_scattered_output_numel, random_init=_memcpy is False, dtype=dtype
)

stripe = self.next_stripe % self.num_stripes
Expand All @@ -409,9 +413,9 @@ def linear_and_reducescatter(
stagings = [
s.view((self.world_size,) + so.shape)
for s, so in zip(
self.staging[stripe, :, :total_scattered_output_numel].split(
scattered_output_numels, dim=-1
),
self.staging.view(dtype)[
stripe, :, :total_scattered_output_numel
].split(scattered_output_numels, dim=-1),
scattered_outputs,
)
]
Expand All @@ -421,7 +425,7 @@ def linear_and_reducescatter(
else [
s.view(so.shape)
for s, so in zip(
bs[stripe, :total_scattered_output_numel].split(
bs.view(dtype)[stripe, :total_scattered_output_numel].split(
scattered_output_numels, dim=-1
),
scattered_outputs,
Expand Down Expand Up @@ -550,7 +554,7 @@ def linear_and_reducescatter(
# We'd store this as an attribute on the PG object itself, but some PGs are
# pybind-bound classes and thus don't support it, so we simulate this as an
# external cache.
CACHE: Dict[Tuple[int, torch.dtype], Optional[_FusedSequenceParallel]] = {}
CACHE: Dict[int, Optional[_FusedSequenceParallel]] = {}


def _can_ranks_communicate_all_to_all_over_nvlink(group: dist.ProcessGroup) -> bool:
Expand All @@ -567,11 +571,11 @@ def _can_ranks_communicate_all_to_all_over_nvlink(group: dist.ProcessGroup) -> b


def _lazy_init(
device: torch.device, dtype: torch.dtype, group: dist.ProcessGroup, num_stripes: int
device: torch.device, group: dist.ProcessGroup, num_stripes: int
) -> Optional[_FusedSequenceParallel]:
world_size = group.size()
try:
obj = CACHE[(id(group), dtype)]
obj = CACHE[id(group)]
except KeyError:
if int(os.environ.get("DISABLE_FUSED_SEQUENCE_PARALLEL", "0")):
obj = None
Expand All @@ -580,8 +584,8 @@ def _lazy_init(
elif not _can_ranks_communicate_all_to_all_over_nvlink(group):
obj = None
else:
obj = _FusedSequenceParallel(device, dtype, group, num_stripes)
CACHE[(id(group), dtype)] = obj
obj = _FusedSequenceParallel(device, group, num_stripes)
CACHE[id(group)] = obj
return obj


Expand Down Expand Up @@ -782,9 +786,7 @@ def fused_allgather_and_anything(

gathered_input_shapes = [(world_size,) + si.shape for si in scattered_inputs]

obj = _lazy_init(
scattered_inputs[0].device, scattered_inputs[0].dtype, group, num_stripes
)
obj = _lazy_init(scattered_inputs[0].device, group, num_stripes)

if world_size == 1:
my_matmul(scattered_inputs, 0, _default_stream_factory)
Expand All @@ -807,7 +809,6 @@ def fused_allgather_and_anything(
# Fast path
else:
assert scattered_inputs[0].device == obj.my_device
assert scattered_inputs[0].dtype == obj.dtype
assert obj.num_stripes == num_stripes
obj.allgather_and_linear(
scattered_inputs,
Expand Down Expand Up @@ -997,9 +998,7 @@ def fused_anything_and_reducescatter(

gathered_output_shapes = [(world_size,) + so.shape for so in scattered_outputs]

obj = _lazy_init(
scattered_outputs[0].device, scattered_outputs[0].dtype, group, num_stripes
)
obj = _lazy_init(scattered_outputs[0].device, group, num_stripes)

if world_size == 1:
my_matmul(scattered_outputs, 0, _default_stream_factory)
Expand All @@ -1022,7 +1021,6 @@ def fused_anything_and_reducescatter(
# Fast path
else:
assert scattered_outputs[0].device == obj.my_device
assert scattered_outputs[0].dtype == obj.dtype
assert obj.num_stripes == num_stripes
gathered_outputs = [
scattered_outputs[0].new_empty(gos) for gos in gathered_output_shapes
Expand Down

0 comments on commit 71308aa

Please sign in to comment.