Skip to content

Commit

Permalink
fixes 2
Browse files Browse the repository at this point in the history
  • Loading branch information
Almaz Dautov committed Nov 6, 2024
1 parent 838b71f commit 88830d6
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 38 deletions.
4 changes: 2 additions & 2 deletions turbo_alignment/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ def reinforce_training(
'''

vllm_engines = create_vllm_engines(
num_engines=experiment_settings.trainer_settings.actor_type.vllm_num_engines,
tensor_parallel_size=experiment_settings.trainer_settings.actor_type.vllm_tensor_parallel_size,
num_engines=experiment_settings.trainer_settings.actor_settings.vllm_num_engines,
tensor_parallel_size=experiment_settings.trainer_settings.actor_settings.vllm_tensor_parallel_size,
pretrain=experiment_settings.model_settings.model_path,
seed=0,
enable_prefix_caching=False,
Expand Down
3 changes: 3 additions & 0 deletions turbo_alignment/settings/pipelines/train/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from turbo_alignment.settings.tf.trainer import TrainerSettings
from turbo_alignment.settings.online import (
CriticType,
vLLMActorSettings,
HFActorSettings,
ActorType,
RewardProcessorType,
)
Expand All @@ -32,6 +34,7 @@ class REINFORCETrainerSettings(TrainerSettings):

actor_type: ActorType = ActorType.DISTRIBUTED_VLLM
critic_type: CriticType = CriticType.RAY_TRANSFORMERS
actor_settings: vLLMActorSettings | HFActorSettings = vLLMActorSettings

reward_processor_type: RewardProcessorType = RewardProcessorType.RLOO

Expand Down
43 changes: 21 additions & 22 deletions turbo_alignment/trainers/online/ray/vllm_engine.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,39 @@
import os
from typing import Dict, List

import ray
from ray.util.placement_group import placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy

from turbo_alignment.common.logging import get_project_logger

logger = get_project_logger()


@ray.remote
class LLMRayActor:
def __init__(self, *args, **kwargs):
import vllm

self.__version__ = vllm.__version__
assert self.__version__ >= '0.4.1', 'OpenRLHF only supports vLLM >= 0.4.1'
assert self.__version__ >= "0.4.1", "OpenRLHF only supports vLLM >= 0.4.1"

self.use_gpu_executor = kwargs['tensor_parallel_size'] == 1
self.tensor_parallel_size = kwargs['tensor_parallel_size']
self.use_gpu_executor = kwargs["tensor_parallel_size"] == 1

# See https://github.com/vllm-project/vllm/blob/main/vllm/executor/gpu_executor.py
if self.use_gpu_executor:
from turbo_alignment.trainers.online.ray.vllm_worker_wrap import WorkerWrap
from vllm_worker_wrap import WorkerWrap

vllm.worker.worker.Worker = WorkerWrap
else:
# RayGPUExecutor
# See the patch https://github.com/vllm-project/vllm/commit/479d69fad0538f04cb22bf13e76ff91cfeb8a4e5
kwargs['worker_use_ray'] = True
kwargs["worker_use_ray"] = True

if vllm.__version__ > '0.4.1':
if vllm.__version__ > "0.4.1":
RayWorkerWrapperPath = vllm.executor.ray_utils
else:
RayWorkerWrapperPath = vllm.engine.ray_utils

class RayWorkerWrapper(RayWorkerWrapperPath.RayWorkerWrapper):
def __init__(self, *args, **kwargs) -> None:
kwargs['worker_module_name'] = 'openrlhf.trainer.ray.vllm_worker_wrap'
kwargs['worker_class_name'] = 'WorkerWrap'
kwargs["worker_module_name"] = "vllm_worker_wrap"
kwargs["worker_class_name"] = "WorkerWrap"
super().__init__(*args, **kwargs)

RayWorkerWrapperPath.RayWorkerWrapper = RayWorkerWrapper
Expand All @@ -53,7 +50,7 @@ def init_process_group(self, master_address, master_port, rank_offset, world_siz
)
else:
return self.llm.llm_engine.model_executor._run_workers(
'init_process_group', master_address, master_port, rank_offset, world_size, group_name, backend
"init_process_group", master_address, master_port, rank_offset, world_size, group_name, backend
)

def update_weight(self, name, dtype, shape, empty_cache=False):
Expand All @@ -62,12 +59,12 @@ def update_weight(self, name, dtype, shape, empty_cache=False):
if self.use_gpu_executor:
return self.llm.llm_engine.model_executor.driver_worker.update_weight(name, dtype, shape, empty_cache)
else:
return self.llm.llm_engine.model_executor._run_workers('update_weight', name, dtype, shape, empty_cache)
return self.llm.llm_engine.model_executor._run_workers("update_weight", name, dtype, shape, empty_cache)

def stop_remote_worker_execution_loop(self):
# Fix error for using 2 communication group
# https://github.com/vllm-project/vllm/commit/eb6d3c264d0cd8e44dec16bca7947fbe96415ce9#diff-e1ad69e38e033accddfa5480ec808c4740eb39244d1ef51cc3407e20dde8cfd4
if self.__version__ > '0.4.2':
if self.__version__ > "0.4.2":
self.llm.llm_engine.model_executor.stop_remote_worker_execution_loop()


Expand All @@ -77,6 +74,7 @@ def create_vllm_engines(
pretrain: str,
seed: int,
enable_prefix_caching: bool,
enforce_eager: bool,
max_model_len: int,
):
vllm_engines = []
Expand All @@ -86,7 +84,7 @@ def create_vllm_engines(
scheduling_strategy = None

if tensor_parallel_size > 1:
bundles = [{'GPU': 1, 'CPU': 1}] * tensor_parallel_size
bundles = [{"GPU": 1, "CPU": 1}] * tensor_parallel_size
pg = placement_group(bundles)
ray.get(pg.ready())

Expand All @@ -103,17 +101,18 @@ def create_vllm_engines(
pretrain,
trust_remote_code=True,
tensor_parallel_size=tensor_parallel_size,
dtype='bfloat16',
dtype="bfloat16",
seed=seed + i,
enable_prefix_caching=enable_prefix_caching,
enforce_eager=enforce_eager,
max_model_len=max_model_len,
)
)

return vllm_engines


if __name__ == '__main__':
llm = LLMRayActor.remote('meta-llama/Llama-2-7b-chat-hf', tensor_parallel_size=4)
output = ray.get(llm.generate.remote('San Franciso is a'))
print(f'output: {output}')
if __name__ == "__main__":
llm = LLMRayActor.remote("meta-llama/Llama-2-7b-chat-hf", tensor_parallel_size=4)
output = ray.get(llm.generate.remote("San Franciso is a"))
print(f"output: {output}")
97 changes: 85 additions & 12 deletions turbo_alignment/trainers/online/ray/vllm_worker_wrap.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,111 @@
import importlib
import inspect

import torch
from vllm.worker.worker import Worker

from turbo_alignment.common.distributed import init_process_group
from turbo_alignment.common.logging import get_project_logger
from datetime import timedelta
from typing import Any, Optional, Union

import torch
import torch.distributed
from torch.distributed.distributed_c10d import (
Backend,
PrefixStore,
Store,
_new_process_group_helper,
_store_based_barrier,
_world,
default_pg_timeout,
rendezvous,
)


# Copy from pytorch to allow creating multiple main groups.
# https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py
def init_process_group(
backend: Union[str, Backend] = None,
init_method: Optional[str] = None,
timeout: Optional[timedelta] = None,
world_size: int = -1,
rank: int = -1,
store: Optional[Store] = None,
group_name: str = None,
pg_options: Optional[Any] = None,
):
assert (store is None) or (init_method is None), "Cannot specify both init_method and store."

if store is not None:
assert world_size > 0, "world_size must be positive if using store"
assert rank >= 0, "rank must be non-negative if using store"
elif init_method is None:
init_method = "env://"

if backend:
backend = Backend(backend)
else:
backend = Backend("undefined")

if timeout is None:
timeout = default_pg_timeout

# backward compatible API
if store is None:
rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout)
store, rank, world_size = next(rendezvous_iterator)
store.set_timeout(timeout)

# Use a PrefixStore to avoid accidental overrides of keys used by
# different systems (e.g. RPC) in case the store is multi-tenant.
store = PrefixStore(group_name, store)

# NOTE: The pg_options parameter was renamed into backend_options in PyTorch 2.6.0
# https://github.com/pytorch/pytorch/commit/a0c7029a75628cd5fa8df83c0de0ea98ee7fd844
# We need to determine the appropriate parameter name based on PyTorch version
pg_options_param_name = "backend_options" if str(torch.__version__) >= "2.6" else "pg_options"
pg, _ = _new_process_group_helper(
world_size,
rank,
[],
backend,
store,
group_name=group_name,
**{pg_options_param_name: pg_options},
timeout=timeout,
)

_world.pg_group_ranks[pg] = {i: i for i in range(world_size)}

return pg

logger = get_project_logger()


class WorkerWrap(Worker):
def init_process_group(self, master_address, master_port, rank_offset, world_size, group_name, backend='nccl'):
def init_process_group(self, master_address, master_port, rank_offset, world_size, group_name, backend="nccl"):
"""Init torch process group for model weights update"""
assert torch.distributed.is_initialized(), 'default torch process group must be initialized'
assert group_name != '', 'group name must not be empty'
assert torch.distributed.is_initialized(), f"default torch process group must be initialized"
assert group_name != "", f"group name must not be empty"

rank = torch.distributed.get_rank() + rank_offset
self._model_update_group = init_process_group(
backend=backend,
init_method=f'tcp://{master_address}:{master_port}',
init_method=f"tcp://{master_address}:{master_port}",
world_size=world_size,
rank=rank,
group_name=group_name,
)
print(
f'init_process_group: master_address={master_address}, master_port={master_port}, ',
f'rank={rank}, world_size={world_size}, group_name={group_name}',
f"init_process_group: master_address={master_address}, master_port={master_port}, ",
f"rank={rank}, world_size={world_size}, group_name={group_name}",
)

def update_weight(self, name, dtype, shape, empty_cache=False):
"""Broadcast weight to all vllm workers from source rank 0 (actor model)"""
if torch.distributed.get_rank() == 0:
print(f'update weight: {name}, dtype: {dtype}, shape: {shape}')
print(f"update weight: {name}, dtype: {dtype}, shape: {shape}")

assert dtype == self.model_config.dtype, f'mismatch dtype: src {dtype}, dst {self.model_config.dtype}'
weight = torch.empty(shape, dtype=dtype, device='cuda')
assert dtype == self.model_config.dtype, f"mismatch dtype: src {dtype}, dst {self.model_config.dtype}"
weight = torch.empty(shape, dtype=dtype, device="cuda")
torch.distributed.broadcast(weight, 0, group=self._model_update_group)

self.model_runner.model.load_weights(weights=[(name, weight)])
Expand Down
4 changes: 2 additions & 2 deletions turbo_alignment/trainers/online/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,10 @@ class REINFORCETrainingArguments(TrainingArguments):
temperature: float | None = None
whiten_rewards: bool = False

actor_type: ActorType = ActorType.DISTRIBUTED_VLLM
critic_type: CriticType = CriticType.RAY_TRANSFORMERS
actor_settings: vLLMActorSettings | HFActorSettings = vLLMActorSettings

critic_type: CriticType = CriticType.LOCAL_TRANSFORMERS

reward_processor_type: RewardProcessorType = RewardProcessorType.RLOO

class REINFORCETrainer(MultiGPUCherryPicksTrainer):
Expand Down

0 comments on commit 88830d6

Please sign in to comment.