Skip to content

Commit

Permalink
pass conversion test
Browse files Browse the repository at this point in the history
  • Loading branch information
AI-WAIFU committed Oct 14, 2024
1 parent 456c45d commit f7eee21
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions tools/ckpts/convert_neox_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,10 +444,12 @@ def reshard_and_split_qkv(

def get_mlp_naming_convention(loaded_tp_ranks, layer_idx, sequential):
"""Determine whether the checkpoint uses the legacy or new MLP naming convention."""
print(list(loaded_tp_ranks[0]["module"].keys()))
for state_dict in loaded_tp_ranks:
print("------------------------------")
print(state_dict.keys())
if any(
[
["mlp.linear1.weight" in key for key in list(state_dict["module"].keys())]
["mlp.linear1.weight" in key for key in list(state_dict.keys())]
for state_dict in loaded_tp_ranks
]
):
Expand All @@ -456,7 +458,7 @@ def get_mlp_naming_convention(loaded_tp_ranks, layer_idx, sequential):
[
[
"mlp.dense_h_to_4h.weight" in key
for key in list(state_dict["module"].keys())
for key in list(state_dict.keys())
]
for state_dict in loaded_tp_ranks
]
Expand Down

0 comments on commit f7eee21

Please sign in to comment.