From 4553e83bbdc145c97e0590a8123bdbe9ec10907a Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sun, 26 May 2024 19:49:47 +0800 Subject: [PATCH 01/12] init internlm2 lora support --- vllm/model_executor/models/internlm2.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index e75c567f589c8..b405adf362a34 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -6,7 +6,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata -from vllm.config import CacheConfig +from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm @@ -242,13 +242,31 @@ def forward( class InternLM2ForCausalLM(nn.Module): + packed_modules_mapping = { + "gate_up_proj": [ + "w1", + "w3", + ] + } + + # LoRA specific attributes + supported_lora_modules = [ + "wqkv", + "wo", + "gate_up_proj", + "w2", + ] + embedding_modules = {} + embedding_padding_modules = [] def __init__( self, config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, ) -> None: + del lora_config # Unused. super().__init__() self.config = config self.quant_config = quant_config From aa9f498b140e2607ea9364fd37e738af83a41bcb Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sun, 26 May 2024 21:01:29 +0800 Subject: [PATCH 02/12] init internlm2 lora support --- vllm/model_executor/models/internlm2.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index b405adf362a34..d78b7983d1d0d 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -243,6 +243,7 @@ def forward( class InternLM2ForCausalLM(nn.Module): packed_modules_mapping = { + "wqkv": ["wqkv"], "gate_up_proj": [ "w1", "w3", From 3000604ae8cb837cd7811e42a0ff48d6329a3edc Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Mon, 27 May 2024 12:45:40 +0800 Subject: [PATCH 03/12] add internlm2 lora test --- tests/lora/conftest.py | 5 +++ tests/lora/test_internlm2.py | 66 ++++++++++++++++++++++++++++++++++++ 2 files changed, 71 insertions(+) create mode 100644 tests/lora/test_internlm2.py diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index e5cf9cd48b65d..f388e0b58e571 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -170,6 +170,11 @@ def phi2_lora_files(): return snapshot_download(repo_id="isotr0py/phi-2-test-sql-lora") +@pytest.fixture(scope="session") +def phi2_lora_files(): + return snapshot_download(repo_id="isotr0py/InternLM2-1_8B-test-sql-lora") + + @pytest.fixture(scope="session") def long_context_lora_files_16k_1(): return snapshot_download(repo_id="SangBinCho/long_context_16k_testing_1") diff --git a/tests/lora/test_internlm2.py b/tests/lora/test_internlm2.py new file mode 100644 index 0000000000000..97a68c574ea48 --- /dev/null +++ b/tests/lora/test_internlm2.py @@ -0,0 +1,66 @@ +import vllm +from vllm.lora.request import LoRARequest + +MODEL_PATH = "internlm/internlm2-1_8b" + +PROMPT_TEMPLATE = "[user] question: {sql_prompt}\n\n context: {context}\n\n [/user] [assistant]" # noqa: E501 + + +def do_sample(llm, lora_path: str, lora_id: int) -> str: + prompts = [ + PROMPT_TEMPLATE.format( + sql_prompt= + "Which catalog publisher has published the most catalogs?", + context="CREATE TABLE catalogs (catalog_publisher VARCHAR);"), + PROMPT_TEMPLATE.format( + sql_prompt= + "Which trip started from the station with the largest dock count? Give me the trip id.", # noqa: E501 + context= + "CREATE TABLE trip (id VARCHAR, start_station_id VARCHAR); CREATE TABLE station (id VARCHAR, dock_count VARCHAR);" # noqa: E501 + ), + PROMPT_TEMPLATE.format( + sql_prompt= + "How many marine species are found in the Southern Ocean?", # noqa: E501 + context= + "CREATE TABLE marine_species (name VARCHAR(50), common_name VARCHAR(50), location VARCHAR(50));" # noqa: E501 + ), + ] + sampling_params = vllm.SamplingParams(temperature=0, + max_tokens=64, + stop="[/assistant]") + outputs = llm.generate( + prompts, + sampling_params, + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) + if lora_id else None, + ) + # Print the outputs. + generated_texts = [] + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text.strip() + generated_texts.append(generated_text) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + return generated_texts + + +def test_internlm2_lora(internlm2_lora_files): + # We enable enforce_eager=True here to reduce VRAM usage for lora-test CI, + # Otherwise, the lora-test will fail due to CUDA OOM. + llm = vllm.LLM(MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=2) + + expected_lora_output = [ + "SELECT catalog_publisher, COUNT(*) as num_catalogs FROM catalogs GROUP BY catalog_publisher ORDER BY num_catalogs DESC LIMIT 1;", # noqa: E501 + "SELECT trip.id FROM trip JOIN station ON trip.start_station_id = station.id WHERE station.dock_count = (SELECT MAX(dock_count) FROM station);", # noqa: E501 + "SELECT COUNT(*) FROM marine_species WHERE location = 'Southern Ocean';", # noqa: E501 + ] + + output1 = do_sample(llm, internlm2_lora_files, lora_id=1) + for i in range(len(expected_lora_output)): + assert output1[i].startswith(expected_lora_output[i]) + output2 = do_sample(llm, internlm2_lora_files, lora_id=2) + for i in range(len(expected_lora_output)): + assert output2[i].startswith(expected_lora_output[i]) From a723e7b252c250c1433074dc27c56d079c5ca7df Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Mon, 27 May 2024 13:19:03 +0800 Subject: [PATCH 04/12] fix internlm2 lora test --- tests/lora/conftest.py | 2 +- tests/lora/test_internlm2.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index f388e0b58e571..1d148cffc07c4 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -171,7 +171,7 @@ def phi2_lora_files(): @pytest.fixture(scope="session") -def phi2_lora_files(): +def internlm2_lora_files(): return snapshot_download(repo_id="isotr0py/InternLM2-1_8B-test-sql-lora") diff --git a/tests/lora/test_internlm2.py b/tests/lora/test_internlm2.py index 97a68c574ea48..4b0f76dd795ac 100644 --- a/tests/lora/test_internlm2.py +++ b/tests/lora/test_internlm2.py @@ -3,7 +3,7 @@ MODEL_PATH = "internlm/internlm2-1_8b" -PROMPT_TEMPLATE = "[user] question: {sql_prompt}\n\n context: {context}\n\n [/user] [assistant]" # noqa: E501 +PROMPT_TEMPLATE = "[user] question: {sql_prompt}\n\n context: {context}\n\n [/user] [assistant] " # noqa: E501 def do_sample(llm, lora_path: str, lora_id: int) -> str: @@ -48,13 +48,14 @@ def test_internlm2_lora(internlm2_lora_files): # We enable enforce_eager=True here to reduce VRAM usage for lora-test CI, # Otherwise, the lora-test will fail due to CUDA OOM. llm = vllm.LLM(MODEL_PATH, + trust_remote_code=True, max_model_len=1024, enable_lora=True, max_loras=2) expected_lora_output = [ "SELECT catalog_publisher, COUNT(*) as num_catalogs FROM catalogs GROUP BY catalog_publisher ORDER BY num_catalogs DESC LIMIT 1;", # noqa: E501 - "SELECT trip.id FROM trip JOIN station ON trip.start_station_id = station.id WHERE station.dock_count = (SELECT MAX(dock_count) FROM station);", # noqa: E501 + "SELECT trip.id FROM trip JOIN station ON trip.start_station_id = station.id WHERE station.dock_count = (SELECT MAX(station.dock_count) FROM station);", # noqa: E501 "SELECT COUNT(*) FROM marine_species WHERE location = 'Southern Ocean';", # noqa: E501 ] From e94d6a456fca1ff6f5517c172c646410063e325c Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Mon, 27 May 2024 14:18:18 +0800 Subject: [PATCH 05/12] fix internlm2 test --- tests/lora/test_internlm2.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/lora/test_internlm2.py b/tests/lora/test_internlm2.py index 4b0f76dd795ac..6f1e0c8db46c9 100644 --- a/tests/lora/test_internlm2.py +++ b/tests/lora/test_internlm2.py @@ -14,9 +14,9 @@ def do_sample(llm, lora_path: str, lora_id: int) -> str: context="CREATE TABLE catalogs (catalog_publisher VARCHAR);"), PROMPT_TEMPLATE.format( sql_prompt= - "Which trip started from the station with the largest dock count? Give me the trip id.", # noqa: E501 + "Remove the 'vehicle_safety_testing' table and its records.", # noqa: E501 context= - "CREATE TABLE trip (id VARCHAR, start_station_id VARCHAR); CREATE TABLE station (id VARCHAR, dock_count VARCHAR);" # noqa: E501 + "CREATE TABLE vehicle_safety_testing (id INT PRIMARY KEY, vehicle_model VARCHAR(255), test_score FLOAT);" # noqa: E501 ), PROMPT_TEMPLATE.format( sql_prompt= @@ -55,7 +55,7 @@ def test_internlm2_lora(internlm2_lora_files): expected_lora_output = [ "SELECT catalog_publisher, COUNT(*) as num_catalogs FROM catalogs GROUP BY catalog_publisher ORDER BY num_catalogs DESC LIMIT 1;", # noqa: E501 - "SELECT trip.id FROM trip JOIN station ON trip.start_station_id = station.id WHERE station.dock_count = (SELECT MAX(station.dock_count) FROM station);", # noqa: E501 + "DROP TABLE vehicle_safety_testing;", # noqa: E501 "SELECT COUNT(*) FROM marine_species WHERE location = 'Southern Ocean';", # noqa: E501 ] From de8bf1fffdde120e14b67b003ff533e99acaa3f9 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Mon, 27 May 2024 16:10:12 +0800 Subject: [PATCH 06/12] use enforce_eager=True --- tests/lora/test_internlm2.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/lora/test_internlm2.py b/tests/lora/test_internlm2.py index 6f1e0c8db46c9..505f3662be2f5 100644 --- a/tests/lora/test_internlm2.py +++ b/tests/lora/test_internlm2.py @@ -51,7 +51,8 @@ def test_internlm2_lora(internlm2_lora_files): trust_remote_code=True, max_model_len=1024, enable_lora=True, - max_loras=2) + max_loras=2, + enforce_eager=True) expected_lora_output = [ "SELECT catalog_publisher, COUNT(*) as num_catalogs FROM catalogs GROUP BY catalog_publisher ORDER BY num_catalogs DESC LIMIT 1;", # noqa: E501 From 93c2d067300f1862f3ebd8fa23889fbf1e813366 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sat, 15 Jun 2024 23:24:51 +0800 Subject: [PATCH 07/12] format code --- tests/lora/test_internlm2.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/lora/test_internlm2.py b/tests/lora/test_internlm2.py index 505f3662be2f5..54bacff734715 100644 --- a/tests/lora/test_internlm2.py +++ b/tests/lora/test_internlm2.py @@ -1,3 +1,5 @@ +from typing import List + import vllm from vllm.lora.request import LoRARequest @@ -6,7 +8,7 @@ PROMPT_TEMPLATE = "[user] question: {sql_prompt}\n\n context: {context}\n\n [/user] [assistant] " # noqa: E501 -def do_sample(llm, lora_path: str, lora_id: int) -> str: +def do_sample(llm, lora_path: str, lora_id: int) -> List[str]: prompts = [ PROMPT_TEMPLATE.format( sql_prompt= @@ -35,7 +37,7 @@ def do_sample(llm, lora_path: str, lora_id: int) -> str: if lora_id else None, ) # Print the outputs. - generated_texts = [] + generated_texts: List[str] = [] for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text.strip() From 333a99395e61eddd462147dd2e286fb87882d252 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Thu, 28 Nov 2024 13:23:20 +0800 Subject: [PATCH 08/12] update internlm2 lora static variable Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/internlm2.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 906128940ff76..a946af955030a 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -27,7 +27,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsPP +from .interfaces import SupportsLoRA, SupportsPP from .utils import (is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -319,7 +319,18 @@ def forward( return hidden_states -class InternLM2ForCausalLM(nn.Module, SupportsPP): +class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): + packed_modules_mapping = {"gate_up_proj": ["gate_proj", "up_proj"]} + + # LoRA specific attributes + supported_lora_modules = [ + "wqkv", + "wo", + "gate_up_proj", + "w2", + ] + embedding_modules = {} + embedding_padding_modules = [] def __init__(self, *, @@ -329,8 +340,12 @@ def __init__(self, super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + self.config = config self.quant_config = quant_config + self.lora_config = lora_config + self.model = model_type(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) self.output = ParallelLMHead(config.vocab_size, From c4ea705cb12f6b1f074afd4d6c98c7d7a18a342d Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Thu, 28 Nov 2024 13:33:30 +0800 Subject: [PATCH 09/12] fix wqkv lora Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/internlm2.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index a946af955030a..924295aa1a60d 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -320,7 +320,10 @@ def forward( class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): - packed_modules_mapping = {"gate_up_proj": ["gate_proj", "up_proj"]} + packed_modules_mapping = { + "wqkv": ["wqkv"], + "gate_up_proj": ["gate_proj", "up_proj"] + } # LoRA specific attributes supported_lora_modules = [ From 90f56b14c2817837dc847e50041751349a80c472 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Thu, 28 Nov 2024 13:37:10 +0800 Subject: [PATCH 10/12] fix gate_up lora Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/internlm2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 924295aa1a60d..41b9f110d771f 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -322,7 +322,7 @@ def forward( class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): packed_modules_mapping = { "wqkv": ["wqkv"], - "gate_up_proj": ["gate_proj", "up_proj"] + "gate_up_proj": ["w1", "w3"], } # LoRA specific attributes From 76f4a31b85b10f55dc57a9857eefb7c7c8d0a9ac Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Thu, 28 Nov 2024 14:12:49 +0800 Subject: [PATCH 11/12] update doc Signed-off-by: Isotr0py <2037008807@qq.com> --- docs/source/models/supported_models.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index c5fbb30b24e28..8b6c2133814b7 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -177,7 +177,7 @@ Text Generation * - :code:`InternLM2ForCausalLM` - InternLM2 - :code:`internlm/internlm2-7b`, :code:`internlm/internlm2-chat-7b`, etc. - - + - ✅︎ - ✅︎ * - :code:`JAISLMHeadModel` - Jais From f6203769dd9a914f0994aeac89668aff37c04818 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Thu, 28 Nov 2024 14:39:13 +0800 Subject: [PATCH 12/12] remove test script Signed-off-by: Isotr0py <2037008807@qq.com> --- tests/lora/conftest.py | 5 --- tests/lora/test_internlm2.py | 70 ------------------------------------ 2 files changed, 75 deletions(-) delete mode 100644 tests/lora/test_internlm2.py diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 1594a3f1893e2..29ecf37808205 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -210,11 +210,6 @@ def phi2_lora_files(): return snapshot_download(repo_id="isotr0py/phi-2-test-sql-lora") -@pytest.fixture(scope="session") -def internlm2_lora_files(): - return snapshot_download(repo_id="isotr0py/InternLM2-1_8B-test-sql-lora") - - @pytest.fixture(scope="session") def long_context_lora_files_16k_1(): return snapshot_download(repo_id="SangBinCho/long_context_16k_testing_1") diff --git a/tests/lora/test_internlm2.py b/tests/lora/test_internlm2.py deleted file mode 100644 index 54bacff734715..0000000000000 --- a/tests/lora/test_internlm2.py +++ /dev/null @@ -1,70 +0,0 @@ -from typing import List - -import vllm -from vllm.lora.request import LoRARequest - -MODEL_PATH = "internlm/internlm2-1_8b" - -PROMPT_TEMPLATE = "[user] question: {sql_prompt}\n\n context: {context}\n\n [/user] [assistant] " # noqa: E501 - - -def do_sample(llm, lora_path: str, lora_id: int) -> List[str]: - prompts = [ - PROMPT_TEMPLATE.format( - sql_prompt= - "Which catalog publisher has published the most catalogs?", - context="CREATE TABLE catalogs (catalog_publisher VARCHAR);"), - PROMPT_TEMPLATE.format( - sql_prompt= - "Remove the 'vehicle_safety_testing' table and its records.", # noqa: E501 - context= - "CREATE TABLE vehicle_safety_testing (id INT PRIMARY KEY, vehicle_model VARCHAR(255), test_score FLOAT);" # noqa: E501 - ), - PROMPT_TEMPLATE.format( - sql_prompt= - "How many marine species are found in the Southern Ocean?", # noqa: E501 - context= - "CREATE TABLE marine_species (name VARCHAR(50), common_name VARCHAR(50), location VARCHAR(50));" # noqa: E501 - ), - ] - sampling_params = vllm.SamplingParams(temperature=0, - max_tokens=64, - stop="[/assistant]") - outputs = llm.generate( - prompts, - sampling_params, - lora_request=LoRARequest(str(lora_id), lora_id, lora_path) - if lora_id else None, - ) - # Print the outputs. - generated_texts: List[str] = [] - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text.strip() - generated_texts.append(generated_text) - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - return generated_texts - - -def test_internlm2_lora(internlm2_lora_files): - # We enable enforce_eager=True here to reduce VRAM usage for lora-test CI, - # Otherwise, the lora-test will fail due to CUDA OOM. - llm = vllm.LLM(MODEL_PATH, - trust_remote_code=True, - max_model_len=1024, - enable_lora=True, - max_loras=2, - enforce_eager=True) - - expected_lora_output = [ - "SELECT catalog_publisher, COUNT(*) as num_catalogs FROM catalogs GROUP BY catalog_publisher ORDER BY num_catalogs DESC LIMIT 1;", # noqa: E501 - "DROP TABLE vehicle_safety_testing;", # noqa: E501 - "SELECT COUNT(*) FROM marine_species WHERE location = 'Southern Ocean';", # noqa: E501 - ] - - output1 = do_sample(llm, internlm2_lora_files, lora_id=1) - for i in range(len(expected_lora_output)): - assert output1[i].startswith(expected_lora_output[i]) - output2 = do_sample(llm, internlm2_lora_files, lora_id=2) - for i in range(len(expected_lora_output)): - assert output2[i].startswith(expected_lora_output[i])