diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 63f4122e2..23b43e65f 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -344,7 +344,6 @@ def __init__( neox_args, self.attention_type, self.num_attention_heads_per_partition, - mpu=mpu, ) else: if self.use_flash_attention: diff --git a/megatron/model/utils.py b/megatron/model/utils.py index 6beac5ca2..96ba3f086 100644 --- a/megatron/model/utils.py +++ b/megatron/model/utils.py @@ -238,7 +238,7 @@ def _set_use_cache(modules, value: bool): recursive_setattr(modules, "use_cache", value, assert_type=bool) -def configure_sparse_attention(neox_args, attention_type, num_attention_heads, mpu): +def configure_sparse_attention(neox_args, attention_type, num_attention_heads): from deepspeed.ops.sparse_attention import ( SparseSelfAttention, VariableSparsityConfig, @@ -337,7 +337,6 @@ def configure_sparse_attention(neox_args, attention_type, num_attention_heads, m sparsity_config=sparsity_config, max_seq_length=neox_args.seq_length, attn_mask_mode="add", - mpu=mpu, )