Skip to content

Commit

Permalink
make style
Browse files Browse the repository at this point in the history
updated README for distributed_strategy="tp"
Added test in tests/test_text_generation_example.py
add a link to the original implementation for the referenced files
  • Loading branch information
kalyanjkk committed Jul 17, 2024
1 parent 0bf2a63 commit 8ecceca
Show file tree
Hide file tree
Showing 10 changed files with 280 additions and 191 deletions.
29 changes: 29 additions & 0 deletions examples/text-generation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,35 @@ set the following environment variables before running the command: `PT_ENABLE_I

You will also need to add `--torch_compile` in your command.

### Running with Tesor parallel strategy
#### Attribution

This repository includes code from the [foundation-model-stack](https://github.com/foundation-model-stack/foundation-model-stack) repository, which is licensed under the Apache License 2.0. See the `LICENSE` file for more details.

torch.compile with tensor parallel strategy is an experimental feature. It has not been validated for all models. To enable
torch.compile with tensor parallel strategy, please set the following environment variables before running the
command: `PT_ENABLE_INT64_SUPPORT=1` and `PT_HPU_LAZY_MODE=0`. This will enable tensor parallel strategy without deepspeed.

You will also need to add `--torch_compile` and `--distributed_strategy="tp"` in your command.

Here is an example:
```bash
python ../gaudi_spawn.py --world_size 8 run_generation.py \
--model_name_or_path meta-llama/Llama-2-70b-hf \
--trim_logits \
--use_kv_cache \
--attn_softmax_bf16 \
--bf16 \
--bucket_internal \
--bucket_size=128 \
--use_flash_attention \
--flash_attention_recompute \
--batch_size 246 \
--max_input_tokens 2048 \
--max_new_tokens 2048 \
--torch_compile \
--distributed_strategy="tp"
```

### Running with FP8

Expand Down
65 changes: 18 additions & 47 deletions examples/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,59 +245,35 @@ def setup_model(args, model_dtype, model_kwargs, logger):
# assistant_model = get_torch_compiled_model(assistant_model)
return model, assistant_model

def setup_distributed_model_tp(args, model_dtype, model_kwargs, logger):

from optimum.habana.distributed import serialization
def setup_distributed_model_tp(args, model_dtype, model_kwargs, logger):
from typing import Any, MutableMapping

from optimum.habana.distributed import tp_wrapping
from optimum.habana.distributed.strategy import DistributedStrategy
from torch import nn

class TensorParallelStrategy(DistributedStrategy):
def __init__(self, group=None, from_meta=False):
super().__init__(from_meta)
assert torch.distributed.is_initialized(), "must initialize a process group"
self.group = group if group is not None else torch.distributed.GroupMember.WORLD

def distribute_module(
self, module: nn.Module, final_layers: bool = False
) -> nn.Module:
return tp_wrapping.apply_tp(module, self.group)

def distribute_layer(self, block: nn.Module, layer: int) -> nn.Module:
return tp_wrapping.apply_tp(block, layer, self.group)

def __getstate__(self):
state = self.__dict__.copy()
state['group'] = None # Remove ProcessGroup from state
return state

def __setstate__(self, state):
self.__dict__.update(state)
self.group = None # Restore to default state or reinitialize
from optimum.habana.distributed import serialization
from optimum.habana.distributed.strategy import TensorParallelStrategy

logger.info("Multi-device run.")

assert args.quant_config == "", "Fp8 is not enabled, unset QUANT_CONFIG"
assert args.assistant_model is None, "Assistant model must be None"

from torch import distributed as dist
if args.device == 'hpu':
import habana_frameworks.torch.distributed.hccl
dist.init_process_group(backend='hccl')

if args.device == "hpu":
dist.init_process_group(backend="hccl")
else:
dist.init_process_group()
assert False, "Supports TP only on HPU"

torch._C._distributed_c10d._register_process_group("default", dist.group.WORLD)
logger.info("Creating Model")
config = AutoConfig.from_pretrained(args.model_name_or_path,torch_dtype=model_dtype, **model_kwargs)
model_kwargs={}
config = AutoConfig.from_pretrained(args.model_name_or_path, torch_dtype=model_dtype, **model_kwargs)
model_kwargs = {}
model_kwargs["distributed_strategy"] = TensorParallelStrategy()
model = AutoModelForCausalLM.from_config(config, torch_dtype=model_dtype, **model_kwargs)

initial_device = torch.device("cpu")
source="hf"
checkpoint_sharding=None
source = "hf"
checkpoint_sharding = None
lazy_sd: MutableMapping[str, Any] = {}
logger.info("Loading Checkpoints")
lazy_sd = serialization.load_state_dict(
Expand All @@ -309,7 +285,7 @@ def __setstate__(self, state):
rank=args.global_rank,
world_size=args.world_size,
)
architecture="llama"
architecture = "llama"
if len(lazy_sd):
serialization.load_state_dict_into_model(
model,
Expand All @@ -323,18 +299,12 @@ def __setstate__(self, state):
args.world_size,
)

if args.quant_config:
model = setup_quantization(model, args)

model = model.eval().to(args.device)

if args.use_hpu_graphs:
from habana_frameworks.torch.hpu import wrap_in_hpu_graph

if check_habana_frameworks_version("1.13.0") and model.config.model_type == "falcon":
model = wrap_in_hpu_graph(model, hash_with_views=False)
else:
model = wrap_in_hpu_graph(model)
model = wrap_in_hpu_graph(model)

if args.torch_compile and model.config.model_type == "llama":
model = get_torch_compiled_model(model)
Expand Down Expand Up @@ -604,7 +574,8 @@ def initialize_model(args, logger):
model, assistant_model = (
setup_model(args, model_dtype, model_kwargs, logger)
if not use_deepspeed
else setup_distributed_model(args, model_dtype, model_kwargs, logger) if not args.distributed_strategy == "tp"
else setup_distributed_model(args, model_dtype, model_kwargs, logger)
if not args.distributed_strategy == "tp"
else setup_distributed_model_tp(args, model_dtype, model_kwargs, logger)
)
tokenizer, model, assistant_model = setup_tokenizer(args, model, assistant_model)
Expand Down
7 changes: 5 additions & 2 deletions optimum/habana/distributed/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from .distributed_runner import DistributedRunner
from .fast_ddp import all_reduce_gradients
import os

import torch

from .distributed_runner import DistributedRunner
from .fast_ddp import all_reduce_gradients


def rank_and_world(group=None):
"""
Returns (rank, world_size) from the optionally-specified group, otherwise
Expand Down
78 changes: 32 additions & 46 deletions optimum/habana/distributed/serialization.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
# Copyright 2024 The Foundation Model Stack Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This file has been modified from its original version.
# The original version can be found at https://github.com/foundation-model-stack/foundation-model-stack

import collections
import os
from collections import ChainMap
Expand All @@ -7,7 +24,7 @@

import torch

from optimum.habana.distributed.tp import TPModule
from .tp import TPModule


__adapters: MutableMapping[str, MutableMapping[str, Callable[[Mapping], Mapping]]] = {}
Expand All @@ -34,9 +51,7 @@ def register_adapter(
sources = __adapters[architecture]

if source in sources:
raise KeyError(
f"Variant {source} already registered for architecture {architecture}"
)
raise KeyError(f"Variant {source} already registered for architecture {architecture}")

sources[source] = adapter
__adapters[architecture] = sources
Expand All @@ -55,14 +70,8 @@ def list_sources(architecture: str):
return list(__adapters[architecture].keys())


def _get_adapter(
architecture: str, source: Optional[str]
) -> Callable[[Mapping[str, Any]], Mapping[str, Any]]:
if (
source is None
or architecture not in __adapters
or source not in __adapters[architecture]
):
def _get_adapter(architecture: str, source: Optional[str]) -> Callable[[Mapping[str, Any]], Mapping[str, Any]]:
if source is None or architecture not in __adapters or source not in __adapters[architecture]:
# if no adapter is registered, assume the attributes are already in
# fms format.
# should we raise an error here instead?
Expand All @@ -71,9 +80,7 @@ def _get_adapter(
return __adapters[architecture][source]


def get_adapted(
architecture: str, source: Optional[str], state_dict: Mapping[str, Any]
) -> Mapping[str, Any]:
def get_adapted(architecture: str, source: Optional[str], state_dict: Mapping[str, Any]) -> Mapping[str, Any]:
"""
Convert a state dict to FMS format, using an adapter specified by name.
Expand All @@ -91,18 +98,11 @@ def get_adapted(
return adapted


# `models` imports each model class, causing models and adapters to be registered.
# down here to avoid circular dependencies.
# from fms import models


def _get_safetensors_item(key, file: Path, device: torch.device) -> torch.Tensor:
from safetensors import safe_open # type: ignore[import-untyped]

with torch.no_grad():
with safe_open(
file, framework="pt", device=str(device)
) as model_weights: # type: ignore[attr-defined]
with safe_open(file, framework="pt", device=str(device)) as model_weights: # type: ignore[attr-defined]
return model_weights.get_tensor(key)


Expand Down Expand Up @@ -153,7 +153,7 @@ def load_state_dict(
if model_path is None or initial_device.type == "meta":
return {}
if checkpoint_sharding == "fsdp" and distributed_strategy not in ["fsdp", "hsdp"]:
raise ValueError(f"FSDP checkpoints can only be loaded into an FSDP model")
raise ValueError("FSDP checkpoints can only be loaded into an FSDP model")
if checkpoint_sharding == "tp" and distributed_strategy != "tp":
raise ValueError("TP checkpoints can only be loaded into a TP model")

Expand Down Expand Up @@ -188,13 +188,11 @@ def load_state_dict(
checkpoints = [model_path]

# Check if we found some files
assert (
len(checkpoints) > 0
), f"Can't find the requested checkpoint data at {model_path}"
assert len(checkpoints) > 0, f"Can't find the requested checkpoint data at {model_path}"

if checkpoint_sharding is not None and checkpoint_sharding != "layer":
assert world_size == len(
checkpoints
assert (
world_size == len(checkpoints)
), f"Loading a {checkpoint_sharding}-sharded checkpoint with len={len(checkpoints)} but world size is {world_size}"

checkpoints = [checkpoints[rank]]
Expand Down Expand Up @@ -304,13 +302,9 @@ def load_state_dict_into_model(
used_keys.add(weight)
partial_sd[weight] = state_dict[weight]
if partial_sd[weight].device != initial_device:
partial_sd[weight] = partial_sd[weight].to(
device=initial_device
)
partial_sd[weight] = partial_sd[weight].to(device=initial_device)
fms_partial_sd = adapter(partial_sd)
_load_partial_state_dict(
model, fms_partial_sd, needs_tp_sharding, rank, world_size
)
_load_partial_state_dict(model, fms_partial_sd, needs_tp_sharding, rank, world_size)
for p_key in partial_sd.keys():
if isinstance(state_dict, ChainMap):
for child_sd in state_dict.maps:
Expand Down Expand Up @@ -341,17 +335,11 @@ def _copy_colwise(param: torch.nn.Parameter, tensor_value, is_bias, rank, world_
output_size_per_partition = param.shape[0]
if not is_bias:
tensor = tensor_value[
(rank * output_size_per_partition) : (
(rank + 1) * output_size_per_partition
),
(rank * output_size_per_partition) : ((rank + 1) * output_size_per_partition),
:,
]
else:
tensor = tensor_value[
(rank * output_size_per_partition) : (
(rank + 1) * output_size_per_partition
)
]
tensor = tensor_value[(rank * output_size_per_partition) : ((rank + 1) * output_size_per_partition)]
param.copy_(tensor, non_blocking=True)


Expand All @@ -376,9 +364,7 @@ def _copy_rowwise(param: torch.nn.Parameter, tensor_value, is_bias, rank, world_
output_size_per_partition = param.shape[1]
tensor = tensor_value[
:,
(rank * output_size_per_partition) : (
(rank + 1) * output_size_per_partition
),
(rank * output_size_per_partition) : ((rank + 1) * output_size_per_partition),
]
param.copy_(tensor, non_blocking=True)
else:
Expand Down
Loading

0 comments on commit 8ecceca

Please sign in to comment.