Skip to content

Commit

Permalink
Fixed imports
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair committed Apr 1, 2024
1 parent 837a1e3 commit 26f6d41
Show file tree
Hide file tree
Showing 8 changed files with 98 additions and 267 deletions.
92 changes: 2 additions & 90 deletions server/lorax_server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from typing import Optional
from enum import Enum

from lorax_server.utils.weights import download_weights as _download_weights


app = typer.Typer()

Expand Down Expand Up @@ -91,96 +93,6 @@ def serve(
)


def _download_weights(
model_id: str,
revision: Optional[str] = None,
extension: str = ".safetensors",
auto_convert: bool = True,
source: str = "hub",
api_token: Optional[str] = None,
):
# Import here after the logger is added to log potential import exceptions
from lorax_server import utils
from lorax_server.utils import sources
model_source = sources.get_model_source(source, model_id, revision, extension, api_token)

# Test if files were already download
try:
model_source.weight_files()
logger.info("Files are already present on the host. " "Skipping download.")
return
# Local files not found
except (utils.LocalEntryNotFoundError, FileNotFoundError):
pass

is_local_model = (Path(model_id).exists() and Path(model_id).is_dir()) or os.getenv(
"WEIGHTS_CACHE_OVERRIDE", None
) is not None

if not is_local_model:
# TODO: Combine into class that takes the source as input
# Try to download weights from the hub
try:
model_source.download_model_assets()
return
# No weights found on the hub with this extension
except utils.EntryNotFoundError as e:
# Check if we want to automatically convert to safetensors or if we can use .bin weights instead
if not extension == ".safetensors" or not auto_convert:
raise e

# Try to see if there are local pytorch weights
try:
# Get weights for a local model, a hub cached model and inside the WEIGHTS_CACHE_OVERRIDE
local_pt_files = model_source.weight_files(extension=".bin")

# No local pytorch weights
except utils.LocalEntryNotFoundError:
if extension == ".safetensors":
logger.warning(
f"No safetensors weights found for model {model_id} at revision {revision}. "
f"Downloading PyTorch weights."
)

# Try to see if there are pytorch weights on the hub
pt_filenames = model_source.remote_weight_files(extension=".bin")
# Download pytorch weights
local_pt_files = model_source.download_weights(pt_filenames)

if auto_convert:
logger.warning(
f"No safetensors weights found for model {model_id} at revision {revision}. "
f"Converting PyTorch weights to safetensors."
)

# Safetensors final filenames
local_st_files = [
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors"
for p in local_pt_files
]
try:
from transformers import AutoConfig
import transformers

config_path = sources.get_config_path(model_id, source)
config = AutoConfig.from_pretrained(
config_path,
revision=revision,
)
architecture = config.architectures[0]

class_ = getattr(transformers, architecture)

# Name for this varible depends on transformers version.
discard_names = getattr(class_, "_tied_weights_keys", [])
discard_names.extend(getattr(class_, "_keys_to_ignore_on_load_missing", []))

except Exception as e:
discard_names = []
# Convert pytorch weights to safetensors
utils.convert_files(local_pt_files, local_st_files, discard_names)


@app.command()
def download_weights(
model_id: str,
Expand Down
3 changes: 0 additions & 3 deletions server/lorax_server/models/flash_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@
LM_HEAD,
)
from lorax_server.utils import (
compute_delta_weight,
get_start_stop_idxs_for_rank,
initialize_torch_distributed,
load_module_map,
weight_files,
Weights,
)
Expand Down
1 change: 1 addition & 0 deletions server/lorax_server/models/flash_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def from_pb(
max_length = max(max_length, input_length + max_new_tokens)

adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64, device=device)
print("!!! ADAPTER INDICES", adapter_indices)

request_tokenizers = [
tokenizers.get_tokenizer(r.adapter_index, tokenizer)
Expand Down
1 change: 0 additions & 1 deletion server/lorax_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from lorax_server.adapters.utils import download_adapter
from lorax_server.cache import Cache
from lorax_server.cli import _download_weights
from lorax_server.interceptor import ExceptionInterceptor
from lorax_server.models import Model, get_model
from lorax_server.pb import generate_pb2_grpc, generate_pb2
Expand Down
4 changes: 0 additions & 4 deletions server/lorax_server/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from lorax_server.utils.adapter import (
compute_delta_weight,
create_merged_weight_files,
load_module_map,
)
from lorax_server.utils.convert import convert_file, convert_files
Expand Down Expand Up @@ -33,8 +31,6 @@
)

