diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 2f53ffd6c..4ec97c5d6 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -32,6 +32,7 @@ from lorax_server.utils.graph import GraphCache from lorax_server.adapters import AdapterBatchData, AdapterBatchMetadata from lorax_server.utils.segments import SegmentConcatBuilder, find_segments +from lorax_server.utils.sources import HUB from lorax_server.utils.state import warmup_mode from lorax_server.utils.tokenizer import TokenizerManager @@ -731,7 +732,7 @@ def __init__( sliding_window: Optional[int] = None, compile: bool = False, adapter_id: str = BASE_MODEL_ADAPTER_ID, - dynamic_adapter_loading_enabled: bool = True, + adapter_source: str = HUB, ): global SLIDING_WINDOW global SLIDING_WINDOW_BLOCKS @@ -751,7 +752,8 @@ def __init__( world_size=world_size, sliding_window=sliding_window, adapter_id=adapter_id, - dynamic_adapter_loading_enabled=dynamic_adapter_loading_enabled, + adapter_source=adapter_source, + dynamic_adapter_loading_enabled=True, ) if sliding_window is not None: diff --git a/server/lorax_server/models/flash_gemma.py b/server/lorax_server/models/flash_gemma.py index c1e17aafd..0030907d9 100644 --- a/server/lorax_server/models/flash_gemma.py +++ b/server/lorax_server/models/flash_gemma.py @@ -12,7 +12,6 @@ GemmaConfig, ) from lorax_server.utils import ( - create_merged_weight_files, initialize_torch_distributed, weight_files, Weights, @@ -63,29 +62,11 @@ def __init__( torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") - - # if adapter_id passed in as part of model instantiation, then we merge - # the adapter weights with the model weights. This also disables dynamic - # adapter loading, since the model is now itself initialized with an adapter. - merged_weight_filenames = None - dynamic_adapter_loading_enabled = True - if len(adapter_id) > 0: - logger.info(f"Merging adapter weights from adapter_id {adapter_id} into model weights.") - # Need to pass the adapter source here - merged_weight_filenames = create_merged_weight_files( - adapter_id, model_id, model_weight_filenames=filenames, adapter_source=adapter_source - ) - dynamic_adapter_loading_enabled = False - adapter_id = adapter_id - else: - adapter_id = BASE_MODEL_ADAPTER_ID - weights = Weights( filenames, device, dtype, process_group=self.process_group, - merged_weight_filenames=merged_weight_filenames ) if config.quantize in ["gptq", "awq", "eetq"]: @@ -107,7 +88,7 @@ def __init__( world_size=world_size, compile=compile, adapter_id=adapter_id, - dynamic_adapter_loading_enabled=dynamic_adapter_loading_enabled, + adapter_source=adapter_source, ) @property diff --git a/server/lorax_server/models/flash_gpt2.py b/server/lorax_server/models/flash_gpt2.py index 388452b18..cfedac509 100644 --- a/server/lorax_server/models/flash_gpt2.py +++ b/server/lorax_server/models/flash_gpt2.py @@ -20,7 +20,6 @@ ) from lorax_server.utils import ( compute_delta_weight, - create_merged_weight_files, get_start_stop_idxs_for_rank, initialize_torch_distributed, load_module_map, @@ -70,23 +69,6 @@ def __init__( torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") - - # if adapter_id passed in as part of model instantiation, then we merge - # the adapter weights with the model weights. This also disables dynamic - # adapter loading, since the model is now itself initialized with an adapter. - merged_weight_filenames = None - dynamic_adapter_loading_enabled = True - if len(adapter_id) > 0: - logger.info(f"Merging adapter weights from adapter_id {adapter_id} into model weights.") - # Need to pass the adapter source here - merged_weight_filenames = create_merged_weight_files( - adapter_id, model_id, model_weight_filenames=filenames, adapter_source=adapter_source - ) - dynamic_adapter_loading_enabled = False - adapter_id = adapter_id - else: - adapter_id = BASE_MODEL_ADAPTER_ID - weights = Weights( filenames, device, @@ -114,7 +96,7 @@ def __init__( world_size=world_size, compile=compile, adapter_id=adapter_id, - dynamic_adapter_loading_enabled=dynamic_adapter_loading_enabled, + adapter_source=adapter_source, ) @property diff --git a/server/lorax_server/models/flash_llama.py b/server/lorax_server/models/flash_llama.py index 07632613b..82dde9199 100644 --- a/server/lorax_server/models/flash_llama.py +++ b/server/lorax_server/models/flash_llama.py @@ -13,7 +13,6 @@ LlamaConfig, ) from lorax_server.utils import ( - create_merged_weight_files, initialize_torch_distributed, weight_files, Weights, @@ -64,29 +63,11 @@ def __init__( torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") - - # if adapter_id passed in as part of model instantiation, then we merge - # the adapter weights with the model weights. This also disables dynamic - # adapter loading, since the model is now itself initialized with an adapter. - merged_weight_filenames = None - dynamic_adapter_loading_enabled = True - if len(adapter_id) > 0: - logger.info(f"Merging adapter weights from adapter_id {adapter_id} into model weights.") - # Need to pass the adapter source here - merged_weight_filenames = create_merged_weight_files( - adapter_id, model_id, model_weight_filenames=filenames, adapter_source=adapter_source - ) - dynamic_adapter_loading_enabled = False - adapter_id = adapter_id - else: - adapter_id = BASE_MODEL_ADAPTER_ID - weights = Weights( filenames, device, dtype, process_group=self.process_group, - merged_weight_filenames=merged_weight_filenames ) if config.quantize in ["gptq", "awq", "eetq"]: @@ -108,7 +89,7 @@ def __init__( world_size=world_size, compile=compile, adapter_id=adapter_id, - dynamic_adapter_loading_enabled=dynamic_adapter_loading_enabled, + adapter_source=adapter_source, ) @property diff --git a/server/lorax_server/models/flash_mistral.py b/server/lorax_server/models/flash_mistral.py index 97df03804..b4fa228b5 100644 --- a/server/lorax_server/models/flash_mistral.py +++ b/server/lorax_server/models/flash_mistral.py @@ -12,7 +12,6 @@ MistralConfig, ) from lorax_server.utils import ( - create_merged_weight_files, initialize_torch_distributed, weight_files, Weights, @@ -61,29 +60,11 @@ def __init__( torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") - - # if adapter_id passed in as part of model instantiation, then we merge - # the adapter weights with the model weights. This also disables dynamic - # adapter loading, since the model is now itself initialized with an adapter. - merged_weight_filenames = None - dynamic_adapter_loading_enabled = True - if len(adapter_id) > 0: - logger.info(f"Merging adapter weights from adapter_id {adapter_id} into model weights.") - # Need to pass the adapter source here - merged_weight_filenames = create_merged_weight_files( - adapter_id, model_id, model_weight_filenames=filenames, adapter_source=adapter_source - ) - dynamic_adapter_loading_enabled = False - adapter_id = adapter_id - else: - adapter_id = BASE_MODEL_ADAPTER_ID - weights = Weights( filenames, device, dtype, process_group=self.process_group, - merged_weight_filenames=merged_weight_filenames ) if config.quantize in ["gptq", "awq", "eetq"]: @@ -106,7 +87,7 @@ def __init__( sliding_window=config.sliding_window, compile=compile, adapter_id=adapter_id, - dynamic_adapter_loading_enabled=dynamic_adapter_loading_enabled, + adapter_source=adapter_source, ) @property diff --git a/server/lorax_server/models/flash_mixtral.py b/server/lorax_server/models/flash_mixtral.py index 76772ce90..5cd7babfe 100644 --- a/server/lorax_server/models/flash_mixtral.py +++ b/server/lorax_server/models/flash_mixtral.py @@ -31,7 +31,6 @@ MixtralConfig, ) from lorax_server.utils import ( - create_merged_weight_files, initialize_torch_distributed, weight_files, Weights, @@ -361,29 +360,11 @@ def __init__( torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") - - # if adapter_id passed in as part of model instantiation, then we merge - # the adapter weights with the model weights. This also disables dynamic - # adapter loading, since the model is now itself initialized with an adapter. - merged_weight_filenames = None - dynamic_adapter_loading_enabled = True - if len(adapter_id) > 0: - logger.info(f"Merging adapter weights from adapter_id {adapter_id} into model weights.") - # Need to pass the adapter source here - merged_weight_filenames = create_merged_weight_files( - adapter_id, model_id, model_weight_filenames=filenames, adapter_source=adapter_source - ) - dynamic_adapter_loading_enabled = False - adapter_id = adapter_id - else: - adapter_id = BASE_MODEL_ADAPTER_ID - weights = Weights( filenames, device, dtype, process_group=self.process_group, - merged_weight_filenames=merged_weight_filenames ) if config.quantize in ["gptq", "awq", "eetq"]: @@ -406,7 +387,7 @@ def __init__( sliding_window=config.sliding_window, compile=compile, adapter_id=adapter_id, - dynamic_adapter_loading_enabled=dynamic_adapter_loading_enabled, + adapter_source=adapter_source, ) @property diff --git a/server/lorax_server/models/flash_phi.py b/server/lorax_server/models/flash_phi.py index ac0be9435..f7d0156d2 100644 --- a/server/lorax_server/models/flash_phi.py +++ b/server/lorax_server/models/flash_phi.py @@ -18,7 +18,6 @@ PhiConfig, ) from lorax_server.utils import ( - create_merged_weight_files, initialize_torch_distributed, weight_files, Weights, @@ -69,29 +68,11 @@ def __init__( torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") - - # if adapter_id passed in as part of model instantiation, then we merge - # the adapter weights with the model weights. This also disables dynamic - # adapter loading, since the model is now itself initialized with an adapter. - merged_weight_filenames = None - dynamic_adapter_loading_enabled = True - if len(adapter_id) > 0: - logger.info(f"Merging adapter weights from adapter_id {adapter_id} into model weights.") - # Need to pass the adapter source here - merged_weight_filenames = create_merged_weight_files( - adapter_id, model_id, model_weight_filenames=filenames, adapter_source=adapter_source - ) - dynamic_adapter_loading_enabled = False - adapter_id = adapter_id - else: - adapter_id = BASE_MODEL_ADAPTER_ID - weights = Weights( filenames, device, dtype, process_group=self.process_group, - merged_weight_filenames=merged_weight_filenames ) if config.quantize in ["gptq", "awq", "eetq"]: @@ -114,7 +95,7 @@ def __init__( world_size=world_size, compile=compile, adapter_id=adapter_id, - dynamic_adapter_loading_enabled=dynamic_adapter_loading_enabled, + adapter_source=adapter_source, ) @property diff --git a/server/lorax_server/models/flash_qwen.py b/server/lorax_server/models/flash_qwen.py index 55e439fc6..01013b0eb 100644 --- a/server/lorax_server/models/flash_qwen.py +++ b/server/lorax_server/models/flash_qwen.py @@ -17,7 +17,6 @@ QwenConfig, ) from lorax_server.utils import ( - create_merged_weight_files, initialize_torch_distributed, weight_files, Weights, @@ -68,29 +67,11 @@ def __init__( torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") - - # if adapter_id passed in as part of model instantiation, then we merge - # the adapter weights with the model weights. This also disables dynamic - # adapter loading, since the model is now itself initialized with an adapter. - merged_weight_filenames = None - dynamic_adapter_loading_enabled = True - if len(adapter_id) > 0: - logger.info(f"Merging adapter weights from adapter_id {adapter_id} into model weights.") - # Need to pass the adapter source here - merged_weight_filenames = create_merged_weight_files( - adapter_id, model_id, model_weight_filenames=filenames, adapter_source=adapter_source - ) - dynamic_adapter_loading_enabled = False - adapter_id = adapter_id - else: - adapter_id = BASE_MODEL_ADAPTER_ID - weights = Weights( filenames, device, dtype, process_group=self.process_group, - merged_weight_filenames=merged_weight_filenames ) if config.quantize in ["gptq", "awq", "eetq"]: @@ -113,7 +94,7 @@ def __init__( world_size=world_size, compile=compile, adapter_id=adapter_id, - dynamic_adapter_loading_enabled=dynamic_adapter_loading_enabled, + adapter_source=adapter_source, ) @property diff --git a/server/lorax_server/models/flash_qwen2.py b/server/lorax_server/models/flash_qwen2.py index 693b1e381..1b3161c41 100644 --- a/server/lorax_server/models/flash_qwen2.py +++ b/server/lorax_server/models/flash_qwen2.py @@ -19,7 +19,6 @@ FlashQwen2ForCausalLM, ) from lorax_server.utils import ( - create_merged_weight_files, initialize_torch_distributed, weight_files, Weights, @@ -69,29 +68,11 @@ def __init__( torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") - - # if adapter_id passed in as part of model instantiation, then we merge - # the adapter weights with the model weights. This also disables dynamic - # adapter loading, since the model is now itself initialized with an adapter. - merged_weight_filenames = None - dynamic_adapter_loading_enabled = True - if len(adapter_id) > 0: - logger.info(f"Merging adapter weights from adapter_id {adapter_id} into model weights.") - # Need to pass the adapter source here - merged_weight_filenames = create_merged_weight_files( - adapter_id, model_id, model_weight_filenames=filenames, adapter_source=adapter_source - ) - dynamic_adapter_loading_enabled = False - adapter_id = adapter_id - else: - adapter_id = BASE_MODEL_ADAPTER_ID - weights = Weights( filenames, device, dtype, process_group=self.process_group, - merged_weight_filenames=merged_weight_filenames ) if config.quantize in ["gptq", "awq", "eetq"]: @@ -116,7 +97,7 @@ def __init__( sliding_window=config.sliding_window, compile=compile, adapter_id=adapter_id, - dynamic_adapter_loading_enabled=dynamic_adapter_loading_enabled, + adapter_source=adapter_source, ) @property diff --git a/server/lorax_server/models/model.py b/server/lorax_server/models/model.py index 047bf8cd4..228b15fa8 100644 --- a/server/lorax_server/models/model.py +++ b/server/lorax_server/models/model.py @@ -7,12 +7,14 @@ from typing import Dict, List, Tuple, Optional, TypeVar, Type from transformers import PreTrainedTokenizerBase +from lorax_server.adapters.utils import download_adapter from lorax_server.models.types import Batch, GeneratedText from lorax_server.pb.generate_pb2 import AdapterParameters, AdapterSource, InfoResponse from lorax_server.utils.adapter import ( BASE_MODEL_ADAPTER_ID, load_and_merge_adapters, ) +from lorax_server.utils.sources import HUB from lorax_server.utils.tokenizer import TokenizerManager from lorax_server.adapters.weights import LayerAdapterWeights from lorax_server.utils.weights import shard_on_dim @@ -33,6 +35,7 @@ def __init__( world_size: int = 1, sliding_window: Optional[int] = None, adapter_id: str = BASE_MODEL_ADAPTER_ID, + adapter_source: str = HUB, dynamic_adapter_loading_enabled: bool = True, ): self.model_id = model_id @@ -59,6 +62,15 @@ def __init__( is not None ) + if adapter_id and adapter_id != BASE_MODEL_ADAPTER_ID: + download_adapter(adapter_id, adapter_source, api_token=None) + self.load_adapter( + AdapterParameters(adapter_ids=[adapter_id]), + adapter_source, + adapter_index=0, + api_token=None, + ) + self.check_initialized() @property diff --git a/server/lorax_server/server.py b/server/lorax_server/server.py index 6aff3f372..0217d9d5b 100644 --- a/server/lorax_server/server.py +++ b/server/lorax_server/server.py @@ -12,6 +12,7 @@ from pathlib import Path from typing import List, Optional +from lorax_server.adapters.utils import download_adapter from lorax_server.cache import Cache from lorax_server.cli import _download_weights from lorax_server.interceptor import ExceptionInterceptor @@ -142,24 +143,7 @@ async def DownloadAdapter(self, request: generate_pb2.DownloadAdapterRequest, co logger.info("No adapter to download for base model. Skipping.") continue - if adapter_source == PBASE: - adapter_id = map_pbase_model_id_to_s3(adapter_id, api_token) - adapter_source = S3 - - if adapter_source == HUB: - # Quick auth check on the repo against the token - HfApi(token=api_token).model_info(adapter_id, revision=None) - - # fail fast if ID is not an adapter (i.e. it is a full model) - source = get_model_source(adapter_source, adapter_id, extension=".safetensors", api_token=api_token) - source.load_config() - - _download_weights( - adapter_id, source=adapter_source, api_token=api_token - ) - - # Calculate size of adapter to be loaded - adapter_bytes += source.get_weight_bytes() + adapter_bytes += download_adapter(adapter_id, adapter_source, api_token) adapter_memory_size = self.model.adapter_memory_size() if adapter_memory_size > 0: