diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index c1b4748d0e..67afcc7015 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -43,6 +43,7 @@ set_seed, ) + def adjust_batch(batch, size): curr_size = batch["input_ids"].shape[1] if curr_size >= size: @@ -328,6 +329,7 @@ def __setstate__(self, state): dist.init_process_group() torch._C._distributed_c10d._register_process_group("default", dist.group.WORLD) + logger.info("Creating Model") config = AutoConfig.from_pretrained(args.model_name_or_path,torch_dtype=model_dtype, **model_kwargs) model_kwargs={} model_kwargs["distributed_strategy"] = TensorParallelStrategy() @@ -337,6 +339,7 @@ def __setstate__(self, state): source="hf" checkpoint_sharding=None lazy_sd: MutableMapping[str, Any] = {} + logger.info("Loading Checkpoints") lazy_sd = serialization.load_state_dict( args.model_name_or_path, source=source, diff --git a/optimum/habana/distributed/strategy.py b/optimum/habana/distributed/strategy.py index 3d77db0cb6..4bc68bbca7 100644 --- a/optimum/habana/distributed/strategy.py +++ b/optimum/habana/distributed/strategy.py @@ -5,9 +5,6 @@ import torch.distributed from torch import nn -#from optimum.habana.distributed import tp_wrapping - - class DistributedStrategy: def __init__(self, from_meta=False): self.from_meta = from_meta @@ -91,7 +88,6 @@ def __init__(self, devices: List[int], num_layers: int, from_meta=False): def distribute_layer(self, block: nn.Module, layer: int) -> nn.Module: device = self.layer_to_device[layer] if self.from_meta: - # https://github.com/pytorch/pytorch/pull/113647 block.to_empty(device=device) # type: ignore[arg-type] wrapped = DeviceMover(block, device) return wrapped diff --git a/optimum/habana/distributed/tensorparallel.py b/optimum/habana/distributed/tensorparallel.py index 8951b2398c..5d484b2fc5 100644 --- a/optimum/habana/distributed/tensorparallel.py +++ b/optimum/habana/distributed/tensorparallel.py @@ -17,7 +17,6 @@ def apply_colwise_tp(par_mod: nn.Linear, mod: nn.Linear, world_size, rank): ) if par_mod.bias is not None: par_mod.bias.copy_(torch.split(mod.bias, output_size_per_partition)[rank]) - # print(f"For rank {rank}, we have the following weights: Base weight {mod.weight} bias {mod.bias}; Par weight {par_mod.weight}, bias {par_mod.bias}") def apply_rowwise_tp(par_mod: nn.Linear, mod: nn.Linear, world_size, rank): @@ -32,7 +31,6 @@ def apply_rowwise_tp(par_mod: nn.Linear, mod: nn.Linear, world_size, rank): par_mod.bias.copy_(mod.bias) else: par_mod.bias.zero_() - # print(f"For rank {rank}, we have the following weights: Base weight {mod.weight}, bias {mod.bias}; Par weight {par_mod.weight}, bias {par_mod.bias}") def apply_embedding_tp(par_mod: nn.Embedding, mod: nn.Embedding, world_size, rank): @@ -42,7 +40,6 @@ def apply_embedding_tp(par_mod: nn.Embedding, mod: nn.Embedding, world_size, ran par_mod.weight.copy_( torch.split(mod.weight, output_size_per_partition, dim=1)[rank] ) - # print(f"For rank {rank}, we have the following weights: Base weight {mod.weight} bias {mod.bias}; Par weight {par_mod.weight}, bias {par_mod.bias}") ## Fixes for PT 2.2 collectives until PT 2.3 is released @@ -75,65 +72,6 @@ def get_volatile_reads_fixed(self): if other_fn in lowering.lowerings: del lowering.lowerings[other_fn] - -@lowering.register_lowering(torch.ops._c10d_functional.all_reduce) -def _all_reduce_fixed(inp, reduce_op, group_name): - inp = torch.clone(inp) - ir._CollectiveKernel.create_inplace( - torch.ops._c10d_functional.all_reduce_.default, - ir.ExternKernel.require_contiguous(inp), - reduce_op, - group_name, - ) - return inp - - -for overload in torch.ops._c10d_functional.all_gather_into_tensor.overloads(): - other_fn = getattr(torch.ops._c10d_functional.all_gather_into_tensor, overload) - if other_fn in lowering.lowerings: - del lowering.lowerings[other_fn] - - -@lowering.register_lowering(torch.ops._c10d_functional.all_gather_into_tensor) -def _all_gather_into_tensor(inp, group_size, group_name): - return ir.TensorBox.create( - ir._CollectiveKernel.create_out_of_place( - torch.ops._c10d_functional.all_gather_into_tensor.default, - ir.ExternKernel.require_contiguous(inp), - group_size, - group_name, - ) - ) - - -def _all_gather(input_: torch.Tensor) -> torch.Tensor: - """Gather the input tensor across model parallel group.""" - world_size = dist.get_world_size() - - if world_size == 1: - return input_ - - # The transposes here are to avoid excessive recompilation due to split() - # specializing the dimension where the all_gather is happening - last_dim = input_.dim() - 1 - # Starting PT 2.3, we can go back to funcol.all_gather_tensor - # TODO SW-180411 WA - # return ( - # torch.ops._c10d_functional.wait_tensor( - # torch.ops._c10d_functional.all_gather_into_tensor( - # input_.transpose(0, last_dim).contiguous(), world_size, "default" - # ) - # ) - # .transpose(0, last_dim) - # .contiguous() - # ) - shape = list(input_.transpose(0, last_dim).size()) - shape[0] *= world_size - output = torch.empty(shape, dtype=input_.dtype, device=input_.device) - dist.all_gather_into_tensor(output, input_.transpose(0, last_dim).contiguous()) - return output.transpose(0, last_dim).contiguous() - - def _all_reduce(input_: torch.Tensor) -> torch.Tensor: """All-reduce the input tensor across model parallel group.""" world_size = dist.get_world_size() @@ -146,43 +84,6 @@ def _all_reduce(input_: torch.Tensor) -> torch.Tensor: torch.ops._c10d_functional.all_reduce(input_, "sum", "default") ) - -def _split(input_: torch.Tensor, rank, world_size) -> torch.Tensor: - """Split the tensor along its last dimension and keep the - corresponding slice.""" - - if world_size == 1: - return input_ - - # Split along last dimension. - # Get the size and dimension. - last_dim = input_.dim() - 1 - last_dim_size = input_.size()[last_dim] // world_size - # Split. - input_list = torch.split(input_, last_dim_size, dim=last_dim) - - # Note: torch.split does not create contiguous tensors by default. - output = input_list[rank].contiguous() - - return output - - -class _CopyToModelParallelRegion(torch.autograd.Function): - """Pass the input to the model parallel region.""" - - @staticmethod - def symbolic(graph, input_): - return input_ - - @staticmethod - def forward(ctx, input_): - return input_ - - @staticmethod - def backward(ctx, grad_output): - return _all_reduce(grad_output) - - class _ReduceFromModelParallelRegion(torch.autograd.Function): """All-reduce the input from the model parallel region.""" @@ -198,32 +99,5 @@ def forward(ctx, input_): def backward(ctx, grad_output): return grad_output - -class _AllGatherFromModelParallelRegion(torch.autograd.Function): - """Gather the input from the model parallel region.""" - - @staticmethod - def symbolic(graph, input_): - return _all_gather(input_) - - @staticmethod - def forward(ctx, input_, rank, world_size): - ctx.rank = rank - ctx.world_size = world_size - return _all_gather(input_) - - @staticmethod - def backward(ctx, grad_output): - return _split(grad_output, ctx.rank, ctx.world_size) - - -def copy_to_tensor_model_parallel_region(input_): - return _CopyToModelParallelRegion.apply(input_) - - def reduce_from_tensor_model_parallel_region(input_): return _ReduceFromModelParallelRegion.apply(input_) - - -def all_gather_from_tensor_model_parallel_region(input_, rank, world_size): - return _AllGatherFromModelParallelRegion.apply(input_, rank, world_size) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 4fd20cd670..d04ba0bcc0 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -19,6 +19,7 @@ apply_rotary_pos_emb, logger, ) + import copy from torch.distributed.distributed_c10d import ProcessGroup from ...modeling_attn_mask_utils import ( @@ -28,13 +29,11 @@ from optimum.habana.distributed.tp import TPModule from optimum.habana import distributed from optimum.habana.distributed.tensorparallel import ( - copy_to_tensor_model_parallel_region, reduce_from_tensor_model_parallel_region, ) from optimum.habana.distributed.strategy import DistributedStrategy -from optimum.habana.distributed.strategy import NotDistributed -NoOpStrategy = NotDistributed() +from optimum.habana.distributed.strategy import NoOpStrategy try: from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE @@ -913,13 +912,14 @@ def __init__(self, config: LlamaConfig): super(LlamaModel, self).__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size + self.distributed_strategy = config.distributed_strategy config.distributed_strategy = None self.embed_tokens = torch.nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) layers = [] - for i in range(config.num_hidden_layers): - layer = GaudiLlamaDecoderLayer(config, i) - layer = self.distributed_strategy.distribute_layer(layer, i) + for layer_idx in range(config.num_hidden_layers): + layer = GaudiLlamaDecoderLayer(config, layer_idx) + layer = self.distributed_strategy.distribute_layer(layer, layer_idx) layers.append(layer) self.layers = torch.nn.ModuleList(layers) @@ -929,7 +929,6 @@ def __init__(self, config: LlamaConfig): # Initialize weights and apply final processing self.post_init() - def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): for layer in self.layers: layer.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) @@ -1157,6 +1156,7 @@ class GaudiLlamaForCausalLM(LlamaForCausalLM): - add new args attn_softmax_bf16 - add new args reuse_cache """ + def __init__(self, config, distributed_strategy: DistributedStrategy = NoOpStrategy): config.distributed_strategy = distributed_strategy super().__init__(config)