diff --git a/.github/workflows/build-image.yaml b/.github/workflows/build-image.yaml index 3c3b8dbe..97a64edf 100644 --- a/.github/workflows/build-image.yaml +++ b/.github/workflows/build-image.yaml @@ -14,7 +14,7 @@ on: jobs: docker: name: Docker build ${{ matrix.name }} - runs-on: ubuntu-latest + runs-on: ubuntu-20.04 permissions: contents: read packages: write diff --git a/.github/workflows/unit-tests.yaml b/.github/workflows/unit-tests.yaml index dd4163ee..4377cfcd 100644 --- a/.github/workflows/unit-tests.yaml +++ b/.github/workflows/unit-tests.yaml @@ -13,9 +13,6 @@ jobs: strategy: matrix: include: - # 1.13.0a0+d0d6b1f - - torch: "1.13" - nvcr: 22.09-py3 # 1.14.0a0+410ce96 - torch: "1.14" nvcr: 22.12-py3 diff --git a/dockerfile/torch1.14-cuda11.8.dockerfile b/dockerfile/torch1.14-cuda11.8.dockerfile index 91cacc67..2f9ec815 100644 --- a/dockerfile/torch1.14-cuda11.8.dockerfile +++ b/dockerfile/torch1.14-cuda11.8.dockerfile @@ -48,8 +48,9 @@ RUN cd third_party/msccl && \ -gencode=arch=compute_90,code=sm_90" && \ make install # cache TE build to save time in CI +ENV MAX_JOBS=1 RUN python3 -m pip install --upgrade pip && \ - python3 -m pip install git+https://github.com/NVIDIA/TransformerEngine.git@v0.7 + python3 -m pip install flash-attn==1.0.9 git+https://github.com/NVIDIA/TransformerEngine.git@v0.11 ADD . . RUN python3 -m pip install . && \ diff --git a/dockerfile/torch2.1-cuda12.1.dockerfile b/dockerfile/torch2.1-cuda12.1.dockerfile index b33c0d8e..06310891 100644 --- a/dockerfile/torch2.1-cuda12.1.dockerfile +++ b/dockerfile/torch2.1-cuda12.1.dockerfile @@ -48,8 +48,9 @@ RUN cd third_party/msccl && \ -gencode=arch=compute_90,code=sm_90" && \ make install # cache TE build to save time in CI +ENV MAX_JOBS=1 RUN python3 -m pip install --upgrade pip && \ - python3 -m pip install git+https://github.com/NVIDIA/TransformerEngine.git@v0.7 + python3 -m pip install flash-attn==1.0.9 git+https://github.com/NVIDIA/TransformerEngine.git@v0.11 ADD . . RUN python3 -m pip install . && \ diff --git a/msamp/operators/gemm/gemm.py b/msamp/operators/gemm/gemm.py index 811251a7..e7834c20 100644 --- a/msamp/operators/gemm/gemm.py +++ b/msamp/operators/gemm/gemm.py @@ -140,6 +140,7 @@ def fp8_gemm( workspace.shape[0], accumulate, use_split_accumulator, + 0, ) else: # do gemm on device that doesn't supported fp8. @@ -165,6 +166,7 @@ def fp8_gemm( workspace.shape[0], accumulate, False, + 0, ) if pN > 0 or pM > 0: diff --git a/msamp/te/__init__.py b/msamp/te/__init__.py new file mode 100644 index 00000000..7b5d04c3 --- /dev/null +++ b/msamp/te/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Expose the interface of MS-AMP te package.""" + +from msamp.te import extension +from msamp.te import modules +from msamp.te.replacer import TeReplacer + +del extension +del modules + +__all__ = ['TeReplacer'] diff --git a/msamp/te/extension.py b/msamp/te/extension.py new file mode 100644 index 00000000..a67dd7a7 --- /dev/null +++ b/msamp/te/extension.py @@ -0,0 +1,130 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""MS-AMP te.extension module.""" + +import torch +import transformer_engine.pytorch as te +import transformer_engine_extensions as tex + +from msamp.common.dtype import Dtypes +from msamp.common.tensor import ScalingTensor + + +class TeExtensionOverrider: + """An Overrider to override some extension functions in transformer engine.""" + dtype_map = { + tex.DType.kFloat8E4M3: Dtypes.kfloat8_e4m3, + tex.DType.kFloat8E5M2: Dtypes.kfloat8_e5m2, + tex.DType.kBFloat16: Dtypes.kbfloat16, + tex.DType.kFloat16: Dtypes.kfloat16, + tex.DType.kFloat32: Dtypes.kfloat32, + } + + original_fused_cast_transpose = tex.fused_cast_transpose + original_cast_to_fp8 = te.cpp_extensions.cast_to_fp8 + original_fp8_cast_transpose_fused = te.cpp_extensions.fp8_cast_transpose_fused + + @staticmethod + @torch.no_grad() + def fused_cast_transpose(input, scale, amax, scale_inv, input_cast, input_transpose, otype): + """Fused cast and transpose, support ScalingTensor. + + Args: + input (torch.Tensor or ScalingTensor): Input tensor. + scale (torch.Tensor): Scale tensor. + amax (torch.Tensor): Amax tensor. + scale_inv (torch.Tensor): Scale inverse tensor. + input_cast (torch.Tensor): Casted input tensor. + input_transpose (torch.Tensor): Transposed input tensor. + otype (tex.DType): Output type. + """ + if isinstance(input, ScalingTensor): + qtype = TeExtensionOverrider.dtype_map[otype] + if input_transpose is not None: + sv = input.cast(qtype) + # data should be contiguous, and TE does not check it. + st = sv.t().contiguous() + v, t = sv.value, st.value + input_transpose.data = t + else: + sv = input.cast(qtype) + v = sv.value + + if input_cast is not None: + input_cast.data = v + scale_inv.copy_(sv.meta.scale_inv) + else: + TeExtensionOverrider.original_fused_cast_transpose( + input, scale, amax, scale_inv, input_cast, input_transpose, otype + ) + + @staticmethod + @torch.no_grad() + def fp8_cast_transpose_fused(inp, fp8_meta_tensor, fp8_tensor, dtype, cast_out=None, transpose_out=None): + """Cast + Transpose with FP8 output, support ScalingTensor. + + Args: + inp (torch.Tensor or ScalingTensor): Input tensor. + fp8_meta_tensor: tex.FP8TensorMeta + fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors] + dtype: tex.DType + cast_out (torch.Tensor, optional): Output tensor. + transpose_out (torch.Tensor, optional): Output tensor. + + Returns: + Union[Tuple[torch.Tensor, torch.Tensor], None]: Output tensor. + """ + if isinstance(inp, ScalingTensor): + qtype = TeExtensionOverrider.dtype_map[dtype] + sv = inp.cast(qtype) + v = sv.value + t = sv.t().contiguous().value + if transpose_out is not None: + transpose_out.data = t + if cast_out is not None: + cast_out.data = v + fp8_meta_tensor.scale_inv[fp8_tensor].copy_(sv.meta.scale_inv) + return v, t + + return TeExtensionOverrider.original_fp8_cast_transpose_fused( + inp, fp8_meta_tensor, fp8_tensor, dtype, cast_out, transpose_out + ) + + @staticmethod + @torch.no_grad() + def cast_to_fp8(inp, fp8_meta_tensor, fp8_tensor, otype, out=None): + """Cast to fp8, support ScalingTensor. + + Args: + inp (torch.Tensor or ScalingTensor): Input tensor. + fp8_meta_tensor (tex.FP8TensorMeta): Fp8 meta tensor. + fp8_tensor (Union[tex.FP8FwdTensors, tex.FP8BwdTensors): Fp8 tensor. + otype (tex.DType): Output type. + out (torch.Tensor, optional): Output tensor. + + Returns: + torch.Tensor: Output tensor. + """ + if isinstance(inp, ScalingTensor): + qtype = TeExtensionOverrider.dtype_map[otype] + sv = inp.cast(qtype) + v = sv.value + if out is not None: + out.data = v + fp8_meta_tensor.scale_inv[fp8_tensor].copy_(sv.meta.scale_inv) + return v + + if out is None: + return TeExtensionOverrider.original_cast_to_fp8(inp, fp8_meta_tensor, fp8_tensor, otype) + return TeExtensionOverrider.original_cast_to_fp8(inp, fp8_meta_tensor, fp8_tensor, otype, out) + + @staticmethod + def override(): + """Override transformer engine extension functions.""" + tex.fused_cast_transpose = TeExtensionOverrider.fused_cast_transpose + te.cpp_extensions.cast_to_fp8 = TeExtensionOverrider.cast_to_fp8 + te.cpp_extensions.fp8_cast_transpose_fused = TeExtensionOverrider.fp8_cast_transpose_fused + + +TeExtensionOverrider.override() diff --git a/msamp/te/modules.py b/msamp/te/modules.py new file mode 100644 index 00000000..24ef0429 --- /dev/null +++ b/msamp/te/modules.py @@ -0,0 +1,276 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""MS-AMP te.modules module.""" + +import torch +import transformer_engine.pytorch as te +from transformer_engine.pytorch.module.base import TransformerEngineBaseModule + +from msamp.common.tensor import ScalingTensor +from msamp.nn import ScalingModule + + +def set_activation_dtype(self, inp): + """Set activation data type for AMP. + + Args: + self (TransformerEngineBaseModule): Module instance. + inp (torch.Tensor or ScalingTensor): Input tensor. + """ + # Native AMP (`torch.autocast`) gets highest priority + if torch.is_autocast_enabled(): + self.activation_dtype = torch.get_autocast_gpu_dtype() + return + + # All checks after this have already been performed once, thus skip + # We assume that user doesn't change input types across iterations + if hasattr(self, 'activation_dtype'): + return + + assert all( + ( + (inp.dtype == param.dtype) if param is not None and not isinstance(param, ScalingTensor) else True + for param in self.parameters() + ) + ), ('Data type for activations and weights must ' + 'match when outside of autocasted region') + assert all(((inp.dtype == buf.dtype) if buf is not None else True for buf in self.buffers()) + ), ('Data type for activations and buffers must ' + 'match when outside of autocasted region') + self.activation_dtype = inp.dtype + + +class MSAMPTransformerEngineBaseModule: + """A base module for MS-AMP transformer engine modules.""" + def set_fp8_weights(self): + """Initializes FP8 weights for the module as class attributes.""" + # when is_first_microbatch is not None + # call every microbatch + # cache weight_fp8, weight_t_fp8 for gradient accumulation + # set_fp8_weights will clean up the cache + if not self.is_msamp_module: + super().set_fp8_weights() + else: + for i, shape in enumerate(self.fp8_weight_shapes, start=1): + weight_cast_attr = f'weight{i}_fp8' + weight_transpose_attr = f'weight{i}_t_fp8' + + if (hasattr(self, weight_cast_attr) and getattr(self, weight_cast_attr).shape == shape): + return + + setattr( + self, + weight_cast_attr, + torch.empty( + (0, 0), + device=torch.cuda.current_device(), + dtype=torch.uint8, + ), + ) + setattr( + self, + weight_transpose_attr, + torch.empty( + (0, 0), + device=torch.cuda.current_device(), + dtype=torch.uint8, + ), + ) + + @property + def is_msamp_module(self): + """Whether this module is MS-AMP module.""" + if not hasattr(self, '_is_msamp_module'): + self._is_msamp_module = False + return self._is_msamp_module + + @is_msamp_module.setter + def is_msamp_module(self, value): + """Set whether this module is MS-AMP module. + + Args: + value (bool): True if this module is MS-AMP module. + """ + self._is_msamp_module = value + + def get_fp8_weights_empty_tensors(self, is_first_microbatch): + """Returns empty tensors to be later used to store fp8 version of weights and their transposes. + + Args: + is_first_microbatch (bool): Whether this is the first microbatch. + + Returns: + a list of fp8 weight tensors. + """ + # when is_first_microbatch is None, create empty tensors + if not self.is_msamp_module: + return super().get_fp8_weights_empty_tensors(is_first_microbatch) + # MS-AMP + old_fp8_weight_shapes = self.fp8_weight_shapes + self.fp8_weight_shapes = [(0, 0)] * len(old_fp8_weight_shapes) + # create empty tensor as placeholder + rtn = super().get_fp8_weights_empty_tensors(is_first_microbatch) + self.fp8_weight_shapes = old_fp8_weight_shapes + return rtn + + +class MSAMPLinear(MSAMPTransformerEngineBaseModule, te.Linear, ScalingModule): + """MS-AMP Linear module.""" + pass + + +class MSAMPLayerNormLinear(MSAMPTransformerEngineBaseModule, te.LayerNormLinear, ScalingModule): + """MS-AMP LayerNormLinear module.""" + pass + + +class MSAMPLayerNormMLP(MSAMPTransformerEngineBaseModule, te.LayerNormMLP, ScalingModule): + """MS-AMP LayerNormMLP module.""" + pass + + +class CtxWrapper: + """A wrapper of FunctionCtx which supports ScalingTenor.""" + def __init__(self, ctx): + """Init a CtxWrapper. + + Args: + ctx (FunctionCtx): Function context. + """ + self.__dict__['ctx'] = ctx + + def __getattr__(self, name): + """Get attribute by name. + + Args: + name (str): Attribute name. + + Returns: + Attribute value. + """ + return self.__dict__.get(name, getattr(self.__dict__['ctx'], name)) + + def __setattr__(self, name, value): + """Set attribute by name. + + Args: + name (str): Attribute name. + value (object): Attribute value. + """ + if name in self.__dict__: + self.__dict__[name] = value + else: + setattr(self.ctx, name, value) + + def save_for_backward(self, *args): + """Save tensors for backward. + + Args: + args (tuple): Tensors to save. + """ + torch_args = [] + scaling_args = [] + for a in args: + if isinstance(a, ScalingTensor): + scaling_args.append(a) + torch_args.append(None) + else: + torch_args.append(a) + scaling_args.append(None) + self.ctx.save_for_backward(*torch_args) + self.ctx.scaling_args = scaling_args + + @property + def saved_tensors(self): + """Get saved tensors.""" + tensors = list(self.ctx.saved_tensors) + for i, v in enumerate(self.ctx.scaling_args): + if v is not None: + tensors[i] = v + return tensors + + +class TeModuleOverrider: + """An Overrider to override some modules and functions in Transformer Engine.""" + @classmethod + def override(cls): + """Override transformer engine modules and functions.""" + cls._override_funcions() + cls._override_classes() + + @classmethod + def _override_funcions(cls): + TransformerEngineBaseModule.set_activation_dtype = set_activation_dtype + cls._override_function(te.module.linear, '_Linear') + cls._override_function(te.module.layernorm_linear, '_LayerNormLinear') + cls._override_function(te.module.layernorm_mlp, '_LayerNormMLP') + + @classmethod + def _override_classes(cls): + """Override some classes in transformer engine.""" + te.Linear = MSAMPLinear + te.LayerNormLinear = MSAMPLayerNormLinear + te.LayerNormMLP = MSAMPLayerNormMLP + + te.attention.Linear = MSAMPLinear + te.attention.LayerNormLinear = MSAMPLayerNormLinear + + te.transformer.Linear = MSAMPLinear + te.transformer.LayerNormLinear = MSAMPLayerNormLinear + te.transformer.LayerNormMLP = MSAMPLayerNormMLP + + @staticmethod + def _override_function(mod, func_name): # noqa: C901 + """Override a function in a module. + + Args: + mod (module): Module. + func_name (str): Function name. + """ + old_func = getattr(mod, func_name) + assert issubclass(old_func, torch.autograd.Function), (func_name, old_func) + + class Func(torch.autograd.Function): + @staticmethod + def forward(ctx, place_holder, *args): + scaling_tensors = [] + for i, a in enumerate(args): + if isinstance(a, ScalingTensor): + scaling_tensors.append((i, a)) + if ctx is not None: + ctx.scaling_tensors = scaling_tensors + ctx = CtxWrapper(ctx) + return old_func.forward(ctx, *args) + + @staticmethod + def backward(ctx, *args): + ctx = CtxWrapper(ctx) + grads = list(old_func.backward(ctx, *args)) + for i, v in ctx.scaling_tensors: + if not v.requires_grad: + continue + assert grads[i] is not None + if v.grad is None: + v.grad = grads[i] + else: + v.grad += grads[i] + v.backward_grad_update(v.grad) + grads[i] = None + return (None, ) + tuple(grads) + + class Wrapper: + EMPTY_TENSOR = torch.tensor([], requires_grad=True) + + @staticmethod + def forward(ctx, *args): + return Func.forward(ctx, Wrapper.EMPTY_TENSOR.detach(), *args) + + @staticmethod + def apply(*args): + return Func.apply(Wrapper.EMPTY_TENSOR, *args) + + setattr(mod, func_name, Wrapper) + + +TeModuleOverrider.override() diff --git a/msamp/te/replacer.py b/msamp/te/replacer.py new file mode 100644 index 00000000..46e3cd83 --- /dev/null +++ b/msamp/te/replacer.py @@ -0,0 +1,49 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""MS-AMP te.replacer module.""" + +import torch + +from msamp.common.dtype import Dtypes +from msamp.nn import ScalingParameter +from msamp.te.modules import MSAMPLinear, MSAMPLayerNormLinear, MSAMPLayerNormMLP + + +class TeReplacer: + """A replacer to replace the weights with ScalingParameter in transformer engine modules.""" + module_weight_names = { + MSAMPLinear: ['weight'], + MSAMPLayerNormLinear: ['weight', 'query_weight', 'key_weight', 'value_weight'], + MSAMPLayerNormMLP: ['fc1_weight', 'fc2_weight'], + } + + @classmethod + def _replace(cls, model): + for mod in TeReplacer.module_weight_names: + if isinstance(model, mod): + mod.is_msamp_module = True + weight_names = TeReplacer.module_weight_names[mod] + for wname in weight_names: + if not hasattr(model, wname): + continue + weight = getattr(model, wname) + requires_grad = weight.requires_grad + sp = ScalingParameter(weight.data.cast(Dtypes.kfloat16), requires_grad=requires_grad) + # release the old weight + weight.data = torch.tensor([]) + setattr(model, wname, sp) + for child_name, child in list(model.named_children()): + setattr(model, child_name, cls._replace(child)) + return model + + @classmethod + def replace(cls, model): + """Replace the weights with ScalingParameter in transformer engine modules.""" + model = cls._replace(model) + fp8_named_weights = [(k, p) for k, p in model.named_parameters() if isinstance(p, ScalingParameter)] + fp8_names = [k for k, _ in fp8_named_weights] + torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(model, fp8_names) + # empty cache + torch.cuda.empty_cache() + return model diff --git a/pyproject.toml b/pyproject.toml index d3074a29..c2c214ec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,8 @@ classifiers=[ ] dependencies = [ "torch", - "transformer-engine@git+https://github.com/NVIDIA/TransformerEngine.git@v0.7#egg=transformer-engine", + "transformer-engine@git+https://github.com/NVIDIA/TransformerEngine.git@v0.11#egg=transformer-engine", + "flash-attn==1.0.9", "colorlog>=6.7.0", "deepspeed==0.9.2", "mpi4py", diff --git a/tests/te/test_extension.py b/tests/te/test_extension.py new file mode 100644 index 00000000..84eea373 --- /dev/null +++ b/tests/te/test_extension.py @@ -0,0 +1,127 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Tests for msamp.te.extention module.""" + +import unittest + +import torch +import transformer_engine_extensions as tex +import transformer_engine.pytorch.cpp_extensions as texcpp + +from tests.helper import decorator +from msamp.common.dtype import Dtypes +from msamp.common.tensor import ScalingMeta +from msamp.te import extension # noqa: F401 + + +class TeExtensionOverriderTestCase(unittest.TestCase): + """Test TeExtension overrider.""" + def setUp(self): + """Hook method for setting up the test fixture before exercising it.""" + torch.manual_seed(1000) + self.size = (4, 4) + self.device = 'cuda' + + def tearDown(self): + """Hook method for deconstructing the test fixture after testing it.""" + pass + + @decorator.cuda_test + def test_fused_cast_transpose(self): + """Test fused_cast_transpose.""" + # cast with torch.tensor + input = torch.randn(self.size, device=self.device) + meta = ScalingMeta(Dtypes.kfloat8_e4m3) + meta.amax[0] = input.abs().max() + meta.reset_scaling_factor() + + input_cast_1 = torch.empty(self.size, device=self.device, dtype=torch.uint8) + transpose_cast_1 = torch.empty(self.size, device=self.device, dtype=torch.uint8) + + tex.fused_cast_transpose( + input, meta.scale, meta.amax, meta.scale_inv, input_cast_1, transpose_cast_1, tex.DType.kFloat8E4M3 + ) + assert torch.equal(input_cast_1.t(), transpose_cast_1) + + # cast with ScalingTensor + scaling_input = input.cast(Dtypes.kfloat32) + input_cast_2 = torch.empty(self.size, device=self.device, dtype=torch.uint8) + transpose_cast_2 = torch.empty(self.size, device=self.device, dtype=torch.uint8) + scale_inv = torch.ones((), device=self.device) + + tex.fused_cast_transpose( + scaling_input, None, None, scale_inv, input_cast_2, transpose_cast_2, tex.DType.kFloat8E4M3 + ) + assert torch.equal(input_cast_2.t(), transpose_cast_2) + + assert torch.equal(input_cast_1, input_cast_2) + + @decorator.cuda_test + def test_cast_to_fp8(self): + """Test cast_to_fp8.""" + # cast with torch.tensor + input = torch.randn(self.size, device=self.device) + scaling_meta = ScalingMeta(Dtypes.kfloat8_e4m3) + scaling_meta.amax[0] = input.abs().max() + scaling_meta.reset_scaling_factor() + scale = scaling_meta.scale.item() + + fp8_type = tex.DType.kFloat8E4M3 + meta = tex.FP8TensorMeta() + meta.scale = torch.ones(1, dtype=torch.float32, device=self.device) * scale + meta.scale_inv = torch.ones(1, dtype=torch.float32, device=self.device) / scale + meta.amax_history = torch.zeros(1, 1, dtype=torch.float32, device=self.device) + meta.amax_history[0][0] = scaling_meta.amax[0] + + ret1 = texcpp.cast_to_fp8(input, meta, tex.FP8FwdTensors.GEMM1_INPUT, fp8_type) + + # cast with ScalingTensor + scaling_input = input.cast(Dtypes.kfloat32) + fp8_type = tex.DType.kFloat8E4M3 + meta = tex.FP8TensorMeta() + meta.scale = torch.ones(1, dtype=torch.float32, device=self.device) * scale + meta.scale_inv = torch.ones(1, dtype=torch.float32, device=self.device) / scale + meta.amax_history = torch.zeros(1, 1, dtype=torch.float32, device=self.device) + ret2 = texcpp.cast_to_fp8(scaling_input, meta, tex.FP8FwdTensors.GEMM1_INPUT, fp8_type) + assert torch.equal(ret1, ret2) + + @decorator.cuda_test + def test_fp8_cast_transpose_fused(self): + """Test fp8_cast_transpose_fused.""" + # cast with torch.tensor + input = torch.randn(self.size, device=self.device) + scaling_meta = ScalingMeta(Dtypes.kfloat8_e4m3) + scaling_meta.amax[0] = input.abs().max() + scaling_meta.reset_scaling_factor() + scale = scaling_meta.scale.item() + + fp8_type = tex.DType.kFloat8E4M3 + meta = tex.FP8TensorMeta() + meta.scale = torch.ones(1, dtype=torch.float32, device=self.device) * scale + meta.scale_inv = torch.ones(1, dtype=torch.float32, device=self.device) / scale + meta.amax_history = torch.zeros(1, 1, dtype=torch.float32, device=self.device) + meta.amax_history[0][0] = scaling_meta.amax[0] + + cast_out1, transpose_out1 = texcpp.fp8_cast_transpose_fused( + input, meta, tex.FP8FwdTensors.GEMM1_INPUT, fp8_type, None, None + ) + assert torch.equal(cast_out1.t(), transpose_out1) + + # cast with ScalingTensor + scaling_input = input.cast(Dtypes.kfloat32) + fp8_type = tex.DType.kFloat8E4M3 + meta = tex.FP8TensorMeta() + meta.scale = torch.ones(1, dtype=torch.float32, device=self.device) * scale + meta.scale_inv = torch.ones(1, dtype=torch.float32, device=self.device) / scale + meta.amax_history = torch.zeros(1, 1, dtype=torch.float32, device=self.device) + + cast_out2 = torch.randn(self.size, device=self.device) + transpose_out2 = torch.randn(self.size, device=self.device) + texcpp.fp8_cast_transpose_fused( + scaling_input, meta, tex.FP8FwdTensors.GEMM1_INPUT, fp8_type, cast_out2, transpose_out2 + ) + assert torch.equal(cast_out2.t(), transpose_out2) + + assert torch.equal(cast_out1, cast_out2) + assert torch.equal(transpose_out1, transpose_out2) diff --git a/tests/te/test_modules.py b/tests/te/test_modules.py new file mode 100644 index 00000000..64efa856 --- /dev/null +++ b/tests/te/test_modules.py @@ -0,0 +1,35 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Tests for msamp.te.modules module.""" + +import unittest + +import torch +import transformer_engine.pytorch as te + +from tests.helper import decorator +from msamp.te.modules import MSAMPLinear, MSAMPLayerNormLinear, MSAMPLayerNormMLP + + +class TeModuleOverriderTestCase(unittest.TestCase): + """Test TeModule overrider.""" + def setUp(self): + """Hook method for setting up the test fixture before exercising it.""" + torch.manual_seed(1000) + + def tearDown(self): + """Hook method for deconstructing the test fixture after testing it.""" + pass + + @decorator.cuda_test + def test_modules(self): + """Test modules overrided by MS-AMP.""" + te_linear = te.Linear(4, 4) + assert isinstance(te_linear, MSAMPLinear) + + te_layernorm_linear = te.LayerNormLinear(4, 4) + assert isinstance(te_layernorm_linear, MSAMPLayerNormLinear) + + te_layernorm_mlp = te.LayerNormMLP(4, 4) + assert isinstance(te_layernorm_mlp, MSAMPLayerNormMLP) diff --git a/tests/te/test_replacer.py b/tests/te/test_replacer.py new file mode 100644 index 00000000..a682a909 --- /dev/null +++ b/tests/te/test_replacer.py @@ -0,0 +1,76 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Tests for msamp.te.replacer module.""" + +import unittest + +import torch +import transformer_engine.pytorch as te +from transformer_engine.common.recipe import Format, DelayedScaling + +from tests.helper import decorator +from msamp.nn import ScalingParameter +from msamp.te.replacer import TeReplacer + + +class TeReplacerTestCase(unittest.TestCase): + """Test TeExtention overrider.""" + def setUp(self): + """Hook method for setting up the test fixture before exercising it.""" + torch.manual_seed(1000) + self.hidden_size = 4096 + self.ffn_hidden_size = 16384 + self.num_attention_heads = 32 + self.dtype = torch.float16 + self.batch_size = 4 + self.sequence_length = 128 + + def tearDown(self): + """Hook method for deconstructing the test fixture after testing it.""" + pass + + @decorator.cuda_test + def test_replace(self): + """Test replace function in TeReplacer.""" + # fused attention need cuda version >= 12.1 + if torch.version.cuda < '12.1': + return + te_transformer = te.TransformerLayer( + self.hidden_size, self.ffn_hidden_size, self.num_attention_heads, fuse_qkv_params=True + ) + te_transformer.to(dtype=self.dtype).cuda() + + model = TeReplacer.replace(te_transformer) + msamp_module_cnt = 0 + + def _check_model(model): + if type(model) in TeReplacer.module_weight_names: + nonlocal msamp_module_cnt + msamp_module_cnt += 1 + weights = TeReplacer.module_weight_names[type(model)] + for weight in weights: + if not hasattr(model, weight): + continue + weight = getattr(model, weight) + assert isinstance(weight, ScalingParameter) + + for _, child in list(model.named_children()): + _check_model(child) + + _check_model(model) + assert msamp_module_cnt == 3 + + scaling_params = [p for p in model.parameters() if isinstance(p, ScalingParameter)] + assert len(scaling_params) == 4 + is_fp8_available, _ = te.fp8.is_fp8_available() + if is_fp8_available: + # Do a forward pass to make sure the model is working. + fp8_format = Format.HYBRID + fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo='max') + x = torch.rand(self.sequence_length, self.batch_size, self.hidden_size).cuda().to(dtype=self.dtype) + + with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): + y = model(x, attention_mask=None) + assert y.shape == (self.sequence_length, self.batch_size, self.hidden_size) + y.sum().backward()