diff --git a/examples/text-generation/README.md b/examples/text-generation/README.md index 440b18713c..efee8aac45 100755 --- a/examples/text-generation/README.md +++ b/examples/text-generation/README.md @@ -264,6 +264,35 @@ set the following environment variables before running the command: `PT_ENABLE_I You will also need to add `--torch_compile` in your command. +### Running with Tesor parallel strategy +#### Attribution + +This repository includes code from the [foundation-model-stack](https://github.com/foundation-model-stack/foundation-model-stack) repository, which is licensed under the Apache License 2.0. See the `LICENSE` file for more details. + +torch.compile with tensor parallel strategy is an experimental feature. It has not been validated for all models. To enable +torch.compile with tensor parallel strategy, please set the following environment variables before running the +command: `PT_ENABLE_INT64_SUPPORT=1` and `PT_HPU_LAZY_MODE=0`. This will enable tensor parallel strategy without deepspeed. + +You will also need to add `--torch_compile` and `--distributed_strategy="tp"` in your command. + +Here is an example: +```bash +python ../gaudi_spawn.py --world_size 8 run_generation.py \ +--model_name_or_path meta-llama/Llama-2-70b-hf \ +--trim_logits \ +--use_kv_cache \ +--attn_softmax_bf16 \ +--bf16 \ +--bucket_internal \ +--bucket_size=128 \ +--use_flash_attention \ +--flash_attention_recompute \ +--batch_size 246 \ +--max_input_tokens 2048 \ +--max_new_tokens 2048 \ +--torch_compile \ +--distributed_strategy="tp" +``` ### Running with FP8 diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index f3ccfc12c5..91f6d23a52 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -245,59 +245,35 @@ def setup_model(args, model_dtype, model_kwargs, logger): # assistant_model = get_torch_compiled_model(assistant_model) return model, assistant_model -def setup_distributed_model_tp(args, model_dtype, model_kwargs, logger): - from optimum.habana.distributed import serialization +def setup_distributed_model_tp(args, model_dtype, model_kwargs, logger): from typing import Any, MutableMapping - from optimum.habana.distributed import tp_wrapping - from optimum.habana.distributed.strategy import DistributedStrategy - from torch import nn - - class TensorParallelStrategy(DistributedStrategy): - def __init__(self, group=None, from_meta=False): - super().__init__(from_meta) - assert torch.distributed.is_initialized(), "must initialize a process group" - self.group = group if group is not None else torch.distributed.GroupMember.WORLD - - def distribute_module( - self, module: nn.Module, final_layers: bool = False - ) -> nn.Module: - return tp_wrapping.apply_tp(module, self.group) - - def distribute_layer(self, block: nn.Module, layer: int) -> nn.Module: - return tp_wrapping.apply_tp(block, layer, self.group) - - def __getstate__(self): - state = self.__dict__.copy() - state['group'] = None # Remove ProcessGroup from state - return state - - def __setstate__(self, state): - self.__dict__.update(state) - self.group = None # Restore to default state or reinitialize + from optimum.habana.distributed import serialization + from optimum.habana.distributed.strategy import TensorParallelStrategy logger.info("Multi-device run.") + assert args.quant_config == "", "Fp8 is not enabled, unset QUANT_CONFIG" assert args.assistant_model is None, "Assistant model must be None" - + from torch import distributed as dist - if args.device == 'hpu': - import habana_frameworks.torch.distributed.hccl - dist.init_process_group(backend='hccl') + + if args.device == "hpu": + dist.init_process_group(backend="hccl") else: - dist.init_process_group() - + assert False, "Supports TP only on HPU" + torch._C._distributed_c10d._register_process_group("default", dist.group.WORLD) logger.info("Creating Model") - config = AutoConfig.from_pretrained(args.model_name_or_path,torch_dtype=model_dtype, **model_kwargs) - model_kwargs={} + config = AutoConfig.from_pretrained(args.model_name_or_path, torch_dtype=model_dtype, **model_kwargs) + model_kwargs = {} model_kwargs["distributed_strategy"] = TensorParallelStrategy() model = AutoModelForCausalLM.from_config(config, torch_dtype=model_dtype, **model_kwargs) initial_device = torch.device("cpu") - source="hf" - checkpoint_sharding=None + source = "hf" + checkpoint_sharding = None lazy_sd: MutableMapping[str, Any] = {} logger.info("Loading Checkpoints") lazy_sd = serialization.load_state_dict( @@ -309,7 +285,7 @@ def __setstate__(self, state): rank=args.global_rank, world_size=args.world_size, ) - architecture="llama" + architecture = "llama" if len(lazy_sd): serialization.load_state_dict_into_model( model, @@ -323,18 +299,12 @@ def __setstate__(self, state): args.world_size, ) - if args.quant_config: - model = setup_quantization(model, args) - model = model.eval().to(args.device) if args.use_hpu_graphs: from habana_frameworks.torch.hpu import wrap_in_hpu_graph - if check_habana_frameworks_version("1.13.0") and model.config.model_type == "falcon": - model = wrap_in_hpu_graph(model, hash_with_views=False) - else: - model = wrap_in_hpu_graph(model) + model = wrap_in_hpu_graph(model) if args.torch_compile and model.config.model_type == "llama": model = get_torch_compiled_model(model) @@ -604,7 +574,8 @@ def initialize_model(args, logger): model, assistant_model = ( setup_model(args, model_dtype, model_kwargs, logger) if not use_deepspeed - else setup_distributed_model(args, model_dtype, model_kwargs, logger) if not args.distributed_strategy == "tp" + else setup_distributed_model(args, model_dtype, model_kwargs, logger) + if not args.distributed_strategy == "tp" else setup_distributed_model_tp(args, model_dtype, model_kwargs, logger) ) tokenizer, model, assistant_model = setup_tokenizer(args, model, assistant_model) diff --git a/optimum/habana/distributed/__init__.py b/optimum/habana/distributed/__init__.py index 12edd6620e..af269ee68c 100644 --- a/optimum/habana/distributed/__init__.py +++ b/optimum/habana/distributed/__init__.py @@ -1,8 +1,11 @@ -from .distributed_runner import DistributedRunner -from .fast_ddp import all_reduce_gradients import os + import torch +from .distributed_runner import DistributedRunner +from .fast_ddp import all_reduce_gradients + + def rank_and_world(group=None): """ Returns (rank, world_size) from the optionally-specified group, otherwise diff --git a/optimum/habana/distributed/serialization.py b/optimum/habana/distributed/serialization.py index c543ab20bd..bf59fb2445 100644 --- a/optimum/habana/distributed/serialization.py +++ b/optimum/habana/distributed/serialization.py @@ -1,3 +1,20 @@ +# Copyright 2024 The Foundation Model Stack Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file has been modified from its original version. +# The original version can be found at https://github.com/foundation-model-stack/foundation-model-stack + import collections import os from collections import ChainMap @@ -7,7 +24,7 @@ import torch -from optimum.habana.distributed.tp import TPModule +from .tp import TPModule __adapters: MutableMapping[str, MutableMapping[str, Callable[[Mapping], Mapping]]] = {} @@ -34,9 +51,7 @@ def register_adapter( sources = __adapters[architecture] if source in sources: - raise KeyError( - f"Variant {source} already registered for architecture {architecture}" - ) + raise KeyError(f"Variant {source} already registered for architecture {architecture}") sources[source] = adapter __adapters[architecture] = sources @@ -55,14 +70,8 @@ def list_sources(architecture: str): return list(__adapters[architecture].keys()) -def _get_adapter( - architecture: str, source: Optional[str] -) -> Callable[[Mapping[str, Any]], Mapping[str, Any]]: - if ( - source is None - or architecture not in __adapters - or source not in __adapters[architecture] - ): +def _get_adapter(architecture: str, source: Optional[str]) -> Callable[[Mapping[str, Any]], Mapping[str, Any]]: + if source is None or architecture not in __adapters or source not in __adapters[architecture]: # if no adapter is registered, assume the attributes are already in # fms format. # should we raise an error here instead? @@ -71,9 +80,7 @@ def _get_adapter( return __adapters[architecture][source] -def get_adapted( - architecture: str, source: Optional[str], state_dict: Mapping[str, Any] -) -> Mapping[str, Any]: +def get_adapted(architecture: str, source: Optional[str], state_dict: Mapping[str, Any]) -> Mapping[str, Any]: """ Convert a state dict to FMS format, using an adapter specified by name. @@ -91,18 +98,11 @@ def get_adapted( return adapted -# `models` imports each model class, causing models and adapters to be registered. -# down here to avoid circular dependencies. -# from fms import models - - def _get_safetensors_item(key, file: Path, device: torch.device) -> torch.Tensor: from safetensors import safe_open # type: ignore[import-untyped] with torch.no_grad(): - with safe_open( - file, framework="pt", device=str(device) - ) as model_weights: # type: ignore[attr-defined] + with safe_open(file, framework="pt", device=str(device)) as model_weights: # type: ignore[attr-defined] return model_weights.get_tensor(key) @@ -153,7 +153,7 @@ def load_state_dict( if model_path is None or initial_device.type == "meta": return {} if checkpoint_sharding == "fsdp" and distributed_strategy not in ["fsdp", "hsdp"]: - raise ValueError(f"FSDP checkpoints can only be loaded into an FSDP model") + raise ValueError("FSDP checkpoints can only be loaded into an FSDP model") if checkpoint_sharding == "tp" and distributed_strategy != "tp": raise ValueError("TP checkpoints can only be loaded into a TP model") @@ -188,13 +188,11 @@ def load_state_dict( checkpoints = [model_path] # Check if we found some files - assert ( - len(checkpoints) > 0 - ), f"Can't find the requested checkpoint data at {model_path}" + assert len(checkpoints) > 0, f"Can't find the requested checkpoint data at {model_path}" if checkpoint_sharding is not None and checkpoint_sharding != "layer": - assert world_size == len( - checkpoints + assert ( + world_size == len(checkpoints) ), f"Loading a {checkpoint_sharding}-sharded checkpoint with len={len(checkpoints)} but world size is {world_size}" checkpoints = [checkpoints[rank]] @@ -304,13 +302,9 @@ def load_state_dict_into_model( used_keys.add(weight) partial_sd[weight] = state_dict[weight] if partial_sd[weight].device != initial_device: - partial_sd[weight] = partial_sd[weight].to( - device=initial_device - ) + partial_sd[weight] = partial_sd[weight].to(device=initial_device) fms_partial_sd = adapter(partial_sd) - _load_partial_state_dict( - model, fms_partial_sd, needs_tp_sharding, rank, world_size - ) + _load_partial_state_dict(model, fms_partial_sd, needs_tp_sharding, rank, world_size) for p_key in partial_sd.keys(): if isinstance(state_dict, ChainMap): for child_sd in state_dict.maps: @@ -341,17 +335,11 @@ def _copy_colwise(param: torch.nn.Parameter, tensor_value, is_bias, rank, world_ output_size_per_partition = param.shape[0] if not is_bias: tensor = tensor_value[ - (rank * output_size_per_partition) : ( - (rank + 1) * output_size_per_partition - ), + (rank * output_size_per_partition) : ((rank + 1) * output_size_per_partition), :, ] else: - tensor = tensor_value[ - (rank * output_size_per_partition) : ( - (rank + 1) * output_size_per_partition - ) - ] + tensor = tensor_value[(rank * output_size_per_partition) : ((rank + 1) * output_size_per_partition)] param.copy_(tensor, non_blocking=True) @@ -376,9 +364,7 @@ def _copy_rowwise(param: torch.nn.Parameter, tensor_value, is_bias, rank, world_ output_size_per_partition = param.shape[1] tensor = tensor_value[ :, - (rank * output_size_per_partition) : ( - (rank + 1) * output_size_per_partition - ), + (rank * output_size_per_partition) : ((rank + 1) * output_size_per_partition), ] param.copy_(tensor, non_blocking=True) else: diff --git a/optimum/habana/distributed/strategy.py b/optimum/habana/distributed/strategy.py index 4bc68bbca7..91b3f00232 100644 --- a/optimum/habana/distributed/strategy.py +++ b/optimum/habana/distributed/strategy.py @@ -1,17 +1,33 @@ +# Copyright 2024 The Foundation Model Stack Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file has been modified from its original version. +# The original version can be found at https://github.com/foundation-model-stack/foundation-model-stack + from abc import abstractmethod -from typing import Any, List, Mapping +from typing import List import torch import torch.distributed from torch import nn + class DistributedStrategy: def __init__(self, from_meta=False): self.from_meta = from_meta - def distribute_module( - self, module: nn.Module, final_layers: bool = False - ) -> nn.Module: + def distribute_module(self, module: nn.Module, final_layers: bool = False) -> nn.Module: """ Optionally a distributed strategy may distribute modules that are not numbered layers @@ -30,9 +46,7 @@ class NotDistributed(DistributedStrategy): def __init__(self, from_meta=False): super().__init__(from_meta) - def distribute_module( - self, module: nn.Module, final_layers: bool = False - ) -> nn.Module: + def distribute_module(self, module: nn.Module, final_layers: bool = False) -> nn.Module: return module def distribute_layer(self, block: nn.Module, layer: int) -> nn.Module: @@ -54,17 +68,8 @@ def __init__(self, module: nn.Module, device): def forward(self, *args, **kwargs): device = self.device - args = [ - arg.to(device) if isinstance(arg, torch.Tensor) else arg for arg in args - ] - kwargs = { - k: ( - kwargs[k].to(device) - if isinstance(kwargs[k], torch.Tensor) - else kwargs[k] - ) - for k in kwargs - } + args = [arg.to(device) if isinstance(arg, torch.Tensor) else arg for arg in args] + kwargs = {k: (kwargs[k].to(device) if isinstance(kwargs[k], torch.Tensor) else kwargs[k]) for k in kwargs} return self.module(*args, **kwargs) @@ -92,9 +97,7 @@ def distribute_layer(self, block: nn.Module, layer: int) -> nn.Module: wrapped = DeviceMover(block, device) return wrapped - def distribute_module( - self, module: nn.Module, final_layers: bool = False - ) -> nn.Module: + def distribute_module(self, module: nn.Module, final_layers: bool = False) -> nn.Module: if final_layers: device = self.layer_to_device[len(self.layer_to_device) - 1] else: @@ -105,4 +108,27 @@ def distribute_module( return wrapped +class TensorParallelStrategy(DistributedStrategy): + def __init__(self, group=None, from_meta=False): + super().__init__(from_meta) + assert torch.distributed.is_initialized(), "must initialize a process group" + self.group = group if group is not None else torch.distributed.GroupMember.WORLD + + def distribute_module(self, module: nn.Module, final_layers: bool = False) -> nn.Module: + from optimum.habana.distributed import tp_wrapping + + return tp_wrapping.apply_tp(module, self.group) + + def distribute_layer(self, block: nn.Module, layer: int) -> nn.Module: + from optimum.habana.distributed import tp_wrapping + + return tp_wrapping.apply_tp(block, layer, self.group) + + def __getstate__(self): + state = self.__dict__.copy() + state["group"] = None # Remove ProcessGroup from state + return state + def __setstate__(self, state): + self.__dict__.update(state) + self.group = None # Restore to default state or reinitialize diff --git a/optimum/habana/distributed/tensorparallel.py b/optimum/habana/distributed/tensorparallel.py index 5d484b2fc5..09a205e451 100644 --- a/optimum/habana/distributed/tensorparallel.py +++ b/optimum/habana/distributed/tensorparallel.py @@ -1,10 +1,24 @@ -# mypy: disable-error-code="method-assign,misc" +# Copyright 2024 The Foundation Model Stack Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file has been modified from its original version. +# The original version can be found at https://github.com/foundation-model-stack/foundation-model-stack import torch import torch._inductor.ir as ir import torch._inductor.lowering as lowering import torch.distributed as dist -import torch.distributed._functional_collectives as funcol from torch import nn @@ -12,9 +26,7 @@ def apply_colwise_tp(par_mod: nn.Linear, mod: nn.Linear, world_size, rank): # Divide the weight matrix along the last dimension. output_size_per_partition = mod.out_features // world_size with torch.no_grad(): - par_mod.weight.copy_( - torch.split(mod.weight, output_size_per_partition, dim=0)[rank] - ) + par_mod.weight.copy_(torch.split(mod.weight, output_size_per_partition, dim=0)[rank]) if par_mod.bias is not None: par_mod.bias.copy_(torch.split(mod.bias, output_size_per_partition)[rank]) @@ -23,9 +35,7 @@ def apply_rowwise_tp(par_mod: nn.Linear, mod: nn.Linear, world_size, rank): # Divide the weight matrix along the last dimension. output_size_per_partition = mod.in_features // world_size with torch.no_grad(): - par_mod.weight.copy_( - torch.split(mod.weight, output_size_per_partition, dim=1)[rank] - ) + par_mod.weight.copy_(torch.split(mod.weight, output_size_per_partition, dim=1)[rank]) if par_mod.bias is not None: if rank == 0: par_mod.bias.copy_(mod.bias) @@ -37,9 +47,7 @@ def apply_embedding_tp(par_mod: nn.Embedding, mod: nn.Embedding, world_size, ran # Divide the weight matrix along the last dimension. output_size_per_partition = mod.embedding_dim // world_size with torch.no_grad(): - par_mod.weight.copy_( - torch.split(mod.weight, output_size_per_partition, dim=1)[rank] - ) + par_mod.weight.copy_(torch.split(mod.weight, output_size_per_partition, dim=1)[rank]) ## Fixes for PT 2.2 collectives until PT 2.3 is released @@ -72,6 +80,7 @@ def get_volatile_reads_fixed(self): if other_fn in lowering.lowerings: del lowering.lowerings[other_fn] + def _all_reduce(input_: torch.Tensor) -> torch.Tensor: """All-reduce the input tensor across model parallel group.""" world_size = dist.get_world_size() @@ -80,9 +89,8 @@ def _all_reduce(input_: torch.Tensor) -> torch.Tensor: return input_ # Starting PT 2.3, we can go back to funcol.all_reduce - return torch.ops._c10d_functional.wait_tensor( - torch.ops._c10d_functional.all_reduce(input_, "sum", "default") - ) + return torch.ops._c10d_functional.wait_tensor(torch.ops._c10d_functional.all_reduce(input_, "sum", "default")) + class _ReduceFromModelParallelRegion(torch.autograd.Function): """All-reduce the input from the model parallel region.""" @@ -99,5 +107,6 @@ def forward(ctx, input_): def backward(ctx, grad_output): return grad_output + def reduce_from_tensor_model_parallel_region(input_): return _ReduceFromModelParallelRegion.apply(input_) diff --git a/optimum/habana/distributed/tp.py b/optimum/habana/distributed/tp.py index 31f33a79cc..c4f156fa61 100644 --- a/optimum/habana/distributed/tp.py +++ b/optimum/habana/distributed/tp.py @@ -1,12 +1,29 @@ +# Copyright 2024 The Foundation Model Stack Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file has been modified from its original version. +# The original version can be found at https://github.com/foundation-model-stack/foundation-model-stack + import itertools from abc import ABCMeta, abstractmethod -from typing import List, Type +from typing import List import torch import torch.nn as nn from torch.distributed.distributed_c10d import ProcessGroup -from optimum.habana.distributed.tensorparallel import ( +from .tensorparallel import ( apply_colwise_tp, apply_embedding_tp, apply_rowwise_tp, @@ -76,7 +93,7 @@ def import_weights(self, module: nn.Module): ) with torch.no_grad(): for mod_name, module in self.named_children(): - if not mod_name in tp_sharded_modules: + if mod_name not in tp_sharded_modules: for param_name, param in module.named_parameters(recurse=False): param.copy_( getattr(getattr(module, mod_name), param_name), diff --git a/optimum/habana/distributed/tp_wrapping.py b/optimum/habana/distributed/tp_wrapping.py index 402accb342..761fa7bff4 100644 --- a/optimum/habana/distributed/tp_wrapping.py +++ b/optimum/habana/distributed/tp_wrapping.py @@ -1,21 +1,36 @@ -import os +# Copyright 2024 The Foundation Model Stack Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file has been modified from its original version. +# The original version can be found at https://github.com/foundation-model-stack/foundation-model-stack + from torch import nn from torch.distributed.distributed_c10d import ProcessGroup -from optimum.habana.transformers.models.llama.modeling_llama import ( - GaudiLlamaMLP, - TPGaudiLlamaMLP, +from ..transformers.models.llama.modeling_llama import ( GaudiLlamaAttention, - TPGaudiLlamaAttention + GaudiLlamaMLP, + TPGaudiLlamaAttention, + TPGaudiLlamaMLP, ) -# this probably belongs somewhere else but can't go in fms.distribtued b/c -# circular dependency. + def _tp_wrapped(module: nn.Module, layer: int, group: ProcessGroup): if hasattr(module, "to_tp"): return module.to_tp(group) elif isinstance(module, GaudiLlamaAttention): - return TPGaudiLlamaAttention.import_module(module,layer, group) + return TPGaudiLlamaAttention.import_module(module, layer, group) elif isinstance(module, GaudiLlamaMLP): return TPGaudiLlamaMLP.import_module(module, group) else: @@ -23,7 +38,7 @@ def _tp_wrapped(module: nn.Module, layer: int, group: ProcessGroup): def apply_tp(model: nn.Module, layer_idx: int, group: ProcessGroup): - wrapped = _tp_wrapped(model, layer_idx, group) + wrapped = _tp_wrapped(model, layer_idx, group) if wrapped is not model: return wrapped diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 715244adfc..24b7e8cf2e 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -1,3 +1,4 @@ +import copy import math import os import warnings @@ -5,6 +6,7 @@ import torch import torch.nn.functional as F +from torch.distributed.distributed_c10d import ProcessGroup from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache, StaticCache from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast @@ -20,20 +22,16 @@ logger, ) -import copy -from torch.distributed.distributed_c10d import ProcessGroup +from .... import distributed +from ....distributed.strategy import DistributedStrategy, NoOpStrategy +from ....distributed.tensorparallel import ( + reduce_from_tensor_model_parallel_region, +) +from ....distributed.tp import TPModule from ...modeling_attn_mask_utils import ( _gaudi_prepare_4d_causal_attention_mask, ) -from optimum.habana.distributed.tp import TPModule -from optimum.habana import distributed -from optimum.habana.distributed.tensorparallel import ( - reduce_from_tensor_model_parallel_region, -) - -from optimum.habana.distributed.strategy import DistributedStrategy -from optimum.habana.distributed.strategy import NoOpStrategy try: from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE @@ -198,6 +196,7 @@ def post_mlp_forward(self, x): return self.down_proj.post_all_reduce(x) return x + class TPGaudiLlamaMLP(GaudiLlamaMLP, TPModule): def __init__( self, @@ -207,16 +206,11 @@ def __init__( assert torch.distributed.is_initialized() rank, world_size = distributed.rank_and_world(group) hidden_dim = int(config.hidden_grow_factor * config.hidden_size) - assert ( - hidden_dim % world_size == 0 - ), "Hidden dim must be divisible by world size" + assert hidden_dim % world_size == 0, "Hidden dim must be divisible by world size" self.config = copy.deepcopy(config) self.config.intermediate_size = int((config.hidden_grow_factor / world_size) * config.hidden_size) - GaudiLlamaMLP.__init__( - self, - self.config - ) + GaudiLlamaMLP.__init__(self, self.config) self.setup_tp(rank, world_size) def colwise_param_names(self) -> List[str]: @@ -229,16 +223,14 @@ def rowwise_param_names(self) -> List[str]: def import_module(glu: GaudiLlamaMLP, group: ProcessGroup) -> "TPGaudiLlamaMLP": config = copy.deepcopy(glu.config) config.hidden_grow_factor = glu.config.intermediate_size / glu.config.hidden_size - tp_glu = TPGaudiLlamaMLP( - config = config, - group=group - ) + tp_glu = TPGaudiLlamaMLP(config=config, group=group) return tp_glu def pre_mlp_forward(self, x): out_par = GaudiLlamaMLP.pre_mlp_forward(self, x) return reduce_from_tensor_model_parallel_region(out_par) + def gaudi_llama_repeat_kv( query_states: torch.Tensor, key_states: torch.Tensor, @@ -596,28 +588,43 @@ def post_attn_forward(self, attn_output): class TPGaudiLlamaAttention(GaudiLlamaAttention, TPModule): - def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None, group: Optional[ProcessGroup] = None,): + def __init__( + self, + config: LlamaConfig, + layer_idx: Optional[int] = None, + group: Optional[ProcessGroup] = None, + ): super().__init__(config, layer_idx) assert torch.distributed.is_initialized() rank, world_size = distributed.rank_and_world(group) - assert ( - config.num_attention_heads % world_size == 0 - ), "The number of heads must be divisible by world size" + assert config.num_attention_heads % world_size == 0, "The number of heads must be divisible by world size" self.config = copy.deepcopy(config) self.pre_tp_kvheads = config.num_key_value_heads - GaudiLlamaAttention.__init__(self, self.config , layer_idx) - self.config.num_attention_heads = self.config.num_attention_heads // world_size - self.config.num_key_value_heads = ( self.config.num_key_value_heads // world_size) if self.config.num_key_value_heads > 1 else self.config.num_key_value_heads + GaudiLlamaAttention.__init__(self, self.config, layer_idx) + self.config.num_attention_heads = self.config.num_attention_heads // world_size + self.config.num_key_value_heads = ( + (self.config.num_key_value_heads // world_size) + if self.config.num_key_value_heads > 1 + else self.config.num_key_value_heads + ) self.head_dim = config.hidden_size // config.num_attention_heads self.hidden_size = self.config.hidden_size // world_size self.num_heads = self.config.num_attention_heads - self.q_proj = torch.nn.Linear(config.hidden_size, self.config.num_attention_heads * self.head_dim , bias=config.attention_bias) - self.k_proj = torch.nn.Linear(config.hidden_size, self.config.num_key_value_heads * self.head_dim , bias=config.attention_bias) - self.v_proj = torch.nn.Linear(config.hidden_size, self.config.num_key_value_heads * self.head_dim , bias=config.attention_bias) - self.o_proj = torch.nn.Linear(self.config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias) + self.q_proj = torch.nn.Linear( + config.hidden_size, self.config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = torch.nn.Linear( + config.hidden_size, self.config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = torch.nn.Linear( + config.hidden_size, self.config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = torch.nn.Linear( + self.config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) self.norm_factor = 1.0 / math.sqrt(self.head_dim) self.setup_tp(rank, world_size) @@ -632,14 +639,8 @@ def rowwise_param_names(self) -> List[str]: return ["o_proj"] @staticmethod - def import_module( - mha: GaudiLlamaAttention, layer_idx, group: ProcessGroup - ) -> "TPGaudiLlamaAttention": - tp_mha = TPGaudiLlamaAttention( - config = mha.config, - layer_idx=layer_idx, - group=group - ) + def import_module(mha: GaudiLlamaAttention, layer_idx, group: ProcessGroup) -> "TPGaudiLlamaAttention": + tp_mha = TPGaudiLlamaAttention(config=mha.config, layer_idx=layer_idx, group=group) return tp_mha def pre_attn_forward( @@ -661,8 +662,8 @@ def pre_attn_forward( cache_idx: int = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - - hidden_states, attn_weights, present_key_value = GaudiLlamaAttention.pre_attn_forward(self, + hidden_states, attn_weights, present_key_value = GaudiLlamaAttention.pre_attn_forward( + self, hidden_states, attention_mask, position_ids, @@ -678,11 +679,13 @@ def pre_attn_forward( flash_attention_causal_mask, flash_attention_fast_softmax, cache_idx, - **kwargs - ) + **kwargs, + ) hidden_states = reduce_from_tensor_model_parallel_region(hidden_states) return hidden_states, attn_weights, present_key_value + + class GaudiLlamaDecoderLayer(LlamaDecoderLayer): def __init__(self, config: LlamaConfig, layer_idx: int): super(LlamaDecoderLayer, self).__init__() @@ -864,7 +867,7 @@ def __init__(self, config: LlamaConfig): layer = self.distributed_strategy.distribute_layer(layer, layer_idx) layers.append(layer) self.layers = torch.nn.ModuleList(layers) - + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False @@ -1100,9 +1103,9 @@ class GaudiLlamaForCausalLM(LlamaForCausalLM): """ def __init__(self, config, distributed_strategy: DistributedStrategy = NoOpStrategy): - config.distributed_strategy = distributed_strategy + config.distributed_strategy = distributed_strategy super().__init__(config) - + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) @@ -1146,7 +1149,6 @@ def forward( global has_fused_rope has_fused_rope = False - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, diff --git a/tests/test_text_generation_example.py b/tests/test_text_generation_example.py index 3ab56b26f5..ddd5faf908 100644 --- a/tests/test_text_generation_example.py +++ b/tests/test_text_generation_example.py @@ -71,6 +71,9 @@ "torch_compile_distributed": [ ("meta-llama/Llama-2-7b-hf", 39.72973199515235), ], + "distributed_tp": [ + ("/mnt/weka/data/pytorch//llama2/Llama-2-7b-hf/", 1856.8140409694543), + ], } else: # Gaudi1 CI baselines @@ -101,6 +104,7 @@ ], "torch_compile": [], "torch_compile_distributed": [], + "distributed_tp": [], } @@ -116,6 +120,7 @@ def _test_text_generation( fp8: bool = False, max_input_tokens: int = 0, max_output_tokens: int = 100, + distributed_strategy: str = None, ): command = ["python3"] path_to_example_dir = Path(__file__).resolve().parent.parent / "examples" @@ -127,6 +132,11 @@ def _test_text_generation( "--use_deepspeed", f"--world_size {world_size}", ] + elif distributed_strategy == "tp": + command += [ + f"{path_to_example_dir / 'gaudi_spawn.py'}", + f"--world_size {world_size}", + ] command += [ f"{path_to_example_dir / 'text-generation' / 'run_generation.py'}", @@ -142,11 +152,13 @@ def _test_text_generation( if "falcon" in model_name.lower(): command += ["--use_flash_attention", "--flash_attention_causal_mask"] - if reuse_cache or torch_compile: + if (reuse_cache or torch_compile) and not distributed_strategy == "tp": command += ["--reuse_cache"] if torch_compile: command += ["--torch_compile"] + command += ["--use_flash_attention"] + command += ["--flash_attention_recompute"] env_variables["PT_ENABLE_INT64_SUPPORT"] = "1" env_variables["PT_HPU_LAZY_MODE"] = "0" else: @@ -188,6 +200,10 @@ def _test_text_generation( f"--max_input_tokens {max_input_tokens}", "--limit_hpu_graphs", ] + if distributed_strategy is not None: + command += [ + f"--distributed_strategy={distributed_strategy}", + ] with TemporaryDirectory() as tmp_dir: command.append(f"--output_dir {tmp_dir}") @@ -289,6 +305,21 @@ def test_text_generation_torch_compile_distributed(model_name: str, baseline: fl _test_text_generation(model_name, baseline, token, deepspeed=True, world_size=world_size, torch_compile=True) +@pytest.mark.parametrize("model_name, baseline", MODELS_TO_TEST["distributed_tp"]) +def test_text_generation_distributed_tp(model_name: str, baseline: float, token: str): + world_size = 8 + _test_text_generation( + model_name, + baseline, + token, + batch_size=64, + max_input_tokens=128, + world_size=world_size, + torch_compile=True, + distributed_strategy="tp", + ) + + class TextGenPipeline(TestCase): def test_text_generation_pipeline_script(self): path_to_script = (