Skip to content

Commit

Permalink
Code cleanup -removed unused code
Browse files Browse the repository at this point in the history
  • Loading branch information
kalyanjkk committed Jul 15, 2024
1 parent a9275fd commit 657f3c1
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 137 deletions.
3 changes: 3 additions & 0 deletions examples/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
set_seed,
)


def adjust_batch(batch, size):
curr_size = batch["input_ids"].shape[1]
if curr_size >= size:
Expand Down Expand Up @@ -328,6 +329,7 @@ def __setstate__(self, state):
dist.init_process_group()

torch._C._distributed_c10d._register_process_group("default", dist.group.WORLD)
logger.info("Creating Model")
config = AutoConfig.from_pretrained(args.model_name_or_path,torch_dtype=model_dtype, **model_kwargs)
model_kwargs={}
model_kwargs["distributed_strategy"] = TensorParallelStrategy()
Expand All @@ -337,6 +339,7 @@ def __setstate__(self, state):
source="hf"
checkpoint_sharding=None
lazy_sd: MutableMapping[str, Any] = {}
logger.info("Loading Checkpoints")
lazy_sd = serialization.load_state_dict(
args.model_name_or_path,
source=source,
Expand Down
4 changes: 0 additions & 4 deletions optimum/habana/distributed/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@
import torch.distributed
from torch import nn

#from optimum.habana.distributed import tp_wrapping


class DistributedStrategy:
def __init__(self, from_meta=False):
self.from_meta = from_meta
Expand Down Expand Up @@ -91,7 +88,6 @@ def __init__(self, devices: List[int], num_layers: int, from_meta=False):
def distribute_layer(self, block: nn.Module, layer: int) -> nn.Module:
device = self.layer_to_device[layer]
if self.from_meta:
# https://github.com/pytorch/pytorch/pull/113647
block.to_empty(device=device) # type: ignore[arg-type]
wrapped = DeviceMover(block, device)
return wrapped
Expand Down
126 changes: 0 additions & 126 deletions optimum/habana/distributed/tensorparallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ def apply_colwise_tp(par_mod: nn.Linear, mod: nn.Linear, world_size, rank):
)
if par_mod.bias is not None:
par_mod.bias.copy_(torch.split(mod.bias, output_size_per_partition)[rank])
# print(f"For rank {rank}, we have the following weights: Base weight {mod.weight} bias {mod.bias}; Par weight {par_mod.weight}, bias {par_mod.bias}")


def apply_rowwise_tp(par_mod: nn.Linear, mod: nn.Linear, world_size, rank):
Expand All @@ -32,7 +31,6 @@ def apply_rowwise_tp(par_mod: nn.Linear, mod: nn.Linear, world_size, rank):
par_mod.bias.copy_(mod.bias)
else:
par_mod.bias.zero_()
# print(f"For rank {rank}, we have the following weights: Base weight {mod.weight}, bias {mod.bias}; Par weight {par_mod.weight}, bias {par_mod.bias}")


def apply_embedding_tp(par_mod: nn.Embedding, mod: nn.Embedding, world_size, rank):
Expand All @@ -42,7 +40,6 @@ def apply_embedding_tp(par_mod: nn.Embedding, mod: nn.Embedding, world_size, ran
par_mod.weight.copy_(
torch.split(mod.weight, output_size_per_partition, dim=1)[rank]
)
# print(f"For rank {rank}, we have the following weights: Base weight {mod.weight} bias {mod.bias}; Par weight {par_mod.weight}, bias {par_mod.bias}")


## Fixes for PT 2.2 collectives until PT 2.3 is released
Expand Down Expand Up @@ -75,65 +72,6 @@ def get_volatile_reads_fixed(self):
if other_fn in lowering.lowerings:
del lowering.lowerings[other_fn]


@lowering.register_lowering(torch.ops._c10d_functional.all_reduce)
def _all_reduce_fixed(inp, reduce_op, group_name):
inp = torch.clone(inp)
ir._CollectiveKernel.create_inplace(
torch.ops._c10d_functional.all_reduce_.default,
ir.ExternKernel.require_contiguous(inp),
reduce_op,
group_name,
)
return inp


for overload in torch.ops._c10d_functional.all_gather_into_tensor.overloads():
other_fn = getattr(torch.ops._c10d_functional.all_gather_into_tensor, overload)
if other_fn in lowering.lowerings:
del lowering.lowerings[other_fn]


@lowering.register_lowering(torch.ops._c10d_functional.all_gather_into_tensor)
def _all_gather_into_tensor(inp, group_size, group_name):
return ir.TensorBox.create(
ir._CollectiveKernel.create_out_of_place(
torch.ops._c10d_functional.all_gather_into_tensor.default,
ir.ExternKernel.require_contiguous(inp),
group_size,
group_name,
)
)


def _all_gather(input_: torch.Tensor) -> torch.Tensor:
"""Gather the input tensor across model parallel group."""
world_size = dist.get_world_size()

if world_size == 1:
return input_

# The transposes here are to avoid excessive recompilation due to split()
# specializing the dimension where the all_gather is happening
last_dim = input_.dim() - 1
# Starting PT 2.3, we can go back to funcol.all_gather_tensor
# TODO SW-180411 WA
# return (
# torch.ops._c10d_functional.wait_tensor(
# torch.ops._c10d_functional.all_gather_into_tensor(
# input_.transpose(0, last_dim).contiguous(), world_size, "default"
# )
# )
# .transpose(0, last_dim)
# .contiguous()
# )
shape = list(input_.transpose(0, last_dim).size())
shape[0] *= world_size
output = torch.empty(shape, dtype=input_.dtype, device=input_.device)
dist.all_gather_into_tensor(output, input_.transpose(0, last_dim).contiguous())
return output.transpose(0, last_dim).contiguous()


def _all_reduce(input_: torch.Tensor) -> torch.Tensor:
"""All-reduce the input tensor across model parallel group."""
world_size = dist.get_world_size()
Expand All @@ -146,43 +84,6 @@ def _all_reduce(input_: torch.Tensor) -> torch.Tensor:
torch.ops._c10d_functional.all_reduce(input_, "sum", "default")
)


