Skip to content

Commit

Permalink
[Trainer] Add hybrid_parallel_topo_order options. (#6782)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZHUI authored Aug 21, 2023
1 parent 8febfc2 commit e0a9f4e
Showing 1 changed file with 19 additions and 5 deletions.
24 changes: 19 additions & 5 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,18 @@ class TrainingArguments:
)
},
)
hybrid_parallel_topo_order: str = field(
default=None,
metadata={
"help": (
"In hybrid parallelism, the order of communication groups may affect efficiency.\n"
"Following options are supported:\n"
"- pp_first. the topo order is dp, pp, sharding, mp \n"
"- sharding_first. the topo order is dp, sharding, pp, mp \n"
"Defalut is None, for pp_first"
)
},
)
recompute: bool = field(
default=False,
metadata={
Expand Down Expand Up @@ -868,16 +880,19 @@ def __post_init__(self):
if tensor_parallel_degree > 1:
strategy.tensor_parallel_configs = {"tensor_init_seed": self.seed}

if tensor_parallel_degree == 1 and sharding_parallel_degree == 1:
order = ["pp", "dp", "sharding", "mp"]
else:
if self.hybrid_parallel_topo_order is None:
self.hybrid_parallel_topo_order = "pp_first"
assert self.hybrid_parallel_topo_order in ["pp_first", "sharding_first"]

if self.hybrid_parallel_topo_order == "pp_first":
order = ["dp", "pp", "sharding", "mp"]
if self.hybrid_parallel_topo_order == "sharding_first":
order = ["dp", "sharding", "pp", "mp"]

hybrid_configs = {
"dp_degree": self.data_parallel_degree,
"mp_degree": tensor_parallel_degree,
"pp_degree": pipeline_parallel_degree,
"order": order,
"sharding_degree": sharding_parallel_degree,
"order": order,
}
Expand Down Expand Up @@ -912,7 +927,6 @@ def __post_init__(self):
"The enable_stage1_tensor_fusion or enable_stage1_overlap is not supported "
"by current version of Paddle. Please try latest develop Paddle."
)
paddle.device.cuda.synchronize()
start_time = time.time()
fleet.init(is_collective=True, strategy=strategy)
paddle.device.cuda.synchronize()
Expand Down

0 comments on commit e0a9f4e

Please sign in to comment.