Skip to content

Commit

Permalink
- parallel output updated
Browse files Browse the repository at this point in the history
  • Loading branch information
dmahan93 committed Sep 25, 2024
1 parent 444fae3 commit c4ae1db
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def save_base_shapes(neox_args, base_shapes, use_cache):
delta_model = GPT2ModelPipe(
neox_args=neox_args,
num_tokentypes=0,
parallel_output=True,
parallel_output=True if neox_args.train_impl != "rm" else False,
topology=mpu.get_topology(),
use_cache=use_cache,
)
Expand Down Expand Up @@ -812,7 +812,7 @@ def get_model(neox_args, use_cache=False):
model = GPT2ModelPipe(
neox_args=neox_args,
num_tokentypes=0,
parallel_output=True,
parallel_output=True if neox_args.train_impl != "rm" else False,
topology=mpu.get_topology(),
use_cache=use_cache,
)
Expand Down

0 comments on commit c4ae1db

Please sign in to comment.