Skip to content

Commit

Permalink
feat: Segment cache by active LoRAs; change key format
Browse files Browse the repository at this point in the history
  • Loading branch information
richdougherty committed Nov 6, 2024
1 parent 5dc0a1e commit d434c77
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 34 deletions.
12 changes: 8 additions & 4 deletions llama_cpp/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from .llama_types import *
from .llama_grammar import LlamaGrammar
from .llama_cache import (
LlamaCacheKey,
BaseLlamaCache,
LlamaCache, # type: ignore
LlamaDiskCache, # type: ignore
Expand Down Expand Up @@ -407,7 +408,7 @@ def __init__(
# Dict from LoRA path to wrapper
self._lora_adapters_paths: Dict[str, internals.LlamaLoraAdapter] = {}
# Immutable value representing active adapters for use as a key
self._lora_adapters_active: Tuple[Tuple[str, float]] = ()
self._lora_adapters_active: Tuple[Tuple[str, float], ...] = ()

if self.lora_adapters:
for lora_path, scale in self.lora_adapters.copy().items():
Expand Down Expand Up @@ -1315,7 +1316,8 @@ def logit_bias_processor(

if self.cache:
try:
cache_item = self.cache[(self._lora_adapters_active, prompt_tokens)]
cache_key = LlamaCacheKey(active_lora_adapters=self._lora_adapters_active, tokens=tuple(prompt_tokens))
cache_item = self.cache[cache_key]
cache_prefix_len = Llama.longest_token_prefix(
cache_item.input_ids.tolist(), prompt_tokens
)
Expand Down Expand Up @@ -1653,15 +1655,17 @@ def logit_bias_processor(
if self.cache:
if self.verbose:
print("Llama._create_completion: cache save", file=sys.stderr)
self.cache[(self._lora_adapters_active, prompt_tokens + completion_tokens)] = self.save_state()
cache_key = LlamaCacheKey(active_lora_adapters=self._lora_adapters_active, tokens=tuple(prompt_tokens + completion_tokens))
self.cache[cache_key] = self.save_state()
if self.verbose:
print("Llama._create_completion: cache saved", file=sys.stderr)
return

if self.cache:
if self.verbose:
print("Llama._create_completion: cache save", file=sys.stderr)
self.cache[(self._lora_adapters_active, prompt_tokens + completion_tokens)] = self.save_state()
cache_key = LlamaCacheKey(active_lora_adapters=self._lora_adapters_active, tokens=tuple(prompt_tokens + completion_tokens))
self.cache[cache_key] = self.save_state()

text_str = text.decode("utf-8", errors="ignore")

Expand Down
120 changes: 90 additions & 30 deletions llama_cpp/llama_cache.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sys
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import (
Optional,
Sequence,
Expand All @@ -13,38 +14,93 @@

from .llama_types import *

@dataclass(eq=True, frozen=True)
class LlamaCacheKey:
"""A key in a LlamaCache. Stores tokens to key by. Also stores
information about active LoRA adapters, because we need different
cached values for different active adapters, even for the same tokens."""
active_lora_adapters: Tuple[Tuple[str, float], ...]
tokens: Tuple[int, ...]

def __post_init__(self):
if not isinstance(self.tokens, tuple):
raise ValueError("tokens must be a tuple")

class BaseLlamaCache(ABC):
"""Base cache class for a llama.cpp model."""

def __init__(self, capacity_bytes: int = (2 << 30)):
self.capacity_bytes = capacity_bytes

def _convert_to_cache_key(self, key: Union[Sequence[int], LlamaCacheKey]) -> LlamaCacheKey:
"""Convert raw tokens to a key if needed"""
if type(key) == LlamaCacheKey:
return key
else:
return LlamaCacheKey(active_lora_adapters=(), tokens=tuple(key))

@property
@abstractmethod
def cache_size(self) -> int:
raise NotImplementedError

def _find_longest_prefix_key(
self,
key: Tuple[int, ...],
) -> Optional[Tuple[int, ...]]:
key: LlamaCacheKey,
) -> Optional[LlamaCacheKey]:
"""Find the cached key with the longest matching token prefix. A match also requires that the active
LoRA adapters match exactly.
Args:
key (LlamaCacheKey): The key to find a prefix match for.
Returns:
Optional[LlamaCacheKey]: The key with the longest matching prefix, or None if no match found.
"""
pass

@abstractmethod
def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState":
def __getitem__(self, key: Union[Sequence[int], LlamaCacheKey]) -> "llama_cpp.llama.LlamaState":
"""Retrieve a cached state by key, matching on the longest common token prefix. A match also requires
that the active LoRA adapters match exactly.
Args:
key: Key to look up. Raw token sequences are supported for backwards compatibility
and assume no active LoRA adapters.
Returns:
llama_cpp.llama.LlamaState: The cached state for the entry sharing the longest token prefix.
Raises:
KeyError: If no prefix match is found.
"""
raise NotImplementedError

@abstractmethod
def __contains__(self, key: Sequence[int]) -> bool:
def __contains__(self, key: Union[Sequence[int], LlamaCacheKey]) -> bool:
"""Check if any cached key shares a token prefix with the given key.
Args:
key: Key to look up. Raw token sequences are supported for backwards compatibility
and assume no active LoRA adapters.
Returns:
bool: True if any cached key shares a token prefix with this key.
"""
raise NotImplementedError

@abstractmethod
def __setitem__(
self, key: Sequence[int], value: "llama_cpp.llama.LlamaState"
self, key: Union[Sequence[int], LlamaCacheKey], value: "llama_cpp.llama.LlamaState"
) -> None:
raise NotImplementedError
"""Store a state keyed on its tokens and information about active LoRA adapters.
Args:
key: Key to store. Raw token sequences are supported for backwards compatibility
and assume no active LoRA adapters
value: The state to cache
"""
raise NotImplementedError

class LlamaRAMCache(BaseLlamaCache):
"""Cache for a llama.cpp model using RAM."""
Expand All @@ -53,7 +109,7 @@ def __init__(self, capacity_bytes: int = (2 << 30)):
super().__init__(capacity_bytes)
self.capacity_bytes = capacity_bytes
self.cache_state: OrderedDict[
Tuple[int, ...], "llama_cpp.llama.LlamaState"
LlamaCacheKey, "llama_cpp.llama.LlamaState"
] = OrderedDict()

@property
Expand All @@ -62,34 +118,33 @@ def cache_size(self):

def _find_longest_prefix_key(
self,
key: Tuple[int, ...],
) -> Optional[Tuple[int, ...]]:
key: LlamaCacheKey,
) -> Optional[LlamaCacheKey]:
min_len = 0
min_key = None
keys = (
(k, llama_cpp.llama.Llama.longest_token_prefix(k, key))
for k in self.cache_state.keys()
)
for k, prefix_len in keys:
min_key: Optional[LlamaCacheKey] = None
for k in self.cache_state.keys():
if k.active_lora_adapters != key.active_lora_adapters: continue
if len(k.tokens) < min_len: continue # Optimization
prefix_len = llama_cpp.llama.Llama.longest_token_prefix(k.tokens, key.tokens)
if prefix_len > min_len:
min_len = prefix_len
min_key = k
return min_key

def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState":
key = tuple(key)
def __getitem__(self, key: Union[Sequence[int], LlamaCacheKey]) -> "llama_cpp.llama.LlamaState":
key = self._convert_to_cache_key(key)
_key = self._find_longest_prefix_key(key)
if _key is None:
raise KeyError("Key not found")
value = self.cache_state[_key]
self.cache_state.move_to_end(_key)
return value

def __contains__(self, key: Sequence[int]) -> bool:
def __contains__(self, key: Union[Sequence[int], LlamaCacheKey]) -> bool:
return self._find_longest_prefix_key(tuple(key)) is not None

def __setitem__(self, key: Sequence[int], value: "llama_cpp.llama.LlamaState"):
key = tuple(key)
def __setitem__(self, key: Union[Sequence[int], LlamaCacheKey], value: "llama_cpp.llama.LlamaState"):
key = self._convert_to_cache_key(key)
if key in self.cache_state:
del self.cache_state[key]
self.cache_state[key] = value
Expand All @@ -116,19 +171,24 @@ def cache_size(self):

def _find_longest_prefix_key(
self,
key: Tuple[int, ...],
) -> Optional[Tuple[int, ...]]:
key: LlamaCacheKey,
) -> Optional[LlamaCacheKey]:
min_len = 0
min_key: Optional[Tuple[int, ...]] = None
for k in self.cache.iterkeys(): # type: ignore
prefix_len = llama_cpp.llama.Llama.longest_token_prefix(k, key)
if not isinstance(k, LlamaCacheKey):
print("LlamaDiskCache: Disk cache keys must be LlamaCacheKey objects: skipping")
continue
if k.active_lora_adapters != key.active_lora_adapters: continue
if len(k.tokens) < min_len: continue # Optimization
prefix_len = llama_cpp.llama.Llama.longest_token_prefix(k.tokens, key.tokens)
if prefix_len > min_len:
min_len = prefix_len
min_key = k # type: ignore
min_key = k
return min_key

def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState":
key = tuple(key)
def __getitem__(self, key: Union[Sequence[int], LlamaCacheKey]) -> "llama_cpp.llama.LlamaState":
key = self._convert_to_cache_key(key)
_key = self._find_longest_prefix_key(key)
if _key is None:
raise KeyError("Key not found")
Expand All @@ -138,12 +198,12 @@ def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState":
# self.cache.push(_key, side="front") # type: ignore
return value

def __contains__(self, key: Sequence[int]) -> bool:
return self._find_longest_prefix_key(tuple(key)) is not None
def __contains__(self, key: Union[Sequence[int], LlamaCacheKey]) -> bool:
return self._find_longest_prefix_key(self._convert_to_cache_key(key)) is not None

def __setitem__(self, key: Sequence[int], value: "llama_cpp.llama.LlamaState"):
def __setitem__(self, key: Union[Sequence[int], LlamaCacheKey], value: "llama_cpp.llama.LlamaState"):
print("LlamaDiskCache.__setitem__: called", file=sys.stderr)
key = tuple(key)
key = self._convert_to_cache_key(key)
if key in self.cache:
print("LlamaDiskCache.__setitem__: delete", file=sys.stderr)
del self.cache[key]
Expand Down

0 comments on commit d434c77

Please sign in to comment.