diff --git a/megatron/training.py b/megatron/training.py index 9ce87c57d..72cc551a8 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -111,7 +111,7 @@ def save_base_shapes(neox_args, base_shapes, use_cache): delta_model = GPT2ModelPipe( neox_args=neox_args, num_tokentypes=0, - parallel_output=True, + parallel_output=True if neox_args.train_impl != "rm" else False, topology=mpu.get_topology(), use_cache=use_cache, ) @@ -812,7 +812,7 @@ def get_model(neox_args, use_cache=False): model = GPT2ModelPipe( neox_args=neox_args, num_tokentypes=0, - parallel_output=True, + parallel_output=True if neox_args.train_impl != "rm" else False, topology=mpu.get_topology(), use_cache=use_cache, )