From c4ae1db14b97f1e8013ed197a3850deb902b9cef Mon Sep 17 00:00:00 2001 From: dmahan93 Date: Wed, 25 Sep 2024 09:49:44 -0500 Subject: [PATCH] - parallel output updated --- megatron/training.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/megatron/training.py b/megatron/training.py index 9ce87c57d..72cc551a8 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -111,7 +111,7 @@ def save_base_shapes(neox_args, base_shapes, use_cache): delta_model = GPT2ModelPipe( neox_args=neox_args, num_tokentypes=0, - parallel_output=True, + parallel_output=True if neox_args.train_impl != "rm" else False, topology=mpu.get_topology(), use_cache=use_cache, ) @@ -812,7 +812,7 @@ def get_model(neox_args, use_cache=False): model = GPT2ModelPipe( neox_args=neox_args, num_tokentypes=0, - parallel_output=True, + parallel_output=True if neox_args.train_impl != "rm" else False, topology=mpu.get_topology(), use_cache=use_cache, )