Skip to content

Commit

Permalink
Fix fused seqpar after change in torch._scaled_mm
Browse files Browse the repository at this point in the history
In pytorch/pytorch@fc2913f, PyTorch changed the API of torch._scaled_mm:
- `out=` is now a single tensor (instead of a tuple tensor + amax)
- `scale_a` and `scale_b` are no longer optional

This triggered some failures in mypy (not sure why the CI was green: maybe it uses too old a PyTorch version?).
I guess this also caused failures when used, however we don't have tests for it since we don't have CI jobs for H100. And apparently no one is using that feature?

ghstack-source-id: 2fe2da80d9400f20c68e009c8cc111e8bf98d627
Pull Request resolved: fairinternal/xformers#1169

__original_commit__ = fairinternal/xformers@31f7f3a
  • Loading branch information
lw authored and xFormers Bot committed Jul 26, 2024
1 parent 2b8f5fc commit 0b9cb70
Showing 1 changed file with 23 additions and 11 deletions.
34 changes: 23 additions & 11 deletions xformers/ops/sequence_parallel_fused_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,17 @@
# LICENSE file in the root directory of this source tree.

import os
from typing import Any, Callable, Dict, List, Mapping, Optional, Union, overload
from typing import (
Any,
Callable,
Dict,
List,
Mapping,
Optional,
Sequence,
Union,
overload,
)

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -684,13 +694,15 @@ def fused_allgather_and_linear(
assert (scale_scattered_input is None) == (scale_weight is None)
if scale_weight is not None:
assert isinstance(weight, list) == isinstance(scale_weight, list)
scales_weights = (
scales_weights: Sequence[Optional[torch.Tensor]] = (
scale_weight if isinstance(scale_weight, list) else [scale_weight]
)
assert len(weights) == len(scales_weights)
assert _is_fp8_dtype(scattered_input.dtype)
assert all(_is_fp8_dtype(w.dtype) for w in weights)
assert out_dtype is not None, "output_dtype is required with FP8"
else:
scales_weights = [torch.empty(1)] * len(weights)
scales_weights = [None] * len(weights)
assert all(w.ndim == 2 for w in weights)
assert scattered_input.ndim >= 2
assert all(scattered_input.shape[-1] == w.shape[-1] for w in weights)
Expand Down Expand Up @@ -727,15 +739,14 @@ def my_matmul(
) -> None:
for w, scale_weight, go in zip(weights, scales_weights, gathered_outputs):
with torch.cuda.stream(stream_factory()):
if _is_fp8_dtype(w.dtype):
output_amax = torch.empty((), dtype=torch.float32, device=w.device)
if scale_scattered_input is not None and scale_weight is not None:
torch._scaled_mm(
inputs[0],
w.t(),
out_dtype=go[src_rank].dtype,
scale_a=scale_scattered_input,
scale_b=scale_weight,
out=(go[src_rank], output_amax),
out=go[src_rank],
)
else:
torch.matmul(inputs[0], w.t(), out=go[src_rank])
Expand Down Expand Up @@ -890,13 +901,15 @@ def fused_linear_and_reducescatter(
assert (scale_gathered_input is None) == (scale_weight is None)
if scale_weight is not None:
assert isinstance(weight, list) == isinstance(scale_weight, list)
scales_weights = (
scales_weights: Sequence[Optional[torch.Tensor]] = (
scale_weight if isinstance(scale_weight, list) else [scale_weight]
)
assert len(weights) == len(scales_weights)
assert _is_fp8_dtype(gathered_input.dtype)
assert all(_is_fp8_dtype(w.dtype) for w in weights)
assert out_dtype is not None, "output_dtype is required with FP8"
else:
scales_weights = [torch.empty(1)] * len(weights)
scales_weights = [None] * len(weights)
assert all(w.ndim == 2 for w in weights)
assert gathered_input.ndim >= 2
assert all(gathered_input.shape[-1] == w.shape[-1] for w in weights)
Expand Down Expand Up @@ -939,15 +952,14 @@ def my_matmul(
) -> None:
for w, scale_weight, o in zip(weights, scales_weights, outputs):
with torch.cuda.stream(stream_factory()):
if _is_fp8_dtype(w.dtype):
output_amax = torch.empty((), dtype=torch.float32, device=o.device)
if scale_gathered_input is not None and scale_weight is not None:
torch._scaled_mm(
gathered_input[dst_rank],
w.t(),
out_dtype=o.dtype,
scale_a=scale_gathered_input,
scale_b=scale_weight,
out=(o, output_amax),
out=o,
)
else:
torch.matmul(gathered_input[dst_rank], w.t(), out=o)
Expand Down

0 comments on commit 0b9cb70

Please sign in to comment.