mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 02:44:36 +00:00
llama : add llama_chat_apply_template() (#5538)
* llama: add llama_chat_apply_template * test-chat-template: remove dedundant vector * chat_template: do not use std::string for buffer * add clarification for llama_chat_apply_template * llama_chat_apply_template: add zephyr template * llama_chat_apply_template: correct docs * llama_chat_apply_template: use term "chat" everywhere * llama_chat_apply_template: change variable name to "tmpl"
This commit is contained in:
parent
3a9cb4ca64
commit
11b12de39b
4
Makefile
4
Makefile
@ -867,3 +867,7 @@ tests/test-model-load-cancel: tests/test-model-load-cancel.cpp ggml.o llama.o te
|
|||||||
tests/test-autorelease: tests/test-autorelease.cpp ggml.o llama.o tests/get-model.cpp $(COMMON_DEPS) $(OBJS)
|
tests/test-autorelease: tests/test-autorelease.cpp ggml.o llama.o tests/get-model.cpp $(COMMON_DEPS) $(OBJS)
|
||||||
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
|
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
|
||||||
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
|
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
|
||||||
|
|
||||||
|
tests/test-chat-template: tests/test-chat-template.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
|
||||||
|
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
|
||||||
|
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
|
||||||
|
117
llama.cpp
117
llama.cpp
@ -12508,6 +12508,123 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// trim whitespace from the beginning and end of a string
|
||||||
|
static std::string trim(const std::string & str) {
|
||||||
|
size_t start = 0;
|
||||||
|
size_t end = str.size();
|
||||||
|
while (start < end && isspace(str[start])) {
|
||||||
|
start += 1;
|
||||||
|
}
|
||||||
|
while (end > start && isspace(str[end - 1])) {
|
||||||
|
end -= 1;
|
||||||
|
}
|
||||||
|
return str.substr(start, end - start);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simple version of "llama_apply_chat_template" that only works with strings
|
||||||
|
// This function uses heuristic checks to determine commonly used template. It is not a jinja parser.
|
||||||
|
static int32_t llama_chat_apply_template_internal(
|
||||||
|
const std::string & tmpl,
|
||||||
|
const std::vector<const llama_chat_message *> & chat,
|
||||||
|
std::string & dest, bool add_ass) {
|
||||||
|
// Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527
|
||||||
|
std::stringstream ss;
|
||||||
|
if (tmpl.find("<|im_start|>") != std::string::npos) {
|
||||||
|
// chatml template
|
||||||
|
for (auto message : chat) {
|
||||||
|
ss << "<|im_start|>" << message->role << "\n" << message->content << "<|im_end|>\n";
|
||||||
|
}
|
||||||
|
if (add_ass) {
|
||||||
|
ss << "<|im_start|>assistant\n";
|
||||||
|
}
|
||||||
|
} else if (tmpl.find("[INST]") != std::string::npos) {
|
||||||
|
// llama2 template and its variants
|
||||||
|
// [variant] support system message
|
||||||
|
bool support_system_message = tmpl.find("<<SYS>>") != std::string::npos;
|
||||||
|
// [variant] space before + after response
|
||||||
|
bool space_around_response = tmpl.find("' ' + eos_token") != std::string::npos;
|
||||||
|
// [variant] add BOS inside history
|
||||||
|
bool add_bos_inside_history = tmpl.find("bos_token + '[INST]") != std::string::npos;
|
||||||
|
// [variant] trim spaces from the input message
|
||||||
|
bool strip_message = tmpl.find("content.strip()") != std::string::npos;
|
||||||
|
// construct the prompt
|
||||||
|
bool is_inside_turn = true; // skip BOS at the beginning
|
||||||
|
ss << "[INST] ";
|
||||||
|
for (auto message : chat) {
|
||||||
|
std::string content = strip_message ? trim(message->content) : message->content;
|
||||||
|
std::string role(message->role);
|
||||||
|
if (!is_inside_turn) {
|
||||||
|
is_inside_turn = true;
|
||||||
|
ss << (add_bos_inside_history ? "<s>[INST] " : "[INST] ");
|
||||||
|
}
|
||||||
|
if (role == "system") {
|
||||||
|
if (support_system_message) {
|
||||||
|
ss << "<<SYS>>\n" << content << "\n<</SYS>>\n\n";
|
||||||
|
} else {
|
||||||
|
// if the model does not support system message, we still include it in the first message, but without <<SYS>>
|
||||||
|
ss << content << "\n";
|
||||||
|
}
|
||||||
|
} else if (role == "user") {
|
||||||
|
ss << content << " [/INST]";
|
||||||
|
} else {
|
||||||
|
ss << (space_around_response ? " " : "") << content << (space_around_response ? " " : "") << "</s>";
|
||||||
|
is_inside_turn = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// llama2 templates seem to not care about "add_generation_prompt"
|
||||||
|
} else if (tmpl.find("<|user|>") != std::string::npos) {
|
||||||
|
// zephyr template
|
||||||
|
for (auto message : chat) {
|
||||||
|
ss << "<|" << message->role << "|>" << "\n" << message->content << "<|endoftext|>\n";
|
||||||
|
}
|
||||||
|
if (add_ass) {
|
||||||
|
ss << "<|assistant|>\n";
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// template not supported
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
dest = ss.str();
|
||||||
|
return dest.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
LLAMA_API int32_t llama_chat_apply_template(
|
||||||
|
const struct llama_model * model,
|
||||||
|
const char * tmpl,
|
||||||
|
const struct llama_chat_message * chat,
|
||||||
|
size_t n_msg,
|
||||||
|
bool add_ass,
|
||||||
|
char * buf,
|
||||||
|
int32_t length) {
|
||||||
|
std::string curr_tmpl(tmpl == nullptr ? "" : tmpl);
|
||||||
|
if (tmpl == nullptr) {
|
||||||
|
GGML_ASSERT(model != nullptr);
|
||||||
|
// load template from model
|
||||||
|
std::vector<char> model_template(2048, 0); // longest known template is about 1200 bytes
|
||||||
|
std::string template_key = "tokenizer.chat_template";
|
||||||
|
int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), curr_tmpl.size());
|
||||||
|
if (res < 0) {
|
||||||
|
// worst case: there is no information about template, we will use chatml by default
|
||||||
|
curr_tmpl = "<|im_start|>"; // see llama_chat_apply_template_internal
|
||||||
|
} else {
|
||||||
|
curr_tmpl = std::string(model_template.data(), model_template.size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// format the chat to string
|
||||||
|
std::vector<const llama_chat_message *> chat_vec;
|
||||||
|
chat_vec.resize(n_msg);
|
||||||
|
for (size_t i = 0; i < n_msg; i++) {
|
||||||
|
chat_vec[i] = &chat[i];
|
||||||
|
}
|
||||||
|
std::string formatted_chat;
|
||||||
|
int32_t res = llama_chat_apply_template_internal(curr_tmpl, chat_vec, formatted_chat, add_ass);
|
||||||
|
if (res < 0) {
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
strncpy(buf, formatted_chat.c_str(), length);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
struct llama_timings llama_get_timings(struct llama_context * ctx) {
|
struct llama_timings llama_get_timings(struct llama_context * ctx) {
|
||||||
struct llama_timings result = {
|
struct llama_timings result = {
|
||||||
/*.t_start_ms =*/ 1e-3 * ctx->t_start_us,
|
/*.t_start_ms =*/ 1e-3 * ctx->t_start_us,
|
||||||
|
25
llama.h
25
llama.h
@ -305,6 +305,12 @@ extern "C" {
|
|||||||
int32_t n_eval;
|
int32_t n_eval;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// used in chat template
|
||||||
|
typedef struct llama_chat_message {
|
||||||
|
const char * role;
|
||||||
|
const char * content;
|
||||||
|
} llama_chat_message;
|
||||||
|
|
||||||
// Helpers for getting default parameters
|
// Helpers for getting default parameters
|
||||||
LLAMA_API struct llama_model_params llama_model_default_params(void);
|
LLAMA_API struct llama_model_params llama_model_default_params(void);
|
||||||
LLAMA_API struct llama_context_params llama_context_default_params(void);
|
LLAMA_API struct llama_context_params llama_context_default_params(void);
|
||||||
@ -699,6 +705,25 @@ extern "C" {
|
|||||||
char * buf,
|
char * buf,
|
||||||
int32_t length);
|
int32_t length);
|
||||||
|
|
||||||
|
/// Apply chat template. Inspired by hf apply_chat_template() on python.
|
||||||
|
/// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model"
|
||||||
|
/// NOTE: This function only support some known jinja templates. It is not a jinja parser.
|
||||||
|
/// @param tmpl A Jinja template to use for this chat. If this is nullptr, the model’s default chat template will be used instead.
|
||||||
|
/// @param chat Pointer to a list of multiple llama_chat_message
|
||||||
|
/// @param n_msg Number of llama_chat_message in this chat
|
||||||
|
/// @param add_ass Whether to end the prompt with the token(s) that indicate the start of an assistant message.
|
||||||
|
/// @param buf A buffer to hold the output formatted prompt. The recommended alloc size is 2 * (total number of characters of all messages)
|
||||||
|
/// @param length The size of the allocated buffer
|
||||||
|
/// @return The total number of bytes of the formatted prompt. If is it larger than the size of buffer, you may need to re-alloc it and then re-apply the template.
|
||||||
|
LLAMA_API int32_t llama_chat_apply_template(
|
||||||
|
const struct llama_model * model,
|
||||||
|
const char * tmpl,
|
||||||
|
const struct llama_chat_message * chat,
|
||||||
|
size_t n_msg,
|
||||||
|
bool add_ass,
|
||||||
|
char * buf,
|
||||||
|
int32_t length);
|
||||||
|
|
||||||
//
|
//
|
||||||
// Grammar
|
// Grammar
|
||||||
//
|
//
|
||||||
|
@ -28,6 +28,7 @@ endfunction()
|
|||||||
llama_build_and_test_executable(test-quantize-fns.cpp)
|
llama_build_and_test_executable(test-quantize-fns.cpp)
|
||||||
llama_build_and_test_executable(test-quantize-perf.cpp)
|
llama_build_and_test_executable(test-quantize-perf.cpp)
|
||||||
llama_build_and_test_executable(test-sampling.cpp)
|
llama_build_and_test_executable(test-sampling.cpp)
|
||||||
|
llama_build_and_test_executable(test-chat-template.cpp)
|
||||||
|
|
||||||
llama_build_executable(test-tokenizer-0-llama.cpp)
|
llama_build_executable(test-tokenizer-0-llama.cpp)
|
||||||
llama_test_executable (test-tokenizer-0-llama test-tokenizer-0-llama.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.gguf)
|
llama_test_executable (test-tokenizer-0-llama test-tokenizer-0-llama.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.gguf)
|
||||||
|
64
tests/test-chat-template.cpp
Normal file
64
tests/test-chat-template.cpp
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
#include <iostream>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
#undef NDEBUG
|
||||||
|
#include <cassert>
|
||||||
|
|
||||||
|
#include "llama.h"
|
||||||
|
|
||||||
|
int main(void) {
|
||||||
|
llama_chat_message conversation[] = {
|
||||||
|
{"system", "You are a helpful assistant"},
|
||||||
|
{"user", "Hello"},
|
||||||
|
{"assistant", "Hi there"},
|
||||||
|
{"user", "Who are you"},
|
||||||
|
{"assistant", " I am an assistant "},
|
||||||
|
{"user", "Another question"},
|
||||||
|
};
|
||||||
|
size_t message_count = 6;
|
||||||
|
std::vector<std::string> templates = {
|
||||||
|
// teknium/OpenHermes-2.5-Mistral-7B
|
||||||
|
"{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}",
|
||||||
|
// mistralai/Mistral-7B-Instruct-v0.2
|
||||||
|
"{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
|
||||||
|
// TheBloke/FusionNet_34Bx2_MoE-AWQ
|
||||||
|
"{%- for idx in range(0, messages|length) -%}\\n{%- if messages[idx]['role'] == 'user' -%}\\n{%- if idx > 1 -%}\\n{{- bos_token + '[INST] ' + messages[idx]['content'] + ' [/INST]' -}}\\n{%- else -%}\\n{{- messages[idx]['content'] + ' [/INST]' -}}\\n{%- endif -%}\\n{% elif messages[idx]['role'] == 'system' %}\\n{{- '[INST] <<SYS>>\\\\n' + messages[idx]['content'] + '\\\\n<</SYS>>\\\\n\\\\n' -}}\\n{%- elif messages[idx]['role'] == 'assistant' -%}\\n{{- ' ' + messages[idx]['content'] + ' ' + eos_token -}}\\n{% endif %}\\n{% endfor %}",
|
||||||
|
// bofenghuang/vigogne-2-70b-chat
|
||||||
|
"{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif true == true and not '<<SYS>>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'Vous êtes Vigogne, un assistant IA créé par Zaion Lab. Vous suivez extrêmement bien les instructions. Aidez autant que vous le pouvez.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\\\n' + system_message + '\\\\n<</SYS>>\\\\n\\\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<<SYS>>\\\\n' + content.strip() + '\\\\n<</SYS>>\\\\n\\\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}",
|
||||||
|
};
|
||||||
|
std::vector<std::string> expected_substr = {
|
||||||
|
"<|im_start|>assistant\n I am an assistant <|im_end|>\n<|im_start|>user\nAnother question<|im_end|>\n<|im_start|>assistant",
|
||||||
|
"[/INST]Hi there</s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]",
|
||||||
|
"</s><s>[INST] Who are you [/INST] I am an assistant </s><s>[INST] Another question [/INST]",
|
||||||
|
"[/INST] Hi there </s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]",
|
||||||
|
};
|
||||||
|
std::vector<char> formatted_chat(1024);
|
||||||
|
int32_t res;
|
||||||
|
|
||||||
|
// test invalid chat template
|
||||||
|
res = llama_chat_apply_template(nullptr, "INVALID TEMPLATE", conversation, message_count, true, formatted_chat.data(), formatted_chat.size());
|
||||||
|
assert(res < 0);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < templates.size(); i++) {
|
||||||
|
std::string custom_template = templates[i];
|
||||||
|
std::string substr = expected_substr[i];
|
||||||
|
formatted_chat.resize(1024);
|
||||||
|
res = llama_chat_apply_template(
|
||||||
|
nullptr,
|
||||||
|
custom_template.c_str(),
|
||||||
|
conversation,
|
||||||
|
message_count,
|
||||||
|
true,
|
||||||
|
formatted_chat.data(),
|
||||||
|
formatted_chat.size()
|
||||||
|
);
|
||||||
|
formatted_chat.resize(res);
|
||||||
|
std::string output(formatted_chat.data(), formatted_chat.size());
|
||||||
|
std::cout << output << "\n-------------------------\n";
|
||||||
|
// expect the "formatted_chat" to contain pre-defined strings
|
||||||
|
assert(output.find(substr) != std::string::npos);
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user