Skip to content

Commit

Permalink
Fix Megatron Text Generation Issues (microsoft#188)
Browse files Browse the repository at this point in the history
* MPU import, local_rank, and config fixes

* Change _transpose_first_dim function to source attn module using self_attention

* Add p2p communication file

* Update generate_text.sh to export CUDA_DEVICE_MAX_CONNECTIONS, disable DS inference for now

* Update logits indexing in generation code to account for MoE

* Properly provide config arg to GPTModel in text gen server code

* Update all GPTModel initialization to use config

* Explicit config argument assignment

* Import mpu from megatron.core instead of megatron
  • Loading branch information
lekurile authored Jul 28, 2023
1 parent 98bcc50 commit 9b42cdb
Show file tree
Hide file tree
Showing 14 changed files with 551 additions and 254 deletions.
4 changes: 4 additions & 0 deletions examples/detxoify_lm/finetune_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from megatron.data.blendable_dataset import BlendableDataset
from megatron.data.gpt_dataset import build_train_valid_test_datasets
from megatron.model import GPTModel
from megatron.arguments import core_transformer_config_from_args
from megatron.core.enums import ModelType
from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids
Expand All @@ -26,8 +27,11 @@
def model_provider(pre_process=True, post_process=True):
"""Build the model."""

config = core_transformer_config_from_args(args)

print_rank_0('building GPT model ...')
model = GPTModel(
config=config,
num_tokentypes=0,
parallel_output=True,
pre_process=pre_process,
Expand Down
5 changes: 4 additions & 1 deletion examples/detxoify_lm/generate_samples_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,17 @@
from megatron.initialize import initialize_megatron
from megatron.model import GPTModel
from megatron.training import get_model
from megatron.arguments import core_transformer_config_from_args
from megatron.text_generation import generate_and_post_process


def model_provider(pre_process=True, post_process=True):
"""Build the model."""

config = core_transformer_config_from_args(args)

print_rank_0('building GPT model ...')
model = GPTModel(num_tokentypes=0, parallel_output=False,
model = GPTModel(config=config, num_tokentypes=0, parallel_output=False,
pre_process=pre_process, post_process=post_process)

return model
Expand Down
12 changes: 7 additions & 5 deletions examples_deepspeed/generate_text.sh
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#!/bin/bash
export TORCH_CUDA_ARCH_LIST=8.6+PTX
CHECKPOINT_PATH=checkpoints/gpt2_345m
VOCAB_FILE=gpt2-vocab.json
MERGE_FILE=gpt2-merges.txt
CHECKPOINT_PATH=dataset/checkpoints/gpt2_345m
VOCAB_FILE=dataset/gpt2-vocab.json
MERGE_FILE=dataset/gpt2-merges.txt
b=8
mp=1
experts=1
Expand All @@ -14,8 +14,10 @@ use_tutel=""
#use_tutel="--use-tutel"


#ds_inference=""
ds_inference="--ds-inference"
ds_inference=""
#ds_inference="--ds-inference"

export CUDA_DEVICE_MAX_CONNECTIONS=1

launch_cmd="deepspeed --num_nodes $nodes --num_gpus $gpus"
L=24
Expand Down
4 changes: 2 additions & 2 deletions megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,8 @@ def _transpose_first_dim(t, num_splits, num_splits_first, model):
# specific to self attention so should work for cross attention as well
while hasattr(model, 'module'):
model = model.module
#attention_module = model.language_model.encoder.layers[0].self_attention
attention_module = model.language_model.encoder.layers[0].attention
attention_module = model.language_model.encoder.layers[0].self_attention
#attention_module = model.language_model.encoder.layers[0].attention
hidden_size_per_attention_head = attention_module.hidden_size_per_attention_head
num_attention_heads_per_partition = attention_module.num_attention_heads_per_partition
if num_splits_first:
Expand Down
264 changes: 264 additions & 0 deletions megatron/p2p_communication.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import reduce
import operator
import torch
from deepspeed.accelerator import get_accelerator
from megatron import get_args
from megatron.core import mpu


def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
use_ring_exchange=False):
"""Communicate tensors between stages. Used as helper method in other
communication methods that are used in megatron/schedules.py.
Takes the following arguments:
tensor_send_next: tensor to send to next rank (no tensor sent if
set to None).
tensor_send_prev: tensor to send to prev rank (no tensor sent if
set to None).
recv_prev: boolean for whether tensor should be received from
previous rank.
recv_next: boolean for whether tensor should be received from
next rank.
use_ring_exchange: boolean for whether torch.distributed.ring_exchange()
API should be used.
Returns:
(tensor_recv_prev, tensor_recv_next)
"""
args = get_args()

# Create placeholder tensors for receive in forward and backward directions
# if needed.
tensor_recv_prev = None
tensor_recv_next = None
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
if args.scatter_gather_tensors_in_pipeline:
tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1) // \
mpu.get_tensor_model_parallel_world_size()
else:
tensor_chunk_shape = tensor_shape
dtype = args.params_dtype
if args.fp32_residual_connection:
dtype = torch.float
if recv_prev:
tensor_recv_prev = torch.empty(tensor_chunk_shape,
requires_grad=True,
device=get_accelerator().current_device_name(),
dtype=dtype)
if recv_next:
tensor_recv_next = torch.empty(tensor_chunk_shape,
requires_grad=True,
device=get_accelerator().current_device_name(),
dtype=dtype)

# Split tensor into smaller chunks if using scatter-gather optimization.
if args.scatter_gather_tensors_in_pipeline:
if tensor_send_next is not None:
tensor_send_next = mpu.split_tensor_into_1d_equal_chunks(tensor_send_next)

if tensor_send_prev is not None:
tensor_send_prev = mpu.split_tensor_into_1d_equal_chunks(tensor_send_prev)

# Send tensors in both the forward and backward directions as appropriate.
if use_ring_exchange:
torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev,
tensor_recv_prev=tensor_recv_prev,
tensor_send_next=tensor_send_next,
tensor_recv_next=tensor_recv_next,
group=mpu.get_pipeline_model_parallel_group())
else:
ops = []
if tensor_send_prev is not None:
send_prev_op = torch.distributed.P2POp(
torch.distributed.isend, tensor_send_prev,
mpu.get_pipeline_model_parallel_prev_rank())
ops.append(send_prev_op)
if tensor_recv_prev is not None:
recv_prev_op = torch.distributed.P2POp(
torch.distributed.irecv, tensor_recv_prev,
mpu.get_pipeline_model_parallel_prev_rank())
ops.append(recv_prev_op)
if tensor_send_next is not None:
send_next_op = torch.distributed.P2POp(
torch.distributed.isend, tensor_send_next,
mpu.get_pipeline_model_parallel_next_rank())
ops.append(send_next_op)
if tensor_recv_next is not None:
recv_next_op = torch.distributed.P2POp(
torch.distributed.irecv, tensor_recv_next,
mpu.get_pipeline_model_parallel_next_rank())
ops.append(recv_next_op)
if len(ops) > 0:
reqs = torch.distributed.batch_isend_irecv(ops)
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv().
get_accelerator().synchronize()

