From 521a97374ddcf2c97eb93252df06987e811f21ae Mon Sep 17 00:00:00 2001 From: LiYuRio Date: Thu, 19 Sep 2024 19:10:17 +0800 Subject: [PATCH] support pp-sharding reshard --- paddlenlp/trainer/utils/reshard/pp_reshard.py | 35 +++++++++++++++++-- 1 file changed, 32 insertions(+), 3 deletions(-) diff --git a/paddlenlp/trainer/utils/reshard/pp_reshard.py b/paddlenlp/trainer/utils/reshard/pp_reshard.py index 5c98e6069212..0caa5eb666c6 100644 --- a/paddlenlp/trainer/utils/reshard/pp_reshard.py +++ b/paddlenlp/trainer/utils/reshard/pp_reshard.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - from collections import OrderedDict from paddle.distributed.fleet.model import PipelineParallel @@ -46,6 +45,25 @@ def get_index_layer_func(): return _GLOBAL_INDEX_LAYER_FUNC +_GLOBAL_SNAME_TO_TNAME_FUNC = None + + +def register_sname_to_tname_func(func): + global _GLOBAL_SNAME_TO_TNAME_FUNC + _GLOBAL_SNAME_TO_TNAME_FUNC = func + + +def has_register_sname_to_tname_func(): + global _GLOBAL_SNAME_TO_TNAME_FUNC + return _GLOBAL_SNAME_TO_TNAME_FUNC is not None + + +def get_sname_to_tname_func(): + global _GLOBAL_SNAME_TO_TNAME_FUNC + assert _GLOBAL_SNAME_TO_TNAME_FUNC is not None, "sname to tname func is not registered yet" + return _GLOBAL_SNAME_TO_TNAME_FUNC + + class LayerNameScope: """ layer name scope for a layer, layer name of the same kind of layer will be named consecutively @@ -206,6 +224,7 @@ def __init__(self): self._segments = OrderedDict() self._layer_to_segment = OrderedDict() self._param_to_tname = OrderedDict() + self._wname_to_rname = OrderedDict() def add_segment(self, start_index, end_index): segment = PipeLineSegment(start_index, end_index) @@ -218,19 +237,24 @@ def add_layer(self, layer_index, layer_name, param_names): segment = self._layer_to_segment[layer_index] segment.add_layer(layer_name, param_names) - def build_name_mapping(self): + def build_name_mapping(self, sname_to_tname=None): for (k, segment) in self._segments.items(): for (i, layer) in segment.layers.items(): for param in layer.params.items(): (param_name, tensor_name) = param # map to a new name n_name = self._rename_mgr.get_new_param_name(layer.name, tensor_name) + if sname_to_tname is not None: + if param_name in sname_to_tname.keys(): + self._wname_to_rname[param_name] = sname_to_tname[param_name] # logger.info(f"{param_name} {tensor_name}=>{n_name}") self._param_to_tname[param_name] = (tensor_name, n_name) def map_name(self, param_name, t_name): assert param_name in self._param_to_tname tensor_name, n_name = self._param_to_tname[param_name] + if param_name in self._wname_to_rname: + n_name = self._wname_to_rname[param_name] assert tensor_name == t_name return n_name @@ -261,6 +285,11 @@ def __init__( self._index_layers() stage_segments = self._segment() + if has_register_sname_to_tname_func(): + self._sname_to_tname = get_sname_to_tname_func()(pp_model) + else: + self._sname_to_tname = None + for (i, stage_seg) in enumerate(stage_segments): pipe_stage = PipeLineStage() self._stages.append(pipe_stage) @@ -275,7 +304,7 @@ def __init__( self._layer_name_to_stage[layer_name] = i for stage in self._stages: - stage.build_name_mapping() + stage.build_name_mapping(self._sname_to_tname) def _index_layers(self): for layer_name in self._param_names_by_layer.keys():