From d94a59846ef656bfaa4f3d7afb882612f1d5a7db Mon Sep 17 00:00:00 2001 From: xiaoxiawu-microsoft Date: Mon, 29 Jul 2024 21:36:03 +0000 Subject: [PATCH 01/10] DeepSpeed sequence parallelism (aka Ulysses) integration with HF transformer --- src/transformers/deepspeed.py | 1 + src/transformers/integrations/__init__.py | 2 ++ src/transformers/integrations/deepspeed.py | 7 +++++++ .../modeling_flash_attention_utils.py | 18 ++++++++++++++++++ 4 files changed, 28 insertions(+) diff --git a/src/transformers/deepspeed.py b/src/transformers/deepspeed.py index 6fd22d8c5cb..56731bfaa34 100644 --- a/src/transformers/deepspeed.py +++ b/src/transformers/deepspeed.py @@ -38,4 +38,5 @@ is_deepspeed_zero3_enabled, set_hf_deepspeed_config, unset_hf_deepspeed_config, + is_deepspeed_sp_enabled, ) diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py index 4c756a23ae0..a5c6bd37e71 100755 --- a/src/transformers/integrations/__init__.py +++ b/src/transformers/integrations/__init__.py @@ -43,6 +43,7 @@ "is_deepspeed_zero3_enabled", "set_hf_deepspeed_config", "unset_hf_deepspeed_config", + "is_deepspeed_sp_enabled", ], "eetq": ["replace_with_eetq_linear"], "fbgemm_fp8": ["FbgemmFp8Linear", "replace_with_fbgemm_fp8_linear"], @@ -125,6 +126,7 @@ is_deepspeed_zero3_enabled, set_hf_deepspeed_config, unset_hf_deepspeed_config, + is_deepspeed_sp_enabled, ) from .eetq import replace_with_eetq_linear from .fbgemm_fp8 import FbgemmFp8Linear, replace_with_fbgemm_fp8_linear diff --git a/src/transformers/integrations/deepspeed.py b/src/transformers/integrations/deepspeed.py index aae1204acf4..3dea00a962f 100644 --- a/src/transformers/integrations/deepspeed.py +++ b/src/transformers/integrations/deepspeed.py @@ -440,3 +440,10 @@ def deepspeed_load_checkpoint(deepspeed_engine, checkpoint_path, load_module_str raise ValueError(f"[deepspeed] failed to resume from checkpoint {checkpoint_path}") else: raise ValueError(f"Can't find a valid checkpoint at {checkpoint_path}") + +def is_deepspeed_sp_enabled(): + if is_deepspeed_available(): + from deepspeed.utils import groups + return groups._get_sequence_parallel_world_size() > 1 + else: + return False diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index 7bb3ee03c07..7612deef374 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -22,6 +22,12 @@ from .utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal +from .integrations.deepspeed import is_deepspeed_available, is_deepspeed_sp_enabled #DeepSpeed seq parallelism (aka Ulysses) + +if is_deepspeed_available(): + from deepspeed.sequence.layer import _SeqAllToAll + from deepspeed.utils import groups as ds_comm_groups + if is_flash_attn_2_available(): from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa @@ -220,6 +226,14 @@ def _flash_attention_forward( deterministic (`bool`, *optional*): Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled. """ + if is_deepspeed_sp_enabled(): + spg = ds_comm_groups._get_sequence_parallel_group() + #qkv tensors are of shape (batch_size, seq_len, num_heads, head_dim) + #Gather on seq_len dimension, scatter on num_heads dimension + query_states = _SeqAllToAll.apply(spg, query_states, 2, 1) + key_states = _SeqAllToAll.apply(spg, key_states, 2, 1) + value_states = _SeqAllToAll.apply(spg, value_states, 2, 1) + if not use_top_left_mask: causal = is_causal else: @@ -298,4 +312,8 @@ def _flash_attention_forward( query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal, **flash_kwargs ) + if is_deepspeed_sp_enabled(): + #Gather on num_heads dimension, scatter on seq_len dimension + attn_output = _SeqAllToAll.apply(spg, attn_output, 1, 2) + return attn_output From 2cd494e372d23c8ded747f16fa5ef9ec77236635 Mon Sep 17 00:00:00 2001 From: Sam Ade Jacobs Date: Wed, 7 Aug 2024 06:23:02 +0000 Subject: [PATCH 02/10] Add deepspeed sp unit test --- tests/deepspeed/test_deepspeed.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/tests/deepspeed/test_deepspeed.py b/tests/deepspeed/test_deepspeed.py index 7b50165babf..a331fea1a9e 100644 --- a/tests/deepspeed/test_deepspeed.py +++ b/tests/deepspeed/test_deepspeed.py @@ -122,7 +122,11 @@ def require_deepspeed_aio(test_case): if is_deepspeed_available(): from deepspeed.utils import logger as deepspeed_logger # noqa from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint - from transformers.integrations.deepspeed import deepspeed_config, is_deepspeed_zero3_enabled # noqa + from transformers.integrations.deepspeed import ( + deepspeed_config, + is_deepspeed_zero3_enabled, + is_deepspeed_sp_enabled, + ) # noqa def get_launcher(distributed=False): @@ -1330,3 +1334,17 @@ def test_clm_from_config_zero3_fp16(self): with CaptureStderr() as cs: execute_subprocess_async(cmd, env=self.get_env()) self.assertIn("Detected DeepSpeed ZeRO-3", cs.err) + + @parameterized.expand([2, 4, 8, 16]) + @require_torch_multi_accelerator + def test_deepspeed_sp(self, sp_size): + #Check if deepspeed_sp is enabled + #Run deepspeed sp with 2 GPUs and different sp_size + self.assertFalse(is_deepspeed_sp_enabled()) + ds_args = [f"--sequence-length={sp_size}"] + script = [f"{self.test_file_dir_str}/test_ulysses.py"] + distributed = True + launcher = get_launcher(distributed) + + cmd = launcher + script + ds_args + execute_subprocess_async(cmd, env=self.get_env()) From ed1b2c73cdcdd61be52e7a4cd8f6baf446849ded Mon Sep 17 00:00:00 2001 From: Sam Ade Jacobs Date: Wed, 7 Aug 2024 06:24:18 +0000 Subject: [PATCH 03/10] Add deepspeed sp unit test --- tests/deepspeed/test_ulysses.py | 73 +++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 tests/deepspeed/test_ulysses.py diff --git a/tests/deepspeed/test_ulysses.py b/tests/deepspeed/test_ulysses.py new file mode 100644 index 00000000000..ea7da00ed3d --- /dev/null +++ b/tests/deepspeed/test_ulysses.py @@ -0,0 +1,73 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.distributed as dist +from deepspeed import initialize +from transformers import AutoModel +import sys + +from transformers.integrations.deepspeed import ( + is_deepspeed_sp_enabled, + ) # noqa + +from transformers.modeling_flash_attention_utils import _flash_attention_forward + + +#Call transformer flash attention with and without deepspeed sp enabled and compare they match +def test_transformer_flash_attention(seq_len=2)->None: + model = AutoModel.from_pretrained('bert-base-uncased') + batch_size = 2 + + #Test with deepspeed sp + sp_size = 2 + dp_size = 1 + ds_engine, _, _, _ = initialize(model=model, config_params={"train_batch_size": batch_size, + "data_parallel_size": dp_size, "sequence_parallel_size": sp_size},) + + assert is_deepspeed_sp_enabled() == True + + seq_len = seq_len + hidden_dim = 16 + num_heads = 4 + head_dim = hidden_dim // num_heads + #Create input tensors + input_tensor = torch.randn(batch_size, seq_len, num_heads, hidden_dim, device=ds_engine.device) + input_tensor = input_tensor.half() + attention_mask = None + q, k, v = input_tensor, input_tensor, input_tensor + + + output_tensor = _flash_attention_forward(q,k,v, attention_mask,query_length=seq_len,is_causal=False) + assert output_tensor is not None + assert output_tensor.shape == (batch_size, seq_len, num_heads, hidden_dim) + + #Now test without deepspeed sp + sp_size = 1 + dp_size = 2 + ds_engine, _, _, _ = initialize(model=model, config_params={"train_batch_size": batch_size, + "data_parallel_size": dp_size, "sequence_parallel_size": sp_size},) + assert is_deepspeed_sp_enabled() == False + + output_tensor_no_sp = _flash_attention_forward(q,k,v, attention_mask,query_length=seq_len,is_causal=False) + assert output_tensor_no_sp is not None + assert output_tensor_no_sp.shape == (batch_size, seq_len, num_heads, hidden_dim) + assert torch.allclose(output_tensor, output_tensor_no_sp) + + +if __name__ == '__main__': + torch.manual_seed(0) + seq_len = int((sys.argv[2]).split('=')[1]) + test_transformer_flash_attention(seq_len=seq_len) \ No newline at end of file From d660f11f2594b55f624308e4e1300e86b4523614 Mon Sep 17 00:00:00 2001 From: Sam Ade Jacobs Date: Fri, 9 Aug 2024 17:27:48 +0000 Subject: [PATCH 04/10] Properly document args to DS SeqAllToAll --- .../modeling_flash_attention_utils.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index 7612deef374..165e9e4b4fc 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -229,10 +229,12 @@ def _flash_attention_forward( if is_deepspeed_sp_enabled(): spg = ds_comm_groups._get_sequence_parallel_group() #qkv tensors are of shape (batch_size, seq_len, num_heads, head_dim) - #Gather on seq_len dimension, scatter on num_heads dimension - query_states = _SeqAllToAll.apply(spg, query_states, 2, 1) - key_states = _SeqAllToAll.apply(spg, key_states, 2, 1) - value_states = _SeqAllToAll.apply(spg, value_states, 2, 1) + scatter_idx = 2 #Scatter on num_heads dimension + gather_idx = 1 #Gather on seq_len dimension + batch_dim_idx = 0 #Synonymous with the batch_first==true + query_states = _SeqAllToAll.apply(spg, query_states, scatter_idx, gather_idx, batch_dim_idx) + key_states = _SeqAllToAll.apply(spg, key_states, scatter_idx, gather_idx,batch_dim_idx) + value_states = _SeqAllToAll.apply(spg, value_states, scatter_idx, gather_idx,batch_dim_idx) if not use_top_left_mask: causal = is_causal @@ -313,7 +315,9 @@ def _flash_attention_forward( ) if is_deepspeed_sp_enabled(): - #Gather on num_heads dimension, scatter on seq_len dimension - attn_output = _SeqAllToAll.apply(spg, attn_output, 1, 2) + scatter_idx = 1 #Scatter back on seq_len dimension + gather_idx = 2 #Gather on num_heads dimension + batch_dim_idx = 0 + attn_output = _SeqAllToAll.apply(spg, attn_output, scatter_idx, gather_idx,batch_dim_idx) return attn_output From 2805b7a31849996287b3eb3ee5a0a6aab6ec705f Mon Sep 17 00:00:00 2001 From: Sam Ade Jacobs Date: Fri, 9 Aug 2024 22:22:24 +0000 Subject: [PATCH 05/10] Add DS seq parallelism doc --- docs/source/en/deepspeed.md | 46 +++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/docs/source/en/deepspeed.md b/docs/source/en/deepspeed.md index 7f7995c4664..be41fb83c14 100644 --- a/docs/source/en/deepspeed.md +++ b/docs/source/en/deepspeed.md @@ -1141,6 +1141,52 @@ Using multiple GPUs with ZeRO-3 for generation requires synchronizing the GPUs b For Transformers>=4.28, if `synced_gpus` is automatically set to `True` if multiple GPUs are detected during generation. +### Non-Trainer Sequence Parallelism +DeepSpeed sequence parallelism, also known as [DeepSpeed Ulysses](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-ulysses/README.md), is compatible with HuggingFace Transformers by adding 'sequence_parallel_size' and 'data_parallel_size' to the DeepSpeed configuration. Additionally, it's required that the user’s script correctly shard the input data along the sequence dimension. + +```py +ds_config { + 'sequence_parallel_size': 2, + 'data_parallel_size': 1, + ...... + ...... +} + +config = transformers.AutoConfig.from_pretrained(model_name) + +model = AutoModelForCausalLM.from_pretrained(model_name, + config=config, + attn_implementation="flash_attention_2") + +model, _, _, _ = deepspeed.initialize(model=model, + model_parameters=model.parameters(), + config=ds_config, + dist_init_required=True,) + + +spg = model.get_sequence_parallel_group() +seq_parallel_world_size = dist.get_world_size(spg) +seq_parallel_rank = dist.get_rank(spg) + +for n, batch in enumerate(data_loader): + seq_length = batch["input_ids"].size(1) + assert seq_length % seq_parallel_world_size == 0 + sub_seq_length = seq_length // seq_parallel_world_size + sub_seq_start = seq_parallel_rank * sub_seq_length + sub_seq_end = (seq_parallel_rank + 1) * sub_seq_length + + batch["input_ids"] = batch["input_ids"][:, sub_seq_start:sub_seq_end] + batch["labels"] = batch["labels"][:, sub_seq_start:sub_seq_end] + +....... + +``` + +The HuggingFace Transformers will internally invoke DeepSpeed Ulysses to take advantage of multi-GPU optimization during the pretraining, posttraining, and fine-tuning of long context LLMs. DeepSpeed sequence parallelism is compatible with FlashAttention and is fully supported. A detailed example script is available [here](https://github.com/microsoft/DeepSpeedExamples/blob/uly-hf/post_training/sequence_parallelism/test_ulysses.py). + +Also, integration with the [`Trainer`] is underway, appropriate documentation will be updated once [`Trainer`] integration feature is available. + + ## Troubleshoot When you encounter an issue, you should consider whether DeepSpeed is the cause of the problem because often it isn't (unless it's super obviously and you can see DeepSpeed modules in the exception)! The first step should be to retry your setup without DeepSpeed, and if the problem persists, then you can report the issue. If the issue is a core DeepSpeed problem and unrelated to the Transformers integration, open an Issue on the [DeepSpeed repository](https://github.com/microsoft/DeepSpeed). From c0cce198e225ab6167b5a32e799f87d68f595bf0 Mon Sep 17 00:00:00 2001 From: Sam Ade Jacobs Date: Sun, 11 Aug 2024 16:16:42 +0000 Subject: [PATCH 06/10] Formatting --- src/transformers/integrations/__init__.py | 4 +- src/transformers/integrations/deepspeed.py | 2 + .../modeling_flash_attention_utils.py | 23 ++++--- tests/deepspeed/test_deepspeed.py | 14 ++-- tests/deepspeed/test_ulysses.py | 68 +++++++++++-------- 5 files changed, 63 insertions(+), 48 deletions(-) diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py index a5c6bd37e71..f51efce68e1 100755 --- a/src/transformers/integrations/__init__.py +++ b/src/transformers/integrations/__init__.py @@ -40,10 +40,10 @@ "deepspeed_load_checkpoint", "deepspeed_optim_sched", "is_deepspeed_available", + "is_deepspeed_sp_enabled", "is_deepspeed_zero3_enabled", "set_hf_deepspeed_config", "unset_hf_deepspeed_config", - "is_deepspeed_sp_enabled", ], "eetq": ["replace_with_eetq_linear"], "fbgemm_fp8": ["FbgemmFp8Linear", "replace_with_fbgemm_fp8_linear"], @@ -123,10 +123,10 @@ deepspeed_load_checkpoint, deepspeed_optim_sched, is_deepspeed_available, + is_deepspeed_sp_enabled, is_deepspeed_zero3_enabled, set_hf_deepspeed_config, unset_hf_deepspeed_config, - is_deepspeed_sp_enabled, ) from .eetq import replace_with_eetq_linear from .fbgemm_fp8 import FbgemmFp8Linear, replace_with_fbgemm_fp8_linear diff --git a/src/transformers/integrations/deepspeed.py b/src/transformers/integrations/deepspeed.py index 3dea00a962f..252caeb0a6d 100644 --- a/src/transformers/integrations/deepspeed.py +++ b/src/transformers/integrations/deepspeed.py @@ -441,9 +441,11 @@ def deepspeed_load_checkpoint(deepspeed_engine, checkpoint_path, load_module_str else: raise ValueError(f"Can't find a valid checkpoint at {checkpoint_path}") + def is_deepspeed_sp_enabled(): if is_deepspeed_available(): from deepspeed.utils import groups + return groups._get_sequence_parallel_world_size() > 1 else: return False diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index 165e9e4b4fc..3d60502ba3d 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -20,9 +20,12 @@ import torch import torch.nn.functional as F +from .integrations.deepspeed import ( # DeepSpeed seq parallelism (aka Ulysses) + is_deepspeed_available, + is_deepspeed_sp_enabled, +) from .utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal -from .integrations.deepspeed import is_deepspeed_available, is_deepspeed_sp_enabled #DeepSpeed seq parallelism (aka Ulysses) if is_deepspeed_available(): from deepspeed.sequence.layer import _SeqAllToAll @@ -228,13 +231,13 @@ def _flash_attention_forward( """ if is_deepspeed_sp_enabled(): spg = ds_comm_groups._get_sequence_parallel_group() - #qkv tensors are of shape (batch_size, seq_len, num_heads, head_dim) - scatter_idx = 2 #Scatter on num_heads dimension - gather_idx = 1 #Gather on seq_len dimension - batch_dim_idx = 0 #Synonymous with the batch_first==true + # qkv tensors are of shape (batch_size, seq_len, num_heads, head_dim) + scatter_idx = 2 # Scatter on num_heads dimension + gather_idx = 1 # Gather on seq_len dimension + batch_dim_idx = 0 # Synonymous with the batch_first==true query_states = _SeqAllToAll.apply(spg, query_states, scatter_idx, gather_idx, batch_dim_idx) - key_states = _SeqAllToAll.apply(spg, key_states, scatter_idx, gather_idx,batch_dim_idx) - value_states = _SeqAllToAll.apply(spg, value_states, scatter_idx, gather_idx,batch_dim_idx) + key_states = _SeqAllToAll.apply(spg, key_states, scatter_idx, gather_idx, batch_dim_idx) + value_states = _SeqAllToAll.apply(spg, value_states, scatter_idx, gather_idx, batch_dim_idx) if not use_top_left_mask: causal = is_causal @@ -315,9 +318,9 @@ def _flash_attention_forward( ) if is_deepspeed_sp_enabled(): - scatter_idx = 1 #Scatter back on seq_len dimension - gather_idx = 2 #Gather on num_heads dimension + scatter_idx = 1 # Scatter back on seq_len dimension + gather_idx = 2 # Gather on num_heads dimension batch_dim_idx = 0 - attn_output = _SeqAllToAll.apply(spg, attn_output, scatter_idx, gather_idx,batch_dim_idx) + attn_output = _SeqAllToAll.apply(spg, attn_output, scatter_idx, gather_idx, batch_dim_idx) return attn_output diff --git a/tests/deepspeed/test_deepspeed.py b/tests/deepspeed/test_deepspeed.py index a331fea1a9e..bde09e95f33 100644 --- a/tests/deepspeed/test_deepspeed.py +++ b/tests/deepspeed/test_deepspeed.py @@ -123,10 +123,10 @@ def require_deepspeed_aio(test_case): from deepspeed.utils import logger as deepspeed_logger # noqa from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint from transformers.integrations.deepspeed import ( - deepspeed_config, - is_deepspeed_zero3_enabled, - is_deepspeed_sp_enabled, - ) # noqa + deepspeed_config, + is_deepspeed_zero3_enabled, + is_deepspeed_sp_enabled, + ) # noqa def get_launcher(distributed=False): @@ -1334,12 +1334,12 @@ def test_clm_from_config_zero3_fp16(self): with CaptureStderr() as cs: execute_subprocess_async(cmd, env=self.get_env()) self.assertIn("Detected DeepSpeed ZeRO-3", cs.err) - + @parameterized.expand([2, 4, 8, 16]) @require_torch_multi_accelerator def test_deepspeed_sp(self, sp_size): - #Check if deepspeed_sp is enabled - #Run deepspeed sp with 2 GPUs and different sp_size + # Check if deepspeed_sp is enabled + # Run deepspeed sp with 2 GPUs and different sp_size self.assertFalse(is_deepspeed_sp_enabled()) ds_args = [f"--sequence-length={sp_size}"] script = [f"{self.test_file_dir_str}/test_ulysses.py"] diff --git a/tests/deepspeed/test_ulysses.py b/tests/deepspeed/test_ulysses.py index ea7da00ed3d..06963f9c5c0 100644 --- a/tests/deepspeed/test_ulysses.py +++ b/tests/deepspeed/test_ulysses.py @@ -13,61 +13,71 @@ # limitations under the License. +import sys + import torch -import torch.distributed as dist from deepspeed import initialize -from transformers import AutoModel -import sys +from transformers import AutoModel from transformers.integrations.deepspeed import ( is_deepspeed_sp_enabled, - ) # noqa - +) # noqa from transformers.modeling_flash_attention_utils import _flash_attention_forward -#Call transformer flash attention with and without deepspeed sp enabled and compare they match -def test_transformer_flash_attention(seq_len=2)->None: - model = AutoModel.from_pretrained('bert-base-uncased') +# Call transformer flash attention with and without deepspeed sp enabled and compare they match +def test_transformer_flash_attention(seq_len=2) -> None: + model = AutoModel.from_pretrained("bert-base-uncased") batch_size = 2 - #Test with deepspeed sp + # Test with deepspeed sp sp_size = 2 dp_size = 1 - ds_engine, _, _, _ = initialize(model=model, config_params={"train_batch_size": batch_size, - "data_parallel_size": dp_size, "sequence_parallel_size": sp_size},) + ds_engine, _, _, _ = initialize( + model=model, + config_params={ + "train_batch_size": batch_size, + "data_parallel_size": dp_size, + "sequence_parallel_size": sp_size, + }, + ) - assert is_deepspeed_sp_enabled() == True + assert is_deepspeed_sp_enabled() seq_len = seq_len hidden_dim = 16 num_heads = 4 head_dim = hidden_dim // num_heads - #Create input tensors - input_tensor = torch.randn(batch_size, seq_len, num_heads, hidden_dim, device=ds_engine.device) - input_tensor = input_tensor.half() - attention_mask = None + # Create input tensors + input_tensor = torch.randn(batch_size, seq_len, num_heads, head_dim, device=ds_engine.device) + input_tensor = input_tensor.half() + attention_mask = None q, k, v = input_tensor, input_tensor, input_tensor - - output_tensor = _flash_attention_forward(q,k,v, attention_mask,query_length=seq_len,is_causal=False) + output_tensor = _flash_attention_forward(q, k, v, attention_mask, query_length=seq_len, is_causal=False) assert output_tensor is not None - assert output_tensor.shape == (batch_size, seq_len, num_heads, hidden_dim) + assert output_tensor.shape == (batch_size, seq_len, num_heads, head_dim) - #Now test without deepspeed sp + # Now test without deepspeed sp sp_size = 1 dp_size = 2 - ds_engine, _, _, _ = initialize(model=model, config_params={"train_batch_size": batch_size, - "data_parallel_size": dp_size, "sequence_parallel_size": sp_size},) - assert is_deepspeed_sp_enabled() == False + ds_engine, _, _, _ = initialize( + model=model, + config_params={ + "train_batch_size": batch_size, + "data_parallel_size": dp_size, + "sequence_parallel_size": sp_size, + }, + ) + assert not is_deepspeed_sp_enabled() - output_tensor_no_sp = _flash_attention_forward(q,k,v, attention_mask,query_length=seq_len,is_causal=False) + output_tensor_no_sp = _flash_attention_forward(q, k, v, attention_mask, query_length=seq_len, is_causal=False) assert output_tensor_no_sp is not None - assert output_tensor_no_sp.shape == (batch_size, seq_len, num_heads, hidden_dim) + assert output_tensor_no_sp.shape == (batch_size, seq_len, num_heads, head_dim) assert torch.allclose(output_tensor, output_tensor_no_sp) - -if __name__ == '__main__': + +if __name__ == "__main__": torch.manual_seed(0) - seq_len = int((sys.argv[2]).split('=')[1]) - test_transformer_flash_attention(seq_len=seq_len) \ No newline at end of file + seq_len = int((sys.argv[2]).split("=")[1]) + test_transformer_flash_attention(seq_len=seq_len) From 82ab867447315150e2aed2636a5a6d6855f149e4 Mon Sep 17 00:00:00 2001 From: Sam Ade Jacobs Date: Mon, 19 Aug 2024 18:24:13 +0000 Subject: [PATCH 07/10] isort fix --- tests/deepspeed/test_ulysses.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/deepspeed/test_ulysses.py b/tests/deepspeed/test_ulysses.py index 06963f9c5c0..aa479c448dc 100644 --- a/tests/deepspeed/test_ulysses.py +++ b/tests/deepspeed/test_ulysses.py @@ -19,10 +19,9 @@ from deepspeed import initialize from transformers import AutoModel -from transformers.integrations.deepspeed import ( - is_deepspeed_sp_enabled, -) # noqa -from transformers.modeling_flash_attention_utils import _flash_attention_forward +from transformers.integrations.deepspeed import is_deepspeed_sp_enabled # noqa +from transformers.modeling_flash_attention_utils import \ + _flash_attention_forward # Call transformer flash attention with and without deepspeed sp enabled and compare they match From 0918a67344566bbfe7208fb9d62c5e34c2ed4358 Mon Sep 17 00:00:00 2001 From: Sam Ade Jacobs Date: Tue, 20 Aug 2024 18:33:11 +0000 Subject: [PATCH 08/10] Quality fix --- tests/deepspeed/test_ulysses.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/deepspeed/test_ulysses.py b/tests/deepspeed/test_ulysses.py index aa479c448dc..51dcb746229 100644 --- a/tests/deepspeed/test_ulysses.py +++ b/tests/deepspeed/test_ulysses.py @@ -20,8 +20,7 @@ from transformers import AutoModel from transformers.integrations.deepspeed import is_deepspeed_sp_enabled # noqa -from transformers.modeling_flash_attention_utils import \ - _flash_attention_forward +from transformers.modeling_flash_attention_utils import _flash_attention_forward # Call transformer flash attention with and without deepspeed sp enabled and compare they match From 9ea2571e79729616c86a6e5dc0ab942a8f4447fc Mon Sep 17 00:00:00 2001 From: Sam Ade Jacobs Date: Wed, 21 Aug 2024 10:47:05 -0700 Subject: [PATCH 09/10] Update test_deepspeed.py Make torch a requirement for deepspeed sp --- tests/deepspeed/test_deepspeed.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/deepspeed/test_deepspeed.py b/tests/deepspeed/test_deepspeed.py index bde09e95f33..5902de61ced 100644 --- a/tests/deepspeed/test_deepspeed.py +++ b/tests/deepspeed/test_deepspeed.py @@ -1336,6 +1336,7 @@ def test_clm_from_config_zero3_fp16(self): self.assertIn("Detected DeepSpeed ZeRO-3", cs.err) @parameterized.expand([2, 4, 8, 16]) + @require_torch_accelerator @require_torch_multi_accelerator def test_deepspeed_sp(self, sp_size): # Check if deepspeed_sp is enabled From 8766b91939b27546f015361c77452a502ce92278 Mon Sep 17 00:00:00 2001 From: Sam Ade Jacobs Date: Wed, 2 Oct 2024 14:44:02 -0700 Subject: [PATCH 10/10] Update deepspeed.md Provide more clarity on the need for sequence parallelism. --- docs/source/en/deepspeed.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/deepspeed.md b/docs/source/en/deepspeed.md index be41fb83c14..42958eed4f5 100644 --- a/docs/source/en/deepspeed.md +++ b/docs/source/en/deepspeed.md @@ -1142,7 +1142,7 @@ Using multiple GPUs with ZeRO-3 for generation requires synchronizing the GPUs b For Transformers>=4.28, if `synced_gpus` is automatically set to `True` if multiple GPUs are detected during generation. ### Non-Trainer Sequence Parallelism -DeepSpeed sequence parallelism, also known as [DeepSpeed Ulysses](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-ulysses/README.md), is compatible with HuggingFace Transformers by adding 'sequence_parallel_size' and 'data_parallel_size' to the DeepSpeed configuration. Additionally, it's required that the user’s script correctly shard the input data along the sequence dimension. +DeepSpeed sequence parallelism, also known as [DeepSpeed Ulysses](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-ulysses/README.md), is a distributed training technique targeting long context LLM problems. Sequence parallelism would allow for a virtually indefinite growth in sequence length and model size with an increase in GPUs, unlimited by single GPU memory. DeepSpeed sequence parallelism is compatible with HuggingFace Transformers by adding 'sequence_parallel_size' and 'data_parallel_size' to the DeepSpeed configuration. Additionally, it's required that the user’s script correctly shard the input data along the sequence dimension. ```py ds_config {