From eb076c8fa9b89c61525199b85753818447ca73bb Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 9 Nov 2023 21:58:19 +0000 Subject: [PATCH] Update all use of "convert_weight_only" to "convert_weights_only" --- mlc_llm/core.py | 10 +++---- mlc_llm/relax_model/stablelm_3b.py | 27 +++++++++++++------ .../test_build_model_from_args.py | 2 +- 3 files changed, 25 insertions(+), 14 deletions(-) diff --git a/mlc_llm/core.py b/mlc_llm/core.py index bd4c29a422..174942b246 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -192,10 +192,10 @@ class BuildArgs: ) convert_weights_only: bool = field( default=False, - dest="convert_weights_only", metadata={ - "help": "Whether to only convert model weights and not build the model.", + "dest": "convert_weights_only", "action": "store_true", + "help": "Whether to only convert model weights and not build the model.", }, ) build_model_only: bool = field( @@ -750,7 +750,7 @@ def build_model_from_args(args: argparse.Namespace): "and it is highly recommended to use q4f16_1 instead" ) if args.num_shards > 1: - if (not args.build_model_only) and (not args.convert_weight_only): + if (not args.build_model_only) and (not args.convert_weights_only): raise ValueError( "`num_shards` should be used together with " "`--build-model-only` and `--convert-weight-only`" @@ -774,7 +774,7 @@ def build_model_from_args(args: argparse.Namespace): with open(os.path.join(args.model_path, "config.json"), encoding="utf-8") as i_f: config = json.load(i_f) - if not use_cache or args.convert_weight_only: + if not use_cache or args.convert_weights_only: model_generators = { "llama": llama, "mistral": mistral, @@ -869,7 +869,7 @@ def build_model_from_args(args: argparse.Namespace): max_window_size=model_config.max_sequence_length, ) - if args.convert_weight_only: + if args.convert_weights_only: exit(0) mod = mod_transform_before_build(mod, param_manager, args, model_config) diff --git a/mlc_llm/relax_model/stablelm_3b.py b/mlc_llm/relax_model/stablelm_3b.py index 89c15a7955..d0b5f8a385 100644 --- a/mlc_llm/relax_model/stablelm_3b.py +++ b/mlc_llm/relax_model/stablelm_3b.py @@ -40,7 +40,7 @@ def __init__( combine_matmul=True, num_shards=1, build_model_only=False, - convert_weight_only=False, + convert_weights_only=False, **kwargs, ): self.dtype = dtype @@ -376,17 +376,21 @@ def forward( all_seq_len_shape=all_seq_len_shape, ) if self.self_attn.num_shards > 1: - residual = nn.emit(residual / R.const(self.self_attn.num_shards, dtype=residual.struct_info.dtype)) + residual = nn.emit( + residual / R.const(self.self_attn.num_shards, dtype=residual.struct_info.dtype) + ) hidden_states = nn.emit(residual + hidden_states) if self.self_attn.num_shards > 1: hidden_states = nn.emit(ccl.allreduce(hidden_states, "sum")) - + # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) if self.mlp.num_shards > 1: - residual = nn.emit(residual / R.const(self.mlp.num_shards, dtype=residual.struct_info.dtype)) + residual = nn.emit( + residual / R.const(self.mlp.num_shards, dtype=residual.struct_info.dtype) + ) hidden_states = nn.emit(residual + hidden_states) if self.mlp.num_shards > 1: hidden_states = nn.emit(ccl.allreduce(hidden_states, "sum")) @@ -444,7 +448,9 @@ def forward(self, input_ids: relax.Expr): class StableLM3bModell(nn.Module): - def __init__(self, config: StableLM3bConfig, vocab_size_var: tvm.tir.Var, sep_embed: bool = False): + def __init__( + self, config: StableLM3bConfig, vocab_size_var: tvm.tir.Var, sep_embed: bool = False + ): rotary_embedding = RotaryEmbedding( hidden_size=config.hidden_size, num_attention_heads=config.num_attention_heads, @@ -461,7 +467,10 @@ def __init__(self, config: StableLM3bConfig, vocab_size_var: tvm.tir.Var, sep_em self.embed_tokens = Embedding(vocab_size_var, config.hidden_size, dtype=config.dtype) self.layers = ModuleList( - [StableLM3bDecoderLayer(config, rotary_embedding) for _ in range(config.num_hidden_layers)] + [ + StableLM3bDecoderLayer(config, rotary_embedding) + for _ in range(config.num_hidden_layers) + ] ) self.norm = LayerNorm(config.hidden_size, dtype=config.dtype, eps=config.norm_eps) @@ -530,7 +539,9 @@ def forward( class StableLM3bForCausalLM(nn.Module): - def __init__(self, config: StableLM3bConfig, vocab_size_var: tvm.tir.Var, sep_embed: bool = False): + def __init__( + self, config: StableLM3bConfig, vocab_size_var: tvm.tir.Var, sep_embed: bool = False + ): self.model = StableLM3bModell(config, vocab_size_var, sep_embed) self.lm_head = Linear(config.hidden_size, vocab_size_var, dtype=config.dtype, bias=False) @@ -779,7 +790,7 @@ def get_model(args, hf_config): combine_matmul=True, num_shards=args.num_shards, build_model_only=args.build_model_only, - convert_weight_only=args.convert_weight_only, + convert_weights_only=args.convert_weights_only, ) if max_seq_len != -1: config.max_sequence_length = max_seq_len diff --git a/tests/legacy-python/test_build_model_from_args.py b/tests/legacy-python/test_build_model_from_args.py index c7990d63df..b342e035bb 100644 --- a/tests/legacy-python/test_build_model_from_args.py +++ b/tests/legacy-python/test_build_model_from_args.py @@ -27,7 +27,7 @@ def setUp(self): self.mock_args.sep_embed = False self.mock_args.build_model_only = True self.mock_args.use_safetensors = False - self.mock_args.convert_weight_only = False + self.mock_args.convert_weights_only = False self.mock_args.no_cutlass_attn = True self.mock_args.no_cutlass_norm = True self.mock_args.reuse_lib = True