Skip to content

Commit

Permalink
WIP: default adapter
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair committed Apr 1, 2024
1 parent bdb3297 commit 837a1e3
Show file tree
Hide file tree
Showing 11 changed files with 26 additions and 179 deletions.
6 changes: 4 additions & 2 deletions server/lorax_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from lorax_server.utils.graph import GraphCache
from lorax_server.adapters import AdapterBatchData, AdapterBatchMetadata
from lorax_server.utils.segments import SegmentConcatBuilder, find_segments
from lorax_server.utils.sources import HUB
from lorax_server.utils.state import warmup_mode
from lorax_server.utils.tokenizer import TokenizerManager

Expand Down Expand Up @@ -731,7 +732,7 @@ def __init__(
sliding_window: Optional[int] = None,
compile: bool = False,
adapter_id: str = BASE_MODEL_ADAPTER_ID,
dynamic_adapter_loading_enabled: bool = True,
adapter_source: str = HUB,
):
global SLIDING_WINDOW
global SLIDING_WINDOW_BLOCKS
Expand All @@ -751,7 +752,8 @@ def __init__(
world_size=world_size,
sliding_window=sliding_window,
adapter_id=adapter_id,
dynamic_adapter_loading_enabled=dynamic_adapter_loading_enabled,
adapter_source=adapter_source,
dynamic_adapter_loading_enabled=True,
)

