Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support LoRA hotswapping and multiple LoRAs at a time #1817

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

- feat: Add hot-swapping for LoRA adapters

## [0.3.2]

- feat: Update llama.cpp to ggerganov/llama.cpp@74d73dc85cc2057446bf63cc37ff649ae7cebd80
Expand Down
1 change: 1 addition & 0 deletions docs/api-reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ High-level Python bindings for llama.cpp.
- __call__
- create_chat_completion
- create_chat_completion_openai_v1
- set_lora_adapter_scale
- set_cache
- save_state
- load_state
Expand Down
74 changes: 56 additions & 18 deletions examples/low_level_api/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import re

from dataclasses import dataclass, field
from typing import List

# Based on https://github.com/ggerganov/llama.cpp/blob/master/examples/common.cpp
from typing import List, Sequence, Tuple
import typing

# Based on https://github.com/ggerganov/llama.cpp/blob/master/common/common.cpp
# and https://github.com/ggerganov/llama.cpp/blob/master/common/arg.cpp

@dataclass
class GptParams:
Expand Down Expand Up @@ -40,8 +41,8 @@ class GptParams:
input_suffix: str = ""
antiprompt: List[str] = field(default_factory=list)

lora_adapter: str = ""
lora_base: str = ""
lora: List[str] = None
lora_scaled: List[Tuple[str, float]] = None

memory_f16: bool = True
random_prompt: bool = False
Expand Down Expand Up @@ -257,16 +258,56 @@ def gpt_params_parse(argv=None):
parser.add_argument(
"--lora",
type=str,
default="",
help="apply LoRA adapter (implies --no-mmap)",
dest="lora_adapter",
)
parser.add_argument(
"--lora-base",
type=str,
default="",
help="optional model to use as a base for the layers modified by the LoRA adapter",
dest="lora_base",
action="append",
default=[],
help="path to LoRA adapter (can be repeated to use multiple adapters)",
metavar="FNAME",
dest="lora",
)

class MultiTupleAction(argparse.Action):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needed this fancy arg parse action to match the llama.cpp argument format which takes two arguments:

https://github.com/ggerganov/llama.cpp/blob/master/common/arg.cpp#L1546-L1551

"""Action for handling multiple arguments as tuples with type conversion"""
def __init__(self,
option_strings: Sequence[str],
dest: str,
nargs: int = None,
type: Tuple = None,
metavar: Tuple = None,
**kwargs):
self.tuple_type = type
super().__init__(
option_strings=option_strings,
dest=dest,
type=str, # We will fix
nargs=nargs,
metavar=metavar,
**kwargs
)

def __call__(self, parser, namespace, values, option_string=None):
if len(values) != self.nargs:
parser.error(
f'{option_string} requires {len(self.metavar)} arguments: '
f'{" ".join(self.metavar)}'
)

converted_values = tuple(value_type(value) for value_type, value in zip(typing.get_args(self.tuple_type), values))
# Initialize list if needed
if not hasattr(namespace, self.dest):
setattr(namespace, self.dest, [])

# Add the converted tuple to the list
getattr(namespace, self.dest).append(converted_values)

parser.add_argument(
"--lora-scaled",
action=MultiTupleAction,
nargs=2,
type=Tuple[str, float],
help="path to LoRA adapter with user defined scaling (can be repeated to use multiple adapters)",
metavar=('FNAME', 'SCALE'),
dest='lora_scaled',
default=[],
)

