Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend][Core][Enhacement] Add S3 and remote source support for LoRA adapters #12029

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions docs/source/features/lora.md
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,38 @@ $ curl http://localhost:8000/v1/models
]
}
```

## S3 Support for LoRA Adapters

vLLM now supports loading LoRA adapters directly from S3. You can specify S3 paths in the same way as local paths:

```bash
vllm serve meta-llama/Llama-2-7b-hf \
--enable-lora \
--lora-modules sql-lora=s3://my-bucket/path/to/lora-adapter
```

Or using the JSON format with base model specification:

```bash
vllm serve meta-llama/Llama-2-7b-hf \
--enable-lora \
--lora-modules '{"name": "sql-lora", "path": "s3://my-bucket/path/to/lora-adapter", "base_model_name": "meta-llama/Llama-2-7b"}'
```

For dynamic loading, you can use the S3 path directly:

```bash
curl -X POST http://localhost:8000/v1/load_lora_adapter \
-H "Content-Type: application/json" \
-d '{
"lora_name": "sql_adapter",
"lora_path": "s3://my-bucket/path/to/lora-adapter"
}'
```

The S3 implementation:
- Streams data directly from S3 to memory to avoid disk I/O
- Handles large files efficiently with chunked downloading
- Supports both safetensors and .bin formats
- Requires proper AWS credentials configuration
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,9 @@ def _read_requirements(filename: str) -> List[str]:
"tensorizer": ["tensorizer>=2.9.0"],
"runai": ["runai-model-streamer", "runai-model-streamer-s3", "boto3"],
"audio": ["librosa", "soundfile"], # Required for audio processing
"video": ["decord"] # Required for video processing
"video": ["decord"], # Required for video processing
"s3": ["boto3>=1.26.0",
"botocore>=1.29.0"], # Optional S3 support for LoRA adapters
},
cmdclass=cmdclass,
package_data=package_data,
Expand Down
43 changes: 37 additions & 6 deletions vllm/entrypoints/openai/serving_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
UnloadLoraAdapterRequest)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.lora.sources import LoRASourceError, S3LoRASource
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.utils import AtomicCounter

Expand Down Expand Up @@ -85,13 +86,43 @@ async def init_static_loras(self):
Raises if any fail to load"""
if self.static_lora_modules is None:
return

for lora in self.static_lora_modules:
load_request = LoadLoraAdapterRequest(lora_path=lora.path,
lora_name=lora.name)
load_result = await self.load_lora_adapter(
request=load_request, base_model_name=lora.base_model_name)
if isinstance(load_result, ErrorResponse):
raise ValueError(load_result.message)
try:
# Create LoRA source if S3 path
if lora.path.startswith("s3://"):
source = S3LoRASource(lora.path)
# Get local path after downloading
lora_path = source.get_local_path()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is there s3-specific logic here in the frontend? Doesn't the engine's lora cache handle the downloading from s3 urls?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMHO, we should support downloading lora model from different storage or hubs, like PR #10762. I wonder if we could implement this in a more extensible way, like OOT

else:
lora_path = lora.path

load_request = LoadLoraAdapterRequest(lora_path=lora_path,
lora_name=lora.name)

load_result = await self.load_lora_adapter(
request=load_request, base_model_name=lora.base_model_name)

if isinstance(load_result, ErrorResponse):
raise ValueError(load_result.message)

except LoRASourceError as e:
# Preserve original error details
if hasattr(e, 'original_error'):
logger.error("Failed to load LoRA %s: %s, caused by: %s",
lora.name, str(e), str(e.original_error))
else:
logger.error("Failed to load LoRA %s: %s", lora.name,
str(e))
raise ValueError(
f"Failed to load LoRA adapter {lora.name}: {str(e)}"
) from e
except Exception as e:
logger.error("Unexpected error loading LoRA %s: %s", lora.name,
str(e))
raise ValueError(
f"Failed to load LoRA adapter {lora.name}: {str(e)}"
) from e

def is_base_model(self, model_name):
return any(model.name == model_name for model in self.base_model_paths)
Expand Down
114 changes: 85 additions & 29 deletions vllm/lora/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.peft_helper import PEFTHelper
from vllm.lora.punica_wrapper import get_punica_wrapper
from vllm.lora.sources import LoRASource, LoRASourceError
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
is_regex_target_modules,
parse_fine_tuned_lora_name, replace_submodule)
Expand Down Expand Up @@ -178,7 +179,7 @@ def from_lora_tensors(
@classmethod
def from_local_checkpoint(
cls,
lora_dir: str,
lora_dir_or_source: Union[str, LoRASource],
expected_lora_modules: List[str],
*,
max_position_embeddings: Optional[int] = None,
Expand All @@ -190,67 +191,122 @@ def from_local_checkpoint(
embedding_padding_modules: Optional[List[str]] = None,
weights_mapper: Optional[WeightsMapper] = None,
) -> "LoRAModel":
"""Create a LoRAModel from a local checkpoint.
"""Create a LoRAModel from a local checkpoint or source.

Args:
lora_dir: The local path that has lora data.
expected_lora_modules: Name of modules that are expected to be
replaced by lora.
max_position_embeddings: Max position embedding length. Used to
scaling the largest context length. If None, the lora model's
context length is not scaled.
lora_model_id: Lora model id. If not given, automatically set by
a global counter.
device: Device where the lora model is loaded.
dtype: dtype of the lora model weights.

lora_dir_or_source: Either local path or LoRASource instance
expected_lora_modules: Name of modules expected to be replaced
max_position_embeddings: Max position embedding length
lora_model_id: Lora model id
device: Device where lora model is loaded
dtype: dtype of lora model weights

Returns:
Loaded LoRA Model.
Loaded LoRA Model
"""
if isinstance(lora_dir_or_source, LoRASource):
try:
# Direct loading from source
config = lora_dir_or_source.get_config()
config[
"vllm_max_position_embeddings"] = max_position_embeddings
peft_helper = PEFTHelper.from_dict(config)

# Get tensors and validate modules
tensors = lora_dir_or_source.get_tensors()
embeddings = lora_dir_or_source.get_embeddings(
) # Optional, can be None

# Validate modules (reuse existing validation logic)
validation_modules = []
for tensor_name in tensors:
module_name, _, _ = parse_fine_tuned_lora_name(
tensor_name, weights_mapper)
part_name = module_name.split(".")[-1]
if part_name not in expected_lora_modules:
validation_modules.append(module_name)

if validation_modules and not is_regex_target_modules(
peft_helper.target_modules, expected_lora_modules):
raise ValueError(
f"Expected target modules in {expected_lora_modules}"
f" but received {validation_modules}."
" Please verify that the loaded LoRA module is correct"
)

return cls.from_lora_tensors(
lora_model_id=get_lora_id()
if lora_model_id is None else lora_model_id,
tensors=tensors,
peft_helper=peft_helper,
device=device,
dtype=dtype,
embeddings=embeddings,
target_embedding_padding=target_embedding_padding,
embedding_modules=embedding_modules,
embedding_padding_modules=embedding_padding_modules,
weights_mapper=weights_mapper)
except LoRASourceError as e:
# Preserve the original error details
if hasattr(e, 'original_error') and e.original_error:
logger.error("LoRA loading error: %s, caused by: %s",
str(e), str(e.original_error))
else:
logger.error("LoRA loading error: %s", str(e))
raise ValueError(
f"Failed to load LoRA adapter: {str(e)}") from e
except Exception as e:
logger.error("Unexpected error loading LoRA: %s", str(e))
raise ValueError(
f"Failed to load LoRA adapter: {str(e)}") from e

# Existing local path logic
lora_dir = lora_dir_or_source
lora_config_path = os.path.join(lora_dir, "adapter_config.json")
lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
new_embeddings_tensor_path = os.path.join(
lora_dir, "new_embeddings.safetensors")
new_embeddings_bin_file_path = os.path.join(lora_dir,
"new_embeddings.bin")

with open(lora_config_path) as f:
config = json.load(f)

config["vllm_max_position_embeddings"] = max_position_embeddings
peft_helper = PEFTHelper.from_dict(config)
unexpected_modules: List[Union[list[str], str]]

if os.path.isfile(lora_tensor_path):
tensors: Dict[str, torch.Tensor] = {}
model_tensors: Dict[str, torch.Tensor] = {}
# Find unexpected modules.
# Use safetensor key as a source of truth to find expected modules.
# in peft if you have target_modules A, B, C and C does not exist
# in the model it wont error and model will be trained with A, B
# loraified. C wont exist in the safetensor but it will exist in
# in the model it won't error and model will be trained with A, B
# loraified. C won't exist in the safetensor but it will exist in
# the target_modules of the adapter_config.json.
unexpected_modules = []
found_modules = []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the name unexpected_modules conveyed the meaning that these are the modules that we did not expect to find.
The name found_modules I think is a bit misleading, since it's not all the modules that we found. It's actually found_modules - expected_lora_modules

with safetensors.safe_open(lora_tensor_path,
framework="pt") as f: # type: ignore
for lora_module in f.keys(): # noqa
module_name, _, _ = parse_fine_tuned_lora_name(
lora_module, weights_mapper)
part_name = module_name.split(".")[-1]
if part_name not in expected_lora_modules:
unexpected_modules.append(module_name)
if unexpected_modules:
found_modules.append(module_name)
if found_modules:
raise ValueError(
f"While loading {lora_dir}, expected"
f" target modules in {expected_lora_modules}"
f" but received {unexpected_modules}."
f" but received {found_modules}."
f" Please verify that the loaded LoRA module is correct"
)
# Load tensors if there are only expected modules.
for module in f.keys(): # noqa
tensors[module] = f.get_tensor(module)
model_tensors[module] = f.get_tensor(module)
elif os.path.isfile(lora_bin_file_path):
# When a bin file is provided, we rely on config to find unexpected
# modules.
unexpected_modules = []
found_modules = []
target_modules = peft_helper.target_modules
if not isinstance(target_modules, list):
target_modules = [target_modules]
Expand All @@ -259,19 +315,19 @@ def from_local_checkpoint(
# such as:layers.11.self_attn.k_proj
part_name = module.split(".")[-1]
if part_name not in expected_lora_modules:
unexpected_modules.append(module)
found_modules.append(module)
# loaded lora's target modules must be a subset of
# expected_lora_modules. It is not reliable. See
# https://github.com/vllm-project/vllm/pull/5909. But there's no
# other better mechanism.
if unexpected_modules and not is_regex_target_modules(
if found_modules and not is_regex_target_modules(
peft_helper.target_modules, expected_lora_modules):
raise ValueError(
f"While loading {lora_dir}, expected"
f" target modules in {expected_lora_modules}"
f" but received {unexpected_modules}."
f" but received {found_modules}."
f" Please verify that the loaded LoRA module is correct")
tensors = torch.load(lora_bin_file_path, map_location=device)
model_tensors = torch.load(lora_bin_file_path, map_location=device)
else:
raise ValueError(f"{lora_dir} doesn't contain tensors")

Expand All @@ -286,7 +342,7 @@ def from_local_checkpoint(
return cls.from_lora_tensors(
lora_model_id=get_lora_id()
if lora_model_id is None else lora_model_id,
tensors=tensors,
tensors=model_tensors,
peft_helper=peft_helper,
device=device,
dtype=dtype,
Expand Down
Loading
Loading