# If using scatter-gather optimization, gather smaller chunks.
if args.scatter_gather_tensors_in_pipeline:
if recv_prev:
tensor_recv_prev = mpu.gather_split_1d_tensor(
tensor_recv_prev).view(tensor_shape).requires_grad_()

if recv_next:
tensor_recv_next = mpu.gather_split_1d_tensor(
tensor_recv_next).view(tensor_shape).requires_grad_()

return tensor_recv_prev, tensor_recv_next


def recv_forward(timers=None):
"""Receive tensor from previous rank in pipeline (forward receive)."""
if mpu.is_pipeline_first_stage():
input_tensor = None
else:
if timers is not None:
timers('forward-recv').start()
input_tensor, _ = _communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_prev=True,
recv_next=False)
if timers is not None:
timers('forward-recv').stop()
return input_tensor


def recv_backward(timers=None):
"""Receive tensor from next rank in pipeline (backward receive)."""
if mpu.is_pipeline_last_stage():
output_tensor_grad = None
else:
if timers is not None:
timers('backward-recv').start()
_, output_tensor_grad = _communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_prev=False,
recv_next=True)
if timers is not None:
timers('backward-recv').stop()
return output_tensor_grad


def send_forward(output_tensor, timers=None):
"""Send tensor to next rank in pipeline (forward send)."""
if not mpu.is_pipeline_last_stage():
if timers is not None:
timers('forward-send').start()
_communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=False,
recv_next=False)
if timers is not None:
timers('forward-send').stop()


def send_backward(input_tensor_grad, timers=None):
"""Send tensor to previous rank in pipeline (backward send)."""
if not mpu.is_pipeline_first_stage():
if timers is not None:
timers('backward-send').start()
_communicate(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=False,
recv_next=False)
if timers is not None:
timers('backward-send').stop()


def send_forward_recv_backward(output_tensor, timers=None):
"""Batched send and recv with next rank in pipeline."""
if mpu.is_pipeline_last_stage():
output_tensor_grad = None
else:
if timers is not None:
timers('forward-send-backward-recv').start()
_, output_tensor_grad = _communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=False,
recv_next=True)
if timers is not None:
timers('forward-send-backward-recv').stop()
return output_tensor_grad


def send_backward_recv_forward(input_tensor_grad, timers=None):
"""Batched send and recv with previous rank in pipeline."""
if mpu.is_pipeline_first_stage():
input_tensor = None
else:
if timers is not None:
timers('backward-send-forward-recv').start()
input_tensor, _ = _communicate(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=True,
recv_next=False)
if timers is not None:
timers('backward-send-forward-recv').stop()
return input_tensor


def send_forward_recv_forward(output_tensor, recv_prev, timers=None):
"""Batched recv from previous rank and send to next rank in pipeline."""
if timers is not None:
timers('forward-send-forward-recv').start()
input_tensor, _ = _communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=recv_prev,
recv_next=False)
if timers is not None:
timers('forward-send-forward-recv').stop()
return input_tensor


def send_backward_recv_backward(input_tensor_grad, recv_next, timers=None):
"""Batched recv from next rank and send to previous rank in pipeline."""
if timers is not None:
timers('backward-send-backward-recv').start()
_, output_tensor_grad = _communicate(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=False,
recv_next=recv_next)
if timers is not None:
timers('backward-send-backward-recv').stop()
return output_tensor_grad


def send_forward_backward_recv_forward_backward(
output_tensor, input_tensor_grad, recv_prev,
recv_next, timers=None):
"""Batched send and recv with previous and next ranks in pipeline."""
if timers is not None:
timers('forward-backward-send-forward-backward-recv').start()
input_tensor, output_tensor_grad = _communicate(
tensor_send_next=output_tensor,
tensor_send_prev=input_tensor_grad,
recv_prev=recv_prev,
recv_next=recv_next)
if timers is not None:
timers('forward-backward-send-forward-backward-recv').stop()
return input_tensor, output_tensor_grad
Loading

0 comments on commit 9b42cdb

Please sign in to comment.