if sliding_window is not None:
Expand Down
21 changes: 1 addition & 20 deletions server/lorax_server/models/flash_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
GemmaConfig,
)
from lorax_server.utils import (
create_merged_weight_files,
initialize_torch_distributed,
weight_files,
Weights,
Expand Down Expand Up @@ -63,29 +62,11 @@ def __init__(
torch.distributed.barrier(group=self.process_group)

filenames = weight_files(model_id, revision=revision, extension=".safetensors")

# if adapter_id passed in as part of model instantiation, then we merge
# the adapter weights with the model weights. This also disables dynamic
# adapter loading, since the model is now itself initialized with an adapter.
merged_weight_filenames = None
dynamic_adapter_loading_enabled = True
if len(adapter_id) > 0:
logger.info(f"Merging adapter weights from adapter_id {adapter_id} into model weights.")
# Need to pass the adapter source here
merged_weight_filenames = create_merged_weight_files(
adapter_id, model_id, model_weight_filenames=filenames, adapter_source=adapter_source
)
dynamic_adapter_loading_enabled = False
adapter_id = adapter_id
else:
adapter_id = BASE_MODEL_ADAPTER_ID

weights = Weights(
filenames,
device,
dtype,
process_group=self.process_group,
merged_weight_filenames=merged_weight_filenames
)

if config.quantize in ["gptq", "awq", "eetq"]:
Expand All @@ -107,7 +88,7 @@ def __init__(
world_size=world_size,
compile=compile,
adapter_id=adapter_id,
dynamic_adapter_loading_enabled=dynamic_adapter_loading_enabled,
adapter_source=adapter_source,
)

@property
Expand Down
20 changes: 1 addition & 19 deletions server/lorax_server/models/flash_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
)
from lorax_server.utils import (
compute_delta_weight,
create_merged_weight_files,
get_start_stop_idxs_for_rank,
initialize_torch_distributed,
load_module_map,
Expand Down Expand Up @@ -70,23 +69,6 @@ def __init__(
torch.distributed.barrier(group=self.process_group)

filenames = weight_files(model_id, revision=revision, extension=".safetensors")

# if adapter_id passed in as part of model instantiation, then we merge
# the adapter weights with the model weights. This also disables dynamic
# adapter loading, since the model is now itself initialized with an adapter.
merged_weight_filenames = None
dynamic_adapter_loading_enabled = True
if len(adapter_id) > 0:
logger.info(f"Merging adapter weights from adapter_id {adapter_id} into model weights.")
# Need to pass the adapter source here
merged_weight_filenames = create_merged_weight_files(
adapter_id, model_id, model_weight_filenames=filenames, adapter_source=adapter_source
)
dynamic_adapter_loading_enabled = False
adapter_id = adapter_id
else:
adapter_id = BASE_MODEL_ADAPTER_ID

weights = Weights(
filenames,
device,
Expand Down Expand Up @@ -114,7 +96,7 @@ def __init__(
world_size=world_size,
compile=compile,
adapter_id=adapter_id,
dynamic_adapter_loading_enabled=dynamic_adapter_loading_enabled,
adapter_source=adapter_source,
)

@property
Expand Down
21 changes: 1 addition & 20 deletions server/lorax_server/models/flash_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
LlamaConfig,
)
from lorax_server.utils import (
create_merged_weight_files,
initialize_torch_distributed,
weight_files,
Weights,
Expand Down Expand Up @@ -64,29 +63,11 @@ def __init__(
torch.distributed.barrier(group=self.process_group)

filenames = weight_files(model_id, revision=revision, extension=".safetensors")

# if adapter_id passed in as part of model instantiation, then we merge
# the adapter weights with the model weights. This also disables dynamic
# adapter loading, since the model is now itself initialized with an adapter.
merged_weight_filenames = None
dynamic_adapter_loading_enabled = True
if len(adapter_id) > 0:
logger.info(f"Merging adapter weights from adapter_id {adapter_id} into model weights.")
# Need to pass the adapter source here
merged_weight_filenames = create_merged_weight_files(
adapter_id, model_id, model_weight_filenames=filenames, adapter_source=adapter_source
)
dynamic_adapter_loading_enabled = False
adapter_id = adapter_id
else:
adapter_id = BASE_MODEL_ADAPTER_ID

weights = Weights(
filenames,
device,
dtype,
process_group=self.process_group,
merged_weight_filenames=merged_weight_filenames
)

if config.quantize in ["gptq", "awq", "eetq"]:
Expand All @@ -108,7 +89,7 @@ def __init__(
world_size=world_size,
compile=compile,
adapter_id=adapter_id,
dynamic_adapter_loading_enabled=dynamic_adapter_loading_enabled,
adapter_source=adapter_source,
)

@property
Expand Down
21 changes: 1 addition & 20 deletions server/lorax_server/models/flash_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
MistralConfig,
)
from lorax_server.utils import (
create_merged_weight_files,
initialize_torch_distributed,
weight_files,
Weights,
Expand Down Expand Up @@ -61,29 +60,11 @@ def __init__(
torch.distributed.barrier(group=self.process_group)

filenames = weight_files(model_id, revision=revision, extension=".safetensors")

# if adapter_id passed in as part of model instantiation, then we merge
# the adapter weights with the model weights. This also disables dynamic
# adapter loading, since the model is now itself initialized with an adapter.
merged_weight_filenames = None
dynamic_adapter_loading_enabled = True
if len(adapter_id) > 0:
logger.info(f"Merging adapter weights from adapter_id {adapter_id} into model weights.")
# Need to pass the adapter source here
merged_weight_filenames = create_merged_weight_files(
adapter_id, model_id, model_weight_filenames=filenames, adapter_source=adapter_source
)
dynamic_adapter_loading_enabled = False
adapter_id = adapter_id
else:
adapter_id = BASE_MODEL_ADAPTER_ID

weights = Weights(
filenames,
device,
dtype,
process_group=self.process_group,
merged_weight_filenames=merged_weight_filenames
)

if config.quantize in ["gptq", "awq", "eetq"]:
Expand All @@ -106,7 +87,7 @@ def __init__(
sliding_window=config.sliding_window,
compile=compile,
adapter_id=adapter_id,
dynamic_adapter_loading_enabled=dynamic_adapter_loading_enabled,
adapter_source=adapter_source,
)

@property
Expand Down
21 changes: 1 addition & 20 deletions server/lorax_server/models/flash_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
MixtralConfig,
)
from lorax_server.utils import (
create_merged_weight_files,
initialize_torch_distributed,
weight_files,
Weights,
Expand Down Expand Up @@ -361,29 +360,11 @@ def __init__(
torch.distributed.barrier(group=self.process_group)

filenames = weight_files(model_id, revision=revision, extension=".safetensors")

# if adapter_id passed in as part of model instantiation, then we merge
# the adapter weights with the model weights. This also disables dynamic
# adapter loading, since the model is now itself initialized with an adapter.
merged_weight_filenames = None
dynamic_adapter_loading_enabled = True
if len(adapter_id) > 0:
logger.info(f"Merging adapter weights from adapter_id {adapter_id} into model weights.")
# Need to pass the adapter source here
merged_weight_filenames = create_merged_weight_files(
adapter_id, model_id, model_weight_filenames=filenames, adapter_source=adapter_source
)
dynamic_adapter_loading_enabled = False
adapter_id = adapter_id
else:
adapter_id = BASE_MODEL_ADAPTER_ID

weights = Weights(
filenames,
device,
dtype,
process_group=self.process_group,
merged_weight_filenames=merged_weight_filenames
)

if config.quantize in ["gptq", "awq", "eetq"]:
Expand All @@ -406,7 +387,7 @@ def __init__(
sliding_window=config.sliding_window,
compile=compile,
adapter_id=adapter_id,
dynamic_adapter_loading_enabled=dynamic_adapter_loading_enabled,
adapter_source=adapter_source,
)

@property
Expand Down
21 changes: 1 addition & 20 deletions server/lorax_server/models/flash_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
PhiConfig,
)
from lorax_server.utils import (
create_merged_weight_files,
initialize_torch_distributed,
weight_files,
Weights,
Expand Down Expand Up @@ -69,29 +68,11 @@ def __init__(
torch.distributed.barrier(group=self.process_group)

filenames = weight_files(model_id, revision=revision, extension=".safetensors")

# if adapter_id passed in as part of model instantiation, then we merge
# the adapter weights with the model weights. This also disables dynamic
# adapter loading, since the model is now itself initialized with an adapter.
merged_weight_filenames = None
dynamic_adapter_loading_enabled = True
if len(adapter_id) > 0:
logger.info(f"Merging adapter weights from adapter_id {adapter_id} into model weights.")
# Need to pass the adapter source here
merged_weight_filenames = create_merged_weight_files(
adapter_id, model_id, model_weight_filenames=filenames, adapter_source=adapter_source
)
dynamic_adapter_loading_enabled = False
adapter_id = adapter_id
else:
adapter_id = BASE_MODEL_ADAPTER_ID

weights = Weights(
filenames,
device,
dtype,
process_group=self.process_group,
merged_weight_filenames=merged_weight_filenames
)

if config.quantize in ["gptq", "awq", "eetq"]:
Expand All @@ -114,7 +95,7 @@ def __init__(
world_size=world_size,
compile=compile,
adapter_id=adapter_id,
dynamic_adapter_loading_enabled=dynamic_adapter_loading_enabled,
adapter_source=adapter_source,
)

@property
Expand Down
21 changes: 1 addition & 20 deletions server/lorax_server/models/flash_qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
QwenConfig,
)
from lorax_server.utils import (
create_merged_weight_files,
initialize_torch_distributed,
weight_files,
Weights,
Expand Down Expand Up @@ -68,29 +67,11 @@ def __init__(
torch.distributed.barrier(group=self.process_group)

filenames = weight_files(model_id, revision=revision, extension=".safetensors")

# if adapter_id passed in as part of model instantiation, then we merge
# the adapter weights with the model weights. This also disables dynamic
# adapter loading, since the model is now itself initialized with an adapter.
merged_weight_filenames = None
dynamic_adapter_loading_enabled = True
if len(adapter_id) > 0:
logger.info(f"Merging adapter weights from adapter_id {adapter_id} into model weights.")
# Need to pass the adapter source here
merged_weight_filenames = create_merged_weight_files(
adapter_id, model_id, model_weight_filenames=filenames, adapter_source=adapter_source
)
dynamic_adapter_loading_enabled = False
adapter_id = adapter_id
else:
adapter_id = BASE_MODEL_ADAPTER_ID

weights = Weights(
filenames,
device,
dtype,
process_group=self.process_group,
merged_weight_filenames=merged_weight_filenames
)

if config.quantize in ["gptq", "awq", "eetq"]:
Expand All @@ -113,7 +94,7 @@ def __init__(
world_size=world_size,
compile=compile,
adapter_id=adapter_id,
dynamic_adapter_loading_enabled=dynamic_adapter_loading_enabled,
adapter_source=adapter_source,
)

@property
Expand Down
21 changes: 1 addition & 20 deletions server/lorax_server/models/flash_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
FlashQwen2ForCausalLM,
)
from lorax_server.utils import (
create_merged_weight_files,
initialize_torch_distributed,
weight_files,
Weights,
Expand Down Expand Up @@ -69,29 +68,11 @@ def __init__(
torch.distributed.barrier(group=self.process_group)

filenames = weight_files(model_id, revision=revision, extension=".safetensors")

# if adapter_id passed in as part of model instantiation, then we merge
# the adapter weights with the model weights. This also disables dynamic
# adapter loading, since the model is now itself initialized with an adapter.
merged_weight_filenames = None
dynamic_adapter_loading_enabled = True
if len(adapter_id) > 0:
logger.info(f"Merging adapter weights from adapter_id {adapter_id} into model weights.")
# Need to pass the adapter source here
merged_weight_filenames = create_merged_weight_files(
adapter_id, model_id, model_weight_filenames=filenames, adapter_source=adapter_source
)
dynamic_adapter_loading_enabled = False
adapter_id = adapter_id
else:
adapter_id = BASE_MODEL_ADAPTER_ID

weights = Weights(
filenames,
device,
dtype,
process_group=self.process_group,
merged_weight_filenames=merged_weight_filenames
)

if config.quantize in ["gptq", "awq", "eetq"]:
Expand All @@ -116,7 +97,7 @@ def __init__(
sliding_window=config.sliding_window,
compile=compile,
adapter_id=adapter_id,
dynamic_adapter_loading_enabled=dynamic_adapter_loading_enabled,
adapter_source=adapter_source,
)

@property
Expand Down
Loading

0 comments on commit 837a1e3

Please sign in to comment.