def _split(input_: torch.Tensor, rank, world_size) -> torch.Tensor:
"""Split the tensor along its last dimension and keep the
corresponding slice."""

if world_size == 1:
return input_

# Split along last dimension.
# Get the size and dimension.
last_dim = input_.dim() - 1
last_dim_size = input_.size()[last_dim] // world_size
# Split.
input_list = torch.split(input_, last_dim_size, dim=last_dim)

# Note: torch.split does not create contiguous tensors by default.
output = input_list[rank].contiguous()

return output


class _CopyToModelParallelRegion(torch.autograd.Function):
"""Pass the input to the model parallel region."""

@staticmethod
def symbolic(graph, input_):
return input_

@staticmethod
def forward(ctx, input_):
return input_

@staticmethod
def backward(ctx, grad_output):
return _all_reduce(grad_output)


class _ReduceFromModelParallelRegion(torch.autograd.Function):
"""All-reduce the input from the model parallel region."""

Expand All @@ -198,32 +99,5 @@ def forward(ctx, input_):
def backward(ctx, grad_output):
return grad_output


class _AllGatherFromModelParallelRegion(torch.autograd.Function):
"""Gather the input from the model parallel region."""

@staticmethod
def symbolic(graph, input_):
return _all_gather(input_)

@staticmethod
def forward(ctx, input_, rank, world_size):
ctx.rank = rank
ctx.world_size = world_size
return _all_gather(input_)

@staticmethod
def backward(ctx, grad_output):
return _split(grad_output, ctx.rank, ctx.world_size)


def copy_to_tensor_model_parallel_region(input_):
return _CopyToModelParallelRegion.apply(input_)


def reduce_from_tensor_model_parallel_region(input_):
return _ReduceFromModelParallelRegion.apply(input_)


def all_gather_from_tensor_model_parallel_region(input_, rank, world_size):
return _AllGatherFromModelParallelRegion.apply(input_, rank, world_size)
14 changes: 7 additions & 7 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
apply_rotary_pos_emb,
logger,
)

import copy
from torch.distributed.distributed_c10d import ProcessGroup
from ...modeling_attn_mask_utils import (
Expand All @@ -28,13 +29,11 @@
from optimum.habana.distributed.tp import TPModule
from optimum.habana import distributed
from optimum.habana.distributed.tensorparallel import (
copy_to_tensor_model_parallel_region,
reduce_from_tensor_model_parallel_region,
)

from optimum.habana.distributed.strategy import DistributedStrategy
from optimum.habana.distributed.strategy import NotDistributed
NoOpStrategy = NotDistributed()
from optimum.habana.distributed.strategy import NoOpStrategy

try:
from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE
Expand Down Expand Up @@ -913,13 +912,14 @@ def __init__(self, config: LlamaConfig):
super(LlamaModel, self).__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size

self.distributed_strategy = config.distributed_strategy
config.distributed_strategy = None
self.embed_tokens = torch.nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
layers = []
for i in range(config.num_hidden_layers):
layer = GaudiLlamaDecoderLayer(config, i)
layer = self.distributed_strategy.distribute_layer(layer, i)
for layer_idx in range(config.num_hidden_layers):
layer = GaudiLlamaDecoderLayer(config, layer_idx)
layer = self.distributed_strategy.distribute_layer(layer, layer_idx)
layers.append(layer)
self.layers = torch.nn.ModuleList(layers)

Expand All @@ -929,7 +929,6 @@ def __init__(self, config: LlamaConfig):
# Initialize weights and apply final processing
self.post_init()


def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len):
for layer in self.layers:
layer.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len)
Expand Down Expand Up @@ -1157,6 +1156,7 @@ class GaudiLlamaForCausalLM(LlamaForCausalLM):
- add new args attn_softmax_bf16
- add new args reuse_cache
"""

def __init__(self, config, distributed_strategy: DistributedStrategy = NoOpStrategy):
config.distributed_strategy = distributed_strategy
super().__init__(config)
Expand Down

0 comments on commit 657f3c1

Please sign in to comment.