From 373ee3fbbabc4c1508eed4f5c3795b23a20939a3 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 22 Feb 2024 19:10:21 +0100 Subject: [PATCH] Add Gemma chat template (#5665) * add gemma chat template * gemma: only apply system_prompt on non-model message --- llama.cpp | 22 ++++++++++++++++++++++ tests/test-chat-template.cpp | 4 ++++ 2 files changed, 26 insertions(+) diff --git a/llama.cpp b/llama.cpp index 6ab5e1bf4..40dda265c 100644 --- a/llama.cpp +++ b/llama.cpp @@ -12782,6 +12782,28 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "assistant\n"; } + } else if (tmpl.find("") != std::string::npos) { + // google/gemma-7b-it + std::string system_prompt = ""; + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + // there is no system message for gemma, but we will merge it with user prompt, so nothing is broken + system_prompt = trim(message->content); + continue; + } + // in gemma, "assistant" is "model" + role = role == "assistant" ? "model" : message->role; + ss << "" << role << "\n"; + if (!system_prompt.empty() && role != "model") { + ss << system_prompt << "\n\n"; + system_prompt = ""; + } + ss << trim(message->content) << "\n"; + } + if (add_ass) { + ss << "model\n"; + } } else { // template not supported return -1; diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index d02b39e14..fa2eb577b 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -29,6 +29,8 @@ int main(void) { "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif true == true and not '<>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'Vous êtes Vigogne, un assistant IA créé par Zaion Lab. Vous suivez extrêmement bien les instructions. Aidez autant que vous le pouvez.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\\\n' + system_message + '\\\\n<>\\\\n\\\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<>\\\\n' + content.strip() + '\\\\n<>\\\\n\\\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}", // mlabonne/AlphaMonarch-7B "{% for message in messages %}{{bos_token + message['role'] + '\\n' + message['content'] + eos_token + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ bos_token + 'assistant\\n' }}{% endif %}", + // google/gemma-7b-it + "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\\n' + message['content'] | trim + '\\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\\n'}}{% endif %}", }; std::vector expected_output = { // teknium/OpenHermes-2.5-Mistral-7B @@ -41,6 +43,8 @@ int main(void) { "[INST] <>\nYou are a helpful assistant\n<>\n\nHello [/INST] Hi there [INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", // mlabonne/AlphaMonarch-7B "system\nYou are a helpful assistant\nuser\nHello\nassistant\nHi there\nuser\nWho are you\nassistant\n I am an assistant \nuser\nAnother question\nassistant\n", + // google/gemma-7b-it + "user\nYou are a helpful assistant\n\nHello\nmodel\nHi there\nuser\nWho are you\nmodel\nI am an assistant\nuser\nAnother question\nmodel\n", }; std::vector formatted_chat(1024); int32_t res;