diff --git a/hardware/cambricon_MLU/a44556c0/a44556c0.patch b/hardware/cambricon_MLU/a44556c0/a44556c0.patch new file mode 100644 index 000000000..ebd83b2e2 --- /dev/null +++ b/hardware/cambricon_MLU/a44556c0/a44556c0.patch @@ -0,0 +1,532 @@ +From 7494e53570af98c61bf5181e730a56d666c4127e Mon Sep 17 00:00:00 2001 +From: cifar10 <731376883@qq.com> +Date: Mon, 23 Sep 2024 15:46:07 +0800 +Subject: [PATCH] [cambricon] support llama3-70b + +--- + megatron/megatron/__init__.py | 49 +++++++++++++++++++ + .../core/fusions/fused_bias_dropout.py | 5 +- + .../megatron/core/fusions/fused_bias_geglu.py | 8 +-- + .../core/fusions/fused_bias_swiglu.py | 10 ++-- + .../megatron/core/fusions/fused_softmax.py | 4 +- + megatron/megatron/core/jit.py | 10 +++- + .../common/embeddings/rotary_pos_embedding.py | 13 +++-- + .../core/tensor_parallel/cross_entropy.py | 41 +++++++++++----- + .../megatron/core/tensor_parallel/layers.py | 8 ++- + .../megatron/core/tensor_parallel/random.py | 6 ++- + .../core/transformer/moe/grouped_gemm_util.py | 8 +-- + .../core/transformer/moe/moe_layer.py | 5 +- + megatron/megatron/core/utils.py | 9 ++-- + .../megatron/legacy/model/fused_layer_norm.py | 5 +- + .../megatron/legacy/model/fused_softmax.py | 4 +- + megatron/megatron/legacy/model/transformer.py | 5 +- + megatron/megatron/training/arguments.py | 7 ++- + megatron/megatron/training/initialize.py | 6 ++- + 18 files changed, 144 insertions(+), 59 deletions(-) + +diff --git a/megatron/megatron/__init__.py b/megatron/megatron/__init__.py +index e69de29b..785a548f 100644 +--- a/megatron/megatron/__init__.py ++++ b/megatron/megatron/__init__.py +@@ -0,0 +1,49 @@ ++import os ++from functools import wraps ++import torch ++ ++try: ++ import torch_mlu ++ import torch._dynamo ++ torch._dynamo.config.suppress_errors = True ++ from torch_mlu.utils.model_transfer import transfer ++except: ++ pass ++ ++ ++if hasattr(torch.distributed, "all_gather_into_tensor") and \ ++ hasattr(torch.distributed, "reduce_scatter_tensor"): ++ torch.distributed._all_gather_base = torch.distributed.all_gather_into_tensor ++ torch.distributed._reduce_scatter_base = torch.distributed.reduce_scatter_tensor ++ ++def wrapper_type(fn): ++ @wraps(fn) ++ def decorated(*args, **kwargs): ++ output = fn(*args, **kwargs) ++ if isinstance(output, str): ++ if output == 'torch.mlu.FloatTensor': ++ output = 'torch.cuda.FloatTensor' ++ elif output == 'torch.mlu.BFloat16Tensor': ++ output = 'torch.cuda.BFloat16Tensor' ++ elif output == 'torch.mlu.HalfTensor': ++ output = 'torch.cuda.HalfTensor' ++ return output ++ ++ return decorated ++ ++def wrapper_backend(fn): ++ @wraps(fn) ++ def decorated(*args, **kwargs): ++ output = fn(*args, **kwargs) ++ if isinstance(output, str): ++ if output == 'cncl': ++ output = 'nccl' ++ return output ++ ++ return decorated ++ ++ ++os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] = '1' ++ ++torch.Tensor.type = wrapper_type(torch.Tensor.type) ++torch.distributed.get_backend = wrapper_backend(torch.distributed.get_backend) +diff --git a/megatron/megatron/core/fusions/fused_bias_dropout.py b/megatron/megatron/core/fusions/fused_bias_dropout.py +index c7fa8419..daf94f74 100644 +--- a/megatron/megatron/core/fusions/fused_bias_dropout.py ++++ b/megatron/megatron/core/fusions/fused_bias_dropout.py +@@ -4,6 +4,7 @@ from typing import Optional, Tuple + import torch + + from megatron.core.jit import jit_fuser ++from apex.contrib.fused_bias_dropout.fused_bias_dropout import get_bias_dropout_add as get_bias_dropout_add_mlu + + + def _bias_dropout_add_func(x_with_bias, residual, prob, training): +@@ -47,9 +48,9 @@ def bias_dropout_add_unfused(training): + + @jit_fuser + def bias_dropout_add_fused_train( +- x_with_bias: Tuple[torch.Tensor, Optional[torch.Tensor]], residual: torch.Tensor, prob: float ++ x_with_bias: Tuple[torch.Tensor, Optional[torch.Tensor]], residual: torch.Tensor, prob: float, + ) -> torch.Tensor: +- return _bias_dropout_add_func(x_with_bias, residual, prob, True) ++ return get_bias_dropout_add_mlu(True, True)(x_with_bias, residual, prob) + + + @jit_fuser +diff --git a/megatron/megatron/core/fusions/fused_bias_geglu.py b/megatron/megatron/core/fusions/fused_bias_geglu.py +index 70ef3488..ae6c6a22 100644 +--- a/megatron/megatron/core/fusions/fused_bias_geglu.py ++++ b/megatron/megatron/core/fusions/fused_bias_geglu.py +@@ -3,6 +3,7 @@ + import torch + + from megatron.core.jit import jit_fuser ++from apex.contrib.activation import fused_glu + + ###### BIAS GELU FUSION/ NO AUTOGRAD ################ + # 1/sqrt(2*pi)-> 0.3989423 +@@ -77,9 +78,8 @@ def bias_geglu_impl(input, bias): + ori_shape = input.shape + assert len(ori_shape) in [2, 3] + input = input.view(-1, ori_shape[-1]) +- if bias is not None: +- output = BiasGeGLUFunction.apply(input, bias) +- else: +- output = GeGLUFunction.apply(input) ++ if bias == None: ++ bias = torch.Tensor() ++ output = fused_glu.fused_geglu(input, bias) + + return output if len(ori_shape) == 2 else output.view(ori_shape[0], ori_shape[1], -1) +diff --git a/megatron/megatron/core/fusions/fused_bias_swiglu.py b/megatron/megatron/core/fusions/fused_bias_swiglu.py +index fd3ac3ec..56998ec6 100644 +--- a/megatron/megatron/core/fusions/fused_bias_swiglu.py ++++ b/megatron/megatron/core/fusions/fused_bias_swiglu.py +@@ -4,6 +4,7 @@ import torch + import torch.nn.functional as F + + from megatron.core.jit import jit_fuser ++from apex.contrib.activation import fused_glu + + ###### BIAS SWIGLU FUSION/ NO AUTOGRAD ################ + +@@ -77,13 +78,12 @@ def bias_swiglu_impl(input, bias, fp8_input_store=False): + ori_shape = input.shape + assert len(ori_shape) in [2, 3] + input = input.view(-1, ori_shape[-1]) +- if bias is not None: +- output = BiasSwiGLUFunction.apply(input, bias, fp8_input_store) +- else: +- output = SwiGLUFunction.apply(input, fp8_input_store) ++ assert fp8_input_store == False, "MLU does not support fp8_input_store is True" ++ if bias == None: ++ bias=torch.Tensor() ++ output = fused_glu.fused_swiglu(input, bias) + + return output if len(ori_shape) == 2 else output.view(ori_shape[0], ori_shape[1], -1) + +- + # bias_swiglu_impl = BiasSwiGLUFunction.apply + # swiglu_impl = SwiGLUFunction.apply +diff --git a/megatron/megatron/core/fusions/fused_softmax.py b/megatron/megatron/core/fusions/fused_softmax.py +index c7bfbb76..b8b295bd 100644 +--- a/megatron/megatron/core/fusions/fused_softmax.py ++++ b/megatron/megatron/core/fusions/fused_softmax.py +@@ -153,12 +153,12 @@ class FusedScaleMaskSoftmax(nn.Module): + if ( + self.scaled_masked_softmax_fusion # user want to fuse + and self.input_in_float16 # input must be fp16 +- and 16 < sk <= 4096 # sk must be 16 ~ 2048 ++ and 16 < sk <= 2048 # sk must be 16 ~ 2048 + and sq % 4 == 0 # sq must be divisor of 4 + and sk % 4 == 0 # sk must be divisor of 4 + and attn_batches % 4 == 0 # np * b must be divisor of 4 + ): +- if 0 <= sk <= 4096: ++ if 0 <= sk <= 2048: + batch_per_block = self.get_batch_per_block(sq, sk, b, np) + + if self.attn_mask_type == AttnMaskType.causal: +diff --git a/megatron/megatron/core/jit.py b/megatron/megatron/core/jit.py +index 8bb18d39..2887471c 100644 +--- a/megatron/megatron/core/jit.py ++++ b/megatron/megatron/core/jit.py +@@ -5,7 +5,13 @@ import torch + TORCH_MAJOR = int(torch.__version__.split(".")[0]) + TORCH_MINOR = int(torch.__version__.split(".")[1]) + +-jit_fuser = torch.jit.script ++def fake_torch_compile(func): ++ def wrapper(*args, **kwargs): ++ return func(*args, **kwargs) ++ return wrapper ++ ++jit_fuser = fake_torch_compile + # nvFuser is deprecated in PyTorch JIT starting from 2.2 +-if (TORCH_MAJOR > 2) or (TORCH_MAJOR == 2 and TORCH_MINOR >= 2): ++if (TORCH_MAJOR > 2) or (TORCH_MAJOR == 2 and TORCH_MINOR >= 5): + jit_fuser = torch.compile ++ +diff --git a/megatron/megatron/core/models/common/embeddings/rotary_pos_embedding.py b/megatron/megatron/core/models/common/embeddings/rotary_pos_embedding.py +index 1fcdc7d2..bfda4fd4 100644 +--- a/megatron/megatron/core/models/common/embeddings/rotary_pos_embedding.py ++++ b/megatron/megatron/core/models/common/embeddings/rotary_pos_embedding.py +@@ -258,13 +258,18 @@ def apply_rotary_pos_emb( + apply_rotary_pos_emb.printed_fused_warning = True + if config.apply_rope_fusion: + if cu_seqlens is None: +- return fused_apply_rotary_pos_emb(t, freqs, transpose_output_memory=True) ++ output = fused_apply_rotary_pos_emb(t, freqs, transpose_output_memory=True) + else: +- return fused_apply_rotary_pos_emb_thd(t, cu_seqlens, freqs) ++ output = fused_apply_rotary_pos_emb_thd(t, cu_seqlens, freqs) + else: + if cu_seqlens is None: +- return apply_rotary_pos_emb_bshd(t, freqs, rotary_interleaved=config.rotary_interleaved) ++ output = apply_rotary_pos_emb_bshd(t, freqs, rotary_interleaved=config.rotary_interleaved) + else: +- return apply_rotary_pos_emb_thd( ++ output = apply_rotary_pos_emb_thd( + t, cu_seqlens, freqs, rotary_interleaved=config.rotary_interleaved + ) ++ # Mlu devices tend to use nhwc' layout, not nchw's layout. ++ # This modification will affect stride, but will not affect the correctness of the results. ++ output = output.contiguous() ++ output = output.reshape(output.shape) ++ return output +diff --git a/megatron/megatron/core/tensor_parallel/cross_entropy.py b/megatron/megatron/core/tensor_parallel/cross_entropy.py +index 0066d126..3d8a3e86 100644 +--- a/megatron/megatron/core/tensor_parallel/cross_entropy.py ++++ b/megatron/megatron/core/tensor_parallel/cross_entropy.py +@@ -113,31 +113,40 @@ class VocabParallelCrossEntropy: + + return grad_input + +- + class _VocabParallelCrossEntropy(torch.autograd.Function): + @staticmethod + def forward(ctx, vocab_parallel_logits, target, label_smoothing=0.0): + +- vocab_parallel_logits, logits_max = VocabParallelCrossEntropy.calculate_logits_max( +- vocab_parallel_logits +- ) ++ # Maximum value along vocab dimension across all GPUs. ++ logits_max = torch.max(vocab_parallel_logits, dim=-1)[0] + torch.distributed.all_reduce( + logits_max, op=torch.distributed.ReduceOp.MAX, group=get_tensor_model_parallel_group() + ) ++ # Subtract the maximum value. ++ vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1)) + +- # Get the partition's vocab indices ++ # Get the partition's vocab indecies + get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size + partition_vocab_size = vocab_parallel_logits.size()[-1] + rank = get_tensor_model_parallel_rank() + world_size = get_tensor_model_parallel_world_size() + vocab_start_index, vocab_end_index = get_vocab_range(partition_vocab_size, rank, world_size) + +- (target_mask, masked_target_1d, predicted_logits, sum_exp_logits, exp_logits) = ( +- VocabParallelCrossEntropy.calculate_predicted_logits( +- vocab_parallel_logits, target, logits_max, vocab_start_index, vocab_end_index +- ) +- ) ++ # Create a mask of valid vocab ids (1 means it needs to be masked). ++ target_mask = (target < vocab_start_index) | (target >= vocab_end_index) ++ masked_target = target.clone() - vocab_start_index ++ masked_target[target_mask] = 0 + ++ # Get predicted-logits = logits[target]. ++ # For Simplicity, we convert logits to a 2-D tensor with size ++ # [*, partition-vocab-size] and target to a 1-D tensor of size [*]. ++ logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size) ++ masked_target_1d = masked_target.view(-1) ++ arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device) ++ predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] ++ predicted_logits_1d = predicted_logits_1d.clone().contiguous() ++ predicted_logits = predicted_logits_1d.view_as(target) ++ predicted_logits[target_mask] = 0.0 + # All reduce is needed to get the chunks from other GPUs. + torch.distributed.all_reduce( + predicted_logits, +@@ -145,15 +154,21 @@ class _VocabParallelCrossEntropy(torch.autograd.Function): + group=get_tensor_model_parallel_group(), + ) + ++ # Sum of exponential of logits along vocab dimension across all GPUs. ++ exp_logits = vocab_parallel_logits ++ torch.exp(vocab_parallel_logits, out=exp_logits) ++ sum_exp_logits = exp_logits.sum(dim=-1) + torch.distributed.all_reduce( + sum_exp_logits, + op=torch.distributed.ReduceOp.SUM, + group=get_tensor_model_parallel_group(), + ) + +- exp_logits, loss = VocabParallelCrossEntropy.calculate_cross_entropy_loss( +- exp_logits, predicted_logits, sum_exp_logits +- ) ++ # Loss = log(sum(exp(logits))) - predicted-logit. ++ loss = torch.log(sum_exp_logits) - predicted_logits ++ ++ # Normalize and optionally smooth logits ++ exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) + + vocab_size = exp_logits.size(-1) + if label_smoothing > 0: +diff --git a/megatron/megatron/core/tensor_parallel/layers.py b/megatron/megatron/core/tensor_parallel/layers.py +index 5707a0b5..4a72b3d2 100644 +--- a/megatron/megatron/core/tensor_parallel/layers.py ++++ b/megatron/megatron/core/tensor_parallel/layers.py +@@ -227,13 +227,17 @@ class VocabParallelEmbedding(torch.nn.Module): + masked_input[input_mask] = 0 + else: + masked_input = input_ ++ + # Get the embeddings. + if self.deterministic_mode: +- output_parallel = self.weight[masked_input] ++ output_parallel = F.embedding( ++ masked_input, self.weight, padding_idx=None, scale_grad_by_freq=False, sparse=False) + else: + # F.embedding currently has a non-deterministic backward function +- output_parallel = F.embedding(masked_input, self.weight) ++ output_parallel = F.embedding( ++ masked_input, self.weight, padding_idx=None, scale_grad_by_freq=False, sparse=False) + # Mask the output embedding. ++ + if self.tensor_model_parallel_size > 1: + output_parallel[input_mask, :] = 0.0 + +diff --git a/megatron/megatron/core/tensor_parallel/random.py b/megatron/megatron/core/tensor_parallel/random.py +index ee074df9..7fe03396 100644 +--- a/megatron/megatron/core/tensor_parallel/random.py ++++ b/megatron/megatron/core/tensor_parallel/random.py +@@ -8,6 +8,7 @@ import logging + from importlib.metadata import version + + import torch ++import torch_mlu + from pkg_resources import packaging + from torch import _C + from torch.cuda import _lazy_call +@@ -62,8 +63,9 @@ def _set_cuda_rng_state(new_state, device=-1): + default_generator = torch.cuda.default_generators[idx] + default_generator.set_state(new_state) + +- _lazy_call(cb) +- ++ #TODO(mlu): mlu's implementation should be called instead of cuda's implementation. ++ #_lazy_call(cb) ++ torch_mlu.mlu._lazy_call(cb) + + def get_expert_parallel_rng_tracker_name(): + global _EXPERT_PARALLEL_RNG_TRACKER_NAME +diff --git a/megatron/megatron/core/transformer/moe/grouped_gemm_util.py b/megatron/megatron/core/transformer/moe/grouped_gemm_util.py +index e7ef79d7..05644862 100644 +--- a/megatron/megatron/core/transformer/moe/grouped_gemm_util.py ++++ b/megatron/megatron/core/transformer/moe/grouped_gemm_util.py +@@ -1,10 +1,12 @@ + # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + + try: +- import grouped_gemm ++ from apex.contrib import grouped_gemm + except ImportError: +- grouped_gemm = None +- ++ try: ++ import grouped_gemm ++ except ImportError: ++ grouped_gemm = None + + def grouped_gemm_is_available(): + return grouped_gemm is not None +diff --git a/megatron/megatron/core/transformer/moe/moe_layer.py b/megatron/megatron/core/transformer/moe/moe_layer.py +index 1ea61ba3..62480151 100644 +--- a/megatron/megatron/core/transformer/moe/moe_layer.py ++++ b/megatron/megatron/core/transformer/moe/moe_layer.py +@@ -71,10 +71,7 @@ class MoELayer(BaseMoELayer): + super(MoELayer, self).__init__(config=config, layer_number=layer_number) + self.router = TopKRouter(config=self.config) + if self.config.moe_grouped_gemm: +- if isinstance(self.submodules, MLPSubmodules): +- self.experts = TEGroupedMLP(self.num_local_experts, self.config, self.submodules) +- else: +- self.experts = GroupedMLP(self.num_local_experts, self.config) ++ self.experts = GroupedMLP(self.num_local_experts, self.config) + else: + assert isinstance(self.submodules, MLPSubmodules) + self.experts = SequentialMLP(self.num_local_experts, self.config, self.submodules) +diff --git a/megatron/megatron/core/utils.py b/megatron/megatron/core/utils.py +index 062372d9..ce8fb5e9 100644 +--- a/megatron/megatron/core/utils.py ++++ b/megatron/megatron/core/utils.py +@@ -854,11 +854,10 @@ class StragglerDetector: + elif ls_bs != ls_be: + logger.warning(f"get_batch Start/Stop out of sync {ls_bs}/{ls_be}") + else: +- temp = torch.cuda.temperature() +- power = torch.cuda.power_draw() +- util = torch.cuda.utilization() +- clock = torch.cuda.clock_rate() +- torch.cuda.synchronize() ++ ++ logger.warning(f"[Warning] mlu does not support torch.cuda.temperature() interface.") ++ util = 0 ++ + # Process Events + for i in range(ls_ev): + e_ev = self.start_gemm_ev[i].elapsed_time(self.stop_gemm_ev[i]) +diff --git a/megatron/megatron/legacy/model/fused_layer_norm.py b/megatron/megatron/legacy/model/fused_layer_norm.py +index 3a788dd3..83ba6d6a 100644 +--- a/megatron/megatron/legacy/model/fused_layer_norm.py ++++ b/megatron/megatron/legacy/model/fused_layer_norm.py +@@ -95,10 +95,7 @@ class MixedFusedLayerNorm(torch.nn.Module): + "fused_layer_norm_affine is not available, please install apex from https://github.com/NVIDIA/apex" + return fused_layer_norm_affine(input, weight, self.bias, self.normalized_shape, eps=self.eps) + else: +- if 'memory_efficient' in inspect.getfullargspec(FastLayerNormFN.forward).args: +- output = FastLayerNormFN.apply(input, weight, self.bias, self.eps, False) +- else: +- output = FastLayerNormFN.apply(input, weight, self.bias, self.eps) ++ output = FastLayerNormFN.apply(input, weight, self.bias, self.eps, False) + # Apex's fast layer norm function outputs a 'view' tensor (i.e., has + # a populated '_base' field). This will result in schedule.py's + # deallocate_output_tensor() throwing an error, so a viewless tensor is +diff --git a/megatron/megatron/legacy/model/fused_softmax.py b/megatron/megatron/legacy/model/fused_softmax.py +index 58f900bd..759f09f4 100644 +--- a/megatron/megatron/legacy/model/fused_softmax.py ++++ b/megatron/megatron/legacy/model/fused_softmax.py +@@ -173,12 +173,12 @@ class FusedScaleMaskSoftmax(nn.Module): + if ( + self.scaled_masked_softmax_fusion # user want to fuse + and self.input_in_float16 # input must be fp16 +- and 16 < sk <= 16384 # sk must be 16 ~ 16384 ++ and 16 < sk <= 2048 # sk must be 16 ~ 16384 + and sq % 4 == 0 # sq must be divisor of 4 + and sk % 4 == 0 # sk must be divisor of 4 + and attn_batches % 4 == 0 # np * b must be divisor of 4 + ): +- if 0 <= sk <= 16384: ++ if 0 <= sk <= 2048: + batch_per_block = self.get_batch_per_block(sq, sk, b, np) + + if self.attn_mask_type == AttnMaskType.causal: +diff --git a/megatron/megatron/legacy/model/transformer.py b/megatron/megatron/legacy/model/transformer.py +index 652f017a..2532860f 100644 +--- a/megatron/megatron/legacy/model/transformer.py ++++ b/megatron/megatron/legacy/model/transformer.py +@@ -35,6 +35,8 @@ try: + except ImportError: + rearrange = None + ++from apex.contrib.fused_bias_dropout.fused_bias_dropout import get_bias_dropout_add as get_bias_dropout_add_mlu ++ + + # Try FlashAttn2 first + try: +@@ -898,7 +900,8 @@ def bias_dropout_add_fused_train(x: torch.Tensor, + bias: Optional[torch.Tensor], + residual: torch.Tensor, + prob: float) -> torch.Tensor: +- return bias_dropout_add(x, bias, residual, prob, True) ++ bias = bias.contiguous() if bias != None else torch.Tensor() ++ return get_bias_dropout_add_mlu(True, True)((x, bias), residual, prob) + + + @jit_fuser +diff --git a/megatron/megatron/training/arguments.py b/megatron/megatron/training/arguments.py +index e20f178b..8c4cdff4 100644 +--- a/megatron/megatron/training/arguments.py ++++ b/megatron/megatron/training/arguments.py +@@ -599,7 +599,10 @@ def validate_args(args, defaults={}): + if args.moe_grouped_gemm: + assert args.bf16, 'Currently GroupedGEMM for MoE only supports bf16 dtype.' + dc = torch.cuda.get_device_capability() +- assert dc[0] >= 8, "Unsupported compute capability for GroupedGEMM kernels." ++ try: ++ import torch_mlu ++ except: ++ assert dc[0] >= 8, "Unsupported compute capability for GroupedGEMM kernels." + + if args.weight_decay_incr_style == 'constant': + assert args.start_weight_decay is None +@@ -1762,7 +1765,7 @@ def _add_distributed_args(parser): + help='overlap pipeline parallel communication with forward and backward chunks', + dest='overlap_p2p_comm') + group.add_argument('--distributed-backend', default='nccl', +- choices=['nccl', 'gloo'], ++ choices=['nccl', 'gloo', 'cncl'], + help='Which backend to use for distributed training.') + group.add_argument('--distributed-timeout-minutes', type=int, default=10, + help='Timeout minutes for torch.distributed.') +diff --git a/megatron/megatron/training/initialize.py b/megatron/megatron/training/initialize.py +index dbfe2da9..c8baf97d 100644 +--- a/megatron/megatron/training/initialize.py ++++ b/megatron/megatron/training/initialize.py +@@ -145,6 +145,7 @@ def _compile_dependencies(): + "seconds".format(time.time() - start_time), + flush=True, + ) ++ return + + # ================== + # Load fused kernels +@@ -224,7 +225,7 @@ def _initialize_tp_communicators(): + + #We create a MPI process group, which is needed to bootstrap the pipelined + #tensor-model-parallel communication overlap +- torch.distributed.new_group(backend='mpi') ++ #torch.distributed.new_group(backend='mpi') + + te_module.base.initialize_ub(shape = input_shape, tp_size = args.tensor_model_parallel_size, + use_fp8 = (args.fp8 is not None) , ub_cfgs = ub_cfgs,) +@@ -354,7 +355,7 @@ def set_jit_fusion_options(): + torch._C._jit_override_can_fuse_on_cpu(False) + torch._C._jit_override_can_fuse_on_gpu(False) + torch._C._jit_set_texpr_fuser_enabled(False) +- torch._C._jit_set_nvfuser_enabled(True) ++ #torch._C._jit_set_nvfuser_enabled(True) + torch._C._debug_set_autodiff_subgraph_inlining(False) + else: + # legacy pytorch fuser +@@ -367,6 +368,7 @@ def set_jit_fusion_options(): + + + def _warmup_jit_function(): ++ return + """Compilie JIT functions before the main training steps""" + args = get_args() + if args.bf16: +-- +2.34.1