diff --git a/README.md b/README.md index 3dcf6b1f4e..e12368fb7d 100644 --- a/README.md +++ b/README.md @@ -140,6 +140,7 @@ Every model is written from scratch to maximize performance and remove layers of | Qwen2.5 Coder | 0.5B, 1.5B, 3B, 7B, 14B, 32B | Alibaba Group | [Hui, Binyuan et al. 2024](https://arxiv.org/abs/2409.12186) | | Qwen2.5 Math | 1.5B, 7B, 72B | Alibaba Group | [An, Yang et al. 2024](https://arxiv.org/abs/2409.12122) | | QwQ | 32B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwq-32b-preview/) | +| SmolLM2 | 135M, 360M, 1.7B | Hugging Face | [Hugging Face 2024](https://github.com/huggingface/smollm) | | Salamandra | 2B, 7B | Barcelona Supercomputing Centre | [BSC-LTC 2024](https://github.com/BSC-LTC/salamandra) | | StableCode | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) | | StableLM | 3B, 7B | Stability AI | [Stability AI 2023](https://github.com/Stability-AI/StableLM) | diff --git a/litgpt/config.py b/litgpt/config.py index 54420826bb..475f017e50 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -2134,10 +2134,10 @@ def norm_class(self) -> Type: configs.extend(qwq) + ############# # Salamandra ############# - salamandra = [ # https://huggingface.co/BSC-LT/salamandra-2b-instruct/blob/main/config.json dict( @@ -2189,4 +2189,78 @@ def norm_class(self) -> Type: configs.append(copy) +############### +# SmolLM2 +############### +smollm2 = [ + # https://huggingface.co/HuggingFaceTB/SmolLM2-135M/blob/main/config.json + dict( + name="SmolLM2-135M{}", + hf_config=dict(org="HuggingFaceTB", name="SmolLM2-135M{}"), + block_size=8192, + vocab_size=49152, + padded_vocab_size=49152, + n_layer=30, + n_head=9, + n_embd=576, + n_query_groups=3, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=1536, + rope_base=100000, + norm_eps=1e-5, + ), + # https://huggingface.co/HuggingFaceTB/SmolLM2-360M/blob/main/config.json + dict( + name="SmolLM2-360M{}", + hf_config=dict(org="HuggingFaceTB", name="SmolLM2-360M{}"), + block_size=8192, + vocab_size=49152, + padded_vocab_size=49152, + n_layer=32, + n_head=15, + n_embd=960, + n_query_groups=5, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=2560, + rope_base=100000, + norm_eps=1e-5, + ), + # https://huggingface.co/HuggingFaceTB/SmolLM2-1.7B/blob/main/config.json + dict( + name="SmolLM2-1.7B{}", + hf_config=dict(org="HuggingFaceTB", name="SmolLM2-1.7B{}"), + block_size=8192, + vocab_size=49152, + padded_vocab_size=49152, + n_layer=24, + n_head=32, + n_embd=2048, + n_query_groups=32, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=8192, + rope_base=130000, + norm_eps=1e-5, + ), +] + +for c in smollm2: + for kind in ("", "-Instruct"): + copy = deepcopy(c) + copy["name"] = c["name"].format(kind) + copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind) + configs.append(copy) + + name_to_config = {config["name"]: config for config in configs} diff --git a/litgpt/prompts.py b/litgpt/prompts.py index 51426e1523..09b3277c7d 100644 --- a/litgpt/prompts.py +++ b/litgpt/prompts.py @@ -300,6 +300,12 @@ def apply(self, prompt: str, **kwargs: str) -> str: return f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" +class SmolLM2(PromptStyle): + def apply(self, prompt: str, **kwargs: str) -> str: + system_message = "You are a helpful AI assistant named SmolLM, trained by Hugging Face" + return f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" + + # Maps prompt style names to PromptStyle classes prompt_styles: Dict[str, Type[PromptStyle]] = { # Dataset-specific prompt styles @@ -326,6 +332,7 @@ def apply(self, prompt: str, **kwargs: str) -> str: "qwen2.5": Qwen2_5, "qwen2.5-math": Qwen2_5_Math, "qwq": QwQ, + "smollm2": SmolLM2, "salamandra": Salamandra, } @@ -371,6 +378,8 @@ def model_name_to_prompt_style(model_name: str) -> PromptStyle: return Qwen2_5() if re.search(r"QwQ-.*", model_name): return QwQ() + if re.search(r"SmolLM2.*-Instruct", model_name): + return SmolLM2() if re.search(r"salamandra-.*-instruct", model_name): return Salamandra() return Default() diff --git a/litgpt/scripts/download.py b/litgpt/scripts/download.py index c1af2af133..fc6c153fad 100644 --- a/litgpt/scripts/download.py +++ b/litgpt/scripts/download.py @@ -131,7 +131,7 @@ def find_weight_files(repo_id: str, access_token: Optional[str]) -> Tuple[List[s with gated_repo_catcher(repo_id, access_token): info = repo_info(repo_id, token=access_token) filenames = [f.rfilename for f in info.siblings] - bins = list(filter_repo_objects(items=filenames, allow_patterns=["*.bin*"])) + bins = list(filter_repo_objects(items=filenames, allow_patterns=["*model*.bin*"])) safetensors = list(filter_repo_objects(items=filenames, allow_patterns=["*.safetensors*"])) return bins, safetensors diff --git a/litgpt/tokenizer.py b/litgpt/tokenizer.py index 6018a44734..10f7d031f6 100644 --- a/litgpt/tokenizer.py +++ b/litgpt/tokenizer.py @@ -87,7 +87,7 @@ def token_to_id(self, token: str) -> int: raise ValueError(f"token {token!r} not found in the collection.") return id_ - def check_if_bos_token_used(self, checkpoint_dir: Path) -> bool: + def check_if_bos_token_used(self, checkpoint_dir: Path) -> bool: if not (tokenizer_config_path := checkpoint_dir / "tokenizer_config.json").is_file(): return False with open(tokenizer_config_path, encoding="utf-8") as fp: @@ -96,6 +96,8 @@ def check_if_bos_token_used(self, checkpoint_dir: Path) -> bool: # `PreTrainedTokenizerFast` if checkpoint_dir.stem.startswith(("Meta-Llama-3", "Llama-3")): return True + if checkpoint_dir.stem.startswith("SmolLM2") and checkpoint_dir.name.endswith("Instruct"): + return True if "add_bos_token" in config: return config["add_bos_token"] # if `add_bos_token` isn't in the config file, but LLaMA tokenizer is used - return True. diff --git a/tests/test_model.py b/tests/test_model.py index 39e92f1204..89e926d173 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -852,6 +852,7 @@ def test_against_original_qwen_2_5(model_name, device, dtype): theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y) + @torch.inference_mode() @pytest.mark.parametrize("model_name", ("salamandra-2b", "salamandra-7b")) @pytest.mark.parametrize( @@ -910,6 +911,66 @@ def test_against_original_salamandra(model_name, device, dtype): ours_y = ours_model(x) theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y) + + +@torch.inference_mode() +@pytest.mark.parametrize("model_name", ("SmolLM2-135M", "SmolLM2-360M", "SmolLM2-1.7B")) +@pytest.mark.parametrize( + ("device", "dtype"), + [ + (torch.device("cpu"), torch.float32), + pytest.param( + torch.device("cuda"), + torch.float16, + marks=[ + # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input + # is slightly different + pytest.mark.xfail(raises=AssertionError, strict=False), + RunIf(min_cuda_gpus=1), + ], + ), + ], +) +def test_against_original_smollm2(model_name, device, dtype): + torch.set_default_dtype(dtype) + + ours_config = Config.from_name( + model_name, + padded_vocab_size=10000, + n_layer=2, + n_head=8, + n_embd=32, + n_query_groups=2, + intermediate_size=86, + ) + T = 5 + theirs_config = LlamaConfig( + vocab_size=ours_config.padded_vocab_size, + hidden_size=ours_config.n_embd, + num_attention_heads=ours_config.n_head, + num_hidden_layers=ours_config.n_layer, + intermediate_size=ours_config.intermediate_size, + max_position_embeddings=T, + rms_norm_eps=ours_config.norm_eps, + num_key_value_heads=ours_config.n_query_groups, + rope_theta=ours_config.rope_base, + attention_bias=ours_config.bias, + ) + assert ours_config.intermediate_size == theirs_config.intermediate_size + + theirs_model = LlamaForCausalLM(theirs_config).to(device) + theirs_state_dict = theirs_model.state_dict() + state_dict = {} + copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict) + ours_model = GPT(ours_config).to(device) + ours_model.load_state_dict(state_dict) + + # test end to end + x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device) + assert x.size(1) == T + ours_y = ours_model(x) + theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float + torch.testing.assert_close(ours_y, theirs_y) @RunIf(dynamo=True) diff --git a/tutorials/download_model_weights.md b/tutorials/download_model_weights.md index 49d3a6b619..876db1916a 100644 --- a/tutorials/download_model_weights.md +++ b/tutorials/download_model_weights.md @@ -40,6 +40,7 @@ LitGPT supports a variety of LLM architectures with publicly available weights. | Qwen2.5 Math | 1.5B, 7B, 72B | Alibaba Group | [An, Yang et al. 2024](https://arxiv.org/abs/2409.12122) | | QwQ | 32B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwq-32b-preview/) | | RedPajama-INCITE | 3B, 7B | Together | [Together 2023](https://together.ai/blog/redpajama-models-v1) | +| SmolLM2 | 135M, 360M, 1.7B | Hugging Face | [Hugging Face 2024](https://github.com/huggingface/smollm) | | StableCode | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) | | Salamandra | 2B, 7B | Barcelona Supercomputing Centre | [BSC-LTC 2024](https://github.com/BSC-LTC/salamandra) | | StableLM | 3B, 7B | Stability AI | [Stability AI 2023](https://github.com/Stability-AI/StableLM) | @@ -122,6 +123,12 @@ google/gemma-2b-it google/gemma-7b google/gemma-7b-it h2oai/h2o-danube2-1.8b-chat +HuggingFaceTB/SmolLM2-135M +HuggingFaceTB/SmolLM2-135M-Instruct +HuggingFaceTB/SmolLM2-360M +HuggingFaceTB/SmolLM2-360M-Instruct +HuggingFaceTB/SmolLM2-1.7B +HuggingFaceTB/SmolLM2-1.7B-Instruct lmsys/longchat-13b-16k lmsys/longchat-7b-16k lmsys/vicuna-13b-v1.3