Skip to content

Commit

Permalink
[Frontend][Core][Enhacement] Add S3 and remote source support for LoR…
Browse files Browse the repository at this point in the history
…A adapters

- Expanded documentation in `lora.md` to clarify adapter path options, including local paths, Hugging Face identifiers, and S3 URIs.
- Refactored `LoRAModel` to asynchronously load adapters from remote sources, enhancing flexibility and performance.
- Introduced `resolve_adapter_path` utility to manage adapter path resolution across different sources.
- Improved error handling in `WorkerLoRAManager` for better debugging and user feedback.
- Added thread-safe S3 utilities for managing S3 interactions.

This update enhances the usability and robustness of the LoRA model management system, allowing for more versatile adapter configurations.

Signed-off-by: Prashant Patel <[email protected]>
  • Loading branch information
Prashant18 committed Jan 14, 2025
1 parent f35ec46 commit e660ed3
Show file tree
Hide file tree
Showing 14 changed files with 1,096 additions and 95 deletions.
39 changes: 35 additions & 4 deletions docs/source/features/lora.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ llm = LLM(model="meta-llama/Llama-2-7b-hf", enable_lora=True)

We can now submit the prompts and call `llm.generate` with the `lora_request` parameter. The first parameter
of `LoRARequest` is a human identifiable name, the second parameter is a globally unique ID for the adapter and
the third parameter is the path to the LoRA adapter.
the third parameter is the path to the LoRA adapter. The adapter path can be:
- A local filesystem path
- A Hugging Face model identifier
- An S3 URI (e.g., s3://bucket/path/to/adapter)

```python
sampling_params = SamplingParams(
Expand All @@ -36,15 +39,31 @@ sampling_params = SamplingParams(
)

prompts = [
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]",
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]",
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]",
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]",
]

# Using a local or Hugging Face adapter
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest("sql_adapter", 1, sql_lora_path)
)

# Using an S3 adapter (requires AWS credentials)
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest("s3_adapter", 2, "s3://my-bucket/path/to/adapter")
)
```

When using S3-hosted adapters, configure AWS credentials via environment variables:
```bash
export AWS_ACCESS_KEY_ID=your_key_id
export AWS_SECRET_ACCESS_KEY=your_secret_key
export AWS_DEFAULT_REGION=us-east-1
export AWS_ENDPOINT_URL=custom_endpoint # Optional: For S3-compatible storage
```

Check out <gh-file:examples/offline_inference/multilora_inference.py> for an example of how to use LoRA adapters with the async engine and how to use more advanced configuration options.
Expand All @@ -60,6 +79,13 @@ vllm serve meta-llama/Llama-2-7b-hf \
--lora-modules sql-lora=$HOME/.cache/huggingface/hub/models--yard1--llama-2-7b-sql-lora-test/snapshots/0dfa347e8877a4d4ed19ee56c140fa518470028c/
```

For S3-hosted adapters, use the JSON format with the S3 URI:
```bash
vllm serve meta-llama/Llama-2-7b-hf \
--enable-lora \
--lora-modules '{"name": "sql-lora", "path": "s3://my-bucket/path/to/adapter", "base_model_name": "meta-llama/Llama-2-7b"}'
```

```{note}
The commit ID `0dfa347e8877a4d4ed19ee56c140fa518470028c` may change over time. Please check the latest commit ID in your environment to ensure you are using the correct one.
```
Expand Down Expand Up @@ -131,7 +157,7 @@ curl -X POST http://localhost:8000/v1/load_lora_adapter \
-H "Content-Type: application/json" \
-d '{
"lora_name": "sql_adapter",
"lora_path": "/path/to/sql-lora-adapter"
"lora_path": "/path/to/sql-lora-adapter" # Can be local path, HF ID, or S3 URI
}'
```

Expand Down Expand Up @@ -168,6 +194,11 @@ Now, you can specify a base_model_name alongside the name and path using JSON fo
--lora-modules '{"name": "sql-lora", "path": "/path/to/lora", "base_model_name": "meta-llama/Llama-2-7b"}'
```

The path can be:
- A local filesystem path
- A Hugging Face model identifier
- An S3 URI (s3://bucket/path)

To provide the backward compatibility support, you can still use the old key-value format (name=path), but the `base_model_name` will remain unspecified in that case.

## Lora model lineage in model card
Expand Down
166 changes: 166 additions & 0 deletions tests/lora/test_sources.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
from pathlib import Path

import pytest

from vllm.lora.sources.errors import AdapterNotFoundError, SourceError
from vllm.lora.sources.local import LocalSource
from vllm.lora.sources.protocol import AdapterSource, SourceMetadata
from vllm.lora.sources.registry import AdapterSourceRegistry


@pytest.fixture
def temp_adapter(tmp_path) -> Path:
"""Create a temporary adapter file."""
adapter_path = tmp_path / "test_adapter"
adapter_path.write_text("mock adapter content")
return adapter_path


@pytest.fixture
def source_registry() -> AdapterSourceRegistry:
"""Create a test source registry."""
return AdapterSourceRegistry()


class TestSourceProtocol:

def test_source_metadata(self):
"""Test source metadata creation and validation."""
metadata = SourceMetadata(adapter_id="test",
size_bytes=1000,
format_version="1.0",
properties={"type": "test"})
assert metadata.adapter_id == "test"
assert metadata.size_bytes == 1000

def test_source_validation(self):
"""Test source URI validation."""
source = LocalSource()
assert source.can_handle("file:///path/to/adapter")
assert not source.can_handle("s3://bucket/adapter")


class TestSourceRegistry:

def test_register_source(self, source_registry):
"""Test source registration."""
source = LocalSource()
source_registry.register("local", source)
assert source_registry.get_source("local") == source

def test_source_resolution(self, source_registry):
"""Test source resolution by URI."""
source = LocalSource()
source_registry.register("local", source)
resolved = source_registry.resolve_uri("file:///path/to/adapter")
assert isinstance(resolved, LocalSource)

def test_invalid_source(self, source_registry):
"""Test handling of invalid sources."""
with pytest.raises(SourceError):
source_registry.get_source("nonexistent")


class TestLocalSource:

@pytest.mark.asyncio
async def test_local_adapter_load(self, temp_adapter):
"""Test loading a local adapter."""
source = LocalSource()
path = await source.get_adapter(str(temp_adapter))
assert path.exists()
assert path.read_text() == "mock adapter content"

@pytest.mark.asyncio
async def test_local_metadata(self, temp_adapter):
"""Test local adapter metadata."""
source = LocalSource()
metadata = await source.get_metadata(str(temp_adapter))
assert metadata.adapter_id == temp_adapter.name
assert metadata.size_bytes == len(temp_adapter.read_bytes())


class TestErrorHandling:

@pytest.mark.asyncio
async def test_missing_adapter(self):
"""Test handling of missing adapters."""
source = LocalSource()
with pytest.raises(AdapterNotFoundError):
await source.get_adapter("/nonexistent/path")

def test_invalid_uri(self, source_registry):
"""Test handling of invalid URIs."""
with pytest.raises(SourceError):
source_registry.resolve_uri("invalid://uri")


class MockSource(AdapterSource):
"""Mock source for testing source behavior."""

async def get_adapter(self, uri: str) -> Path:
if "error" in uri:
raise SourceError("Simulated error")
return Path("/mock/path")

async def get_metadata(self, uri: str) -> SourceMetadata:
return SourceMetadata(adapter_id="mock",
size_bytes=100,
format_version="1.0",
properties={})

def can_handle(self, uri: str) -> bool:
return uri.startswith("mock://")


class TestMockSource:

@pytest.mark.asyncio
async def test_mock_source(self):
"""Test mock source behavior."""
source = MockSource()
path = await source.get_adapter("mock://adapter")
assert path == Path("/mock/path")

@pytest.mark.asyncio
async def test_mock_source_error(self):
"""Test mock source error handling."""
source = MockSource()
with pytest.raises(SourceError):
await source.get_adapter("mock://error")


class TestSourceIntegration:

@pytest.mark.asyncio
async def test_source_workflow(self, source_registry, temp_adapter):
"""Test complete source workflow."""
source = LocalSource()
source_registry.register("local", source)

# Test resolution
uri = f"file://{temp_adapter}"
resolved_source = source_registry.resolve_uri(uri)

# Test loading
path = await resolved_source.get_adapter(uri)
assert path.exists()

# Test metadata
metadata = await resolved_source.get_metadata(uri)
assert metadata.adapter_id == temp_adapter.name


class TestSourcePerformance:

@pytest.mark.asyncio
async def test_concurrent_access(self, source_registry, temp_adapter):
"""Test concurrent source access."""
source = LocalSource()
source_registry.register("local", source)

import asyncio
uris = [f"file://{temp_adapter}" for _ in range(10)]
tasks = [source.get_adapter(uri) for uri in uris]
results = await asyncio.gather(*tasks)
assert all(path.exists() for path in results)
79 changes: 68 additions & 11 deletions vllm/lora/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def from_lora_tensors(
scaling_factor=peft_helper.vllm_long_context_scaling_factor)

@classmethod
def from_local_checkpoint(
async def from_remote_checkpoint(
cls,
lora_dir: str,
expected_lora_modules: List[str],
Expand All @@ -190,10 +190,10 @@ def from_local_checkpoint(
embedding_padding_modules: Optional[List[str]] = None,
weights_mapper: Optional[WeightsMapper] = None,
) -> "LoRAModel":
"""Create a LoRAModel from a local checkpoint.
"""Create a LoRAModel from a remote checkpoint.
Args:
lora_dir: The local path that has lora data.
lora_dir: The path/URI to the adapter (local, HF, S3)
expected_lora_modules: Name of modules that are expected to be
replaced by lora.
max_position_embeddings: Max position embedding length. Used to
Expand All @@ -203,17 +203,25 @@ def from_local_checkpoint(
a global counter.
device: Device where the lora model is loaded.
dtype: dtype of the lora model weights.
Returns:
Loaded LoRA Model.
"""
lora_config_path = os.path.join(lora_dir, "adapter_config.json")
lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
from vllm.lora.utils import resolve_adapter_path

# Resolve the adapter path
resolved_path = await resolve_adapter_path(lora_dir)

# Now use the resolved path to load files
lora_config_path = os.path.join(resolved_path, "adapter_config.json")
lora_tensor_path = os.path.join(resolved_path,
"adapter_model.safetensors")
lora_bin_file_path = os.path.join(resolved_path, "adapter_model.bin")
new_embeddings_tensor_path = os.path.join(
lora_dir, "new_embeddings.safetensors")
new_embeddings_bin_file_path = os.path.join(lora_dir,
resolved_path, "new_embeddings.safetensors")
new_embeddings_bin_file_path = os.path.join(resolved_path,
"new_embeddings.bin")

with open(lora_config_path) as f:
config = json.load(f)

Expand All @@ -225,8 +233,8 @@ def from_local_checkpoint(
# Find unexpected modules.
# Use safetensor key as a source of truth to find expected modules.
# in peft if you have target_modules A, B, C and C does not exist
# in the model it wont error and model will be trained with A, B
# loraified. C wont exist in the safetensor but it will exist in
# in the model it won't error and model will be trained with A, B
# loraified. C won't exist in the safetensor but it will exist in
# the target_modules of the adapter_config.json.
unexpected_modules = []
with safetensors.safe_open(lora_tensor_path,
Expand Down Expand Up @@ -296,6 +304,55 @@ def from_local_checkpoint(
embedding_padding_modules=embedding_padding_modules,
weights_mapper=weights_mapper)

@classmethod
def from_local_checkpoint(
cls,
lora_dir: str,
expected_lora_modules: List[str],
*,
max_position_embeddings: Optional[int] = None,
lora_model_id: Optional[int] = None,
device: str = "cuda",
dtype: Optional[torch.dtype] = None,
target_embedding_padding: Optional[int] = None,
embedding_modules: Optional[Dict[str, str]] = None,
embedding_padding_modules: Optional[List[str]] = None,
weights_mapper: Optional[WeightsMapper] = None,
) -> "LoRAModel":
"""Create a LoRAModel from a local checkpoint.
This is now a synchronous wrapper around from_remote_checkpoint for
backward compatibility.
Args:
lora_dir: The local path that has lora data.
expected_lora_modules: Name of modules that are expected to be
replaced by lora.
max_position_embeddings: Max position embedding length. Used to
scaling the largest context length. If None, the lora model's
context length is not scaled.
lora_model_id: Lora model id. If not given, automatically set by
a global counter.
device: Device where the lora model is loaded.
dtype: dtype of the lora model weights.
Returns:
Loaded LoRA Model.
"""
import asyncio
return asyncio.run(
cls.from_remote_checkpoint(
lora_dir=lora_dir,
expected_lora_modules=expected_lora_modules,
max_position_embeddings=max_position_embeddings,
lora_model_id=lora_model_id,
device=device,
dtype=dtype,
target_embedding_padding=target_embedding_padding,
embedding_modules=embedding_modules,
embedding_padding_modules=embedding_padding_modules,
weights_mapper=weights_mapper))


class LoRAModelManager(AdapterModelManager):
"""A manager that manages multiple LoRA-fine-tuned models."""
Expand Down
5 changes: 5 additions & 0 deletions vllm/lora/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class LoRARequest(
lora_local_path: Optional[str] = msgspec.field(default=None)
long_lora_max_len: Optional[int] = None
base_model_name: Optional[str] = msgspec.field(default=None)
loading_time: float = 0.0 # Time taken to load the adapter

def __post_init__(self):
if 'lora_local_path' in self.__struct_fields__:
Expand Down Expand Up @@ -93,3 +94,7 @@ def __hash__(self) -> int:
identified by their names across engines.
"""
return hash(self.lora_name)

def __bool__(self) -> bool:
"""Return True if this is a valid request."""
return bool(self.lora_name and self.lora_path)
Loading

0 comments on commit e660ed3

Please sign in to comment.