Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support latest TransformerEngine #98

Merged
merged 21 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build-image.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ on:
jobs:
docker:
name: Docker build ${{ matrix.name }}
runs-on: ubuntu-latest
runs-on: ubuntu-20.04
permissions:
contents: read
packages: write
Expand Down
3 changes: 0 additions & 3 deletions .github/workflows/unit-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@ jobs:
strategy:
matrix:
include:
# 1.13.0a0+d0d6b1f
- torch: "1.13"
nvcr: 22.09-py3
# 1.14.0a0+410ce96
- torch: "1.14"
nvcr: 22.12-py3
Expand Down
3 changes: 2 additions & 1 deletion dockerfile/torch1.14-cuda11.8.dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ RUN cd third_party/msccl && \
-gencode=arch=compute_90,code=sm_90" && \
make install
# cache TE build to save time in CI
ENV MAX_JOBS=1
RUN python3 -m pip install --upgrade pip && \
python3 -m pip install git+https://github.com/NVIDIA/TransformerEngine.git@v0.7
python3 -m pip install flash-attn==1.0.9 git+https://github.com/NVIDIA/TransformerEngine.git@v0.11

ADD . .
RUN python3 -m pip install . && \
Expand Down
3 changes: 2 additions & 1 deletion dockerfile/torch2.1-cuda12.1.dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ RUN cd third_party/msccl && \
-gencode=arch=compute_90,code=sm_90" && \
make install
# cache TE build to save time in CI
ENV MAX_JOBS=1
RUN python3 -m pip install --upgrade pip && \
python3 -m pip install git+https://github.com/NVIDIA/TransformerEngine.git@v0.7
python3 -m pip install flash-attn==1.0.9 git+https://github.com/NVIDIA/TransformerEngine.git@v0.11

ADD . .
RUN python3 -m pip install . && \
Expand Down
2 changes: 2 additions & 0 deletions msamp/operators/gemm/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def fp8_gemm(
workspace.shape[0],
accumulate,
use_split_accumulator,
0,
)
else:
# do gemm on device that doesn't supported fp8.
Expand All @@ -165,6 +166,7 @@ def fp8_gemm(
workspace.shape[0],
accumulate,
False,
0,
)

if pN > 0 or pM > 0:
Expand Down
13 changes: 13 additions & 0 deletions msamp/te/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""Expose the interface of MS-AMP te package."""

from msamp.te import extension
from msamp.te import modules
from msamp.te.replacer import TeReplacer

del extension
del modules

__all__ = ['TeReplacer']
130 changes: 130 additions & 0 deletions msamp/te/extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""MS-AMP te.extension module."""

import torch
import transformer_engine.pytorch as te
import transformer_engine_extensions as tex

from msamp.common.dtype import Dtypes
from msamp.common.tensor import ScalingTensor


class TeExtensionOverrider:
"""An Overrider to override some extension functions in transformer engine."""
dtype_map = {
tex.DType.kFloat8E4M3: Dtypes.kfloat8_e4m3,
tex.DType.kFloat8E5M2: Dtypes.kfloat8_e5m2,
tex.DType.kBFloat16: Dtypes.kbfloat16,
tex.DType.kFloat16: Dtypes.kfloat16,
tex.DType.kFloat32: Dtypes.kfloat32,
}

original_fused_cast_transpose = tex.fused_cast_transpose
original_cast_to_fp8 = te.cpp_extensions.cast_to_fp8
original_fp8_cast_transpose_fused = te.cpp_extensions.fp8_cast_transpose_fused

@staticmethod
@torch.no_grad()
def fused_cast_transpose(input, scale, amax, scale_inv, input_cast, input_transpose, otype):
"""Fused cast and transpose, support ScalingTensor.
Args:
input (torch.Tensor or ScalingTensor): Input tensor.
scale (torch.Tensor): Scale tensor.
amax (torch.Tensor): Amax tensor.
scale_inv (torch.Tensor): Scale inverse tensor.
input_cast (torch.Tensor): Casted input tensor.
input_transpose (torch.Tensor): Transposed input tensor.
otype (tex.DType): Output type.
"""
if isinstance(input, ScalingTensor):
qtype = TeExtensionOverrider.dtype_map[otype]
if input_transpose is not None:
sv = input.cast(qtype)
# data should be contiguous, and TE does not check it.
st = sv.t().contiguous()
v, t = sv.value, st.value
input_transpose.data = t
else:
sv = input.cast(qtype)
v = sv.value

if input_cast is not None:
input_cast.data = v
scale_inv.copy_(sv.meta.scale_inv)
else:
TeExtensionOverrider.original_fused_cast_transpose(
input, scale, amax, scale_inv, input_cast, input_transpose, otype
)

@staticmethod
@torch.no_grad()
def fp8_cast_transpose_fused(inp, fp8_meta_tensor, fp8_tensor, dtype, cast_out=None, transpose_out=None):
"""Cast + Transpose with FP8 output, support ScalingTensor.
Args:
inp (torch.Tensor or ScalingTensor): Input tensor.
fp8_meta_tensor: tex.FP8TensorMeta
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors]
dtype: tex.DType
cast_out (torch.Tensor, optional): Output tensor.
transpose_out (torch.Tensor, optional): Output tensor.
Returns:
Union[Tuple[torch.Tensor, torch.Tensor], None]: Output tensor.
"""
if isinstance(inp, ScalingTensor):
qtype = TeExtensionOverrider.dtype_map[dtype]
sv = inp.cast(qtype)
v = sv.value
t = sv.t().contiguous().value
if transpose_out is not None:
transpose_out.data = t
if cast_out is not None:
cast_out.data = v
fp8_meta_tensor.scale_inv[fp8_tensor].copy_(sv.meta.scale_inv)
return v, t

return TeExtensionOverrider.original_fp8_cast_transpose_fused(
inp, fp8_meta_tensor, fp8_tensor, dtype, cast_out, transpose_out
)

@staticmethod
@torch.no_grad()
def cast_to_fp8(inp, fp8_meta_tensor, fp8_tensor, otype, out=None):
"""Cast to fp8, support ScalingTensor.
Args:
inp (torch.Tensor or ScalingTensor): Input tensor.
fp8_meta_tensor (tex.FP8TensorMeta): Fp8 meta tensor.
fp8_tensor (Union[tex.FP8FwdTensors, tex.FP8BwdTensors): Fp8 tensor.
otype (tex.DType): Output type.
out (torch.Tensor, optional): Output tensor.
Returns:
torch.Tensor: Output tensor.
"""
if isinstance(inp, ScalingTensor):
qtype = TeExtensionOverrider.dtype_map[otype]
sv = inp.cast(qtype)
v = sv.value
if out is not None:
out.data = v
fp8_meta_tensor.scale_inv[fp8_tensor].copy_(sv.meta.scale_inv)
return v

if out is None:
return TeExtensionOverrider.original_cast_to_fp8(inp, fp8_meta_tensor, fp8_tensor, otype)
return TeExtensionOverrider.original_cast_to_fp8(inp, fp8_meta_tensor, fp8_tensor, otype, out)

@staticmethod
def override():
"""Override transformer engine extension functions."""
tex.fused_cast_transpose = TeExtensionOverrider.fused_cast_transpose
te.cpp_extensions.cast_to_fp8 = TeExtensionOverrider.cast_to_fp8
te.cpp_extensions.fp8_cast_transpose_fused = TeExtensionOverrider.fp8_cast_transpose_fused


TeExtensionOverrider.override()
Loading