-
-
Notifications
You must be signed in to change notification settings - Fork 5.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Frontend][Core][Enhacement] Add S3 and remote source support for LoR…
…A adapters - Expanded documentation in `lora.md` to clarify adapter path options, including local paths, Hugging Face identifiers, and S3 URIs. - Refactored `LoRAModel` to asynchronously load adapters from remote sources, enhancing flexibility and performance. - Introduced `resolve_adapter_path` utility to manage adapter path resolution across different sources. - Improved error handling in `WorkerLoRAManager` for better debugging and user feedback. - Added thread-safe S3 utilities for managing S3 interactions. This update enhances the usability and robustness of the LoRA model management system, allowing for more versatile adapter configurations. Signed-off-by: Prashant Patel <[email protected]>
- Loading branch information
1 parent
f35ec46
commit e660ed3
Showing
14 changed files
with
1,096 additions
and
95 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
from pathlib import Path | ||
|
||
import pytest | ||
|
||
from vllm.lora.sources.errors import AdapterNotFoundError, SourceError | ||
from vllm.lora.sources.local import LocalSource | ||
from vllm.lora.sources.protocol import AdapterSource, SourceMetadata | ||
from vllm.lora.sources.registry import AdapterSourceRegistry | ||
|
||
|
||
@pytest.fixture | ||
def temp_adapter(tmp_path) -> Path: | ||
"""Create a temporary adapter file.""" | ||
adapter_path = tmp_path / "test_adapter" | ||
adapter_path.write_text("mock adapter content") | ||
return adapter_path | ||
|
||
|
||
@pytest.fixture | ||
def source_registry() -> AdapterSourceRegistry: | ||
"""Create a test source registry.""" | ||
return AdapterSourceRegistry() | ||
|
||
|
||
class TestSourceProtocol: | ||
|
||
def test_source_metadata(self): | ||
"""Test source metadata creation and validation.""" | ||
metadata = SourceMetadata(adapter_id="test", | ||
size_bytes=1000, | ||
format_version="1.0", | ||
properties={"type": "test"}) | ||
assert metadata.adapter_id == "test" | ||
assert metadata.size_bytes == 1000 | ||
|
||
def test_source_validation(self): | ||
"""Test source URI validation.""" | ||
source = LocalSource() | ||
assert source.can_handle("file:///path/to/adapter") | ||
assert not source.can_handle("s3://bucket/adapter") | ||
|
||
|
||
class TestSourceRegistry: | ||
|
||
def test_register_source(self, source_registry): | ||
"""Test source registration.""" | ||
source = LocalSource() | ||
source_registry.register("local", source) | ||
assert source_registry.get_source("local") == source | ||
|
||
def test_source_resolution(self, source_registry): | ||
"""Test source resolution by URI.""" | ||
source = LocalSource() | ||
source_registry.register("local", source) | ||
resolved = source_registry.resolve_uri("file:///path/to/adapter") | ||
assert isinstance(resolved, LocalSource) | ||
|
||
def test_invalid_source(self, source_registry): | ||
"""Test handling of invalid sources.""" | ||
with pytest.raises(SourceError): | ||
source_registry.get_source("nonexistent") | ||
|
||
|
||
class TestLocalSource: | ||
|
||
@pytest.mark.asyncio | ||
async def test_local_adapter_load(self, temp_adapter): | ||
"""Test loading a local adapter.""" | ||
source = LocalSource() | ||
path = await source.get_adapter(str(temp_adapter)) | ||
assert path.exists() | ||
assert path.read_text() == "mock adapter content" | ||
|
||
@pytest.mark.asyncio | ||
async def test_local_metadata(self, temp_adapter): | ||
"""Test local adapter metadata.""" | ||
source = LocalSource() | ||
metadata = await source.get_metadata(str(temp_adapter)) | ||
assert metadata.adapter_id == temp_adapter.name | ||
assert metadata.size_bytes == len(temp_adapter.read_bytes()) | ||
|
||
|
||
class TestErrorHandling: | ||
|
||
@pytest.mark.asyncio | ||
async def test_missing_adapter(self): | ||
"""Test handling of missing adapters.""" | ||
source = LocalSource() | ||
with pytest.raises(AdapterNotFoundError): | ||
await source.get_adapter("/nonexistent/path") | ||
|
||
def test_invalid_uri(self, source_registry): | ||
"""Test handling of invalid URIs.""" | ||
with pytest.raises(SourceError): | ||
source_registry.resolve_uri("invalid://uri") | ||
|
||
|
||
class MockSource(AdapterSource): | ||
"""Mock source for testing source behavior.""" | ||
|
||
async def get_adapter(self, uri: str) -> Path: | ||
if "error" in uri: | ||
raise SourceError("Simulated error") | ||
return Path("/mock/path") | ||
|
||
async def get_metadata(self, uri: str) -> SourceMetadata: | ||
return SourceMetadata(adapter_id="mock", | ||
size_bytes=100, | ||
format_version="1.0", | ||
properties={}) | ||
|
||
def can_handle(self, uri: str) -> bool: | ||
return uri.startswith("mock://") | ||
|
||
|
||
class TestMockSource: | ||
|
||
@pytest.mark.asyncio | ||
async def test_mock_source(self): | ||
"""Test mock source behavior.""" | ||
source = MockSource() | ||
path = await source.get_adapter("mock://adapter") | ||
assert path == Path("/mock/path") | ||
|
||
@pytest.mark.asyncio | ||
async def test_mock_source_error(self): | ||
"""Test mock source error handling.""" | ||
source = MockSource() | ||
with pytest.raises(SourceError): | ||
await source.get_adapter("mock://error") | ||
|
||
|
||
class TestSourceIntegration: | ||
|
||
@pytest.mark.asyncio | ||
async def test_source_workflow(self, source_registry, temp_adapter): | ||
"""Test complete source workflow.""" | ||
source = LocalSource() | ||
source_registry.register("local", source) | ||
|
||
# Test resolution | ||
uri = f"file://{temp_adapter}" | ||
resolved_source = source_registry.resolve_uri(uri) | ||
|
||
# Test loading | ||
path = await resolved_source.get_adapter(uri) | ||
assert path.exists() | ||
|
||
# Test metadata | ||
metadata = await resolved_source.get_metadata(uri) | ||
assert metadata.adapter_id == temp_adapter.name | ||
|
||
|
||
class TestSourcePerformance: | ||
|
||
@pytest.mark.asyncio | ||
async def test_concurrent_access(self, source_registry, temp_adapter): | ||
"""Test concurrent source access.""" | ||
source = LocalSource() | ||
source_registry.register("local", source) | ||
|
||
import asyncio | ||
uris = [f"file://{temp_adapter}" for _ in range(10)] | ||
tasks = [source.get_adapter(uri) for uri in uris] | ||
results = await asyncio.gather(*tasks) | ||
assert all(path.exists() for path in results) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.