parser.add_argument(
Expand Down Expand Up @@ -375,9 +416,6 @@ def gpt_params_parse(argv=None):
delattr(args, "logit_bias_str")
params = GptParams(**vars(args))

if params.lora_adapter:
params.use_mmap = False

if logit_bias_str != None:
for i in logit_bias_str:
if m := re.match(r"(\d+)([-+]\d+)", i):
Expand Down
24 changes: 8 additions & 16 deletions examples/low_level_api/low_level_api_chat_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,22 +93,14 @@ def __init__(self, params: GptParams) -> None:
if self.params.ignore_eos:
self.params.logit_bias[llama_cpp.llama_token_eos()] = -float("inf")

if len(self.params.lora_adapter) > 0:
if (
llama_cpp.llama_apply_lora_from_file(
self.ctx,
self.params.lora_adapter.encode("utf8"),
(
self.params.lora_base.encode("utf8")
if len(self.params.lora_base) > 0
else None
),
self.params.n_threads,
)
!= 0
):
print("error: failed to apply lora adapter")
return
for lora_path, scale in [(pth, 1.0) for pth in self.params.lora] + self.params.lora_scaled:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't test this extensively, but this code at least worked this far on - the actual example failed later for me for unrelated reasons.

lora_adapter = llama_cpp.llama_lora_adapter_init(
self.model,
lora_path.encode("utf8"))
if lora_adapter is None:
raise RuntimeError(f"error: failed to load lora adapter '{lora_path}'")
if scale != 0.0:
llama_cpp.llama_lora_adapter_set(self.ctx, lora_adapter, scale)

print(file=sys.stderr)
print(
Expand Down
54 changes: 54 additions & 0 deletions llama_cpp/_internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,18 @@ def kv_cache_seq_keep(self, seq_id: int):
def kv_cache_seq_shift(self, seq_id: int, p0: int, p1: int, shift: int):
llama_cpp.llama_kv_cache_seq_add(self.ctx, seq_id, p0, p1, shift)

def lora_adapter_set(self, adapter: LlamaLoraAdapter, scale: float):
return_code = llama_cpp.llama_lora_adapter_set(self.ctx, adapter.lora_adapter, scale)
if return_code != 0:
raise RuntimeError(f"lora_adapter_set returned {return_code}")

def lora_adapter_remove(self, adapter: LlamaLoraAdapter) -> bool:
return_code = llama_cpp.llama_lora_adapter_remove(self.ctx, adapter.lora_adapter)
return return_code != 0

def lora_adapter_clear(self):
llama_cpp.llama_lora_adapter_clear(self.ctx)

def get_state_size(self) -> int:
return llama_cpp.llama_get_state_size(self.ctx)

Expand Down Expand Up @@ -861,3 +873,45 @@ def close(self):

def __del__(self):
self.close()

class LlamaLoraAdapter:
"""Intermediate Python wrapper for a llama.cpp llama_lora_adapter.
NOTE: For stability it's recommended you use the Llama class instead."""

def __init__(
self,
model: LlamaModel,
lora_path: str,
*,
verbose: bool = True,
):
self.model = model
self.lora_path = lora_path

lora_adapter = None

if not os.path.exists(lora_path):
raise ValueError(f"LoRA adapter path does not exist: {lora_path}")

with suppress_stdout_stderr(disable=verbose):
lora_adapter = llama_cpp.llama_lora_adapter_init(
self.model.model,
self.lora_path.encode("utf-8"),
)

if lora_adapter is None:
raise RuntimeError(
f"Failed to initialize LoRA adapter from lora path: {self.lora_path}"
)

# The llama_lora_adapter will be freed by the llama_model as part of its
# lifecycle. The llama_model destructor destroys each llama_lora_adapter,
# and the destructor for llama_lora_adapter calls llama_lora_adapter_free.
# All we do here is clear the wrapped reference when the LlamaModel wrapper
# is closed, so that the LlamaLoraAdapter wrapper reference is cleared to
# when the llama_lora_adapters are freed.
def clear_lora_adapter():
self.lora_adapter = None
self.model._exit_stack.callback(clear_lora_adapter)
Copy link
Contributor Author

@richdougherty richdougherty Nov 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seemed to be a clean way to keep the reference back to the parent LlamaModel up to date.


self.lora_adapter = lora_adapter
Loading