Skip to content

Commit

Permalink
- Add reinforce
Browse files Browse the repository at this point in the history
- Add reinforce leave one out
- Add model weight sharing via pointers
- Add online dataset
  • Loading branch information
dmahan93 committed Sep 19, 2024
1 parent 54af72a commit 053f67e
Show file tree
Hide file tree
Showing 5 changed files with 389 additions and 10 deletions.
52 changes: 51 additions & 1 deletion megatron/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from megatron.data.blendable_dataset import BlendableDataset
from megatron.data.gpt2_dataset import GPT2Dataset
from megatron.data.pairwise_dataset import PairwiseDataset
from megatron.data.online_dataset import OnlineDataset
from megatron.data.samplers import DistributedBatchSampler


Expand Down Expand Up @@ -487,7 +488,56 @@ def build_train_valid_test_data_iterators(neox_args):
pipe_load = True

# Data loader only on rank 0 of each model parallel group.
if mpu.get_model_parallel_rank() == 0 and pipe_load:
if (
pipe_load
and (neox_args.dataset_impl == "online")
and (mpu.get_model_parallel_rank() == 0)
):
# Can skip most of the work...
train_iters = neox_args.train_iters
eval_iters = (train_iters // neox_args.eval_interval + 1) * neox_args.eval_iters
test_iters = neox_args.eval_iters
# Build datasets...
print(
f"train_iters: {train_iters}, eval_iters: {eval_iters}, test_iters: {test_iters}"
)
train_datasets = OnlineDataset(
leave_one_out=neox_args.reinforce_leave_one_out,
data_split="train",
num_samples=train_iters * neox_args.train_batch_size,
seq_length=neox_args.seq_length,
dataserver_ips=neox_args.online_dataserver_ips,
dataserver_ports=neox_args.online_dataserver_ports,
)
valid_datasets = OnlineDataset(
leave_one_out=neox_args.reinforce_leave_one_out,
data_split="valid",
num_samples=eval_iters * neox_args.train_batch_size,
seq_length=neox_args.seq_length,
dataserver_ips=neox_args.online_dataserver_ips,
dataserver_ports=neox_args.online_dataserver_ports,
)
test_datasets = OnlineDataset(
leave_one_out=neox_args.reinforce_leave_one_out,
data_split="test",
num_samples=test_iters * neox_args.train_batch_size,
seq_length=neox_args.seq_length,
dataserver_ips=neox_args.online_dataserver_ips,
dataserver_ports=neox_args.online_dataserver_ports,
)
# print length of datasets
# Build dataloders.
train_dataloader = make_data_loader(train_datasets, neox_args=neox_args)
valid_dataloader = make_data_loader(valid_datasets, neox_args=neox_args)
test_dataloader = make_data_loader(test_datasets, neox_args=neox_args)

# Flags to know if we need to do training/validation/testing.
do_train = train_dataloader is not None and neox_args.train_iters > 0
do_valid = valid_dataloader is not None and neox_args.eval_iters > 0
do_test = test_dataloader is not None and neox_args.eval_iters > 0
# Need to broadcast num_tokens and num_type_tokens.
flags = torch.cuda.LongTensor([int(do_train), int(do_valid), int(do_test)])
elif mpu.get_model_parallel_rank() == 0 and pipe_load:
# Number of train/valid/test samples.
train_iters = neox_args.train_iters
eval_iters = (train_iters // neox_args.eval_interval + 1) * neox_args.eval_iters
Expand Down
127 changes: 127 additions & 0 deletions megatron/data/online_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright (c) 2024, EleutherAI
# This file is based on code by the authors denoted below and has been modified from its original version.
#
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Online dataset."""
from typing import Union, List

import numpy as np
import torch
import torch.utils.data
import socket
import pickle
from megatron.mpu.initialize import get_data_parallel_src_rank


class OnlineDataset(torch.utils.data.Dataset):
def __init__(
self,
num_samples,
seq_length,
leave_one_out=False,
data_split="train",
dataserver_ips: Union[str, List[str]] = "localhost",
dataserver_ports: Union[int, List[int]] = 10000,
):
self.num_samples = num_samples
self.global_rank = get_data_parallel_src_rank()
self.leave_one_out = leave_one_out
self.reward_buffer = []
self.online_batching_data = []
self.data_split = data_split
self.seq_length = seq_length
self.dataserver_ips = dataserver_ips
self.dataserver_ports = dataserver_ports

def __len__(self):
# dummy value since it's decided by the Online Trainer
return self.num_samples

def update_online_batches(self):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
if isinstance(self.dataserver_ips, str):
ipaddr = self.dataserver_ips
else:
ipaddr = self.dataserver_ips[self.global_rank]
if isinstance(self.dataserver_ports, int):
# simply add over the global rank
port = self.dataserver_ports
else:
# in case we want to use different ports for different ranks, e.g. per machine sampling
port = self.dataserver_ports[self.global_rank]
s.connect((ipaddr, port))
s.send(self.data_split.encode())
data = b""
while True:
chunk = s.recv(4096)
if not chunk:
break
data += chunk
batch_data = pickle.loads(data)
s.close()
print(f"Received {len(batch_data)} samples from the server.")
for data in batch_data:
if self.leave_one_out:
rewards = list()
for i in range(len(data["rewards"])):
rewards.append(
data["rewards"][i]
- np.mean(
[
data["rewards"][j]
for j in range(len(data["rewards"]))
if j != i
]
)
)
data["raw_rewards"] = data["rewards"]
data["rewards"] = rewards
else:
moving_average = 0
if len(self.reward_buffer) > 0:
moving_average = np.mean(self.reward_buffer)
self.reward_buffer.append(np.mean(data["rewards"]))
if len(self.reward_buffer) > 100:
self.reward_buffer.pop(0)
# For metrics...
data["raw_rewards"] = data["rewards"]
data["rewards"] = [r - moving_average for r in data["rewards"]]
for i in range(len(data["completions"])):
self.online_batching_data.append(
[
data["prefix"],
data["completions"][i],
data["rewards"][i],
data["raw_rewards"][i],
]
)

def __getitem__(self, idx):
if len(self.online_batching_data) == 0:
self.update_online_batches()
batch = self.online_batching_data.pop(0)
text = batch[0] + batch[1]
label = [-100 for _ in batch[0]] + batch[1]
# +1 because of causal masking
if len(text) <= self.seq_length:
text = text + [0] * ((self.seq_length + 1) - len(text))
label = label + [-100] * ((self.seq_length + 1) - len(label))
return {
"text": np.array(text, dtype=np.int64),
"label": np.array(label, dtype=np.int64),
"reward": np.array([batch[2]], dtype=np.float32),
"raw_reward": np.array([batch[3]], dtype=np.float32),
}
64 changes: 64 additions & 0 deletions megatron/model/weight_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from typing import Union, List

import torch
import socket
import pickle


def send_tensor(state_dict_key, data, sock, end: bool):
storage = data.storage()
(
storage_device,
storage_handle,
storage_size_bytes,
storage_offset_bytes,
ref_counter_handle,
ref_counter_offset,
event_handle,
event_sync_required,
) = storage._share_cuda_()
sock.send(
pickle.dumps(
{
"state_dict_key": state_dict_key,
"dtype": data.dtype,
"tensor_size": data.shape,
"tensor_stride": data.stride(),
"tensor_offset": data.storage_offset(), # !Not sure about this one.
"storage_cls": type(storage),
"storage_device": storage_device,
"storage_handle": storage_handle,
"storage_size_bytes": storage_size_bytes,
"storage_offset_bytes": storage_offset_bytes,
"requires_grad": False,
"ref_counter_handle": ref_counter_handle,
"ref_counter_offset": ref_counter_offset,
"event_handle": event_handle,
"event_sync_required": event_sync_required,
"end": end,
}
)
)


def send_state_dict(state_dict, sock):
for i, key in enumerate(state_dict.keys()):
print(key)
end = i == len(state_dict.keys()) - 1
send_tensor(key, state_dict[key], sock, end)
sock.recv(4096)


def start_server(model, ports: Union[int, List[int]] = 6000):
global_rank = torch.distributed.get_rank()
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
if type(ports) == int:
port = ports + global_rank
else:
port = ports[global_rank]
s.bind(("localhost", port))
s.listen(1)
conn, addr = s.accept()
state_dict = model.state_dict()
send_state_dict(state_dict, conn)
conn.close()
51 changes: 47 additions & 4 deletions megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,28 @@ class NeoXArgsModel(NeoXArgsTemplate):
Parameter controlling whether the output layer is parallelized over the hidden dim (row) or the vocab dim (column)
"""

serve_model_weights: bool = False
"""
If true, serve model weight pointers over a socket connection
"""

weight_server_port: Union[int, List[int]] = 6000
"""
Port(s) to serve model weights over
If an integer is provided, the port for each GPU will be 6000 + global rank
If a list is provided, the ports will be used in order, e.g. rank0 will be weight_server_port[0]
"""

online_dataserver_ips: Union[str, List[str]] = "localhost"
"""
ip addresses to connect to for online data serving, defaults to localhost
"""

online_dataserver_ports: Union[int, List[int]] = 10000
"""
Port(s) to connect to for online data serving, defaults to 10000
"""


@dataclass
class NeoXArgsOptimizer(NeoXArgsTemplate):
Expand Down Expand Up @@ -1047,14 +1069,14 @@ class NeoXArgsTraining(NeoXArgsTemplate):
warning: pack_until_overflow is very naive and will likely have issues with pretraining scale datasets
"""

dataset_impl: Literal["gpt2", "pairwise"] = "gpt2"
dataset_impl: Literal["gpt2", "pairwise", "online"] = "gpt2"
"""
Dataset implementation, can be one of "gpt2" or "pairwise"
Dataset implementation, can be one of "gpt2", "pairwise", or "online"
"""

train_impl: Literal["normal", "dpo", "rm", "kto"] = "normal"
train_impl: Literal["normal", "dpo", "rm", "kto", "reinforce"] = "normal"
"""
Training implementation, can be one of "normal", "dpo", "kto", or "rm"
Training implementation, can be one of "normal", "dpo", "kto", "reinforce", or "rm"
"""

dpo_fp32: bool = True
Expand Down Expand Up @@ -1092,6 +1114,27 @@ class NeoXArgsTraining(NeoXArgsTemplate):
Beta value for KTO
"""

fp32_reinforce: bool = True
"""
Whether to cast logits to fp32 for Reinforce loss calculation.
"""

use_full_kl: bool = True
"""
Use full KL divergence in Reinforce loss calculation.
"""

kl_div_beta: float = 0.1
"""
Beta value for KL divergence in Reinforce loss calculation.
"""

reinforce_leave_one_out: bool = False
"""
Whether to use reinforce leave one out for training
(from https://arxiv.org/abs/2402.14740 and https://api.semanticscholar.org/CorpusID:198489118)
"""

allow_chopped: bool = True
"""
WARNING: if your packing impl is packed, this is ignored.
Expand Down
Loading

0 comments on commit 053f67e

Please sign in to comment.