Skip to content

Commit

Permalink
fix aria model
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
  • Loading branch information
youkaichao committed Nov 25, 2024
1 parent 9db713a commit d42d04f
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 28 deletions.
26 changes: 4 additions & 22 deletions vllm/model_executor/models/aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
LlamaModel)
from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper,
is_pp_missing_parameter,
make_layers, maybe_prefix,
maybe_prefix,
merge_multimodal_embeddings)
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs
Expand Down Expand Up @@ -363,27 +363,9 @@ class AriaMoELMModel(LlamaModel):
"""

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)

config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config

# FIXME: this is a hack to disable the compilation of the model
self.do_not_compile = True

self.layers = None

self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: MoEDecoderLayer(
config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix,
),
prefix=f"{prefix}.layers",
)
super().__init__(vllm_config=vllm_config,
prefix=prefix,
layer_type=MoEDecoderLayer)

# Adapted from LlamaModel.load_weights with the modification of adding
# the expert weights mapping to `stacked_params_mapping`
Expand Down
16 changes: 10 additions & 6 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union

import torch
from torch import nn
Expand Down Expand Up @@ -273,7 +273,11 @@ def forward(
@support_torch_compile
class LlamaModel(nn.Module):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer):
super().__init__()

config = vllm_config.model_config.hf_config
Expand All @@ -299,10 +303,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.embed_tokens = PPMissingLayer()
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: LlamaDecoderLayer(config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix),
lambda prefix: layer_type(config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix),
prefix=f"{prefix}.layers",
)
if get_pp_group().is_last_rank:
Expand Down

0 comments on commit d42d04f

Please sign in to comment.