Skip to content

Commit

Permalink
api: add mistral function calling format to all models loaded with "m…
Browse files Browse the repository at this point in the history
…istral" format (#1053)
  • Loading branch information
AlpinDale authored Dec 27, 2024
1 parent b3f9ab3 commit 1264e0b
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 41 deletions.
6 changes: 5 additions & 1 deletion aphrodite/endpoints/llm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from contextlib import contextmanager
from typing import ClassVar, List, Optional, Sequence, Union, cast, overload
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Union, cast,
overload)

from tqdm import tqdm

Expand Down Expand Up @@ -353,6 +354,7 @@ def chat(
lora_request: Optional[LoRARequest] = None,
chat_template: Optional[str] = None,
add_generation_prompt: bool = True,
tools: Optional[List[Dict[str, Any]]] = None,
) -> List[RequestOutput]:
"""
Generate responses for a chat conversation.
Expand Down Expand Up @@ -397,13 +399,15 @@ def chat(
messages=messages,
chat_template=chat_template,
add_generation_prompt=add_generation_prompt,
tools=tools,
)
else:
prompt = apply_hf_chat_template(
tokenizer,
conversation=conversation,
chat_template=chat_template,
add_generation_prompt=add_generation_prompt,
tools=tools,
)

inputs: PromptInputs
Expand Down
9 changes: 5 additions & 4 deletions aphrodite/endpoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ async def create_chat_completion(
]

prompt: Union[str, List[int]]
if isinstance(tokenizer, MistralTokenizer):
is_mistral_tokenizer = isinstance(tokenizer, MistralTokenizer)
if is_mistral_tokenizer:
prompt = apply_mistral_chat_template(
tokenizer,
messages=request.messages,
Expand Down Expand Up @@ -156,10 +157,10 @@ async def create_chat_completion(
return self.create_error_response(
"tool_choice = \"required\" is not supported!")

# "auto" tools requires --enable-auto-tool-choice
# and --tool-call-parser
if request.tool_choice == "auto" and not (
if not is_mistral_tokenizer and request.tool_choice == "auto" and not (
self.enable_auto_tools and self.tool_parser is not None):
# for hf tokenizers, "auto" tools requires
# --enable-auto-tool-choice and --tool-call-parser
return self.create_error_response(
"\"auto\" tool choice requires "
"--enable-auto-tool-choice and --tool-call-parser to be set")
Expand Down
101 changes: 65 additions & 36 deletions aphrodite/transformers_utils/tokenizers/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from aphrodite.common.logger import log_once

if TYPE_CHECKING:
from aphrodite.endpoints.chat_utils import ConversationMessage
from aphrodite.endpoints.chat_utils import ChatCompletionMessageParam


@dataclass
Expand Down Expand Up @@ -47,25 +47,25 @@ class MistralTokenizer:
def __init__(self, tokenizer: PublicMistralTokenizer) -> None:
self.mistral = tokenizer
self.instruct = tokenizer.instruct_tokenizer
self.tokenizer = tokenizer.instruct_tokenizer.tokenizer

self.vocab_size = len(self.tokenizer.vocab())

assert isinstance(self.tokenizer,
(Tekkenizer, SentencePieceTokenizer)), type(
self.tokenizer)
if (is_tekken := isinstance(self.tokenizer, Tekkenizer)):
tokenizer_ = tokenizer.instruct_tokenizer.tokenizer
if isinstance(tokenizer_, Tekkenizer):
# Make sure special tokens will not raise
self.tokenizer.special_token_policy = SpecialTokenPolicy.IGNORE

self._is_tekken = is_tekken
tokenizer_.special_token_policy = SpecialTokenPolicy.IGNORE

self._vocab = {
token: idx
for idx, token in enumerate(tokenizer_.vocab())
}
elif isinstance(tokenizer_, SentencePieceTokenizer):
self._vocab = {
token: idx
for idx, token in enumerate(tokenizer_.vocab())
}
else:
raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}")

# the following attributes are set to fit VLLM's design
self.is_fast = True
self.chat_template = True
self.all_special_ids: List[Any] = []
self.all_special_tokens: List[Any] = []
self.all_special_tokens_extended: List[Any] = []
self.tokenizer = tokenizer_

@classmethod
def from_pretrained(cls,
Expand Down Expand Up @@ -103,6 +103,38 @@ def _download_mistral_tokenizer_from_hf(tokenizer_name: str,
revision=revision)
return tokenizer_file

# the following attributes are set to fit VLLM's design
@property
def all_special_tokens_extended(self) -> List[str]:
return []

@property
def all_special_tokens(self) -> List[str]:
return []

@property
def all_special_ids(self) -> List[int]:
return []

@property
def bos_token_id(self) -> int:
return self.tokenizer.bos_id

@property
def eos_token_id(self) -> int:
return self.tokenizer.eos_id

@property
def is_fast(self) -> bool:
return True

@property
def vocab_size(self) -> int:
return len(self._vocab)

def __len__(self) -> int:
return self.vocab_size

def __call__(
self,
prompt: str,
Expand All @@ -118,32 +150,35 @@ def __call__(

return Encoding(input_ids=input_ids)

def get_added_vocab(self) -> List[str]:
def get_vocab(self) -> Dict[str, int]:
return self._vocab

def get_added_vocab(self) -> Dict[str, int]:
# Mistral tokenizers have no added vocabulary
return []
return {}

def encode(self, prompt: str) -> List[int]:
# `encode ` should only be used for prompt completion
# `encode` should only be used for prompt completion
# it should never be used for chat_completion.
# For chat completion use `apply_chat_template`
return self.tokenizer.encode(prompt, bos=True, eos=False)

def apply_chat_template(self,
conversation: List["ConversationMessage"],
messages: List["ChatCompletionMessageParam"],
tools: Optional[Dict[str, Any]] = None,
**kwargs) -> List[int]:
assert tools is None, "`tools` are not yet supported."

request = ChatCompletionRequest(
messages=conversation) # type: ignore[type-var]
request = ChatCompletionRequest(messages=messages,
tools=tools) # type: ignore[type-var]
encoded = self.mistral.encode_chat_completion(request)

# encode-decode to get clean prompt
return encoded.tokens

def convert_tokens_to_string(self, tokens: List[str]) -> str:
if self._is_tekken:
return "".join(tokens)
if isinstance(self.tokenizer, Tekkenizer):
return "".join(t for t in tokens
if t not in self.tokenizer._all_special_tokens)
else:
return self.tokenizer.decode(tokens) # type: ignore[arg-type]

Expand All @@ -152,14 +187,11 @@ def decode(self, ids: Union[List[int], int]) -> str:
ids = [ids]
return self.tokenizer.decode(ids)

@property
def eos_token_id(self):
return self.tokenizer.eos_id

def convert_ids_to_tokens(
self,
ids: List[int],
skip_special_tokens: Optional[bool] = True) -> List[str]:
self,
ids: List[int],
skip_special_tokens: bool = True,
) -> List[str]:
# TODO(Patrick) - potentially allow special tokens to not be skipped
if not skip_special_tokens:
log_once(
Expand All @@ -173,6 +205,3 @@ def convert_ids_to_tokens(

tokens = [self.tokenizer.id_to_piece(id) for id in ids]
return tokens

def __len__(self):
return self.vocab_size

0 comments on commit 1264e0b

Please sign in to comment.