-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support LoRA hotswapping and multiple LoRAs at a time #1817
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -93,22 +93,14 @@ def __init__(self, params: GptParams) -> None: | |
if self.params.ignore_eos: | ||
self.params.logit_bias[llama_cpp.llama_token_eos()] = -float("inf") | ||
|
||
if len(self.params.lora_adapter) > 0: | ||
if ( | ||
llama_cpp.llama_apply_lora_from_file( | ||
self.ctx, | ||
self.params.lora_adapter.encode("utf8"), | ||
( | ||
self.params.lora_base.encode("utf8") | ||
if len(self.params.lora_base) > 0 | ||
else None | ||
), | ||
self.params.n_threads, | ||
) | ||
!= 0 | ||
): | ||
print("error: failed to apply lora adapter") | ||
return | ||
for lora_path, scale in [(pth, 1.0) for pth in self.params.lora] + self.params.lora_scaled: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I didn't test this extensively, but this code at least worked this far on - the actual example failed later for me for unrelated reasons. |
||
lora_adapter = llama_cpp.llama_lora_adapter_init( | ||
self.model, | ||
lora_path.encode("utf8")) | ||
if lora_adapter is None: | ||
raise RuntimeError(f"error: failed to load lora adapter '{lora_path}'") | ||
if scale != 0.0: | ||
llama_cpp.llama_lora_adapter_set(self.ctx, lora_adapter, scale) | ||
|
||
print(file=sys.stderr) | ||
print( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -285,6 +285,18 @@ def kv_cache_seq_keep(self, seq_id: int): | |
def kv_cache_seq_shift(self, seq_id: int, p0: int, p1: int, shift: int): | ||
llama_cpp.llama_kv_cache_seq_add(self.ctx, seq_id, p0, p1, shift) | ||
|
||
def lora_adapter_set(self, adapter: LlamaLoraAdapter, scale: float): | ||
return_code = llama_cpp.llama_lora_adapter_set(self.ctx, adapter.lora_adapter, scale) | ||
if return_code != 0: | ||
raise RuntimeError(f"lora_adapter_set returned {return_code}") | ||
|
||
def lora_adapter_remove(self, adapter: LlamaLoraAdapter) -> bool: | ||
return_code = llama_cpp.llama_lora_adapter_remove(self.ctx, adapter.lora_adapter) | ||
return return_code != 0 | ||
|
||
def lora_adapter_clear(self): | ||
llama_cpp.llama_lora_adapter_clear(self.ctx) | ||
|
||
def get_state_size(self) -> int: | ||
return llama_cpp.llama_get_state_size(self.ctx) | ||
|
||
|
@@ -861,3 +873,45 @@ def close(self): | |
|
||
def __del__(self): | ||
self.close() | ||
|
||
class LlamaLoraAdapter: | ||
"""Intermediate Python wrapper for a llama.cpp llama_lora_adapter. | ||
NOTE: For stability it's recommended you use the Llama class instead.""" | ||
|
||
def __init__( | ||
self, | ||
model: LlamaModel, | ||
lora_path: str, | ||
*, | ||
verbose: bool = True, | ||
): | ||
self.model = model | ||
self.lora_path = lora_path | ||
|
||
lora_adapter = None | ||
|
||
if not os.path.exists(lora_path): | ||
raise ValueError(f"LoRA adapter path does not exist: {lora_path}") | ||
|
||
with suppress_stdout_stderr(disable=verbose): | ||
lora_adapter = llama_cpp.llama_lora_adapter_init( | ||
self.model.model, | ||
self.lora_path.encode("utf-8"), | ||
) | ||
|
||
if lora_adapter is None: | ||
raise RuntimeError( | ||
f"Failed to initialize LoRA adapter from lora path: {self.lora_path}" | ||
) | ||
|
||
# The llama_lora_adapter will be freed by the llama_model as part of its | ||
# lifecycle. The llama_model destructor destroys each llama_lora_adapter, | ||
# and the destructor for llama_lora_adapter calls llama_lora_adapter_free. | ||
# All we do here is clear the wrapped reference when the LlamaModel wrapper | ||
# is closed, so that the LlamaLoraAdapter wrapper reference is cleared to | ||
# when the llama_lora_adapters are freed. | ||
def clear_lora_adapter(): | ||
self.lora_adapter = None | ||
self.model._exit_stack.callback(clear_lora_adapter) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seemed to be a clean way to keep the reference back to the parent |
||
|
||
self.lora_adapter = lora_adapter |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needed this fancy arg parse action to match the llama.cpp argument format which takes two arguments:
https://github.com/ggerganov/llama.cpp/blob/master/common/arg.cpp#L1546-L1551