diff --git a/examples/text-generation/README.md b/examples/text-generation/README.md index 440b18713c..fad94bdbcb 100755 --- a/examples/text-generation/README.md +++ b/examples/text-generation/README.md @@ -264,6 +264,37 @@ set the following environment variables before running the command: `PT_ENABLE_I You will also need to add `--torch_compile` in your command. +### Running with tensor-parallel strategy + +> [!NOTE] +> This strategy includes code from the [foundation-model-stack](https://github.com/foundation-model-stack/foundation-model-stack) repository, which is licensed under the Apache License 2.0. See the `LICENSE` file for more details. + +> [!WARNING] +> torch.compile with tensor parallel strategy is an experimental feature. It has not been validated for all models. + +To enable torch.compile with tensor parallel strategy, please set the following environment variables before running the +command: `PT_ENABLE_INT64_SUPPORT=1` and `PT_HPU_LAZY_MODE=0`. This will enable tensor parallel strategy without deepspeed. + +You will also need to add `--torch_compile` and `--parallel_strategy="tp"` in your command. + +Here is an example: +```bash +PT_ENABLE_INT64_SUPPORT=1 PT_HPU_LAZY_MODE=0 python ../gaudi_spawn.py --world_size 8 run_generation.py \ +--model_name_or_path meta-llama/Llama-2-70b-hf \ +--trim_logits \ +--use_kv_cache \ +--attn_softmax_bf16 \ +--bf16 \ +--bucket_internal \ +--bucket_size=128 \ +--use_flash_attention \ +--flash_attention_recompute \ +--batch_size 246 \ +--max_input_tokens 2048 \ +--max_new_tokens 2048 \ +--torch_compile \ +--parallel_strategy="tp" +``` ### Running with FP8 diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index 1a020555e5..c41664ebf3 100755 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -287,6 +287,14 @@ def setup_parser(parser): action="store_true", help="Whether to trust the execution of code from datasets/models defined on the Hub. This option should only be set to `True` for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine.", ) + parser.add_argument( + "--parallel_strategy", + type=str, + choices=["tp", "none"], # Add other strategies as needed + default="none", + help="Run multi card with the specified parallel strategy. Choices are 'tp' for Tensor Parallel Strategy or 'none'.", + ) + args = parser.parse_args() if args.torch_compile: diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index fa1946b914..e7c3bb1d46 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -248,6 +248,72 @@ def setup_model(args, model_dtype, model_kwargs, logger): return model, assistant_model +def setup_distributed_model_tp(args, model_dtype, model_kwargs, logger, cache_dir): + from typing import Any, MutableMapping + + from optimum.habana.distributed import serialization + from optimum.habana.distributed.strategy import TensorParallelStrategy + + logger.info("Multi-device run.") + + assert args.quant_config == "", "Fp8 is not enabled, unset QUANT_CONFIG" + assert args.assistant_model is None, "Assistant model must be None" + + from torch import distributed as dist + + if args.device == "hpu": + dist.init_process_group(backend="hccl") + else: + assert False, "Supports TP only on HPU" + + torch._C._distributed_c10d._register_process_group("default", dist.group.WORLD) + logger.info("Creating Model") + config = AutoConfig.from_pretrained(args.model_name_or_path, torch_dtype=model_dtype, **model_kwargs) + model_kwargs = {} + model_kwargs["parallel_strategy"] = TensorParallelStrategy() + model = AutoModelForCausalLM.from_config(config, torch_dtype=model_dtype, **model_kwargs) + + initial_device = torch.device("cpu") + source = "hf" + checkpoint_sharding = None + lazy_sd: MutableMapping[str, Any] = {} + logger.info("Loading Checkpoints") + lazy_sd = serialization.load_state_dict( + cache_dir, + source=source, + distributed_strategy=args.parallel_strategy, + checkpoint_sharding=None, + initial_device=initial_device, + rank=args.global_rank, + world_size=args.world_size, + ) + architecture = "llama" + if len(lazy_sd): + serialization.load_state_dict_into_model( + model, + lazy_sd, + architecture, + source, + args.parallel_strategy, + checkpoint_sharding, + initial_device, + args.local_rank, + args.world_size, + ) + + model = model.eval().to(args.device) + + if args.use_hpu_graphs: + from habana_frameworks.torch.hpu import wrap_in_hpu_graph + + model = wrap_in_hpu_graph(model) + + if args.torch_compile and model.config.model_type == "llama": + model = get_torch_compiled_model(model) + + return model, args.assistant_model + + def setup_distributed_model(args, model_dtype, model_kwargs, logger): import deepspeed @@ -500,7 +566,7 @@ def initialize_model(args, logger): setup_env(args) setup_device(args) set_seed(args.seed) - get_repo_root(args.model_name_or_path, local_rank=args.local_rank, token=args.token) + cache_dir = get_repo_root(args.model_name_or_path, local_rank=args.local_rank, token=args.token) if args.assistant_model is not None: get_repo_root(args.assistant_model, local_rank=args.local_rank, token=args.token) use_deepspeed = args.world_size > 0 @@ -522,6 +588,8 @@ def initialize_model(args, logger): setup_model(args, model_dtype, model_kwargs, logger) if not use_deepspeed else setup_distributed_model(args, model_dtype, model_kwargs, logger) + if not args.parallel_strategy == "tp" + else setup_distributed_model_tp(args, model_dtype, model_kwargs, logger, cache_dir) ) tokenizer, model, assistant_model = setup_tokenizer(args, model, assistant_model) generation_config = setup_generation_config(args, model, assistant_model, tokenizer) diff --git a/optimum/habana/distributed/__init__.py b/optimum/habana/distributed/__init__.py index 2dedd7333d..af269ee68c 100644 --- a/optimum/habana/distributed/__init__.py +++ b/optimum/habana/distributed/__init__.py @@ -1,2 +1,31 @@ +import os + +import torch + from .distributed_runner import DistributedRunner from .fast_ddp import all_reduce_gradients + + +def rank_and_world(group=None): + """ + Returns (rank, world_size) from the optionally-specified group, otherwise + from the default group, or if non-distributed just returns (0, 1) + """ + if torch.distributed.is_initialized() and group is None: + group = torch.distributed.GroupMember.WORLD + + if group is None: + world_size = 1 + rank = 0 + else: + world_size = group.size() + rank = group.rank() + + return rank, world_size + + +_LOCAL_RANK = int(os.getenv("LOCAL_RANK", 0)) + + +def local_rank(): + return _LOCAL_RANK diff --git a/optimum/habana/distributed/serialization.py b/optimum/habana/distributed/serialization.py new file mode 100644 index 0000000000..bf59fb2445 --- /dev/null +++ b/optimum/habana/distributed/serialization.py @@ -0,0 +1,475 @@ +# Copyright 2024 The Foundation Model Stack Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file has been modified from its original version. +# The original version can be found at https://github.com/foundation-model-stack/foundation-model-stack + +import collections +import os +from collections import ChainMap +from collections.abc import Iterable +from pathlib import Path +from typing import Any, Callable, List, Mapping, MutableMapping, Optional, Union + +import torch + +from .tp import TPModule + + +__adapters: MutableMapping[str, MutableMapping[str, Callable[[Mapping], Mapping]]] = {} + + +def register_adapter( + architecture: str, + source: str, + adapter: Callable[[Mapping], Mapping], +): + """ + Registers a state dict adapter to be available to the (de) serialization + API. + + Args: + architecture: The name of the model architecture, e.g. 'llama' + source: A label representing the format of the weights to be converted. + E.g. 'hf' + adapter: the class of the adapter. The class must accept one constructor + parameter, which will be a state dict (`OrderedDict`) + """ + sources: MutableMapping[str, Callable[[Mapping], Mapping]] = {} + if architecture in __adapters: + sources = __adapters[architecture] + + if source in sources: + raise KeyError(f"Variant {source} already registered for architecture {architecture}") + + sources[source] = adapter + __adapters[architecture] = sources + + +def list_sources(architecture: str): + """ + Lists available sources (attribute formats) of a model architecture. + E.g. `models.list_variants('llama')` -> ['meta', 'fms', 'hf'] + Args: + architecture: one of the registered architectures returned by + `models.list_models()`. + """ + if architecture not in __adapters: + return [] + return list(__adapters[architecture].keys()) + + +def _get_adapter(architecture: str, source: Optional[str]) -> Callable[[Mapping[str, Any]], Mapping[str, Any]]: + if source is None or architecture not in __adapters or source not in __adapters[architecture]: + # if no adapter is registered, assume the attributes are already in + # fms format. + # should we raise an error here instead? + return lambda x: x + else: + return __adapters[architecture][source] + + +def get_adapted(architecture: str, source: Optional[str], state_dict: Mapping[str, Any]) -> Mapping[str, Any]: + """ + Convert a state dict to FMS format, using an adapter specified by name. + + Args: + architecture: one of the architectures from `models.list_models()`. + E.g. llama. + source: A reference to an attribute format + state_dict: the model.state_dict() to be converted/adapted. + """ + # sometimes we only load onto rank 0 so may not have a state_dict here. + if not len(state_dict): + return state_dict + adapter = _get_adapter(architecture, source) + adapted = adapter(state_dict) + return adapted + + +def _get_safetensors_item(key, file: Path, device: torch.device) -> torch.Tensor: + from safetensors import safe_open # type: ignore[import-untyped] + + with torch.no_grad(): + with safe_open(file, framework="pt", device=str(device)) as model_weights: # type: ignore[attr-defined] + return model_weights.get_tensor(key) + + +class LazySafetensorsDict(collections.UserDict): + def set_lazy_tensor(self, key, file, device): + super().__setitem__(key, lambda: _get_safetensors_item(key, file, device)) + + def __getitem__(self, key): + lazy_tensor = super().__getitem__(key) + if callable(lazy_tensor): + lazy_tensor = lazy_tensor() + super().__setitem__(key, lazy_tensor) + return lazy_tensor + + +def load_state_dict( + model_path: Union[str, Path], + *, + source: Optional[str] = None, + distributed_strategy: Optional[str] = None, + checkpoint_sharding: Optional[str] = None, + initial_device: torch.device = torch.device("cpu"), + rank: int = 0, + world_size: int = 1, +) -> MutableMapping[str, Any]: + """ + Validates that the file(s) found at a checkpoint path are compatible with + the intended (possibly distributed) use-case, and returns a lazy loading + state dict if possible (some formats may not support that). + + If model_path is a directory, it'll try to load models based on the source + (e.g. .bin for HF, .pth for Meta), and, if no source is specified or hasn't + been registered, it'll try .safetensors, .pth, and .bin. + + Args: + model_path: the path to find the weights. If not set, return None. + source: If the weights in the state dict didn't come from an FMS model, + `source` specifies which conversion function might be needed. + See `serialization.list_sources(architecture)` + distributed_strategy: the kind of possibly-distributed model in which we + intend to load these weights. E.g. tp, fsdp, None. Used for + validation. + checkpoint_sharding: the sharding format of the checkpoint. + E.g. layer, tp, fsdp. + initial_device: where the state dict will be loaded if not lazy. + If meta, return empty dict. + """ + if model_path is None or initial_device.type == "meta": + return {} + if checkpoint_sharding == "fsdp" and distributed_strategy not in ["fsdp", "hsdp"]: + raise ValueError("FSDP checkpoints can only be loaded into an FSDP model") + if checkpoint_sharding == "tp" and distributed_strategy != "tp": + raise ValueError("TP checkpoints can only be loaded into a TP model") + + # Before creating the Path object, check if model_path has a glob pattern + if isinstance(model_path, str): + model_path, sep, glob_pattern = model_path.partition("*") + else: + sep = "" + glob_pattern = "" + glob_pattern = sep + glob_pattern + + model_path = Path(os.path.expanduser(model_path)) + + checkpoints = [] + + if model_path.is_dir(): + if glob_pattern != "": + glob_pattern_list = [glob_pattern] + elif source == "meta": + glob_pattern_list = ["*.pth", "*.safetensors"] + elif source == "hf": + glob_pattern_list = ["*.bin", "*.safetensors"] + else: + glob_pattern_list = ["*.safetensors", "*.pth", "*.bin"] + for glob_pattern_possibility in glob_pattern_list: + file_list = list(model_path.glob(glob_pattern_possibility)) + if len(file_list) > 0: + checkpoints = sorted(file_list) + break + + if model_path.is_file(): + checkpoints = [model_path] + + # Check if we found some files + assert len(checkpoints) > 0, f"Can't find the requested checkpoint data at {model_path}" + + if checkpoint_sharding is not None and checkpoint_sharding != "layer": + assert ( + world_size == len(checkpoints) + ), f"Loading a {checkpoint_sharding}-sharded checkpoint with len={len(checkpoints)} but world size is {world_size}" + + checkpoints = [checkpoints[rank]] + + # if there's only one checkpoint for fsdp/hsdp, load it only into rank zero + # and it will be distributed by the FSDP `sync_module_states` parameter + if checkpoint_sharding is None and distributed_strategy in {"hsdp", "fsdp"}: + if rank == 0: + checkpoints = [checkpoints[0]] + else: + return {} + + checkpoint_sds = [] + if checkpoints[0].suffix == ".safetensors": + for ckp in checkpoints: + checkpoint_sds.append( + _load_safetensors_state_dict( + ckp, + initial_device, + ) + ) + else: + with torch.no_grad(): + checkpoint_sds = [ + torch.load(str(ckpt_path), map_location=initial_device, mmap=True) for ckpt_path in checkpoints + ] + return ChainMap(*checkpoint_sds) + + +def _load_safetensors_state_dict( + checkpoint: Path, + device: torch.device, +): + sd = LazySafetensorsDict() + + from safetensors import safe_open + + with safe_open(checkpoint, framework="pt", device=str(device)) as model_weights: # type: ignore[attr-defined] + sd_keys = list(model_weights.keys()) + for key in sd_keys: + sd.set_lazy_tensor(key, checkpoint, device) + return sd + + +class FusableWeightsMissingError(Exception): + missing_weights: List[str] = [] + + def __init__(self, missing_weights): + self.missing_weights = missing_weights + super().__init__() + + +def load_state_dict_into_model( + model: torch.nn.Module, + state_dict: MutableMapping[str, Any], + architecture: str, + source: str, + distributed_strategy: Optional[str] = None, + checkpoint_sharding: Optional[str] = None, + initial_device: torch.device = torch.device("cpu"), + rank: int = 0, + world_size: int = 0, +) -> None: + """ + This function loads state_dict into model in the most efficient way possible, + and it removes all weights that have been used in model from state_dict + in order to conserve memory. + + Args: + model: The model where the weights are being loaded. + state_dict: The dictionary with all the weights. If it has been mmaped + (for torch.load) or it is an instance of LazySafetensorsDict, + the weights are loaded lazily from disk. + architecture: the model architecture, e.g. llama. See `models.list_models()`. + source: If the weights in the state dict didn't come from an FMS model, + `source` specifies which conversion function might be needed. + See `serialization.list_sources(architecture)` + distributed_strategy: the kind of possibly-distributed model in which we + intend to load these weights. E.g. tp, fsdp, None. Used for weight + sharding. + checkpoint_sharding: the sharding format of the checkpoint. + E.g. layer, tp, fsdp. Used for weight sharding. + initial_device: where the weights will be loaded from disk. + """ + + # 1. Get the adapter from checkpoint sd to fms sd + adapter = _get_adapter(architecture, source) + + # 2. Decide if model needs sharding and how (for now only TP) + needs_tp_sharding = checkpoint_sharding != "tp" and distributed_strategy == "tp" + + # 3. Iterate over the weights and load them into the model + used_keys = set() + sd_keys = list(state_dict.keys()) + with torch.no_grad(): + for key in sd_keys: + if key in used_keys: + continue + used_keys.add(key) + try: + partial_sd = {key: state_dict[key]} + if partial_sd[key].device != initial_device: + partial_sd[key] = partial_sd[key].to(device=initial_device) + fms_partial_sd = adapter(partial_sd) + except FusableWeightsMissingError as e: + for weight in e.missing_weights: + used_keys.add(weight) + partial_sd[weight] = state_dict[weight] + if partial_sd[weight].device != initial_device: + partial_sd[weight] = partial_sd[weight].to(device=initial_device) + fms_partial_sd = adapter(partial_sd) + _load_partial_state_dict(model, fms_partial_sd, needs_tp_sharding, rank, world_size) + for p_key in partial_sd.keys(): + if isinstance(state_dict, ChainMap): + for child_sd in state_dict.maps: + child_sd.pop(p_key, None) + else: + state_dict.pop(p_key) + del partial_sd + del fms_partial_sd + + +def _copy_colwise(param: torch.nn.Parameter, tensor_value, is_bias, rank, world_size): + """ + This function copies the correct shard of the weights for a colwise-TP'd module + according to the rank of the process and the world_size. + + Args + ==== + param: torch.nn.Parameter + Parameter that has had TP applied + tensor_value: torch.Tensor + tensor that needs sharding + rank: int + Rank of the current process + world_size: int + Total number of TP processes + """ + # Divide the weight matrix along the first dimension. + output_size_per_partition = param.shape[0] + if not is_bias: + tensor = tensor_value[ + (rank * output_size_per_partition) : ((rank + 1) * output_size_per_partition), + :, + ] + else: + tensor = tensor_value[(rank * output_size_per_partition) : ((rank + 1) * output_size_per_partition)] + param.copy_(tensor, non_blocking=True) + + +def _copy_rowwise(param: torch.nn.Parameter, tensor_value, is_bias, rank, world_size): + """ + This function copies the correct shard of the weights for a rowwise-TP'd module + according to the rank of the process and the world_size. + + Args + ==== + param: torch.nn.Parameter + Parameter that has had TP applied + tensor_value: torch.Tensor + tensor that needs sharding + rank: int + Rank of the current process + world_size: int + Total number of TP processes + """ + # Divide the weight matrix along the last dimension. + if not is_bias: + output_size_per_partition = param.shape[1] + tensor = tensor_value[ + :, + (rank * output_size_per_partition) : ((rank + 1) * output_size_per_partition), + ] + param.copy_(tensor, non_blocking=True) + else: + if rank == 0: + _copy_if_present(param, tensor_value) + else: + param.zero_() + + +def _copy_embedding(param: torch.nn.Parameter, tensor_value, rank, world_size): + """ + This function copies the correct shard of the weights for a TP'd embedding module + according to the rank of the process and the world_size. + + Args + ==== + param: torch.nn.Parameter + Parameter that has had TP applied + tensor_value: torch.Tensor + tensor that needs sharding + rank: int + Rank of the current process + world_size: int + Total number of TP processes + """ + # Divide the weight matrix along the last dimension. + output_size_per_partition = param.shape[1] + tensor = tensor_value[ + :, + (rank * output_size_per_partition) : ((rank + 1) * output_size_per_partition), + ] + param.copy_(tensor, non_blocking=True) + + +def _copy_if_present(parameter, tensor_value): + parameter.copy_(tensor_value, non_blocking=True) + + +def _load_partial_state_dict( + model: torch.nn.Module, + state_dict, + needs_tp_sharding: bool, + rank=0, + world_size=1, +): + unused_params = [] + for key, tensor_value in state_dict.items(): + target_module = model + # Find where to put the weight and decide whether it needs TP'ing + key_steps = key.split(".") + prefix = "" + key_step = 0 + tp_module = None + # Navigate the model tree to find the module where the parameter is + # located and whether there is a TPModule in the way in case the + # parameter requires sharding + while key_step < len(key_steps) - 1: + try: + target_module = getattr(target_module, key_steps[key_step]) + if key_step > 0: + prefix += "." + prefix += key_steps[key_step] + key_step += 1 + if isinstance(target_module, Iterable): + target_module = target_module[int(key_steps[key_step])] # type: ignore[index] + prefix += "." + key_steps[key_step] + key_step += 1 + if isinstance(target_module, TPModule): + tp_module = target_module + except AttributeError: + unused_params.append(key) + break + + # Check if target_module has the Parameter/buffer + try: + param = getattr(target_module, key_steps[-1]) + + # If TP sharding is not needed, copy the parameter + # into the model + if not needs_tp_sharding or tp_module is None: + _copy_if_present(param, tensor_value) + elif tp_module is not None: + # Handle TP sharding + if key_steps[-2] in tp_module.colwise_param_names(): + _copy_colwise( + param, + tensor_value, + key_steps[-1] == "bias", + rank, + world_size, + ) + if key_steps[-2] in tp_module.rowwise_param_names(): + _copy_rowwise( + param, + tensor_value, + key_steps[-1] == "bias", + rank, + world_size, + ) + if key_steps[-2] in tp_module.embedding_param_names(): + _copy_embedding( + param, + tensor_value, + rank, + world_size, + ) + except AttributeError: + unused_params.append(key) diff --git a/optimum/habana/distributed/strategy.py b/optimum/habana/distributed/strategy.py new file mode 100644 index 0000000000..91b3f00232 --- /dev/null +++ b/optimum/habana/distributed/strategy.py @@ -0,0 +1,134 @@ +# Copyright 2024 The Foundation Model Stack Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file has been modified from its original version. +# The original version can be found at https://github.com/foundation-model-stack/foundation-model-stack + +from abc import abstractmethod +from typing import List + +import torch +import torch.distributed +from torch import nn + + +class DistributedStrategy: + def __init__(self, from_meta=False): + self.from_meta = from_meta + + def distribute_module(self, module: nn.Module, final_layers: bool = False) -> nn.Module: + """ + Optionally a distributed strategy may distribute modules that are not + numbered layers + """ + return module + + @abstractmethod + def distribute_layer(self, block: nn.Module, layer: int) -> nn.Module: + """ + Distribute each layer as-appropriate + """ + pass + + +class NotDistributed(DistributedStrategy): + def __init__(self, from_meta=False): + super().__init__(from_meta) + + def distribute_module(self, module: nn.Module, final_layers: bool = False) -> nn.Module: + return module + + def distribute_layer(self, block: nn.Module, layer: int) -> nn.Module: + return block + + +NoOpStrategy = NotDistributed() + + +class DeviceMover(nn.Module): + def __init__(self, module: nn.Module, device): + super().__init__() + self.device = device + # make this wrapper module behave as if it was the wrapped module. + attr = module.__dict__ + attr["module"] = module.to(device) + attr["device"] = device + self.__dict__ = attr + + def forward(self, *args, **kwargs): + device = self.device + args = [arg.to(device) if isinstance(arg, torch.Tensor) else arg for arg in args] + kwargs = {k: (kwargs[k].to(device) if isinstance(kwargs[k], torch.Tensor) else kwargs[k]) for k in kwargs} + return self.module(*args, **kwargs) + + +class UniformModelParallelStrategy(DistributedStrategy): + def __init__(self, devices: List[int], num_layers: int, from_meta=False): + super().__init__(from_meta) + num_dev = len(devices) + layers_per_dev = num_layers // num_dev + remainder = num_layers - (layers_per_dev * num_dev) + self.layer_to_device = [0] * num_layers + layer_id = 0 + for dev_idx in range(len(devices)): + for i in range(layers_per_dev): + self.layer_to_device[layer_id] = devices[dev_idx] + layer_id = layer_id + 1 + if remainder > 0: + self.layer_to_device[layer_id] = devices[dev_idx] + layer_id = layer_id + 1 + remainder -= 1 + + def distribute_layer(self, block: nn.Module, layer: int) -> nn.Module: + device = self.layer_to_device[layer] + if self.from_meta: + block.to_empty(device=device) # type: ignore[arg-type] + wrapped = DeviceMover(block, device) + return wrapped + + def distribute_module(self, module: nn.Module, final_layers: bool = False) -> nn.Module: + if final_layers: + device = self.layer_to_device[len(self.layer_to_device) - 1] + else: + device = self.layer_to_device[0] + if self.from_meta: + return module.to_empty(device=device) # type: ignore[arg-type] + wrapped = DeviceMover(module, device) + return wrapped + + +class TensorParallelStrategy(DistributedStrategy): + def __init__(self, group=None, from_meta=False): + super().__init__(from_meta) + assert torch.distributed.is_initialized(), "must initialize a process group" + self.group = group if group is not None else torch.distributed.GroupMember.WORLD + + def distribute_module(self, module: nn.Module, final_layers: bool = False) -> nn.Module: + from optimum.habana.distributed import tp_wrapping + + return tp_wrapping.apply_tp(module, self.group) + + def distribute_layer(self, block: nn.Module, layer: int) -> nn.Module: + from optimum.habana.distributed import tp_wrapping + + return tp_wrapping.apply_tp(block, layer, self.group) + + def __getstate__(self): + state = self.__dict__.copy() + state["group"] = None # Remove ProcessGroup from state + return state + + def __setstate__(self, state): + self.__dict__.update(state) + self.group = None # Restore to default state or reinitialize diff --git a/optimum/habana/distributed/tensorparallel.py b/optimum/habana/distributed/tensorparallel.py new file mode 100644 index 0000000000..09a205e451 --- /dev/null +++ b/optimum/habana/distributed/tensorparallel.py @@ -0,0 +1,112 @@ +# Copyright 2024 The Foundation Model Stack Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file has been modified from its original version. +# The original version can be found at https://github.com/foundation-model-stack/foundation-model-stack + +import torch +import torch._inductor.ir as ir +import torch._inductor.lowering as lowering +import torch.distributed as dist +from torch import nn + + +def apply_colwise_tp(par_mod: nn.Linear, mod: nn.Linear, world_size, rank): + # Divide the weight matrix along the last dimension. + output_size_per_partition = mod.out_features // world_size + with torch.no_grad(): + par_mod.weight.copy_(torch.split(mod.weight, output_size_per_partition, dim=0)[rank]) + if par_mod.bias is not None: + par_mod.bias.copy_(torch.split(mod.bias, output_size_per_partition)[rank]) + + +def apply_rowwise_tp(par_mod: nn.Linear, mod: nn.Linear, world_size, rank): + # Divide the weight matrix along the last dimension. + output_size_per_partition = mod.in_features // world_size + with torch.no_grad(): + par_mod.weight.copy_(torch.split(mod.weight, output_size_per_partition, dim=1)[rank]) + if par_mod.bias is not None: + if rank == 0: + par_mod.bias.copy_(mod.bias) + else: + par_mod.bias.zero_() + + +def apply_embedding_tp(par_mod: nn.Embedding, mod: nn.Embedding, world_size, rank): + # Divide the weight matrix along the last dimension. + output_size_per_partition = mod.embedding_dim // world_size + with torch.no_grad(): + par_mod.weight.copy_(torch.split(mod.weight, output_size_per_partition, dim=1)[rank]) + + +## Fixes for PT 2.2 collectives until PT 2.3 is released + + +# Fix 1: https://github.com/pytorch/pytorch/issues/121311 +def get_volatile_reads_fixed(self): + inp = self.inputs[0] + if isinstance(inp, ir._CollectiveKernel): + # Out-of-place single-output + return [inp.inputs[0]] + elif isinstance(inp, ir.MultiOutput): + # Out-of-place multi-output + coll = inp.inputs[0] + if isinstance(coll, ir._CollectiveKernel): + _, idx = inp.indices[0] + return [coll.inputs[idx]] + return [] # e.g. regular FallbackKernel + else: + # In-place requires no additional deps handling for volatile + # reads since the inputs are mutated. + return [] + + +ir._WaitKernel.get_volatile_reads = get_volatile_reads_fixed + +# Fix 2: These are fixed already in nightlies and will be in 2.3 +for overload in torch.ops._c10d_functional.all_reduce.overloads(): + other_fn = getattr(torch.ops._c10d_functional.all_reduce, overload) + if other_fn in lowering.lowerings: + del lowering.lowerings[other_fn] + + +def _all_reduce(input_: torch.Tensor) -> torch.Tensor: + """All-reduce the input tensor across model parallel group.""" + world_size = dist.get_world_size() + + if world_size == 1: + return input_ + + # Starting PT 2.3, we can go back to funcol.all_reduce + return torch.ops._c10d_functional.wait_tensor(torch.ops._c10d_functional.all_reduce(input_, "sum", "default")) + + +class _ReduceFromModelParallelRegion(torch.autograd.Function): + """All-reduce the input from the model parallel region.""" + + @staticmethod + def symbolic(graph, input_): + return _all_reduce(input_) + + @staticmethod + def forward(ctx, input_): + return _all_reduce(input_) + + @staticmethod + def backward(ctx, grad_output): + return grad_output + + +def reduce_from_tensor_model_parallel_region(input_): + return _ReduceFromModelParallelRegion.apply(input_) diff --git a/optimum/habana/distributed/tp.py b/optimum/habana/distributed/tp.py new file mode 100644 index 0000000000..c4f156fa61 --- /dev/null +++ b/optimum/habana/distributed/tp.py @@ -0,0 +1,101 @@ +# Copyright 2024 The Foundation Model Stack Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file has been modified from its original version. +# The original version can be found at https://github.com/foundation-model-stack/foundation-model-stack + +import itertools +from abc import ABCMeta, abstractmethod +from typing import List + +import torch +import torch.nn as nn +from torch.distributed.distributed_c10d import ProcessGroup + +from .tensorparallel import ( + apply_colwise_tp, + apply_embedding_tp, + apply_rowwise_tp, +) + + +class TPModule(nn.Module, metaclass=ABCMeta): + """ + This is an abstract class that any nn.Module can implement to enable + Tensor Parallel. On top of inheriting from this class, the TP module + will have to implement list_colwise_weights, list_rowwise_weights, + list_embedding_weights, and import_module for their relevant weights. + Finally, the module must call setup_tp at the end of their __init__ + function. See examples in attention.py, feedforward.py and embedding.py + + """ + + rank: int + world_size: int + + def setup_tp(self, rank: int, world_size: int) -> None: + self.rank = rank + self.world_size = world_size + + def colwise_param_names(self) -> List[str]: + return [] + + def rowwise_param_names(self) -> List[str]: + return [] + + def embedding_param_names(self) -> List[str]: + return [] + + @staticmethod + @abstractmethod + def import_module(module, group: ProcessGroup): + pass + + def import_weights(self, module: nn.Module): + for weight in self.colwise_param_names(): + apply_colwise_tp( + getattr(self, weight), + getattr(module, weight), + self.world_size, + self.rank, + ) + for weight in self.rowwise_param_names(): + apply_rowwise_tp( + getattr(self, weight), + getattr(module, weight), + self.world_size, + self.rank, + ) + for weight in self.embedding_param_names(): + apply_embedding_tp( + getattr(self, weight), + getattr(module, weight), + self.world_size, + self.rank, + ) + tp_sharded_modules = list( + itertools.chain( + self.colwise_param_names(), + self.rowwise_param_names(), + self.embedding_param_names(), + ) + ) + with torch.no_grad(): + for mod_name, module in self.named_children(): + if mod_name not in tp_sharded_modules: + for param_name, param in module.named_parameters(recurse=False): + param.copy_( + getattr(getattr(module, mod_name), param_name), + non_blocking=True, + ) diff --git a/optimum/habana/distributed/tp_wrapping.py b/optimum/habana/distributed/tp_wrapping.py new file mode 100644 index 0000000000..761fa7bff4 --- /dev/null +++ b/optimum/habana/distributed/tp_wrapping.py @@ -0,0 +1,48 @@ +# Copyright 2024 The Foundation Model Stack Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file has been modified from its original version. +# The original version can be found at https://github.com/foundation-model-stack/foundation-model-stack + +from torch import nn +from torch.distributed.distributed_c10d import ProcessGroup + +from ..transformers.models.llama.modeling_llama import ( + GaudiLlamaAttention, + GaudiLlamaMLP, + TPGaudiLlamaAttention, + TPGaudiLlamaMLP, +) + + +def _tp_wrapped(module: nn.Module, layer: int, group: ProcessGroup): + if hasattr(module, "to_tp"): + return module.to_tp(group) + elif isinstance(module, GaudiLlamaAttention): + return TPGaudiLlamaAttention.import_module(module, layer, group) + elif isinstance(module, GaudiLlamaMLP): + return TPGaudiLlamaMLP.import_module(module, group) + else: + return module + + +def apply_tp(model: nn.Module, layer_idx: int, group: ProcessGroup): + wrapped = _tp_wrapped(model, layer_idx, group) + if wrapped is not model: + return wrapped + + for name, layer in model.named_children(): + tp_layer = apply_tp(layer, layer_idx, group) + setattr(model, name, tp_layer) + return model diff --git a/optimum/habana/transformers/models/llama/configuration_llama.py b/optimum/habana/transformers/models/llama/configuration_llama.py index dcba1c0738..12ad78e29a 100644 --- a/optimum/habana/transformers/models/llama/configuration_llama.py +++ b/optimum/habana/transformers/models/llama/configuration_llama.py @@ -27,6 +27,7 @@ def __init__( attention_dropout=0.0, mlp_bias=False, fused_qkv=False, + parallel_strategy=None, **kwargs, ): super().__init__( @@ -55,3 +56,4 @@ def __init__( self.mlp_bias = mlp_bias self.fused_qkv = fused_qkv + self.parallel_strategy = parallel_strategy diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 1cbd714df5..4630678a97 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -1,3 +1,4 @@ +import copy import math import os import warnings @@ -5,6 +6,7 @@ import torch import torch.nn.functional as F +from torch.distributed.distributed_c10d import ProcessGroup from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache, StaticCache from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast @@ -20,6 +22,12 @@ logger, ) +from .... import distributed +from ....distributed.strategy import DistributedStrategy, NoOpStrategy +from ....distributed.tensorparallel import ( + reduce_from_tensor_model_parallel_region, +) +from ....distributed.tp import TPModule from ...modeling_attn_mask_utils import ( _gaudi_prepare_4d_causal_attention_mask, ) @@ -189,6 +197,40 @@ def post_mlp_forward(self, x): return x +class TPGaudiLlamaMLP(GaudiLlamaMLP, TPModule): + def __init__( + self, + config, + group: Optional[ProcessGroup] = None, + ): + assert torch.distributed.is_initialized() + rank, world_size = distributed.rank_and_world(group) + hidden_dim = int(config.hidden_grow_factor * config.hidden_size) + assert hidden_dim % world_size == 0, "Hidden dim must be divisible by world size" + + self.config = copy.deepcopy(config) + self.config.intermediate_size = int((config.hidden_grow_factor / world_size) * config.hidden_size) + GaudiLlamaMLP.__init__(self, self.config) + self.setup_tp(rank, world_size) + + def colwise_param_names(self) -> List[str]: + return ["up_proj", "gate_proj"] + + def rowwise_param_names(self) -> List[str]: + return ["down_proj"] + + @staticmethod + def import_module(glu: GaudiLlamaMLP, group: ProcessGroup) -> "TPGaudiLlamaMLP": + config = copy.deepcopy(glu.config) + config.hidden_grow_factor = glu.config.intermediate_size / glu.config.hidden_size + tp_glu = TPGaudiLlamaMLP(config=config, group=group) + return tp_glu + + def pre_mlp_forward(self, x): + out_par = GaudiLlamaMLP.pre_mlp_forward(self, x) + return reduce_from_tensor_model_parallel_region(out_par) + + def gaudi_llama_repeat_kv( query_states: torch.Tensor, key_states: torch.Tensor, @@ -545,6 +587,105 @@ def post_attn_forward(self, attn_output): return attn_output +class TPGaudiLlamaAttention(GaudiLlamaAttention, TPModule): + def __init__( + self, + config: LlamaConfig, + layer_idx: Optional[int] = None, + group: Optional[ProcessGroup] = None, + ): + super().__init__(config, layer_idx) + + assert torch.distributed.is_initialized() + rank, world_size = distributed.rank_and_world(group) + assert config.num_attention_heads % world_size == 0, "The number of heads must be divisible by world size" + self.config = copy.deepcopy(config) + + self.pre_tp_kvheads = config.num_key_value_heads + GaudiLlamaAttention.__init__(self, self.config, layer_idx) + self.config.num_attention_heads = self.config.num_attention_heads // world_size + self.config.num_key_value_heads = ( + (self.config.num_key_value_heads // world_size) + if self.config.num_key_value_heads > 1 + else self.config.num_key_value_heads + ) + self.head_dim = config.hidden_size // config.num_attention_heads + self.hidden_size = self.config.hidden_size // world_size + self.num_heads = self.config.num_attention_heads + + self.q_proj = torch.nn.Linear( + config.hidden_size, self.config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = torch.nn.Linear( + config.hidden_size, self.config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = torch.nn.Linear( + config.hidden_size, self.config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = torch.nn.Linear( + self.config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.norm_factor = 1.0 / math.sqrt(self.head_dim) + self.setup_tp(rank, world_size) + + def colwise_param_names(self) -> List[str]: + colwise_weights = ["q_proj"] + if self.pre_tp_kvheads != 1: + colwise_weights.append("k_proj") + colwise_weights.append("v_proj") + return colwise_weights + + def rowwise_param_names(self) -> List[str]: + return ["o_proj"] + + @staticmethod + def import_module(mha: GaudiLlamaAttention, layer_idx, group: ProcessGroup) -> "TPGaudiLlamaAttention": + tp_mha = TPGaudiLlamaAttention(config=mha.config, layer_idx=layer_idx, group=group) + return tp_mha + + def pre_attn_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + token_idx: Optional[torch.Tensor] = None, + attn_softmax_bf16: Optional[bool] = False, + reuse_cache: Optional[bool] = False, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, + flash_attention_fast_softmax: Optional[bool] = False, + cache_idx: int = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + hidden_states, attn_weights, present_key_value = GaudiLlamaAttention.pre_attn_forward( + self, + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + use_cache, + cache_position, + token_idx, + attn_softmax_bf16, + reuse_cache, + use_flash_attention, + flash_attention_recompute, + flash_attention_causal_mask, + flash_attention_fast_softmax, + cache_idx, + **kwargs, + ) + + hidden_states = reduce_from_tensor_model_parallel_region(hidden_states) + return hidden_states, attn_weights, present_key_value + + class GaudiLlamaDecoderLayer(LlamaDecoderLayer): def __init__(self, config: LlamaConfig, layer_idx: int): super(LlamaDecoderLayer, self).__init__() @@ -716,11 +857,17 @@ def __init__(self, config: LlamaConfig): super(LlamaModel, self).__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - self.embed_tokens = torch.nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = torch.nn.ModuleList( - [GaudiLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) + layers = [] + for layer_idx in range(config.num_hidden_layers): + layer = GaudiLlamaDecoderLayer(config, layer_idx) + if config.parallel_strategy is not None: + layer = config.parallel_strategy.distribute_layer(layer, layer_idx) + layers.append(layer) + self.layers = torch.nn.ModuleList(layers) + # parallel_strategy is not JSON serializable + config.parallel_strategy = None + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False @@ -955,6 +1102,10 @@ class GaudiLlamaForCausalLM(LlamaForCausalLM): - add new args reuse_cache """ + def __init__(self, config, parallel_strategy: DistributedStrategy = NoOpStrategy): + config.parallel_strategy = parallel_strategy + super().__init__(config) + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) diff --git a/tests/test_text_generation_example.py b/tests/test_text_generation_example.py index 4e116242f5..b32288ffac 100644 --- a/tests/test_text_generation_example.py +++ b/tests/test_text_generation_example.py @@ -71,6 +71,9 @@ "torch_compile_distributed": [ ("meta-llama/Llama-2-7b-hf", 39.72973199515235), ], + "distributed_tp": [ + ("meta-llama/Llama-2-7b-hf", 1345.2369318328463), + ], } else: # Gaudi1 CI baselines @@ -101,6 +104,7 @@ ], "torch_compile": [], "torch_compile_distributed": [], + "distributed_tp": [], } @@ -116,6 +120,7 @@ def _test_text_generation( fp8: bool = False, max_input_tokens: int = 0, max_output_tokens: int = 100, + parallel_strategy: str = None, ): command = ["python3"] path_to_example_dir = Path(__file__).resolve().parent.parent / "examples" @@ -127,6 +132,11 @@ def _test_text_generation( "--use_deepspeed", f"--world_size {world_size}", ] + elif parallel_strategy == "tp": + command += [ + f"{path_to_example_dir / 'gaudi_spawn.py'}", + f"--world_size {world_size}", + ] command += [ f"{path_to_example_dir / 'text-generation' / 'run_generation.py'}", @@ -148,11 +158,14 @@ def _test_text_generation( if "starcoder2" in model_name.lower(): command += ["--flash_attention_recompute"] - if reuse_cache or torch_compile: + if (reuse_cache or torch_compile) and not parallel_strategy == "tp": command += ["--reuse_cache"] if torch_compile: command += ["--torch_compile"] + if parallel_strategy == "tp": + command += ["--use_flash_attention"] + command += ["--flash_attention_recompute"] env_variables["PT_ENABLE_INT64_SUPPORT"] = "1" env_variables["PT_HPU_LAZY_MODE"] = "0" else: @@ -194,6 +207,10 @@ def _test_text_generation( f"--max_input_tokens {max_input_tokens}", "--limit_hpu_graphs", ] + if parallel_strategy is not None: + command += [ + f"--parallel_strategy={parallel_strategy}", + ] with TemporaryDirectory() as tmp_dir: command.append(f"--output_dir {tmp_dir}") @@ -294,6 +311,21 @@ def test_text_generation_torch_compile_distributed(model_name: str, baseline: fl _test_text_generation(model_name, baseline, token, deepspeed=True, world_size=world_size, torch_compile=True) +@pytest.mark.parametrize("model_name, baseline", MODELS_TO_TEST["distributed_tp"]) +def test_text_generation_distributed_tp(model_name: str, baseline: float, token: str): + world_size = 8 + _test_text_generation( + model_name, + baseline, + token, + batch_size=64, + max_input_tokens=128, + world_size=world_size, + torch_compile=True, + parallel_strategy="tp", + ) + + class TextGenPipeline(TestCase): def test_text_generation_pipeline_script(self): path_to_script = (