From a62c6ff0129a86691e07463fbdbfdbfef943d9cd Mon Sep 17 00:00:00 2001 From: jerryzhuang Date: Tue, 29 Oct 2024 13:06:58 +1100 Subject: [PATCH] add chat template directly Signed-off-by: jerryzhuang --- docker/presets/models/tfs/Dockerfile | 8 +------ .../chat_templates/falcon-instruct.jinja | 20 ++++++++++++++++ .../chat_templates/llama-2-chat.jinja | 24 +++++++++++++++++++ .../chat_templates/llama-3-instruct.jinja | 18 ++++++++++++++ .../chat_templates/mistral-instruct.jinja | 19 +++++++++++++++ .../chat_templates/phi-3-small.jinja | 18 ++++++++++++++ presets/inference/chat_templates/phi-3.jinja | 17 +++++++++++++ 7 files changed, 117 insertions(+), 7 deletions(-) create mode 100644 presets/inference/chat_templates/falcon-instruct.jinja create mode 100644 presets/inference/chat_templates/llama-2-chat.jinja create mode 100644 presets/inference/chat_templates/llama-3-instruct.jinja create mode 100644 presets/inference/chat_templates/mistral-instruct.jinja create mode 100644 presets/inference/chat_templates/phi-3-small.jinja create mode 100644 presets/inference/chat_templates/phi-3.jinja diff --git a/docker/presets/models/tfs/Dockerfile b/docker/presets/models/tfs/Dockerfile index 6e1c0d045..da5a3b732 100644 --- a/docker/presets/models/tfs/Dockerfile +++ b/docker/presets/models/tfs/Dockerfile @@ -24,13 +24,7 @@ COPY kaito/presets/inference/${MODEL_TYPE}/inference_api.py \ COPY kaito/presets/inference/vllm/inference_api.py /workspace/vllm/inference_api.py # Chat template -RUN apt update && apt install -y git && \ - rm /var/lib/apt/lists/* -r -RUN git clone https://github.com/chujiezheng/chat_templates /tmp/chat_templates && \ - cd /tmp/chat_templates && \ - git reset --hard 670a2eb && \ - cp -r ./chat_templates/ /workspace/ && \ - rm -rf /tmp/chat_templates +ADD kaito/presets/inference/chat_templates /workspace/chat_templates # Model weights COPY ${WEIGHTS_PATH} /workspace/weights diff --git a/presets/inference/chat_templates/falcon-instruct.jinja b/presets/inference/chat_templates/falcon-instruct.jinja new file mode 100644 index 000000000..19699ff6d --- /dev/null +++ b/presets/inference/chat_templates/falcon-instruct.jinja @@ -0,0 +1,20 @@ +{% if messages[0]['role'] == 'system' %} + {% set system_message = messages[0]['content'] %} + {% set messages = messages[1:] %} +{% else %} + {% set system_message = '' %} +{% endif %} + +{{ system_message | trim }} +{% for message in messages %} + {% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %} + {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }} + {% endif %} + + {% set content = message['content'].replace('\r\n', '\n').replace('\n\n', '\n') %} + {{ '\n\n' + message['role'] | capitalize + ': ' + content | trim }} +{% endfor %} + +{% if add_generation_prompt %} + {{ '\n\nAssistant:' }} +{% endif %} \ No newline at end of file diff --git a/presets/inference/chat_templates/llama-2-chat.jinja b/presets/inference/chat_templates/llama-2-chat.jinja new file mode 100644 index 000000000..8b0bc1e4b --- /dev/null +++ b/presets/inference/chat_templates/llama-2-chat.jinja @@ -0,0 +1,24 @@ +{% if messages[0]['role'] == 'system' %} + {% set system_message = '<>\n' + messages[0]['content'] | trim + '\n<>\n\n' %} + {% set messages = messages[1:] %} +{% else %} + {% set system_message = '' %} +{% endif %} + +{% for message in messages %} + {% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %} + {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }} + {% endif %} + + {% if loop.index0 == 0 %} + {% set content = system_message + message['content'] %} + {% else %} + {% set content = message['content'] %} + {% endif %} + + {% if message['role'] == 'user' %} + {{ bos_token + '[INST] ' + content | trim + ' [/INST]' }} + {% elif message['role'] == 'assistant' %} + {{ ' ' + content | trim + ' ' + eos_token }} + {% endif %} +{% endfor %} \ No newline at end of file diff --git a/presets/inference/chat_templates/llama-3-instruct.jinja b/presets/inference/chat_templates/llama-3-instruct.jinja new file mode 100644 index 000000000..40dbc415f --- /dev/null +++ b/presets/inference/chat_templates/llama-3-instruct.jinja @@ -0,0 +1,18 @@ +{% if messages[0]['role'] == 'system' %} + {% set offset = 1 %} +{% else %} + {% set offset = 0 %} +{% endif %} + +{{ bos_token }} +{% for message in messages %} + {% if (message['role'] == 'user') != (loop.index0 % 2 == offset) %} + {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }} + {% endif %} + + {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' }} +{% endfor %} + +{% if add_generation_prompt %} + {{ '<|start_header_id|>' + 'assistant' + '<|end_header_id|>\n\n' }} +{% endif %} \ No newline at end of file diff --git a/presets/inference/chat_templates/mistral-instruct.jinja b/presets/inference/chat_templates/mistral-instruct.jinja new file mode 100644 index 000000000..9ec65aa99 --- /dev/null +++ b/presets/inference/chat_templates/mistral-instruct.jinja @@ -0,0 +1,19 @@ +{% if messages[0]['role'] == 'system' %} + {% set system_message = messages[0]['content'] | trim + '\n\n' %} + {% set messages = messages[1:] %} +{% else %} + {% set system_message = '' %} +{% endif %} + +{{ bos_token + system_message}} +{% for message in messages %} + {% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %} + {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }} + {% endif %} + + {% if message['role'] == 'user' %} + {{ '[INST] ' + message['content'] | trim + ' [/INST]' }} + {% elif message['role'] == 'assistant' %} + {{ ' ' + message['content'] | trim + eos_token }} + {% endif %} +{% endfor %} \ No newline at end of file diff --git a/presets/inference/chat_templates/phi-3-small.jinja b/presets/inference/chat_templates/phi-3-small.jinja new file mode 100644 index 000000000..f66f223d6 --- /dev/null +++ b/presets/inference/chat_templates/phi-3-small.jinja @@ -0,0 +1,18 @@ +{% if messages[0]['role'] == 'system' %} + {% set offset = 1 %} +{% else %} + {% set offset = 0 %} +{% endif %} + +{{ bos_token }} +{% for message in messages %} + {% if (message['role'] == 'user') != (loop.index0 % 2 == offset) %} + {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }} + {% endif %} + + {{ '<|' + message['role'] + '|>\n' + message['content'] | trim + '<|end|>' + '\n' }} +{% endfor %} + +{% if add_generation_prompt %} + {{ '<|assistant|>\n' }} +{% endif %} \ No newline at end of file diff --git a/presets/inference/chat_templates/phi-3.jinja b/presets/inference/chat_templates/phi-3.jinja new file mode 100644 index 000000000..c94a2423a --- /dev/null +++ b/presets/inference/chat_templates/phi-3.jinja @@ -0,0 +1,17 @@ +{% if messages[0]['role'] == 'system' %} + {% set offset = 1 %} +{% else %} + {% set offset = 0 %} +{% endif %} + +{% for message in messages %} + {% if (message['role'] == 'user') != (loop.index0 % 2 == offset) %} + {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }} + {% endif %} + + {{ '<|' + message['role'] + '|>\n' + message['content'] | trim + '<|end|>' + '\n' }} +{% endfor %} + +{% if add_generation_prompt %} + {{ '<|assistant|>\n' }} +{% endif %} \ No newline at end of file