diff --git a/tools/ckpts/convert_module_to_hf.py b/tools/ckpts/convert_module_to_hf.py index f3f43c308..9a5823cb9 100644 --- a/tools/ckpts/convert_module_to_hf.py +++ b/tools/ckpts/convert_module_to_hf.py @@ -225,7 +225,9 @@ def convert(input_checkpoint_path, loaded_config, output_checkpoint_path): "mlp.dense_4h_to_h.bias", "attention.dense.bias", ]: - state_dict[key] = sum([t[key] for t in loaded_tp_ranks]) + state_dict[key] = sum([t[key] for t in loaded_tp_ranks]) / len( + loaded_tp_ranks + ) # Just take one state_dict["attention.rotary_emb.inv_freq"] = loaded_tp_ranks[0][ diff --git a/tools/ckpts/convert_sequential_to_hf.py b/tools/ckpts/convert_sequential_to_hf.py index f0a505ac3..69ad58786 100644 --- a/tools/ckpts/convert_sequential_to_hf.py +++ b/tools/ckpts/convert_sequential_to_hf.py @@ -238,7 +238,9 @@ def convert(input_checkpoint_path, loaded_config, output_checkpoint_path): "mlp.dense_4h_to_h.bias", "attention.dense.bias", ]: - state_dict[key] = sum(get_state(loaded_tp_ranks, key, layer_i + 2)) + state_dict[key] = sum(get_state(loaded_tp_ranks, key, layer_i + 2)) / len( + loaded_tp_ranks + ) # Just take one state_dict["attention.rotary_emb.inv_freq"] = get_state(