Skip to content

Commit

Permalink
Update all use of "convert_weight_only" to "convert_weights_only"
Browse files Browse the repository at this point in the history
  • Loading branch information
Lunderberg committed Nov 9, 2023
1 parent 42eb0bb commit eb076c8
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 14 deletions.
10 changes: 5 additions & 5 deletions mlc_llm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,10 +192,10 @@ class BuildArgs:
)
convert_weights_only: bool = field(
default=False,
dest="convert_weights_only",
metadata={
"help": "Whether to only convert model weights and not build the model.",
"dest": "convert_weights_only",
"action": "store_true",
"help": "Whether to only convert model weights and not build the model.",
},
)
build_model_only: bool = field(
Expand Down Expand Up @@ -750,7 +750,7 @@ def build_model_from_args(args: argparse.Namespace):
"and it is highly recommended to use q4f16_1 instead"
)
if args.num_shards > 1:
if (not args.build_model_only) and (not args.convert_weight_only):
if (not args.build_model_only) and (not args.convert_weights_only):
raise ValueError(
"`num_shards` should be used together with "
"`--build-model-only` and `--convert-weight-only`"
Expand All @@ -774,7 +774,7 @@ def build_model_from_args(args: argparse.Namespace):
with open(os.path.join(args.model_path, "config.json"), encoding="utf-8") as i_f:
config = json.load(i_f)

if not use_cache or args.convert_weight_only:
if not use_cache or args.convert_weights_only:
model_generators = {
"llama": llama,
"mistral": mistral,
Expand Down Expand Up @@ -869,7 +869,7 @@ def build_model_from_args(args: argparse.Namespace):
max_window_size=model_config.max_sequence_length,
)

if args.convert_weight_only:
if args.convert_weights_only:
exit(0)

mod = mod_transform_before_build(mod, param_manager, args, model_config)
Expand Down
27 changes: 19 additions & 8 deletions mlc_llm/relax_model/stablelm_3b.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(
combine_matmul=True,
num_shards=1,
build_model_only=False,
convert_weight_only=False,
convert_weights_only=False,
**kwargs,
):
self.dtype = dtype
Expand Down Expand Up @@ -376,17 +376,21 @@ def forward(
all_seq_len_shape=all_seq_len_shape,
)
if self.self_attn.num_shards > 1:
residual = nn.emit(residual / R.const(self.self_attn.num_shards, dtype=residual.struct_info.dtype))
residual = nn.emit(
residual / R.const(self.self_attn.num_shards, dtype=residual.struct_info.dtype)
)
hidden_states = nn.emit(residual + hidden_states)
if self.self_attn.num_shards > 1:
hidden_states = nn.emit(ccl.allreduce(hidden_states, "sum"))

# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
if self.mlp.num_shards > 1:
residual = nn.emit(residual / R.const(self.mlp.num_shards, dtype=residual.struct_info.dtype))
residual = nn.emit(
residual / R.const(self.mlp.num_shards, dtype=residual.struct_info.dtype)
)
hidden_states = nn.emit(residual + hidden_states)
if self.mlp.num_shards > 1:
hidden_states = nn.emit(ccl.allreduce(hidden_states, "sum"))
Expand Down Expand Up @@ -444,7 +448,9 @@ def forward(self, input_ids: relax.Expr):


class StableLM3bModell(nn.Module):
def __init__(self, config: StableLM3bConfig, vocab_size_var: tvm.tir.Var, sep_embed: bool = False):
def __init__(
self, config: StableLM3bConfig, vocab_size_var: tvm.tir.Var, sep_embed: bool = False
):
rotary_embedding = RotaryEmbedding(
hidden_size=config.hidden_size,
num_attention_heads=config.num_attention_heads,
Expand All @@ -461,7 +467,10 @@ def __init__(self, config: StableLM3bConfig, vocab_size_var: tvm.tir.Var, sep_em
self.embed_tokens = Embedding(vocab_size_var, config.hidden_size, dtype=config.dtype)

self.layers = ModuleList(
[StableLM3bDecoderLayer(config, rotary_embedding) for _ in range(config.num_hidden_layers)]
[
StableLM3bDecoderLayer(config, rotary_embedding)
for _ in range(config.num_hidden_layers)
]
)
self.norm = LayerNorm(config.hidden_size, dtype=config.dtype, eps=config.norm_eps)

Expand Down Expand Up @@ -530,7 +539,9 @@ def forward(


class StableLM3bForCausalLM(nn.Module):
def __init__(self, config: StableLM3bConfig, vocab_size_var: tvm.tir.Var, sep_embed: bool = False):
def __init__(
self, config: StableLM3bConfig, vocab_size_var: tvm.tir.Var, sep_embed: bool = False
):
self.model = StableLM3bModell(config, vocab_size_var, sep_embed)
self.lm_head = Linear(config.hidden_size, vocab_size_var, dtype=config.dtype, bias=False)

Expand Down Expand Up @@ -779,7 +790,7 @@ def get_model(args, hf_config):
combine_matmul=True,
num_shards=args.num_shards,
build_model_only=args.build_model_only,
convert_weight_only=args.convert_weight_only,
convert_weights_only=args.convert_weights_only,
)
if max_seq_len != -1:
config.max_sequence_length = max_seq_len
Expand Down
2 changes: 1 addition & 1 deletion tests/legacy-python/test_build_model_from_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def setUp(self):
self.mock_args.sep_embed = False
self.mock_args.build_model_only = True
self.mock_args.use_safetensors = False
self.mock_args.convert_weight_only = False
self.mock_args.convert_weights_only = False
self.mock_args.no_cutlass_attn = True
self.mock_args.no_cutlass_norm = True
self.mock_args.reuse_lib = True
Expand Down

0 comments on commit eb076c8

Please sign in to comment.