diff --git a/mlc_llm/core.py b/mlc_llm/core.py index 17352ec811..52a1041a1d 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -380,9 +380,9 @@ def _parse_args(parsed) -> argparse.Namespace: parsed.quantization.name, ] if parsed.use_presharded_weights: - model_name.append(f'presharded-{parsed.num_shards}gpu') + model_name.append(f"presharded-{parsed.num_shards}gpu") - parsed.artifact_path = os.path.join(parsed.artifact_path, '-'.join(model_name)) + parsed.artifact_path = os.path.join(parsed.artifact_path, "-".join(model_name)) return parsed @@ -828,7 +828,6 @@ def build_model_from_args(args: argparse.Namespace): sharding_module = create_shard_info_func(param_manager, args, model_config) mod.update(sharding_module) - with open(cache_path, "wb") as outfile: pickle.dump(mod, outfile) print(f"Save a cached module to {cache_path}.") diff --git a/mlc_llm/utils.py b/mlc_llm/utils.py index be93628a41..b995de2956 100644 --- a/mlc_llm/utils.py +++ b/mlc_llm/utils.py @@ -280,11 +280,9 @@ def save_params(params: List[tvm.nd.NDArray], artifact_path: str, num_presharded assert len(params) % num_presharded == 0 num_weights = len(params) // num_presharded - meta_data = {} param_dict = {} meta_data["ParamSize"] = len(params) - total_size = 0.0 for i, nd in enumerate(params): if num_presharded == 1: param_name = f"param_{i}"