__all__ = [
"compute_delta_weight",
"create_merged_weight_files",
"load_module_map",
"convert_file",
"convert_files",
Expand Down
170 changes: 3 additions & 167 deletions server/lorax_server/utils/adapter.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,14 @@
from dataclasses import dataclass
import os
from collections import defaultdict
from functools import lru_cache
from pathlib import Path
from typing import TYPE_CHECKING, List, Dict, Set, Tuple
from typing import TYPE_CHECKING, Set, Tuple
import warnings

import torch
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from loguru import logger
from peft.utils import transpose
from safetensors.torch import load_file, save_file
from safetensors.torch import load_file
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
from tqdm import tqdm
from filelock import FileLock

from lorax_server.pb import generate_pb2
from lorax_server.utils.sources import get_model_source, get_config_path, weight_files
from lorax_server.utils.sources import get_model_source, get_config_path
from lorax_server.utils.merges.strategies import merge_adapters
from lorax_server.adapters.lora import get_scaling_factor

if TYPE_CHECKING:
from lorax_server.adapters.config import AdapterConfig, ModuleMap
Expand Down Expand Up @@ -145,157 +135,3 @@ def load_module_map(
# map the model weights to the relevant adapter weights (LoRA A and B matrices)
module_map, adapter_weight_names = adapter_config.map_weights_for_model(adapter_weights, weight_names)
return module_map, adapter_config, adapter_weight_names, adapter_tokenizer


def compute_delta_weight(
lora_A: torch.Tensor,
lora_B: torch.Tensor,
fan_in_fan_out: bool,
alpha: float,
r: float,
uses_rslora: bool = False
) -> torch.Tensor:
"""Computes the delta weight for a Linear layer given A and B LoRA matrices.
TODO: add logic for other module types beyond Linear layers.
Reference: https://github.com/huggingface/peft/blob/v0.4.0/src/peft/tuners/lora.py#L799-L806
"""
scaling = get_scaling_factor(alpha, r, uses_rslora=uses_rslora)
delta_weight = transpose(lora_B @ lora_A, fan_in_fan_out) * scaling
return delta_weight


def merge_adapter_weights(
model_weights: Dict[str, torch.Tensor],
adapter_weights: Dict[str, torch.Tensor],
adapter_config: "AdapterConfig"
) -> Tuple[Dict[str, torch.Tensor], Set[str]]:
"""
Merges the adapter weights into the model weights.
Args:
model_weights (Dict[str, torch.Tensor]): The weights of the base model.
adapter_weights (Dict[str, torch.Tensor]): The weights of the adapters.
adapter_config (AdapterConfig): The configuration for the adapter.
Returns:
Tuple[Dict[str, torch.Tensor], Set[str]]: A tuple containing the merged weights and the set of processed adapter weight names.
"""
from lorax_server.adapters.lora import LoraConfig

if not isinstance(adapter_config, LoraConfig):
raise ValueError(f"Unsupported adapter config type: {type(adapter_config)}")

module_mapping = defaultdict(dict)
processed_adapter_weight_names = set()

# map the original tensor names to their adapter counterparts
for weight_name in model_weights:
end_idx = weight_name.rfind(".weight")
key = weight_name[:end_idx]
for adapter_weight_name in adapter_weights:
if key in adapter_weight_name:
# example value: 'base_model.model.model.layers.10.self_attn.v_proj.lora_B.weight'
# matrix_type gets the second to last element in the module name, i.e. 'lora_B'
matrix_type = adapter_weight_name.split(".")[-2]
module_mapping[weight_name][matrix_type] = adapter_weight_name
processed_adapter_weight_names.add(adapter_weight_name)

# merge adapter weights into model weights
merged_weights = {}
for weight_name, adapter_weight_names in tqdm(
module_mapping.items(), desc="Merging adapter weights", total=len(module_mapping)):

# TODO: support adapter types beyond LoRA
# TODO: put this on GPU if it is available. This should greatly speedup compute_delta_weight
lora_A = adapter_weights[adapter_weight_names["lora_A"]]
lora_B = adapter_weights[adapter_weight_names["lora_B"]]
delta_weight = compute_delta_weight(
lora_A,
lora_B,
adapter_config.fan_in_fan_out,
adapter_config.lora_alpha,
adapter_config.r,
uses_rslora=adapter_config.use_rslora,
)

# transpose delta weight if necessary
# TODO(geoffrey): I believe this is required when using Conv1D layers (gpt2).
# We can likely take this out once we've switched to using Linear layers.
if (delta_weight.shape != model_weights[weight_name].shape and
delta_weight.T.shape == model_weights[weight_name].shape):
delta_weight = delta_weight.T
merged_weights[weight_name] = model_weights[weight_name] + delta_weight
return merged_weights, processed_adapter_weight_names


def create_merged_weight_files(
adapter_id: str,
model_id: str,
model_weight_filenames: List[Path],
adapter_source: str = "hub",
) -> List[Path]:
"""Creates merged weight files for the given adapter ID and filenames."""
api_token = None # TODO(travis): add support for API token
source = get_model_source(adapter_source, adapter_id, api_token=api_token)
adapter_filenames = source.weight_files()

adapter_config = source.load_config()
if adapter_config.base_model_name_or_path != model_id:
expected_config = AutoConfig.from_pretrained(model_id)
model_config = AutoConfig.from_pretrained(adapter_config.base_model_name_or_path)
if model_config.architectures == expected_config.architectures:
warnings.warn(
f"Adapter '{adapter_id}' was not trained on base model '{model_id}'. "
f"If you encounter issues, use --model-id '{adapter_config.base_model_name_or_path}' instead."
)
else:
# TODO(travis): revisit this when we support clasification heads which will not use CausalLM
raise ValueError(f"Adapter '{adapter_id}' is not compatible with model '{model_id}'. "
f"Architectures differ: {model_config.architectures} != {expected_config.architectures}. "
f"Use --model-id '{adapter_config.base_model_name_or_path}' instead.")

# load adapter weights from all shards (should have relatively small memory footprint)
adapter_weights = {}
for filename in adapter_filenames:
adapter_weights.update(load_file(filename))
remaining_adapter_weight_names = set(adapter_weights.keys())

merged_weight_directory = Path(HUGGINGFACE_HUB_CACHE) / f"models--{adapter_id.replace('/', '--')}-merged"
# just grab the existing files if they already exist and return immediately
lock = FileLock(str(merged_weight_directory)+ ".lock")
with lock:
if merged_weight_directory.is_dir():
logger.info(f"Merged weight directory {merged_weight_directory} exist, skipping merge computation.")
return weight_files(merged_weight_directory)
else:
logger.info("Merged weight files do not exist, computing merge.")
os.makedirs(merged_weight_directory)

merged_weight_filenames = []
for i, filename in enumerate(model_weight_filenames):
logger.info(
f"Merging adapter weights into model weights in "
f"{filename} ({i+1} / {len(model_weight_filenames)})..."
)
model_weights = load_file(filename)
merged_weights, processed_adapter_weight_names = merge_adapter_weights(
model_weights, adapter_weights, adapter_config)

merged_adapter_filename = Path(merged_weight_directory, os.path.basename(filename))
save_file(merged_weights, merged_adapter_filename)
logger.debug(f"Saved merged weights into {merged_adapter_filename}")

merged_weight_filenames.append(merged_adapter_filename)
remaining_adapter_weight_names = remaining_adapter_weight_names.difference(
processed_adapter_weight_names)

if len(remaining_adapter_weight_names) > 0:
logger.warning("WARNING: The following lora weights were not merged into the model weights:")
for lora_name in remaining_adapter_weight_names:
logger.warning("\t" + lora_name)

logger.info(
f"Finished merging adapter weights. Merged weight files saved to: {merged_weight_directory}")
return merged_weight_filenames
4 changes: 2 additions & 2 deletions server/lorax_server/utils/sources/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Optional, List
from pathlib import Path

from lorax_server.adapters import load_adapter_config
from lorax_server.adapters.config import AdapterConfig


Expand Down Expand Up @@ -132,7 +131,8 @@ def get_weight_bytes(self) -> int:
return total_size

def load_config(self) -> AdapterConfig:
from lorax_server.adapters import load_adapter_config

config_path = self.download_file("config.json", ignore_errors=True)
adapter_config_path = self.download_file("adapter_config.json", ignore_errors=True)
return load_adapter_config(config_path, adapter_config_path, self.api_token)

Loading

0 comments on commit 26f6d41

Please sign in to comment.