Skip to content

Commit

Permalink
Merge branch 'fix-inference' into 'main'
Browse files Browse the repository at this point in the history
Fix inference after T5 pipeline merge

See merge request ADLR/megatron-lm!332
  • Loading branch information
jaredcasper committed Oct 7, 2021
2 parents cdc614c + f2c35bb commit b31e129
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 10 deletions.
6 changes: 6 additions & 0 deletions megatron/model/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,12 @@ def __init__(self,

def set_input_tensor(self, input_tensor):
""" See megatron.model.transformer.set_input_tensor()"""

# This is usually handled in schedules.py but some inference code still
# gives us non-lists or None
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]

if self.add_encoder and self.add_decoder:
assert len(input_tensor) == 1, \
'input_tensor should only be length 1 for stage with both encoder and decoder'
Expand Down
25 changes: 16 additions & 9 deletions megatron/p2p_communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
# if needed.
tensor_recv_prev = None
tensor_recv_next = None

# Some legacy inference code doesn't set the tensor shape, do so now
# for the normal values for gpt/bert. This could be removed if inference
# code is changed to provide tensor_shape.
if tensor_shape is None:
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)

override_scatter_gather_tensors_in_pipeline = False
if args.scatter_gather_tensors_in_pipeline:
tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1)
Expand Down Expand Up @@ -143,7 +150,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
return tensor_recv_prev, tensor_recv_next


def recv_forward(tensor_shape, dtype_=None, timers=None):
def recv_forward(tensor_shape=None, dtype_=None, timers=None):
"""Receive tensor from previous rank in pipeline (forward receive)."""

if mpu.is_pipeline_first_stage():
Expand All @@ -163,7 +170,7 @@ def recv_forward(tensor_shape, dtype_=None, timers=None):
return input_tensor


def recv_backward(tensor_shape, timers=None):
def recv_backward(tensor_shape=None, timers=None):
"""Receive tensor from next rank in pipeline (backward receive)."""
if mpu.is_pipeline_last_stage():
output_tensor_grad = None
Expand All @@ -181,7 +188,7 @@ def recv_backward(tensor_shape, timers=None):
return output_tensor_grad


def send_forward(output_tensor, tensor_shape, dtype_=None, timers=None):
def send_forward(output_tensor, tensor_shape=None, dtype_=None, timers=None):
"""Send tensor to next rank in pipeline (forward send)."""

if not mpu.is_pipeline_last_stage():
Expand All @@ -198,7 +205,7 @@ def send_forward(output_tensor, tensor_shape, dtype_=None, timers=None):
timers('forward-send').stop()


def send_backward(input_tensor_grad, tensor_shape, timers=None):
def send_backward(input_tensor_grad, tensor_shape=None, timers=None):
"""Send tensor to previous rank in pipeline (backward send)."""
if not mpu.is_pipeline_first_stage():
if timers is not None:
Expand All @@ -213,7 +220,7 @@ def send_backward(input_tensor_grad, tensor_shape, timers=None):
timers('backward-send').stop()


def send_forward_recv_backward(output_tensor, tensor_shape, timers=None):
def send_forward_recv_backward(output_tensor, tensor_shape=None, timers=None):
"""Batched send and recv with next rank in pipeline."""
if mpu.is_pipeline_last_stage():
output_tensor_grad = None
Expand All @@ -231,7 +238,7 @@ def send_forward_recv_backward(output_tensor, tensor_shape, timers=None):
return output_tensor_grad


def send_backward_recv_forward(input_tensor_grad, tensor_shape, timers=None):
def send_backward_recv_forward(input_tensor_grad, tensor_shape=None, timers=None):
"""Batched send and recv with previous rank in pipeline."""
if mpu.is_pipeline_first_stage():
input_tensor = None
Expand All @@ -249,7 +256,7 @@ def send_backward_recv_forward(input_tensor_grad, tensor_shape, timers=None):
return input_tensor


def send_forward_recv_forward(output_tensor, recv_prev, tensor_shape, timers=None):
def send_forward_recv_forward(output_tensor, recv_prev, tensor_shape=None, timers=None):
"""Batched recv from previous rank and send to next rank in pipeline."""
if timers is not None:
timers('forward-send-forward-recv').start()
Expand All @@ -264,7 +271,7 @@ def send_forward_recv_forward(output_tensor, recv_prev, tensor_shape, timers=Non
return input_tensor


def send_backward_recv_backward(input_tensor_grad, recv_next, tensor_shape, timers=None):
def send_backward_recv_backward(input_tensor_grad, recv_next, tensor_shape=None, timers=None):
"""Batched recv from next rank and send to previous rank in pipeline."""
if timers is not None:
timers('backward-send-backward-recv').start()
Expand All @@ -281,7 +288,7 @@ def send_backward_recv_backward(input_tensor_grad, recv_next, tensor_shape, time

def send_forward_backward_recv_forward_backward(
output_tensor, input_tensor_grad, recv_prev,
recv_next, tensor_shape, timers=None):
recv_next, tensor_shape=None, timers=None):
"""Batched send and recv with previous and next ranks in pipeline."""
if timers is not None:
timers('forward-backward-send-forward-backward-recv').start()
Expand Down
2 changes: 1 addition & 1 deletion megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def update_train_iters(args):
print_rank_0('setting training iterations to {}'.format(args.train_iters))


def get_model(model_provider_func, model_type, wrap_with_ddp=True):
def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True):
"""Build the model."""
args = get_args()
args.model_type = model_type
Expand Down

0 comments on commit b31e129

Please sign in to comment.