From f7eee21d4180408fadc4ad80f5e0f5c12c0d86be Mon Sep 17 00:00:00 2001 From: AI_WAIFU Date: Mon, 14 Oct 2024 14:35:27 +0000 Subject: [PATCH] pass conversion test --- tools/ckpts/convert_neox_to_hf.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tools/ckpts/convert_neox_to_hf.py b/tools/ckpts/convert_neox_to_hf.py index 8dfe02d54..ae480dd2d 100644 --- a/tools/ckpts/convert_neox_to_hf.py +++ b/tools/ckpts/convert_neox_to_hf.py @@ -444,10 +444,12 @@ def reshard_and_split_qkv( def get_mlp_naming_convention(loaded_tp_ranks, layer_idx, sequential): """Determine whether the checkpoint uses the legacy or new MLP naming convention.""" - print(list(loaded_tp_ranks[0]["module"].keys())) + for state_dict in loaded_tp_ranks: + print("------------------------------") + print(state_dict.keys()) if any( [ - ["mlp.linear1.weight" in key for key in list(state_dict["module"].keys())] + ["mlp.linear1.weight" in key for key in list(state_dict.keys())] for state_dict in loaded_tp_ranks ] ): @@ -456,7 +458,7 @@ def get_mlp_naming_convention(loaded_tp_ranks, layer_idx, sequential): [ [ "mlp.dense_h_to_4h.weight" in key - for key in list(state_dict["module"].keys()) + for key in list(state_dict.keys()) ] for state_dict in loaded_tp_ranks ]