From 5673a2f13ae78c7a7dc118cdba9731baceda6d14 Mon Sep 17 00:00:00 2001 From: Xu Song Date: Fri, 15 Sep 2023 22:07:56 +0800 Subject: [PATCH 1/2] Fix SequentialGeneration --- megatron/model/utils.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/megatron/model/utils.py b/megatron/model/utils.py index 34e9c20a9..6beac5ca2 100644 --- a/megatron/model/utils.py +++ b/megatron/model/utils.py @@ -97,6 +97,7 @@ def __init__( self.activation_checkpoint_interval = activation_checkpoint_interval self.parent_class_name = parent_class_name self.activation_checkpoint_func = activation_checkpoint_func + self.batch_fn = None def _is_checkpointable(self, funcs): if self.parent_class_name == "GPT2ModelPipe": @@ -106,6 +107,14 @@ def _is_checkpointable(self, funcs): params = [f.parameters() for f in funcs if isinstance(f, torch.nn.Module)] return any(len(list(p)) > 0 for p in params) + def set_batch_fn(self, fn): + """Execute a post-processing function on input data. + + Args: + fn (function): The function to run. + """ + self.batch_fn = fn + def inference_mode(self, use_cache=True): """ Sets up the model for inference by turning on k/v caching (if specified) and setting `parallel output` of the final layer to false, @@ -127,6 +136,9 @@ def forward( self, forward_input, curriculum_seqlen=None, labels=None, neox_args=None ): + if self.batch_fn: + forward_input = self.batch_fn(forward_input) + if ( curriculum_seqlen is not None and isinstance(forward_input, tuple) From 5098970b6ee6f39602590b159aea16e4f54c98d5 Mon Sep 17 00:00:00 2001 From: Xu Song Date: Fri, 15 Sep 2023 22:10:14 +0800 Subject: [PATCH 2/2] Fix SequentialGeneration --- megatron/training.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/megatron/training.py b/megatron/training.py index 96a94a1d0..03491f70a 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -351,6 +351,16 @@ def get_batch_pipe(data, neox_args, curr_scheduler=None): return (tokens, position_ids, attention_mask), (labels, loss_mask) +def get_batch_sequential(forward_input, neox_args): + """A modification of get_batch() to work with the latest batch instead of an iterator.""" + attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( + data=forward_input[0], + eod_token=neox_args.tokenizer.eod, + eod_mask_loss=neox_args.eod_mask_loss, + ) + return (forward_input[0], forward_input[1], attention_mask) + + def forward_step( data_iterator, model, neox_args, timers, return_logits=False, is_train=False ): @@ -653,6 +663,13 @@ def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None): get_batch_pipe, neox_args=neox_args, curr_scheduler=curr_scheduler ) ) + else: + model.module.set_batch_fn( + partial( + get_batch_sequential, neox_args=neox_args + ) + ) + else: raise ValueError("Must be using deepspeed to run neox")