From c96ff886b04187af0da31faddee2977c3ae2c806 Mon Sep 17 00:00:00 2001 From: Ajinkya Tejankar Date: Wed, 27 Nov 2024 12:44:33 -0800 Subject: [PATCH] Add support for targeting cross_attn layers in mllama (#693) --- .../models/custom_modeling/mllama.py | 136 +++++++++++------- server/lorax_server/models/flash_causal_lm.py | 10 +- server/lorax_server/models/mllama.py | 27 ++-- 3 files changed, 107 insertions(+), 66 deletions(-) diff --git a/server/lorax_server/models/custom_modeling/mllama.py b/server/lorax_server/models/custom_modeling/mllama.py index d912f4d5d..bf20c27c1 100644 --- a/server/lorax_server/models/custom_modeling/mllama.py +++ b/server/lorax_server/models/custom_modeling/mllama.py @@ -38,11 +38,14 @@ TensorParallelRowLinear, ) from lorax_server.utils.lora import ( + DOWN_PROJ, FC1, FC2, + GATE_PROJ, K_PROJ, O_PROJ, Q_PROJ, + UP_PROJ, V_PROJ, ) @@ -242,7 +245,7 @@ def __init__(self, *, prefix, config, weights, layer_id, model_type): def forward(self, hidden_states: torch.Tensor, adapter_data: AdapterBatchData) -> torch.Tensor: hidden_states = self.fc1(hidden_states, adapter_data) hidden_states = self.activation_fn(hidden_states) - hidden_states = self.fc2(hidden_states, adapter_data) + hidden_states = self.fc2(hidden_states.view(-1, hidden_states.shape[-1]), adapter_data) return hidden_states @@ -329,7 +332,7 @@ def forward( attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask) attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(batch_size, q_seq_len, -1) + attn_output = attn_output.view(batch_size * q_seq_len, -1) output = self.o_proj(attn_output, adapter_data) return output @@ -691,29 +694,55 @@ def __init__(self, *, prefix, config, weights, layer_idx): self.num_heads = self.num_heads // weights.process_group.size() self.num_key_value_heads = self.num_key_value_heads // weights.process_group.size() - self.q_proj = TensorParallelColumnLinear.load( - config, - prefix=f"{prefix}.q_proj", - weights=weights, - bias=False, + self.q_proj = TensorParallelMultiAdapterLinear.load( + TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj"], + weights=weights, + dim=0, + bias=False, + ), + layer_idx, + [Q_PROJ], + sizes=[self.head_size * self.num_heads], + process_group=weights.process_group, ) - self.k_proj = TensorParallelColumnLinear.load( - config, - prefix=f"{prefix}.k_proj", - weights=weights, - bias=False, + self.k_proj = TensorParallelMultiAdapterLinear.load( + TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.k_proj"], + weights=weights, + dim=0, + bias=False, + ), + layer_idx, + [K_PROJ], + sizes=[self.head_size * self.num_key_value_heads], + process_group=weights.process_group, ) - self.v_proj = TensorParallelColumnLinear.load( - config, - prefix=f"{prefix}.v_proj", - weights=weights, - bias=False, + self.v_proj = TensorParallelMultiAdapterLinear.load( + TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.v_proj"], + weights=weights, + dim=0, + bias=False, + ), + layer_idx, + [V_PROJ], + sizes=[self.head_size * self.num_key_value_heads], + process_group=weights.process_group, ) - self.o_proj = TensorParallelRowLinear.load( - config, - prefix=f"{prefix}.o_proj", - weights=weights, - bias=False, + self.o_proj = TensorParallelAdapterRowLinear.load( + TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=False, + ), + layer_idx, + O_PROJ, + process_group=weights.process_group, ) self.q_norm = MllamaTextRMSNorm.load(prefix=f"{prefix}.q_norm", weights=weights, eps=config.rms_norm_eps) @@ -727,11 +756,12 @@ def forward( # past_key_value=None, # attention_mask: Optional[torch.Tensor] = None, # cache_position: Optional[torch.LongTensor] = None, + adapter_data: Optional[AdapterBatchData] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" # hidden_states = hidden_states.unsqueeze(0) # bsz, q_len, _ = hidden_states.size() - query_states = self.q_proj(hidden_states) + query_states = self.q_proj(hidden_states, adapter_data) query_states = query_states.view(-1, self.num_heads, self.head_size) query_states = self.q_norm(query_states) @@ -744,8 +774,8 @@ def forward( indices, ) = cross_attention_states - key_states = self.k_proj(cross_attention_states) - value_states = self.v_proj(cross_attention_states) + key_states = self.k_proj(cross_attention_states, adapter_data) + value_states = self.v_proj(cross_attention_states, adapter_data) key_states = key_states.view(-1, self.num_key_value_heads, self.head_size) value_states = value_states.view(-1, self.num_key_value_heads, self.head_size) key_states = self.k_norm(key_states) @@ -779,38 +809,54 @@ def forward( False, None, )[0] - attn_output = self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) + attn_output = self.o_proj(attn_output.view(-1, self.num_heads * self.head_size), adapter_data) return attn_output # Copied from transformers.models.gemma2.modeling_gemma2.Gemma2MLP with Gemma2->MllamaText class MllamaTextMLP(nn.Module): - def __init__(self, *, prefix, config, weights): + def __init__(self, *, prefix, config, weights, layer_idx): super().__init__() self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size // weights.process_group.size() - self.gate_up_proj = TensorParallelColumnLinear.load_multi( + gate_up_proj = TensorParallelColumnLinear.load_multi( config, prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], weights=weights, dim=0, bias=False, ) - self.down_proj = TensorParallelRowLinear.load( + self.gate_up_proj = TensorParallelMultiAdapterLinear.load( + gate_up_proj, + layer_idx, + [GATE_PROJ, UP_PROJ], + sizes=[ + config.intermediate_size, + config.intermediate_size, + ], + process_group=weights.process_group, + ) + down_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.down_proj", weights=weights, bias=False, ) + self.down_proj = TensorParallelAdapterRowLinear.load( + down_proj, + layer_idx, + DOWN_PROJ, + process_group=weights.process_group, + ) self.act_fn = ACT2FN[config.hidden_act] - def forward(self, x): + def forward(self, x, adapter_data): shape = x.shape - gate_up_states = self.gate_up_proj(x) + gate_up_states = self.gate_up_proj(x, adapter_data) gate_up_states = gate_up_states.view(*shape[:-1], 2, self.intermediate_size) - result = self.down_proj(self.act_fn(gate_up_states[:, 0]) * gate_up_states[:, 1]) + result = self.down_proj(self.act_fn(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data) return result @@ -834,7 +880,7 @@ def __init__(self, layer_id, prefix, config, weights) -> None: weights.get_tensor(f"{prefix}.cross_attn_attn_gate"), requires_grad=False ) - self.mlp = MllamaTextMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) + self.mlp = MllamaTextMLP(prefix=f"{prefix}.mlp", config=config, weights=weights, layer_idx=layer_idx) self.post_attention_layernorm = MllamaTextRMSNorm.load( prefix=f"{prefix}.post_attention_layernorm", weights=weights, @@ -877,12 +923,13 @@ def forward( hidden_states=hidden_states, # attention_mask=cross_attention_mask, cross_attention_states=cross_attention_states, + adapter_data=adapter_data, ) hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) + hidden_states = self.mlp(hidden_states, adapter_data) hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states out_hidden_states[indices] = hidden_states @@ -922,29 +969,10 @@ def __init__(self, prefix, config, weights): config.text_config._attn_implementation = "sdpa" self.hidden_size = config.text_config.hidden_size cross_attention_layers = getattr(config.text_config, "cross_attention_layers", []) - # note(ajinkya): Since cross attention layers are not currently targeted, we need to handle - # the case of some layers not having lora adapters which lorax doesn't currently support. - # Hence, this hack where we a dict that goes from actual layer index to index if the layers - # were filtered according to their types. For exmaple: - # all layers = [0, 1, 2, 3, 4] - # cross attention layers = [1, 3] - # layer wise layer ids = [0, 0, 1, 1, 2] - # since layers 1 and 3 are of different type they are indexed as if they are sequential - # this prevents illegal memory access errors from running the punica kernels - layer_wise_layer_id = [0] * config.text_config.num_hidden_layers - i = j = 0 - for k in range(config.text_config.num_hidden_layers): - if j == len(cross_attention_layers) or k < cross_attention_layers[j]: - layer_wise_layer_id[k] = i - i += 1 - else: - layer_wise_layer_id[k] = j - j += 1 - def create_layer(layer_id, prefix, config, weights): layer_cls = FlashLlamaCrossLayer if layer_id in cross_attention_layers else FlashLlamaLayer return layer_cls( - layer_wise_layer_id[layer_id], + layer_id, prefix=prefix, config=config, weights=weights, diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 2e3990328..48a3e17b2 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -1568,6 +1568,11 @@ def forward(self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData) -> return out + # note(ajinkya): hack needed to make sure that we can target cross_attn layers in mllama + # default behavior is to just return prefill state, but mllama always returns True + def adapter_prefill_state(self, prefill: bool) -> bool: + return prefill + @tracer.start_as_current_span("generate_token") def generate_token( self, batch: FlashCausalLMBatch, is_warmup: bool = False @@ -1595,13 +1600,14 @@ def generate_token( # Assign pointers to adapter weights # TODO(travis): don't update this if indices haven't changed - self.punica_wrapper.update_metadata(adapter_meta, prefill) + adapter_prefill_state = self.adapter_prefill_state(prefill) + self.punica_wrapper.update_metadata(adapter_meta, adapter_prefill_state) adapter_data = AdapterBatchData.from_meta( adapter_meta, self.layer_to_adapter_weights, self.layer_to_lora_weights, self.punica_wrapper, - prefill, + adapter_prefill_state, batch.prefill_head_indices, ) diff --git a/server/lorax_server/models/mllama.py b/server/lorax_server/models/mllama.py index 2d2fd77ff..822b3ae04 100644 --- a/server/lorax_server/models/mllama.py +++ b/server/lorax_server/models/mllama.py @@ -203,11 +203,8 @@ def get_num_layers_for_type(self, layer_type: str) -> int: return len(self.model.vision_model.global_transformer.layers) if "VISION_TRANSFORMER_" in layer_type: return len(self.model.vision_model.transformer.layers) - return [ - layer_id - for layer_id, layer in enumerate(self.model.text_model.model.layers) - if not isinstance(layer, FlashLlamaCrossLayer) - ] + + return len(self.model.text_model.model.layers) def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: layer_weights = {} @@ -215,11 +212,15 @@ def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: prefix = "language_model.model.layers" for i, layer in enumerate(self.model.text_model.model.layers): if isinstance(layer, FlashLlamaCrossLayer): - continue - layer_weights[(i, Q_PROJ)] = (f"{prefix}.{i}.self_attn.q_proj", layer.self_attn.query_key_value) - layer_weights[(i, K_PROJ)] = (f"{prefix}.{i}.self_attn.k_proj", layer.self_attn.query_key_value) - layer_weights[(i, V_PROJ)] = (f"{prefix}.{i}.self_attn.v_proj", layer.self_attn.query_key_value) - layer_weights[(i, O_PROJ)] = (f"{prefix}.{i}.self_attn.o_proj", layer.self_attn.o_proj) + layer_weights[(i, Q_PROJ)] = (f"{prefix}.{i}.cross_attn.q_proj", layer.cross_attn.q_proj) + layer_weights[(i, K_PROJ)] = (f"{prefix}.{i}.cross_attn.k_proj", layer.cross_attn.k_proj) + layer_weights[(i, V_PROJ)] = (f"{prefix}.{i}.cross_attn.v_proj", layer.cross_attn.v_proj) + layer_weights[(i, O_PROJ)] = (f"{prefix}.{i}.cross_attn.o_proj", layer.cross_attn.o_proj) + else: + layer_weights[(i, Q_PROJ)] = (f"{prefix}.{i}.self_attn.q_proj", layer.self_attn.query_key_value) + layer_weights[(i, K_PROJ)] = (f"{prefix}.{i}.self_attn.k_proj", layer.self_attn.query_key_value) + layer_weights[(i, V_PROJ)] = (f"{prefix}.{i}.self_attn.v_proj", layer.self_attn.query_key_value) + layer_weights[(i, O_PROJ)] = (f"{prefix}.{i}.self_attn.o_proj", layer.self_attn.o_proj) layer_weights[(i, GATE_PROJ)] = (f"{prefix}.{i}.mlp.gate_proj", layer.mlp.gate_up_proj) layer_weights[(i, UP_PROJ)] = (f"{prefix}.{i}.mlp.up_proj", layer.mlp.gate_up_proj) @@ -255,6 +256,12 @@ def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: return layer_weights + # note(ajinkya): for cross_attn in mllama we need to disable bgmv kernels + # during decode, but doing this selectively for cross_attn is tricky so + # simply resorting to sgmv kernels by always passing prefill=True + def adapter_prefill_state(self, prefill: bool) -> bool: + return True + def forward( self, batch: VlmCausalLMBatch,