Skip to content

Commit

Permalink
Updating for review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Lunderberg committed Nov 7, 2023
1 parent b85f723 commit 6844257
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 5 deletions.
5 changes: 2 additions & 3 deletions mlc_llm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,9 +380,9 @@ def _parse_args(parsed) -> argparse.Namespace:
parsed.quantization.name,
]
if parsed.use_presharded_weights:
model_name.append(f'presharded-{parsed.num_shards}gpu')
model_name.append(f"presharded-{parsed.num_shards}gpu")

parsed.artifact_path = os.path.join(parsed.artifact_path, '-'.join(model_name))
parsed.artifact_path = os.path.join(parsed.artifact_path, "-".join(model_name))

return parsed

Expand Down Expand Up @@ -828,7 +828,6 @@ def build_model_from_args(args: argparse.Namespace):
sharding_module = create_shard_info_func(param_manager, args, model_config)
mod.update(sharding_module)


with open(cache_path, "wb") as outfile:
pickle.dump(mod, outfile)
print(f"Save a cached module to {cache_path}.")
Expand Down
2 changes: 0 additions & 2 deletions mlc_llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,11 +280,9 @@ def save_params(params: List[tvm.nd.NDArray], artifact_path: str, num_presharded
assert len(params) % num_presharded == 0
num_weights = len(params) // num_presharded


meta_data = {}
param_dict = {}
meta_data["ParamSize"] = len(params)
total_size = 0.0
for i, nd in enumerate(params):
if num_presharded == 1:
param_name = f"param_{i}"
Expand Down

0 comments on commit 6844257

Please sign in to comment.