From 541ae26396646efe1d5e01a9a8c61387e2aa5248 Mon Sep 17 00:00:00 2001 From: Kalyan Kumar Date: Mon, 15 Jul 2024 21:34:53 +0530 Subject: [PATCH] Tensor parallel distributed strategy without using deepspeed (#280) * TP reference - ibm foundation-model-stack * Code cleanup -removed unused code --------- Co-authored-by: Kalyan --- examples/text-generation/run_generation.py | 13 + examples/text-generation/utils.py | 99 +++- optimum/habana/distributed/__init__.py | 26 + optimum/habana/distributed/serialization.py | 489 ++++++++++++++++++ optimum/habana/distributed/strategy.py | 108 ++++ optimum/habana/distributed/tensorparallel.py | 103 ++++ optimum/habana/distributed/tp.py | 84 +++ optimum/habana/distributed/tp_wrapping.py | 33 ++ .../models/llama/modeling_llama.py | 154 +++++- 9 files changed, 1105 insertions(+), 4 deletions(-) create mode 100644 optimum/habana/distributed/serialization.py create mode 100644 optimum/habana/distributed/strategy.py create mode 100644 optimum/habana/distributed/tensorparallel.py create mode 100644 optimum/habana/distributed/tp.py create mode 100644 optimum/habana/distributed/tp_wrapping.py diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index eb133b8ba7..610ecf1a8b 100755 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -323,6 +323,19 @@ def __call__(self, parser, namespace, values, option_string=None): action="store_true", help="Whether or not to allow for custom models defined on the Hub in their own modeling files.", ) + parser.add_argument( + "--run_partial_dataset", + action="store_true", + help="Run the inference with dataset for specified --n_iterations(default:5)", + ) + parser.add_argument( + "--distributed_strategy", + type=str, + choices=["tp", "none"], # Add other strategies as needed + default="none", + help="Run multi card with the specified distributed strategy. Choices are 'tp' for Tensor Parallel Strategy or 'none'.", + ) + args = parser.parse_args() if args.torch_compile: diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index cd3d7068b3..67afcc7015 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -285,6 +285,102 @@ def setup_model(args, model_dtype, model_kwargs, logger): # assistant_model = get_torch_compiled_model(assistant_model) return model, assistant_model +def setup_distributed_model_tp(args, model_dtype, model_kwargs, logger): + + from optimum.habana.distributed import serialization + from typing import Any, MutableMapping + + from optimum.habana.distributed import tp_wrapping + from optimum.habana.distributed.strategy import DistributedStrategy + from torch import nn + + class TensorParallelStrategy(DistributedStrategy): + def __init__(self, group=None, from_meta=False): + super().__init__(from_meta) + assert torch.distributed.is_initialized(), "must initialize a process group" + self.group = group if group is not None else torch.distributed.GroupMember.WORLD + + def distribute_module( + self, module: nn.Module, final_layers: bool = False + ) -> nn.Module: + return tp_wrapping.apply_tp(module, self.group) + + def distribute_layer(self, block: nn.Module, layer: int) -> nn.Module: + return tp_wrapping.apply_tp(block, layer, self.group) + + def __getstate__(self): + state = self.__dict__.copy() + state['group'] = None # Remove ProcessGroup from state + return state + + def __setstate__(self, state): + self.__dict__.update(state) + self.group = None # Restore to default state or reinitialize + + logger.info("Multi-device run.") + + assert args.assistant_model is None, "Assistant model must be None" + + from torch import distributed as dist + if args.device == 'hpu': + import habana_frameworks.torch.distributed.hccl + dist.init_process_group(backend='hccl') + else: + dist.init_process_group() + + torch._C._distributed_c10d._register_process_group("default", dist.group.WORLD) + logger.info("Creating Model") + config = AutoConfig.from_pretrained(args.model_name_or_path,torch_dtype=model_dtype, **model_kwargs) + model_kwargs={} + model_kwargs["distributed_strategy"] = TensorParallelStrategy() + model = AutoModelForCausalLM.from_config(config, torch_dtype=model_dtype, **model_kwargs) + + initial_device = torch.device("cpu") + source="hf" + checkpoint_sharding=None + lazy_sd: MutableMapping[str, Any] = {} + logger.info("Loading Checkpoints") + lazy_sd = serialization.load_state_dict( + args.model_name_or_path, + source=source, + distributed_strategy=args.distributed_strategy, + checkpoint_sharding=None, + initial_device=initial_device, + rank=args.global_rank, + world_size=args.world_size, + ) + architecture="llama" + if len(lazy_sd): + serialization.load_state_dict_into_model( + model, + lazy_sd, + architecture, + source, + args.distributed_strategy, + checkpoint_sharding, + initial_device, + args.local_rank, + args.world_size, + ) + + if args.quant_config: + model = setup_quantization(model, args) + + model = model.eval().to(args.device) + + if args.use_hpu_graphs: + from habana_frameworks.torch.hpu import wrap_in_hpu_graph + + if check_habana_frameworks_version("1.13.0") and model.config.model_type == "falcon": + model = wrap_in_hpu_graph(model, hash_with_views=False) + else: + model = wrap_in_hpu_graph(model) + + if args.torch_compile and model.config.model_type == "llama": + model = get_torch_compiled_model(model) + + return model, args.assistant_model + def setup_distributed_model(args, model_dtype, model_kwargs, logger): import deepspeed @@ -548,7 +644,8 @@ def initialize_model(args, logger): model, assistant_model = ( setup_model(args, model_dtype, model_kwargs, logger) if not use_deepspeed - else setup_distributed_model(args, model_dtype, model_kwargs, logger) + else setup_distributed_model(args, model_dtype, model_kwargs, logger) if not args.distributed_strategy == "tp" + else setup_distributed_model_tp(args, model_dtype, model_kwargs, logger) ) tokenizer, model, assistant_model = setup_tokenizer(args, model, assistant_model) generation_config = setup_generation_config(args, model, assistant_model, tokenizer) diff --git a/optimum/habana/distributed/__init__.py b/optimum/habana/distributed/__init__.py index 2dedd7333d..12edd6620e 100644 --- a/optimum/habana/distributed/__init__.py +++ b/optimum/habana/distributed/__init__.py @@ -1,2 +1,28 @@ from .distributed_runner import DistributedRunner from .fast_ddp import all_reduce_gradients +import os +import torch + +def rank_and_world(group=None): + """ + Returns (rank, world_size) from the optionally-specified group, otherwise + from the default group, or if non-distributed just returns (0, 1) + """ + if torch.distributed.is_initialized() and group is None: + group = torch.distributed.GroupMember.WORLD + + if group is None: + world_size = 1 + rank = 0 + else: + world_size = group.size() + rank = group.rank() + + return rank, world_size + + +_LOCAL_RANK = int(os.getenv("LOCAL_RANK", 0)) + + +def local_rank(): + return _LOCAL_RANK diff --git a/optimum/habana/distributed/serialization.py b/optimum/habana/distributed/serialization.py new file mode 100644 index 0000000000..c543ab20bd --- /dev/null +++ b/optimum/habana/distributed/serialization.py @@ -0,0 +1,489 @@ +import collections +import os +from collections import ChainMap +from collections.abc import Iterable +from pathlib import Path +from typing import Any, Callable, List, Mapping, MutableMapping, Optional, Union + +import torch + +from optimum.habana.distributed.tp import TPModule + + +__adapters: MutableMapping[str, MutableMapping[str, Callable[[Mapping], Mapping]]] = {} + + +def register_adapter( + architecture: str, + source: str, + adapter: Callable[[Mapping], Mapping], +): + """ + Registers a state dict adapter to be available to the (de) serialization + API. + + Args: + architecture: The name of the model architecture, e.g. 'llama' + source: A label representing the format of the weights to be converted. + E.g. 'hf' + adapter: the class of the adapter. The class must accept one constructor + parameter, which will be a state dict (`OrderedDict`) + """ + sources: MutableMapping[str, Callable[[Mapping], Mapping]] = {} + if architecture in __adapters: + sources = __adapters[architecture] + + if source in sources: + raise KeyError( + f"Variant {source} already registered for architecture {architecture}" + ) + + sources[source] = adapter + __adapters[architecture] = sources + + +def list_sources(architecture: str): + """ + Lists available sources (attribute formats) of a model architecture. + E.g. `models.list_variants('llama')` -> ['meta', 'fms', 'hf'] + Args: + architecture: one of the registered architectures returned by + `models.list_models()`. + """ + if architecture not in __adapters: + return [] + return list(__adapters[architecture].keys()) + + +def _get_adapter( + architecture: str, source: Optional[str] +) -> Callable[[Mapping[str, Any]], Mapping[str, Any]]: + if ( + source is None + or architecture not in __adapters + or source not in __adapters[architecture] + ): + # if no adapter is registered, assume the attributes are already in + # fms format. + # should we raise an error here instead? + return lambda x: x + else: + return __adapters[architecture][source] + + +def get_adapted( + architecture: str, source: Optional[str], state_dict: Mapping[str, Any] +) -> Mapping[str, Any]: + """ + Convert a state dict to FMS format, using an adapter specified by name. + + Args: + architecture: one of the architectures from `models.list_models()`. + E.g. llama. + source: A reference to an attribute format + state_dict: the model.state_dict() to be converted/adapted. + """ + # sometimes we only load onto rank 0 so may not have a state_dict here. + if not len(state_dict): + return state_dict + adapter = _get_adapter(architecture, source) + adapted = adapter(state_dict) + return adapted + + +# `models` imports each model class, causing models and adapters to be registered. +# down here to avoid circular dependencies. +# from fms import models + + +def _get_safetensors_item(key, file: Path, device: torch.device) -> torch.Tensor: + from safetensors import safe_open # type: ignore[import-untyped] + + with torch.no_grad(): + with safe_open( + file, framework="pt", device=str(device) + ) as model_weights: # type: ignore[attr-defined] + return model_weights.get_tensor(key) + + +class LazySafetensorsDict(collections.UserDict): + def set_lazy_tensor(self, key, file, device): + super().__setitem__(key, lambda: _get_safetensors_item(key, file, device)) + + def __getitem__(self, key): + lazy_tensor = super().__getitem__(key) + if callable(lazy_tensor): + lazy_tensor = lazy_tensor() + super().__setitem__(key, lazy_tensor) + return lazy_tensor + + +def load_state_dict( + model_path: Union[str, Path], + *, + source: Optional[str] = None, + distributed_strategy: Optional[str] = None, + checkpoint_sharding: Optional[str] = None, + initial_device: torch.device = torch.device("cpu"), + rank: int = 0, + world_size: int = 1, +) -> MutableMapping[str, Any]: + """ + Validates that the file(s) found at a checkpoint path are compatible with + the intended (possibly distributed) use-case, and returns a lazy loading + state dict if possible (some formats may not support that). + + If model_path is a directory, it'll try to load models based on the source + (e.g. .bin for HF, .pth for Meta), and, if no source is specified or hasn't + been registered, it'll try .safetensors, .pth, and .bin. + + Args: + model_path: the path to find the weights. If not set, return None. + source: If the weights in the state dict didn't come from an FMS model, + `source` specifies which conversion function might be needed. + See `serialization.list_sources(architecture)` + distributed_strategy: the kind of possibly-distributed model in which we + intend to load these weights. E.g. tp, fsdp, None. Used for + validation. + checkpoint_sharding: the sharding format of the checkpoint. + E.g. layer, tp, fsdp. + initial_device: where the state dict will be loaded if not lazy. + If meta, return empty dict. + """ + if model_path is None or initial_device.type == "meta": + return {} + if checkpoint_sharding == "fsdp" and distributed_strategy not in ["fsdp", "hsdp"]: + raise ValueError(f"FSDP checkpoints can only be loaded into an FSDP model") + if checkpoint_sharding == "tp" and distributed_strategy != "tp": + raise ValueError("TP checkpoints can only be loaded into a TP model") + + # Before creating the Path object, check if model_path has a glob pattern + if isinstance(model_path, str): + model_path, sep, glob_pattern = model_path.partition("*") + else: + sep = "" + glob_pattern = "" + glob_pattern = sep + glob_pattern + + model_path = Path(os.path.expanduser(model_path)) + + checkpoints = [] + + if model_path.is_dir(): + if glob_pattern != "": + glob_pattern_list = [glob_pattern] + elif source == "meta": + glob_pattern_list = ["*.pth", "*.safetensors"] + elif source == "hf": + glob_pattern_list = ["*.bin", "*.safetensors"] + else: + glob_pattern_list = ["*.safetensors", "*.pth", "*.bin"] + for glob_pattern_possibility in glob_pattern_list: + file_list = list(model_path.glob(glob_pattern_possibility)) + if len(file_list) > 0: + checkpoints = sorted(file_list) + break + + if model_path.is_file(): + checkpoints = [model_path] + + # Check if we found some files + assert ( + len(checkpoints) > 0 + ), f"Can't find the requested checkpoint data at {model_path}" + + if checkpoint_sharding is not None and checkpoint_sharding != "layer": + assert world_size == len( + checkpoints + ), f"Loading a {checkpoint_sharding}-sharded checkpoint with len={len(checkpoints)} but world size is {world_size}" + + checkpoints = [checkpoints[rank]] + + # if there's only one checkpoint for fsdp/hsdp, load it only into rank zero + # and it will be distributed by the FSDP `sync_module_states` parameter + if checkpoint_sharding is None and distributed_strategy in {"hsdp", "fsdp"}: + if rank == 0: + checkpoints = [checkpoints[0]] + else: + return {} + + checkpoint_sds = [] + if checkpoints[0].suffix == ".safetensors": + for ckp in checkpoints: + checkpoint_sds.append( + _load_safetensors_state_dict( + ckp, + initial_device, + ) + ) + else: + with torch.no_grad(): + checkpoint_sds = [ + torch.load(str(ckpt_path), map_location=initial_device, mmap=True) for ckpt_path in checkpoints + ] + return ChainMap(*checkpoint_sds) + + +def _load_safetensors_state_dict( + checkpoint: Path, + device: torch.device, +): + sd = LazySafetensorsDict() + + from safetensors import safe_open + + with safe_open(checkpoint, framework="pt", device=str(device)) as model_weights: # type: ignore[attr-defined] + sd_keys = list(model_weights.keys()) + for key in sd_keys: + sd.set_lazy_tensor(key, checkpoint, device) + return sd + + +class FusableWeightsMissingError(Exception): + missing_weights: List[str] = [] + + def __init__(self, missing_weights): + self.missing_weights = missing_weights + super().__init__() + + +def load_state_dict_into_model( + model: torch.nn.Module, + state_dict: MutableMapping[str, Any], + architecture: str, + source: str, + distributed_strategy: Optional[str] = None, + checkpoint_sharding: Optional[str] = None, + initial_device: torch.device = torch.device("cpu"), + rank: int = 0, + world_size: int = 0, +) -> None: + """ + This function loads state_dict into model in the most efficient way possible, + and it removes all weights that have been used in model from state_dict + in order to conserve memory. + + Args: + model: The model where the weights are being loaded. + state_dict: The dictionary with all the weights. If it has been mmaped + (for torch.load) or it is an instance of LazySafetensorsDict, + the weights are loaded lazily from disk. + architecture: the model architecture, e.g. llama. See `models.list_models()`. + source: If the weights in the state dict didn't come from an FMS model, + `source` specifies which conversion function might be needed. + See `serialization.list_sources(architecture)` + distributed_strategy: the kind of possibly-distributed model in which we + intend to load these weights. E.g. tp, fsdp, None. Used for weight + sharding. + checkpoint_sharding: the sharding format of the checkpoint. + E.g. layer, tp, fsdp. Used for weight sharding. + initial_device: where the weights will be loaded from disk. + """ + + # 1. Get the adapter from checkpoint sd to fms sd + adapter = _get_adapter(architecture, source) + + # 2. Decide if model needs sharding and how (for now only TP) + needs_tp_sharding = checkpoint_sharding != "tp" and distributed_strategy == "tp" + + # 3. Iterate over the weights and load them into the model + used_keys = set() + sd_keys = list(state_dict.keys()) + with torch.no_grad(): + for key in sd_keys: + if key in used_keys: + continue + used_keys.add(key) + try: + partial_sd = {key: state_dict[key]} + if partial_sd[key].device != initial_device: + partial_sd[key] = partial_sd[key].to(device=initial_device) + fms_partial_sd = adapter(partial_sd) + except FusableWeightsMissingError as e: + for weight in e.missing_weights: + used_keys.add(weight) + partial_sd[weight] = state_dict[weight] + if partial_sd[weight].device != initial_device: + partial_sd[weight] = partial_sd[weight].to( + device=initial_device + ) + fms_partial_sd = adapter(partial_sd) + _load_partial_state_dict( + model, fms_partial_sd, needs_tp_sharding, rank, world_size + ) + for p_key in partial_sd.keys(): + if isinstance(state_dict, ChainMap): + for child_sd in state_dict.maps: + child_sd.pop(p_key, None) + else: + state_dict.pop(p_key) + del partial_sd + del fms_partial_sd + + +def _copy_colwise(param: torch.nn.Parameter, tensor_value, is_bias, rank, world_size): + """ + This function copies the correct shard of the weights for a colwise-TP'd module + according to the rank of the process and the world_size. + + Args + ==== + param: torch.nn.Parameter + Parameter that has had TP applied + tensor_value: torch.Tensor + tensor that needs sharding + rank: int + Rank of the current process + world_size: int + Total number of TP processes + """ + # Divide the weight matrix along the first dimension. + output_size_per_partition = param.shape[0] + if not is_bias: + tensor = tensor_value[ + (rank * output_size_per_partition) : ( + (rank + 1) * output_size_per_partition + ), + :, + ] + else: + tensor = tensor_value[ + (rank * output_size_per_partition) : ( + (rank + 1) * output_size_per_partition + ) + ] + param.copy_(tensor, non_blocking=True) + + +def _copy_rowwise(param: torch.nn.Parameter, tensor_value, is_bias, rank, world_size): + """ + This function copies the correct shard of the weights for a rowwise-TP'd module + according to the rank of the process and the world_size. + + Args + ==== + param: torch.nn.Parameter + Parameter that has had TP applied + tensor_value: torch.Tensor + tensor that needs sharding + rank: int + Rank of the current process + world_size: int + Total number of TP processes + """ + # Divide the weight matrix along the last dimension. + if not is_bias: + output_size_per_partition = param.shape[1] + tensor = tensor_value[ + :, + (rank * output_size_per_partition) : ( + (rank + 1) * output_size_per_partition + ), + ] + param.copy_(tensor, non_blocking=True) + else: + if rank == 0: + _copy_if_present(param, tensor_value) + else: + param.zero_() + + +def _copy_embedding(param: torch.nn.Parameter, tensor_value, rank, world_size): + """ + This function copies the correct shard of the weights for a TP'd embedding module + according to the rank of the process and the world_size. + + Args + ==== + param: torch.nn.Parameter + Parameter that has had TP applied + tensor_value: torch.Tensor + tensor that needs sharding + rank: int + Rank of the current process + world_size: int + Total number of TP processes + """ + # Divide the weight matrix along the last dimension. + output_size_per_partition = param.shape[1] + tensor = tensor_value[ + :, + (rank * output_size_per_partition) : ((rank + 1) * output_size_per_partition), + ] + param.copy_(tensor, non_blocking=True) + + +def _copy_if_present(parameter, tensor_value): + parameter.copy_(tensor_value, non_blocking=True) + + +def _load_partial_state_dict( + model: torch.nn.Module, + state_dict, + needs_tp_sharding: bool, + rank=0, + world_size=1, +): + unused_params = [] + for key, tensor_value in state_dict.items(): + target_module = model + # Find where to put the weight and decide whether it needs TP'ing + key_steps = key.split(".") + prefix = "" + key_step = 0 + tp_module = None + # Navigate the model tree to find the module where the parameter is + # located and whether there is a TPModule in the way in case the + # parameter requires sharding + while key_step < len(key_steps) - 1: + try: + target_module = getattr(target_module, key_steps[key_step]) + if key_step > 0: + prefix += "." + prefix += key_steps[key_step] + key_step += 1 + if isinstance(target_module, Iterable): + target_module = target_module[int(key_steps[key_step])] # type: ignore[index] + prefix += "." + key_steps[key_step] + key_step += 1 + if isinstance(target_module, TPModule): + tp_module = target_module + except AttributeError: + unused_params.append(key) + break + + # Check if target_module has the Parameter/buffer + try: + param = getattr(target_module, key_steps[-1]) + + # If TP sharding is not needed, copy the parameter + # into the model + if not needs_tp_sharding or tp_module is None: + _copy_if_present(param, tensor_value) + elif tp_module is not None: + # Handle TP sharding + if key_steps[-2] in tp_module.colwise_param_names(): + _copy_colwise( + param, + tensor_value, + key_steps[-1] == "bias", + rank, + world_size, + ) + if key_steps[-2] in tp_module.rowwise_param_names(): + _copy_rowwise( + param, + tensor_value, + key_steps[-1] == "bias", + rank, + world_size, + ) + if key_steps[-2] in tp_module.embedding_param_names(): + _copy_embedding( + param, + tensor_value, + rank, + world_size, + ) + except AttributeError: + unused_params.append(key) diff --git a/optimum/habana/distributed/strategy.py b/optimum/habana/distributed/strategy.py new file mode 100644 index 0000000000..4bc68bbca7 --- /dev/null +++ b/optimum/habana/distributed/strategy.py @@ -0,0 +1,108 @@ +from abc import abstractmethod +from typing import Any, List, Mapping + +import torch +import torch.distributed +from torch import nn + +class DistributedStrategy: + def __init__(self, from_meta=False): + self.from_meta = from_meta + + def distribute_module( + self, module: nn.Module, final_layers: bool = False + ) -> nn.Module: + """ + Optionally a distributed strategy may distribute modules that are not + numbered layers + """ + return module + + @abstractmethod + def distribute_layer(self, block: nn.Module, layer: int) -> nn.Module: + """ + Distribute each layer as-appropriate + """ + pass + + +class NotDistributed(DistributedStrategy): + def __init__(self, from_meta=False): + super().__init__(from_meta) + + def distribute_module( + self, module: nn.Module, final_layers: bool = False + ) -> nn.Module: + return module + + def distribute_layer(self, block: nn.Module, layer: int) -> nn.Module: + return block + + +NoOpStrategy = NotDistributed() + + +class DeviceMover(nn.Module): + def __init__(self, module: nn.Module, device): + super().__init__() + self.device = device + # make this wrapper module behave as if it was the wrapped module. + attr = module.__dict__ + attr["module"] = module.to(device) + attr["device"] = device + self.__dict__ = attr + + def forward(self, *args, **kwargs): + device = self.device + args = [ + arg.to(device) if isinstance(arg, torch.Tensor) else arg for arg in args + ] + kwargs = { + k: ( + kwargs[k].to(device) + if isinstance(kwargs[k], torch.Tensor) + else kwargs[k] + ) + for k in kwargs + } + return self.module(*args, **kwargs) + + +class UniformModelParallelStrategy(DistributedStrategy): + def __init__(self, devices: List[int], num_layers: int, from_meta=False): + super().__init__(from_meta) + num_dev = len(devices) + layers_per_dev = num_layers // num_dev + remainder = num_layers - (layers_per_dev * num_dev) + self.layer_to_device = [0] * num_layers + layer_id = 0 + for dev_idx in range(len(devices)): + for i in range(layers_per_dev): + self.layer_to_device[layer_id] = devices[dev_idx] + layer_id = layer_id + 1 + if remainder > 0: + self.layer_to_device[layer_id] = devices[dev_idx] + layer_id = layer_id + 1 + remainder -= 1 + + def distribute_layer(self, block: nn.Module, layer: int) -> nn.Module: + device = self.layer_to_device[layer] + if self.from_meta: + block.to_empty(device=device) # type: ignore[arg-type] + wrapped = DeviceMover(block, device) + return wrapped + + def distribute_module( + self, module: nn.Module, final_layers: bool = False + ) -> nn.Module: + if final_layers: + device = self.layer_to_device[len(self.layer_to_device) - 1] + else: + device = self.layer_to_device[0] + if self.from_meta: + return module.to_empty(device=device) # type: ignore[arg-type] + wrapped = DeviceMover(module, device) + return wrapped + + + diff --git a/optimum/habana/distributed/tensorparallel.py b/optimum/habana/distributed/tensorparallel.py new file mode 100644 index 0000000000..5d484b2fc5 --- /dev/null +++ b/optimum/habana/distributed/tensorparallel.py @@ -0,0 +1,103 @@ +# mypy: disable-error-code="method-assign,misc" + +import torch +import torch._inductor.ir as ir +import torch._inductor.lowering as lowering +import torch.distributed as dist +import torch.distributed._functional_collectives as funcol +from torch import nn + + +def apply_colwise_tp(par_mod: nn.Linear, mod: nn.Linear, world_size, rank): + # Divide the weight matrix along the last dimension. + output_size_per_partition = mod.out_features // world_size + with torch.no_grad(): + par_mod.weight.copy_( + torch.split(mod.weight, output_size_per_partition, dim=0)[rank] + ) + if par_mod.bias is not None: + par_mod.bias.copy_(torch.split(mod.bias, output_size_per_partition)[rank]) + + +def apply_rowwise_tp(par_mod: nn.Linear, mod: nn.Linear, world_size, rank): + # Divide the weight matrix along the last dimension. + output_size_per_partition = mod.in_features // world_size + with torch.no_grad(): + par_mod.weight.copy_( + torch.split(mod.weight, output_size_per_partition, dim=1)[rank] + ) + if par_mod.bias is not None: + if rank == 0: + par_mod.bias.copy_(mod.bias) + else: + par_mod.bias.zero_() + + +def apply_embedding_tp(par_mod: nn.Embedding, mod: nn.Embedding, world_size, rank): + # Divide the weight matrix along the last dimension. + output_size_per_partition = mod.embedding_dim // world_size + with torch.no_grad(): + par_mod.weight.copy_( + torch.split(mod.weight, output_size_per_partition, dim=1)[rank] + ) + + +## Fixes for PT 2.2 collectives until PT 2.3 is released + + +# Fix 1: https://github.com/pytorch/pytorch/issues/121311 +def get_volatile_reads_fixed(self): + inp = self.inputs[0] + if isinstance(inp, ir._CollectiveKernel): + # Out-of-place single-output + return [inp.inputs[0]] + elif isinstance(inp, ir.MultiOutput): + # Out-of-place multi-output + coll = inp.inputs[0] + if isinstance(coll, ir._CollectiveKernel): + _, idx = inp.indices[0] + return [coll.inputs[idx]] + return [] # e.g. regular FallbackKernel + else: + # In-place requires no additional deps handling for volatile + # reads since the inputs are mutated. + return [] + + +ir._WaitKernel.get_volatile_reads = get_volatile_reads_fixed + +# Fix 2: These are fixed already in nightlies and will be in 2.3 +for overload in torch.ops._c10d_functional.all_reduce.overloads(): + other_fn = getattr(torch.ops._c10d_functional.all_reduce, overload) + if other_fn in lowering.lowerings: + del lowering.lowerings[other_fn] + +def _all_reduce(input_: torch.Tensor) -> torch.Tensor: + """All-reduce the input tensor across model parallel group.""" + world_size = dist.get_world_size() + + if world_size == 1: + return input_ + + # Starting PT 2.3, we can go back to funcol.all_reduce + return torch.ops._c10d_functional.wait_tensor( + torch.ops._c10d_functional.all_reduce(input_, "sum", "default") + ) + +class _ReduceFromModelParallelRegion(torch.autograd.Function): + """All-reduce the input from the model parallel region.""" + + @staticmethod + def symbolic(graph, input_): + return _all_reduce(input_) + + @staticmethod + def forward(ctx, input_): + return _all_reduce(input_) + + @staticmethod + def backward(ctx, grad_output): + return grad_output + +def reduce_from_tensor_model_parallel_region(input_): + return _ReduceFromModelParallelRegion.apply(input_) diff --git a/optimum/habana/distributed/tp.py b/optimum/habana/distributed/tp.py new file mode 100644 index 0000000000..31f33a79cc --- /dev/null +++ b/optimum/habana/distributed/tp.py @@ -0,0 +1,84 @@ +import itertools +from abc import ABCMeta, abstractmethod +from typing import List, Type + +import torch +import torch.nn as nn +from torch.distributed.distributed_c10d import ProcessGroup + +from optimum.habana.distributed.tensorparallel import ( + apply_colwise_tp, + apply_embedding_tp, + apply_rowwise_tp, +) + + +class TPModule(nn.Module, metaclass=ABCMeta): + """ + This is an abstract class that any nn.Module can implement to enable + Tensor Parallel. On top of inheriting from this class, the TP module + will have to implement list_colwise_weights, list_rowwise_weights, + list_embedding_weights, and import_module for their relevant weights. + Finally, the module must call setup_tp at the end of their __init__ + function. See examples in attention.py, feedforward.py and embedding.py + + """ + + rank: int + world_size: int + + def setup_tp(self, rank: int, world_size: int) -> None: + self.rank = rank + self.world_size = world_size + + def colwise_param_names(self) -> List[str]: + return [] + + def rowwise_param_names(self) -> List[str]: + return [] + + def embedding_param_names(self) -> List[str]: + return [] + + @staticmethod + @abstractmethod + def import_module(module, group: ProcessGroup): + pass + + def import_weights(self, module: nn.Module): + for weight in self.colwise_param_names(): + apply_colwise_tp( + getattr(self, weight), + getattr(module, weight), + self.world_size, + self.rank, + ) + for weight in self.rowwise_param_names(): + apply_rowwise_tp( + getattr(self, weight), + getattr(module, weight), + self.world_size, + self.rank, + ) + for weight in self.embedding_param_names(): + apply_embedding_tp( + getattr(self, weight), + getattr(module, weight), + self.world_size, + self.rank, + ) + tp_sharded_modules = list( + itertools.chain( + self.colwise_param_names(), + self.rowwise_param_names(), + self.embedding_param_names(), + ) + ) + with torch.no_grad(): + for mod_name, module in self.named_children(): + if not mod_name in tp_sharded_modules: + for param_name, param in module.named_parameters(recurse=False): + param.copy_( + getattr(getattr(module, mod_name), param_name), + non_blocking=True, + ) diff --git a/optimum/habana/distributed/tp_wrapping.py b/optimum/habana/distributed/tp_wrapping.py new file mode 100644 index 0000000000..402accb342 --- /dev/null +++ b/optimum/habana/distributed/tp_wrapping.py @@ -0,0 +1,33 @@ +import os +from torch import nn +from torch.distributed.distributed_c10d import ProcessGroup + +from optimum.habana.transformers.models.llama.modeling_llama import ( + GaudiLlamaMLP, + TPGaudiLlamaMLP, + GaudiLlamaAttention, + TPGaudiLlamaAttention +) + +# this probably belongs somewhere else but can't go in fms.distribtued b/c +# circular dependency. +def _tp_wrapped(module: nn.Module, layer: int, group: ProcessGroup): + if hasattr(module, "to_tp"): + return module.to_tp(group) + elif isinstance(module, GaudiLlamaAttention): + return TPGaudiLlamaAttention.import_module(module,layer, group) + elif isinstance(module, GaudiLlamaMLP): + return TPGaudiLlamaMLP.import_module(module, group) + else: + return module + + +def apply_tp(model: nn.Module, layer_idx: int, group: ProcessGroup): + wrapped = _tp_wrapped(model, layer_idx, group) + if wrapped is not model: + return wrapped + + for name, layer in model.named_children(): + tp_layer = apply_tp(layer, layer_idx, group) + setattr(model, name, tp_layer) + return model diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 0c51d6ce83..0c9b0d83b2 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -20,10 +20,20 @@ logger, ) +import copy +from torch.distributed.distributed_c10d import ProcessGroup from ...modeling_attn_mask_utils import ( _gaudi_prepare_4d_causal_attention_mask, ) +from optimum.habana.distributed.tp import TPModule +from optimum.habana import distributed +from optimum.habana.distributed.tensorparallel import ( + reduce_from_tensor_model_parallel_region, +) + +from optimum.habana.distributed.strategy import DistributedStrategy +from optimum.habana.distributed.strategy import NoOpStrategy try: from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE @@ -188,6 +198,46 @@ def post_mlp_forward(self, x): return self.down_proj.post_all_reduce(x) return x +class TPGaudiLlamaMLP(GaudiLlamaMLP, TPModule): + def __init__( + self, + config, + group: Optional[ProcessGroup] = None, + ): + assert torch.distributed.is_initialized() + rank, world_size = distributed.rank_and_world(group) + hidden_dim = int(config.hidden_grow_factor * config.hidden_size) + assert ( + hidden_dim % world_size == 0 + ), "Hidden dim must be divisible by world size" + + self.config = copy.deepcopy(config) + self.config.intermediate_size = int((config.hidden_grow_factor / world_size) * config.hidden_size) + GaudiLlamaMLP.__init__( + self, + self.config + ) + self.setup_tp(rank, world_size) + + def colwise_param_names(self) -> List[str]: + return ["up_proj", "gate_proj"] + + def rowwise_param_names(self) -> List[str]: + return ["down_proj"] + + @staticmethod + def import_module(glu: GaudiLlamaMLP, group: ProcessGroup) -> "TPGaudiLlamaMLP": + config = copy.deepcopy(glu.config) + config.hidden_grow_factor = glu.config.intermediate_size / glu.config.hidden_size + tp_glu = TPGaudiLlamaMLP( + config = config, + group=group + ) + return tp_glu + + def pre_mlp_forward(self, x): + out_par = GaudiLlamaMLP.pre_mlp_forward(self, x) + return reduce_from_tensor_model_parallel_region(out_par) def gaudi_llama_repeat_kv( query_states: torch.Tensor, @@ -561,6 +611,94 @@ def post_attn_forward(self, attn_output): return attn_output +class TPGaudiLlamaAttention(GaudiLlamaAttention, TPModule): + def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None, group: Optional[ProcessGroup] = None,): + super().__init__(config, layer_idx) + + assert torch.distributed.is_initialized() + rank, world_size = distributed.rank_and_world(group) + assert ( + config.num_attention_heads % world_size == 0 + ), "The number of heads must be divisible by world size" + self.config = copy.deepcopy(config) + + self.pre_tp_kvheads = config.num_key_value_heads + GaudiLlamaAttention.__init__(self, self.config , layer_idx) + self.config.num_attention_heads = self.config.num_attention_heads // world_size + self.config.num_key_value_heads = ( self.config.num_key_value_heads // world_size) if self.config.num_key_value_heads > 1 else self.config.num_key_value_heads + self.head_dim = config.hidden_size // config.num_attention_heads + self.hidden_size = self.config.hidden_size // world_size + self.num_heads = self.config.num_attention_heads + + self.q_proj = torch.nn.Linear(config.hidden_size, self.config.num_attention_heads * self.head_dim , bias=config.attention_bias) + self.k_proj = torch.nn.Linear(config.hidden_size, self.config.num_key_value_heads * self.head_dim , bias=config.attention_bias) + self.v_proj = torch.nn.Linear(config.hidden_size, self.config.num_key_value_heads * self.head_dim , bias=config.attention_bias) + self.o_proj = torch.nn.Linear(self.config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias) + self.norm_factor = 1.0 / math.sqrt(self.head_dim) + self.setup_tp(rank, world_size) + + def colwise_param_names(self) -> List[str]: + colwise_weights = ["q_proj"] + if self.pre_tp_kvheads != 1: + colwise_weights.append("k_proj") + colwise_weights.append("v_proj") + return colwise_weights + + def rowwise_param_names(self) -> List[str]: + return ["o_proj"] + + @staticmethod + def import_module( + mha: GaudiLlamaAttention, layer_idx, group: ProcessGroup + ) -> "TPGaudiLlamaAttention": + tp_mha = TPGaudiLlamaAttention( + config = mha.config, + layer_idx=layer_idx, + group=group + ) + return tp_mha + + def pre_attn_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + token_idx: Optional[torch.Tensor] = None, + attn_softmax_bf16: Optional[bool] = False, + reuse_cache: Optional[bool] = False, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, + flash_attention_fast_softmax: Optional[bool] = False, + cache_idx: int = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + + hidden_states, attn_weights, present_key_value = GaudiLlamaAttention.pre_attn_forward(self, + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + use_cache, + cache_position, + token_idx, + attn_softmax_bf16, + reuse_cache, + use_flash_attention, + flash_attention_recompute, + flash_attention_causal_mask, + flash_attention_fast_softmax, + cache_idx, + **kwargs + ) + + hidden_states = reduce_from_tensor_model_parallel_region(hidden_states) + return hidden_states, attn_weights, present_key_value class GaudiLlamaDecoderLayer(LlamaDecoderLayer): def __init__(self, config: LlamaConfig, layer_idx: int): super(LlamaDecoderLayer, self).__init__() @@ -733,10 +871,16 @@ def __init__(self, config: LlamaConfig): self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size + self.distributed_strategy = config.distributed_strategy + config.distributed_strategy = None self.embed_tokens = torch.nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = torch.nn.ModuleList( - [GaudiLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) + layers = [] + for layer_idx in range(config.num_hidden_layers): + layer = GaudiLlamaDecoderLayer(config, layer_idx) + layer = self.distributed_strategy.distribute_layer(layer, layer_idx) + layers.append(layer) + self.layers = torch.nn.ModuleList(layers) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False @@ -971,6 +1115,10 @@ class GaudiLlamaForCausalLM(LlamaForCausalLM): - add new args reuse_cache """ + def __init__(self, config, distributed_strategy: DistributedStrategy = NoOpStrategy): + config.distributed_strategy = distributed_strategy + super().__init__(config) + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len)