Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensor parallel distributed strategy without using deepspeed #1121

Merged
merged 10 commits into from
Jul 30, 2024
31 changes: 31 additions & 0 deletions examples/text-generation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,37 @@ set the following environment variables before running the command: `PT_ENABLE_I

You will also need to add `--torch_compile` in your command.

### Running with tensor-parallel strategy

> [!NOTE]
> This strategy includes code from the [foundation-model-stack](https://github.com/foundation-model-stack/foundation-model-stack) repository, which is licensed under the Apache License 2.0. See the `LICENSE` file for more details.

> [!WARNING]
> torch.compile with tensor parallel strategy is an experimental feature. It has not been validated for all models.

To enable torch.compile with tensor parallel strategy, please set the following environment variables before running the
command: `PT_ENABLE_INT64_SUPPORT=1` and `PT_HPU_LAZY_MODE=0`. This will enable tensor parallel strategy without deepspeed.

You will also need to add `--torch_compile` and `--parallel_strategy="tp"` in your command.

Here is an example:
```bash
PT_ENABLE_INT64_SUPPORT=1 PT_HPU_LAZY_MODE=0 python ../gaudi_spawn.py --world_size 8 run_generation.py \
--model_name_or_path meta-llama/Llama-2-70b-hf \
--trim_logits \
--use_kv_cache \
--attn_softmax_bf16 \
--bf16 \
--bucket_internal \
--bucket_size=128 \
--use_flash_attention \
--flash_attention_recompute \
--batch_size 246 \
--max_input_tokens 2048 \
--max_new_tokens 2048 \
--torch_compile \
--parallel_strategy="tp"
```

### Running with FP8

Expand Down
8 changes: 8 additions & 0 deletions examples/text-generation/run_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,14 @@ def setup_parser(parser):
action="store_true",
help="Whether to trust the execution of code from datasets/models defined on the Hub. This option should only be set to `True` for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine.",
)
parser.add_argument(
"--parallel_strategy",
type=str,
choices=["tp", "none"], # Add other strategies as needed
default="none",
help="Run multi card with the specified parallel strategy. Choices are 'tp' for Tensor Parallel Strategy or 'none'.",
)

args = parser.parse_args()

if args.torch_compile:
Expand Down
70 changes: 69 additions & 1 deletion examples/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,72 @@ def setup_model(args, model_dtype, model_kwargs, logger):
return model, assistant_model


def setup_distributed_model_tp(args, model_dtype, model_kwargs, logger, cache_dir):
from typing import Any, MutableMapping

from optimum.habana.distributed import serialization
from optimum.habana.distributed.strategy import TensorParallelStrategy

logger.info("Multi-device run.")

assert args.quant_config == "", "Fp8 is not enabled, unset QUANT_CONFIG"
assert args.assistant_model is None, "Assistant model must be None"

from torch import distributed as dist

if args.device == "hpu":
dist.init_process_group(backend="hccl")
else:
assert False, "Supports TP only on HPU"

torch._C._distributed_c10d._register_process_group("default", dist.group.WORLD)
logger.info("Creating Model")
config = AutoConfig.from_pretrained(args.model_name_or_path, torch_dtype=model_dtype, **model_kwargs)
model_kwargs = {}
model_kwargs["parallel_strategy"] = TensorParallelStrategy()
model = AutoModelForCausalLM.from_config(config, torch_dtype=model_dtype, **model_kwargs)

initial_device = torch.device("cpu")
source = "hf"
checkpoint_sharding = None
lazy_sd: MutableMapping[str, Any] = {}
logger.info("Loading Checkpoints")
lazy_sd = serialization.load_state_dict(
cache_dir,
source=source,
distributed_strategy=args.parallel_strategy,
checkpoint_sharding=None,
initial_device=initial_device,
rank=args.global_rank,
world_size=args.world_size,
)
architecture = "llama"
if len(lazy_sd):
serialization.load_state_dict_into_model(
model,
lazy_sd,
architecture,
source,
args.parallel_strategy,
checkpoint_sharding,
initial_device,
args.local_rank,
args.world_size,
)

model = model.eval().to(args.device)

if args.use_hpu_graphs:
from habana_frameworks.torch.hpu import wrap_in_hpu_graph

model = wrap_in_hpu_graph(model)

if args.torch_compile and model.config.model_type == "llama":
model = get_torch_compiled_model(model)

return model, args.assistant_model


def setup_distributed_model(args, model_dtype, model_kwargs, logger):
import deepspeed

Expand Down Expand Up @@ -500,7 +566,7 @@ def initialize_model(args, logger):
setup_env(args)
setup_device(args)
set_seed(args.seed)
get_repo_root(args.model_name_or_path, local_rank=args.local_rank, token=args.token)
cache_dir = get_repo_root(args.model_name_or_path, local_rank=args.local_rank, token=args.token)
if args.assistant_model is not None:
get_repo_root(args.assistant_model, local_rank=args.local_rank, token=args.token)
use_deepspeed = args.world_size > 0
Expand All @@ -522,6 +588,8 @@ def initialize_model(args, logger):
setup_model(args, model_dtype, model_kwargs, logger)
if not use_deepspeed
else setup_distributed_model(args, model_dtype, model_kwargs, logger)
if not args.parallel_strategy == "tp"
else setup_distributed_model_tp(args, model_dtype, model_kwargs, logger, cache_dir)
)
tokenizer, model, assistant_model = setup_tokenizer(args, model, assistant_model)
generation_config = setup_generation_config(args, model, assistant_model, tokenizer)
Expand Down
29 changes: 29 additions & 0 deletions optimum/habana/distributed/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,31 @@
import os

import torch

from .distributed_runner import DistributedRunner
from .fast_ddp import all_reduce_gradients


def rank_and_world(group=None):
"""
Returns (rank, world_size) from the optionally-specified group, otherwise
from the default group, or if non-distributed just returns (0, 1)
"""
if torch.distributed.is_initialized() and group is None:
group = torch.distributed.GroupMember.WORLD

if group is None:
world_size = 1
rank = 0
else:
world_size = group.size()
rank = group.rank()

return rank, world_size


_LOCAL_RANK = int(os.getenv("LOCAL_RANK", 0))


def local_rank():
return _LOCAL_RANK
Loading
Loading