tool-call: factor chat template away from legacy API

This commit is contained in:
ochafik 2024-09-26 17:19:29 +01:00
parent d7ec84f78c
commit cf7bece6a7
15 changed files with 431 additions and 399 deletions

View File

@ -934,6 +934,7 @@ OBJ_LLAMA = \
OBJ_COMMON = \ OBJ_COMMON = \
common/common.o \ common/common.o \
common/chat-template.o \
common/arg.o \ common/arg.o \
common/log.o \ common/log.o \
common/console.o \ common/console.o \
@ -1170,6 +1171,8 @@ $(LIB_LLAMA_S): \
common/common.o: \ common/common.o: \
common/common.cpp \ common/common.cpp \
common/common.h \ common/common.h \
common/chat-template.cpp \
common/chat-template.h \
common/console.h \ common/console.h \
common/sampling.h \ common/sampling.h \
common/json.hpp \ common/json.hpp \
@ -1465,6 +1468,7 @@ llama-server: \
examples/server/prompt-formats.js.hpp \ examples/server/prompt-formats.js.hpp \
examples/server/json-schema-to-grammar.mjs.hpp \ examples/server/json-schema-to-grammar.mjs.hpp \
examples/server/loading.html.hpp \ examples/server/loading.html.hpp \
common/chat-template.h \
common/json.hpp \ common/json.hpp \
common/stb_image.h \ common/stb_image.h \
$(OBJ_ALL) $(OBJ_ALL)

View File

@ -54,6 +54,8 @@ add_library(${TARGET} STATIC
arg.cpp arg.cpp
arg.h arg.h
base64.hpp base64.hpp
chat-template.cpp
chat-template.h
common.cpp common.cpp
common.h common.h
console.cpp console.cpp

118
common/chat-template.cpp Normal file
View File

@ -0,0 +1,118 @@
#include "chat-template.h"
#include "minja.hpp"
#include "llama.h"
using json = nlohmann::ordered_json;
static std::string _llama_token_to_piece(const struct llama_model * model, llama_token token, bool special) {
std::string piece;
piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n'
const int n_chars = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special);
if (n_chars < 0) {
piece.resize(-n_chars);
int check = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special);
GGML_ASSERT(check == -n_chars);
}
else {
piece.resize(n_chars);
}
return piece;
}
static std::string llama_model_meta_val_str(const struct llama_model * model, const char * key) {
int32_t tlen = llama_model_meta_val_str(model, key, nullptr, 0);
if (tlen > 0) {
std::vector<char> curr_tmpl_buf(tlen + 1, 0);
if (llama_model_meta_val_str(model, key, curr_tmpl_buf.data(), curr_tmpl_buf.size()) == tlen) {
return std::string(curr_tmpl_buf.data(), tlen);
}
}
return "";
}
llama_chat_template llama_chat_template::from_model(
const struct llama_model * model,
const std::string & chat_template_override)
{
// TODO: handle "chatml"?
auto chat_template = chat_template_override.empty()
? llama_model_meta_val_str(model, "tokenizer.chat_template")
: chat_template_override;
auto bos_token = _llama_token_to_piece(model, llama_token_bos(model), true);
auto eos_token = _llama_token_to_piece(model, llama_token_eos(model), true);
return llama_chat_template(chat_template, bos_token, eos_token);
}
std::string llama_chat_template::apply(
const json & messages,
const json & tools,
bool add_generation_prompt) const
{
auto actual_messages = messages;
// First, "fix" messages so they have a chance to be rendered correctly by the template
if (_requires_object_arguments || !_supports_system_role) {
std::string pending_system;
auto flush_sys = [&]() {
if (!pending_system.empty()) {
actual_messages.push_back({
{"role", "user"},
{"content", pending_system},
});
pending_system.clear();
}
};
for (auto & message : actual_messages) {
if (!message.contains("role") || !message.contains("content")) {
throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump());
}
std::string role = message.at("role");
std::string content = message.at("content");
if (!_supports_system_role) {
if (role == "system") {
if (!pending_system.empty()) pending_system += "\n";
pending_system += content;
continue;
} else {
if (role == "user") {
if (!pending_system.empty()) {
message["content"] = pending_system + (content.empty() ? "" : "\n" + content);
pending_system.clear();
}
} else {
flush_sys();
}
}
}
if (_requires_object_arguments && message.contains("tool_calls")) {
for (auto & tool_call : message.at("tool_calls")) {
std::string arguments = tool_call.at("arguments");
tool_call["arguments"] = json::parse(arguments);
}
}
}
flush_sys();
}
auto context = minja::Context::make(json({
{"messages", actual_messages},
{"add_generation_prompt", add_generation_prompt},
{"bos_token", _bos_token},
{"eos_token", _eos_token},
}));
if (!tools.is_null() && !tools.empty()) {
auto tools_val = minja::Value(tools);
context->set("tools", tools_val);
}
auto tmpl_root = minja::Parser::parse(_chat_template, {
/* .trim_blocks = */ true,
/* .lstrip_blocks = */ true,
/* .keep_trailing_newline = */ false,
});
return tmpl_root->render(context);
}

64
common/chat-template.h Normal file
View File

@ -0,0 +1,64 @@
#pragma once
#include <json.hpp>
#include <string>
#include <vector>
using json = nlohmann::ordered_json;
enum llama_tool_call_style {
Unknown,
Llama31,
FunctionaryV3Llama3,
FunctionaryV3Llama31,
Hermes2Pro,
};
class llama_chat_template {
public:
private:
llama_tool_call_style _tool_call_style = Unknown;
bool _supports_tools = true;
// Meta-Llama-3.1-8B-Instruct's template expects arguments to be an object.
// Most other templates (and OpenAI's API) expect the arguments object to be stringified.
bool _requires_object_arguments = false;
bool _supports_system_role = true;
std::string _chat_template;
std::string _bos_token;
std::string _eos_token;
public:
llama_chat_template(const std::string & chat_template, const std::string & bos_token, const std::string & eos_token)
: _chat_template(chat_template), _bos_token(bos_token), _eos_token(eos_token) {
_supports_tools = chat_template.find("tools") != std::string::npos;
_requires_object_arguments = chat_template.find("tool_call.arguments | items") != std::string::npos;
_supports_system_role = chat_template.find("System role not supported") == std::string::npos;
if (chat_template.find("<tool_call>") != std::string::npos) {
_tool_call_style = Hermes2Pro;
} else if (chat_template.find(">>>all") != std::string::npos) {
_tool_call_style = FunctionaryV3Llama3;
} else if (chat_template.find("<|start_header_id|>") != std::string::npos) {
if (chat_template.find("<function=") != std::string::npos) {
_tool_call_style = FunctionaryV3Llama31;
} else if (chat_template.find("<|python_tag|>") != std::string::npos) {
_tool_call_style = Llama31;
}
}
}
static llama_chat_template from_model(
const struct llama_model * model,
const std::string & chat_template_override);
llama_tool_call_style tool_call_style() const { return _tool_call_style; }
const std::string & chat_template() const { return _chat_template; }
bool supports_tools() const { return _supports_tools; }
std::string apply(
const nlohmann::ordered_json & messages,
const nlohmann::ordered_json & tools,
bool add_generation_prompt) const;
};

View File

@ -9,6 +9,7 @@
#include "json.hpp" #include "json.hpp"
#include "json-schema-to-grammar.h" #include "json-schema-to-grammar.h"
#include "llama.h" #include "llama.h"
#include "chat-template.h"
#include <algorithm> #include <algorithm>
#include <cinttypes> #include <cinttypes>
@ -1511,6 +1512,20 @@ std::string llama_detokenize(llama_context * ctx, const std::vector<llama_token>
// //
bool llama_chat_verify_template(const std::string & tmpl, bool use_jinja) { bool llama_chat_verify_template(const std::string & tmpl, bool use_jinja) {
if (use_jinja) {
try {
auto chat_template = llama_chat_template(tmpl, "<s>", "</s>");
chat_template.apply({{
{"role", "user"},
{"content", "test"},
}}, json(), true);
return true;
} catch (const std::exception & e) {
LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what());
return false;
}
}
llama_chat_message chat[] = {{"user", "test"}}; llama_chat_message chat[] = {{"user", "test"}};
int res = llama_chat_apply_template( int res = llama_chat_apply_template(
nullptr, nullptr,
@ -1519,22 +1534,14 @@ bool llama_chat_verify_template(const std::string & tmpl, bool use_jinja) {
1, 1,
/* add_ass= */ true, /* add_ass= */ true,
/* buffer= */ nullptr, /* buffer= */ nullptr,
/* length= */ 0, /* length= */ 0);
use_jinja,
/* tools= */ nullptr,
"<s>",
"</s>");
return res >= 0; return res >= 0;
} }
std::string llama_chat_apply_template(const struct llama_model * model, std::string llama_chat_apply_template(const struct llama_model * model,
const std::string & tmpl, const std::string & tmpl,
const std::vector<llama_chat_msg> & msgs, const std::vector<llama_chat_msg> & msgs,
bool add_ass, bool add_ass) {
bool use_jinja,
const char * tools,
const char * bos_token,
const char * eos_token) {
int alloc_size = 0; int alloc_size = 0;
bool fallback = false; // indicate if we must fallback to default chatml bool fallback = false; // indicate if we must fallback to default chatml
std::vector<llama_chat_message> chat; std::vector<llama_chat_message> chat;
@ -1547,7 +1554,7 @@ std::string llama_chat_apply_template(const struct llama_model * model,
std::vector<char> buf(alloc_size); std::vector<char> buf(alloc_size);
// run the first time to get the total output length // run the first time to get the total output length
int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size(), use_jinja, tools, bos_token, eos_token); int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
// error: chat template is not supported // error: chat template is not supported
if (res < 0) { if (res < 0) {
@ -1557,7 +1564,7 @@ std::string llama_chat_apply_template(const struct llama_model * model,
throw std::runtime_error("this custom template is not supported"); throw std::runtime_error("this custom template is not supported");
} else { } else {
// If the built-in template is not supported, we default to chatml // If the built-in template is not supported, we default to chatml
res = llama_chat_apply_template(nullptr, "chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size(), use_jinja, tools, bos_token, eos_token); res = llama_chat_apply_template(nullptr, "chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size());
fallback = true; fallback = true;
} }
} }
@ -1568,7 +1575,7 @@ std::string llama_chat_apply_template(const struct llama_model * model,
res = llama_chat_apply_template( res = llama_chat_apply_template(
fallback ? nullptr : model, fallback ? nullptr : model,
fallback ? "chatml" : ptr_tmpl, fallback ? "chatml" : ptr_tmpl,
chat.data(), chat.size(), add_ass, buf.data(), buf.size(), use_jinja, tools, bos_token, eos_token); chat.data(), chat.size(), add_ass, buf.data(), buf.size());
} }
std::string formatted_chat(buf.data(), res); std::string formatted_chat(buf.data(), res);
@ -1579,13 +1586,9 @@ std::string llama_chat_format_single(const struct llama_model * model,
const std::string & tmpl, const std::string & tmpl,
const std::vector<llama_chat_msg> & past_msg, const std::vector<llama_chat_msg> & past_msg,
const llama_chat_msg & new_msg, const llama_chat_msg & new_msg,
bool add_ass, bool add_ass) {
bool use_jinja,
const char * tools,
const char * bos_token,
const char * eos_token) {
std::ostringstream ss; std::ostringstream ss;
auto fmt_past_msg = past_msg.empty() ? "" : llama_chat_apply_template(model, tmpl, past_msg, false, use_jinja, tools, bos_token, eos_token); auto fmt_past_msg = past_msg.empty() ? "" : llama_chat_apply_template(model, tmpl, past_msg, false);
std::vector<llama_chat_msg> chat_new(past_msg); std::vector<llama_chat_msg> chat_new(past_msg);
// if the past_msg ends with a newline, we must preserve it in the formatted version // if the past_msg ends with a newline, we must preserve it in the formatted version
if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') { if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
@ -1593,7 +1596,7 @@ std::string llama_chat_format_single(const struct llama_model * model,
}; };
// format chat with new_msg // format chat with new_msg
chat_new.push_back(new_msg); chat_new.push_back(new_msg);
auto fmt_new_msg = llama_chat_apply_template(model, tmpl, chat_new, add_ass, use_jinja, tools, bos_token, eos_token); auto fmt_new_msg = llama_chat_apply_template(model, tmpl, chat_new, add_ass);
// get the diff part // get the diff part
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size()); ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
return ss.str(); return ss.str();

View File

@ -471,21 +471,14 @@ std::string llama_detokenize(
// Chat template utils // Chat template utils
// //
struct llama_chat_msg_tool_call {
std::string name;
std::string arguments;
};
// same as llama_chat_message, but uses std::string and std::vector // same as llama_chat_message, but uses std::string and std::vector
struct llama_chat_msg { struct llama_chat_msg {
std::string role; std::string role;
std::string content; std::string content;
std::string tool;
std::vector<struct llama_chat_msg_tool_call> tool_calls;
}; };
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid // Check if the template is supported or not. Returns true if it's valid
bool llama_chat_verify_template(const std::string & tmpl, bool use_jinja = false); bool llama_chat_verify_template(const std::string & tmpl, bool use_jinja);
// CPP wrapper for llama_chat_apply_template // CPP wrapper for llama_chat_apply_template
// If the built-in template is not supported, we default to chatml // If the built-in template is not supported, we default to chatml
@ -493,22 +486,14 @@ bool llama_chat_verify_template(const std::string & tmpl, bool use_jinja = false
std::string llama_chat_apply_template(const struct llama_model * model, std::string llama_chat_apply_template(const struct llama_model * model,
const std::string & tmpl, const std::string & tmpl,
const std::vector<llama_chat_msg> & chat, const std::vector<llama_chat_msg> & chat,
bool add_ass, bool add_ass);
bool use_jinja = false,
const char * tools = nullptr,
const char * bos_token = nullptr,
const char * eos_token = nullptr);
// Format single message, while taking into account the position of that message in chat history // Format single message, while taking into account the position of that message in chat history
std::string llama_chat_format_single(const struct llama_model * model, std::string llama_chat_format_single(const struct llama_model * model,
const std::string & tmpl, const std::string & tmpl,
const std::vector<llama_chat_msg> & past_msg, const std::vector<llama_chat_msg> & past_msg,
const llama_chat_msg & new_msg, const llama_chat_msg & new_msg,
bool add_ass, bool add_ass);
bool use_jinja = false,
const char * tools = nullptr,
const char * bos_token = nullptr,
const char * eos_token = nullptr);
// Returns an example of formatted chat // Returns an example of formatted chat
std::string llama_chat_format_example(const struct llama_model * model, std::string llama_chat_format_example(const struct llama_model * model,

View File

@ -12,27 +12,6 @@
using json = nlohmann::ordered_json; using json = nlohmann::ordered_json;
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3.llama3.txt
static bool needs_functionary_v3_tool_call(const std::string & chat_template) {
return chat_template.find("<|start_header_id|>") != std::string::npos
&& chat_template.find(">>>all") != std::string::npos;
}
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
static bool needs_functionary_v3_llama_3_1_tool_call(const std::string & chat_template) {
return chat_template.find("<|start_header_id|>") != std::string::npos
&& chat_template.find("<function=") != std::string::npos;
}
static bool needs_llama_3_1_tool_call(const std::string & chat_template) {
return chat_template.find("<|start_header_id|>") != std::string::npos
&& chat_template.find("<|python_tag|>") != std::string::npos;
}
static bool needs_hermes_pro_tool_call(const std::string & chat_template) {
return chat_template.find("<tool_call>") != std::string::npos;
}
static bool parse_json(std::string::const_iterator & it, const std::string::const_iterator & end, json & out) { static bool parse_json(std::string::const_iterator & it, const std::string::const_iterator & end, json & out) {
// // https://json.nlohmann.me/features/parsing/sax_interface/ // // https://json.nlohmann.me/features/parsing/sax_interface/
struct json_error_locator : public nlohmann::json_sax<json> { struct json_error_locator : public nlohmann::json_sax<json> {
@ -209,137 +188,145 @@ static llama_tool_calls parse_functionary_v3_tool_calls(const std::string& input
return parse_functionary_tool_calls(input, function_regex, close_regex); return parse_functionary_tool_calls(input, function_regex, close_regex);
} }
llama_tool_calls parse_tool_calls(const json & tools, const std::string & chat_template, const std::string& input) { llama_tool_calls parse_tool_calls(llama_tool_call_style style, const json & tools, const std::string& input) {
if (needs_hermes_pro_tool_call(chat_template)) { switch (style) {
return parse_hermes_tool_calls(input); case llama_tool_call_style::Llama31:
} else if (needs_functionary_v3_tool_call(chat_template)) { return parse_llama_3_1_tool_calls(tools, input);
return parse_functionary_v3_tool_calls(input); case llama_tool_call_style::FunctionaryV3Llama3:
} else if (needs_functionary_v3_llama_3_1_tool_call(chat_template)) { return parse_functionary_v3_tool_calls(input);
return parse_functionary_v3_llama_3_1_tool_calls(input); case llama_tool_call_style::FunctionaryV3Llama31:
} else if (needs_llama_3_1_tool_call(chat_template)) { return parse_functionary_v3_llama_3_1_tool_calls(input);
return parse_llama_3_1_tool_calls(tools, input); case llama_tool_call_style::Hermes2Pro:
} else { return parse_hermes_tool_calls(input);
throw std::runtime_error("Unsupported chat template for tool calls"); default:
throw std::runtime_error("Unsupported tool call style");
} }
} }
llama_tool_call_handler llama_tool_call_handler_init( llama_tool_call_handler llama_tool_call_handler_init(
const std::string & chat_template, const llama_chat_template & tmpl,
bool allow_content, bool allow_content,
bool parallel_tool_calls, bool parallel_tool_calls,
const nlohmann::ordered_json & tools) const nlohmann::ordered_json & tools)
{ {
llama_tool_call_handler handler; llama_tool_call_handler handler;
if (needs_functionary_v3_tool_call(chat_template)) { switch (tmpl.tool_call_style()) {
// MeetKaiFunctionary_3_2 case llama_tool_call_style::Llama31: {
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}... handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar static std::vector<std::string> builtin_tools {"wolfram_alpha", "brave_search"};
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) { std::vector<std::string> tool_rules;
std::vector<std::string> tool_rules;
for (size_t i = 0, n = tools.size(); i < n; i++) { for (const auto & tool : tools) {
auto & tool = tools[i]; const auto & function = tool["function"];
const auto & function = tool["function"]; std::string name = function["name"];
std::string name = function["name"]; auto parameters = function["parameters"];
auto parameters = function["parameters"]; builder.resolve_refs(parameters);
auto tool_rule = builder.add_rule(name + "-call", "\">>>" + name + "\\n\" " + builder.add_schema(name + "-args", parameters)); if (name == "ipython" || std::find(builtin_tools.begin(), builtin_tools.end(), name) != builtin_tools.end()) {
tool_rules.push_back(tool_rule); tool_rules.push_back(builder.add_rule("ipython-call", "\"<|python_tag|>\" .*"));
if (allow_content) {
handler.grammar_trigger_words.push_back("<|python_tag|>");
}
} else {
//"<|start_header_id|>assistant<|end_header_id|>\n\n{\"name\": \"" + name + "\", " +
tool_rules.push_back(
builder.add_rule(
name + "-call",
"\"\\n{\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " +
builder.add_schema(name + "-args", parameters) +
" \"}\""));
if (allow_content) {
handler.grammar_trigger_words.push_back("\n{\"" + name + "\"");
}
}
}
builder.add_rule("root", join(tool_rules.begin(), tool_rules.end(), " | "));
});
handler.additional_stop_words.push_back("<|eom_id|>");
break;
}
case llama_tool_call_style::FunctionaryV3Llama3: {
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
std::vector<std::string> tool_rules;
for (size_t i = 0, n = tools.size(); i < n; i++) {
auto & tool = tools[i];
const auto & function = tool["function"];
std::string name = function["name"];
auto parameters = function["parameters"];
auto tool_rule = builder.add_rule(name + "-call", "\">>>" + name + "\\n\" " + builder.add_schema(name + "-args", parameters));
tool_rules.push_back(tool_rule);
if (allow_content) {
handler.grammar_trigger_words.push_back(">>>" + name + "\n");
}
}
auto tool_call = builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " space";
builder.add_rule("root", parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
});
// handler.parser = parse_functionary_3_2_tool_calls;
break;
}
case llama_tool_call_style::FunctionaryV3Llama31: {
// ./tests/chat/templates/meetkai-functionary-medium-v3.1.jinja
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
// TODO: handle tool {type: code_interpreter} as python
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
std::vector<std::string> tool_rules;
for (size_t i = 0, n = tools.size(); i < n; i++) {
auto & tool = tools[i];
const auto & function = tool["function"];
std::string name = function["name"];
auto parameters = function["parameters"];
if (name == "python") {
tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));
if (allow_content) {
handler.grammar_trigger_words.push_back("<|python_tag|>");
}
} else {
tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\""));
}
}
auto tool_call = builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " space";
builder.add_rule("root", parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
if (allow_content) { if (allow_content) {
handler.grammar_trigger_words.push_back(">>>" + name + "\n"); handler.grammar_trigger_words.push_back("<function=");
} }
} });
auto tool_call = builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " space"; // handler.parser = parse_functionary_3_2_tool_calls;
builder.add_rule("root", parallel_tool_calls ? "(" + tool_call + ")+" : tool_call); break;
}); }
// handler.parser = parse_functionary_3_2_tool_calls; case llama_tool_call_style::Hermes2Pro: {
} else if (needs_functionary_v3_llama_3_1_tool_call(chat_template)) { // NousResearchHermesPro_2
// ./tests/chat/templates/meetkai-functionary-medium-v3.1.jinja // (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
// TODO: handle tool {type: code_interpreter} as python std::vector<std::string> tool_rules;
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) { for (const auto & tool : tools) {
std::vector<std::string> tool_rules; const auto & function = tool["function"];
for (size_t i = 0, n = tools.size(); i < n; i++) { std::string name = function["name"];
auto & tool = tools[i]; auto parameters = function["parameters"];
const auto & function = tool["function"]; builder.resolve_refs(parameters);
std::string name = function["name"]; tool_rules.push_back(builder.add_schema(name + "-call", {
auto parameters = function["parameters"]; {"type", "object"},
if (name == "python") { {"properties", json {
tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*")); {"name", json {{"const", name}}},
if (allow_content) { {"arguments", parameters},
handler.grammar_trigger_words.push_back("<|python_tag|>"); }},
} {"required", json::array({"name", "arguments"})},
} else { }));
tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\""));
} }
}
auto tool_call = builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " space";
builder.add_rule("root", parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
if (allow_content) {
handler.grammar_trigger_words.push_back("<function=");
}
});
// handler.parser = parse_functionary_3_2_tool_calls;
} else if (needs_hermes_pro_tool_call(chat_template)) {
// NousResearchHermesPro_2
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
std::vector<std::string> tool_rules;
for (const auto & tool : tools) {
const auto & function = tool["function"];
std::string name = function["name"];
auto parameters = function["parameters"];
builder.resolve_refs(parameters);
tool_rules.push_back(builder.add_schema(name + "-call", {
{"type", "object"},
{"properties", json {
{"name", json {{"const", name}}},
{"arguments", parameters},
}},
{"required", json::array({"name", "arguments"})},
}));
}
auto tool_call = "\"<tool_call>\" " + builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " \"</tool_call>\" space"; auto tool_call = "\"<tool_call>\" " + builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " \"</tool_call>\" space";
builder.add_rule("root", parallel_tool_calls ? "(" + tool_call + ")+" : tool_call); builder.add_rule("root", parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
if (allow_content) { if (allow_content) {
handler.grammar_trigger_words.push_back("<tool_call>"); handler.grammar_trigger_words.push_back("<tool_call>");
}
});
} else if (needs_llama_3_1_tool_call(chat_template)) {
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
static std::vector<std::string> builtin_tools {"wolfram_alpha", "brave_search"};
std::vector<std::string> tool_rules;
for (const auto & tool : tools) {
const auto & function = tool["function"];
std::string name = function["name"];
auto parameters = function["parameters"];
builder.resolve_refs(parameters);
if (name == "ipython" || std::find(builtin_tools.begin(), builtin_tools.end(), name) != builtin_tools.end()) {
tool_rules.push_back(builder.add_rule("ipython-call", "\"<|python_tag|>\" .*"));
if (allow_content) {
handler.grammar_trigger_words.push_back("<|python_tag|>");
}
} else {
//"<|start_header_id|>assistant<|end_header_id|>\n\n{\"name\": \"" + name + "\", " +
tool_rules.push_back(
builder.add_rule(
name + "-call",
"\"\\n{\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " +
builder.add_schema(name + "-args", parameters) +
" \"}\""));
if (allow_content) {
handler.grammar_trigger_words.push_back("\n{\"" + name + "\"");
}
} }
} });
break;
builder.add_rule("root", join(tool_rules.begin(), tool_rules.end(), " | ")); }
}); default:
handler.additional_stop_words.push_back("<|eom_id|>"); throw std::runtime_error("Unsupported tool call style");
} else {
// TODO: generic thoughtful schema.
throw std::runtime_error("Unsupported tool call style!");
} }
return handler; return handler;
} }

View File

@ -5,22 +5,29 @@
// Change JSON_ASSERT from assert() to GGML_ASSERT: // Change JSON_ASSERT from assert() to GGML_ASSERT:
#define JSON_ASSERT GGML_ASSERT #define JSON_ASSERT GGML_ASSERT
#include "json.hpp" #include "json.hpp"
#include "chat-template.h"
struct llama_tool_call {
std::string name;
std::string arguments;
};
struct llama_tool_calls { struct llama_tool_calls {
std::string content; std::string content;
std::vector<llama_chat_msg_tool_call> tool_calls; std::vector<llama_tool_call> tool_calls;
}; };
struct llama_tool_call_handler { struct llama_tool_call_handler {
std::string grammar; std::string grammar;
std::vector<std::string> grammar_trigger_words; std::vector<std::string> grammar_trigger_words;
std::vector<std::string> additional_stop_words; std::vector<std::string> additional_stop_words;
nlohmann::ordered_json updated_tools;
}; };
llama_tool_calls parse_tool_calls(const nlohmann::ordered_json & tools, const std::string & chat_template, const std::string& input); llama_tool_calls parse_tool_calls(llama_tool_call_style style, const nlohmann::ordered_json & tools, const std::string& input);
llama_tool_call_handler llama_tool_call_handler_init( llama_tool_call_handler llama_tool_call_handler_init(
const std::string & chat_template, const llama_chat_template & tmpl,
bool allow_content, bool allow_content,
bool parallel_tool_calls, bool parallel_tool_calls,
const nlohmann::ordered_json & tools); const nlohmann::ordered_json & tools);

View File

@ -571,6 +571,12 @@ Given a ChatML-formatted json description in `messages`, it returns the predicte
```shell ```shell
llama-server --jinja -hfr lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF -hff Meta-Llama-3.1-8B-Instruct-Q5_K_M.gguf -fa llama-server --jinja -hfr lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF -hff Meta-Llama-3.1-8B-Instruct-Q5_K_M.gguf -fa
# https://huggingface.co/meetkai/functionary-medium-v3.2
llama-server --jinja -hfr bartowski/functionary-medium-v3.2-GGUF -hff functionary-medium-v3.2-IQ4_XS.gguf -fa
# https://huggingface.co/meetkai/functionary-medium-v3.1
llama-server --jinja -hfr meetkai/functionary-medium-v3.1-GGUF -hff functionary-medium-llama-3.1.Q4_0.gguf -fa
curl http://localhost:8080/v1/chat/completions \ curl http://localhost:8080/v1/chat/completions \
-d '{ -d '{
"model": "gpt-3.5-turbo", "model": "gpt-3.5-turbo",

View File

@ -662,7 +662,7 @@ struct server_context {
bool validate_model_chat_template(bool use_jinja) const { bool validate_model_chat_template(bool use_jinja) const {
llama_chat_message chat[] = {{"user", "test"}}; llama_chat_message chat[] = {{"user", "test"}};
const int res = llama_chat_apply_template(model, nullptr, chat, 1, true, nullptr, 0, use_jinja, nullptr, nullptr, nullptr); const int res = llama_chat_apply_template(model, nullptr, chat, 1, true, nullptr, 0);
return res > 0; return res > 0;
} }
@ -2860,9 +2860,11 @@ int main(int argc, char ** argv) {
return; return;
} }
auto chat_template = llama_chat_template::from_model(ctx_server.model, params.chat_template);
json data; json data;
try { try {
data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template, params.use_jinja); data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), chat_template, params.use_jinja);
} catch (const std::runtime_error & e) { } catch (const std::runtime_error & e) {
res_error(res, format_error_response(e.what(), ERROR_TYPE_NOT_SUPPORTED)); res_error(res, format_error_response(e.what(), ERROR_TYPE_NOT_SUPPORTED));
return; return;
@ -2880,7 +2882,7 @@ int main(int argc, char ** argv) {
ctx_server.receive_cmpl_results(task_ids, [&](const std::vector<server_task_result> & results) { ctx_server.receive_cmpl_results(task_ids, [&](const std::vector<server_task_result> & results) {
// multitask is never support in chat completion, there is only one result // multitask is never support in chat completion, there is only one result
try { try {
json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, /*.streaming =*/ false, verbose); json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, chat_template, /*.streaming =*/ false, verbose);
res_ok(res, result_oai); res_ok(res, result_oai);
} catch (const std::runtime_error & e) { } catch (const std::runtime_error & e) {
res_error(res, format_error_response(e.what(), ERROR_TYPE_SERVER)); res_error(res, format_error_response(e.what(), ERROR_TYPE_SERVER));

View File

@ -23,19 +23,19 @@ Feature: llama.cpp server
And a model test And a model test
And <n_predict> max tokens to predict And <n_predict> max tokens to predict
And a user prompt write a hello world in python And a user prompt write a hello world in python
And a tool choice <tool_choice> And a tool choice required
And tools <tools> And tools <tools>
And an OAI compatible chat completions request with no api error And an OAI compatible chat completions request with no api error
Then tool <tool_name> is called with arguments <tool_arguments> Then tool <tool_name> is called with arguments <tool_arguments>
Examples: Prompts Examples: Prompts
| template_name | n_predict | tool_name | tool_arguments | tool_choice | tools | | template_name | n_predict | tool_name | tool_arguments | tools |
| meetkai-functionary-medium-v3.1 | 128 | test | {} | required | [{"type":"function", "function": {"name": "test", "description": "", "parameters": {"type": "object", "properties": {}}}}] | | meetkai-functionary-medium-v3.1 | 128 | test | {} | [{"type":"function", "function": {"name": "test", "description": "", "parameters": {"type": "object", "properties": {}}}}] |
| meetkai-functionary-medium-v3.1 | 128 | ipython | {"code": "Yes, you can."} | required | [{"type":"function", "function": {"name": "ipython", "description": "", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": ""}}, "required": ["code"]}}}] | | meetkai-functionary-medium-v3.1 | 128 | ipython | {"code": "I'm sorry,"} | [{"type":"function", "function": {"name": "ipython", "description": "", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": ""}}, "required": ["code"]}}}] |
| meetkai-functionary-medium-v3.2 | 128 | test | {} | required | [{"type":"function", "function": {"name": "test", "description": "", "parameters": {"type": "object", "properties": {}}}}] | | meetkai-functionary-medium-v3.2 | 128 | test | {} | [{"type":"function", "function": {"name": "test", "description": "", "parameters": {"type": "object", "properties": {}}}}] |
| meetkai-functionary-medium-v3.2 | 128 | ipython | {"code": "Yes,"} | required | [{"type":"function", "function": {"name": "ipython", "description": "", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": ""}}, "required": ["code"]}}}] | | meetkai-functionary-medium-v3.2 | 128 | ipython | {"code": "Yes,"} | [{"type":"function", "function": {"name": "ipython", "description": "", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": ""}}, "required": ["code"]}}}] |
| meta-llama-Meta-Llama-3.1-8B-Instruct | 64 | test | {} | required | [{"type":"function", "function": {"name": "test", "description": "", "parameters": {"type": "object", "properties": {}}}}] | | meta-llama-Meta-Llama-3.1-8B-Instruct | 64 | test | {} | [{"type":"function", "function": {"name": "test", "description": "", "parameters": {"type": "object", "properties": {}}}}] |
| meta-llama-Meta-Llama-3.1-8B-Instruct | 16 | ipython | {"code": "it and "} | required | [{"type":"function", "function": {"name": "ipython", "description": "", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": ""}}, "required": ["code"]}}}] | | meta-llama-Meta-Llama-3.1-8B-Instruct | 16 | ipython | {"code": ". A"} | [{"type":"function", "function": {"name": "ipython", "description": "", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": ""}}, "required": ["code"]}}}] |
Scenario: OAI Compatibility w/ no tool Scenario: OAI Compatibility w/ no tool

View File

@ -14,6 +14,7 @@
// Change JSON_ASSERT from assert() to GGML_ASSERT: // Change JSON_ASSERT from assert() to GGML_ASSERT:
#define JSON_ASSERT GGML_ASSERT #define JSON_ASSERT GGML_ASSERT
#include "chat-template.h"
#include "json.hpp" #include "json.hpp"
#include "minja.hpp" #include "minja.hpp"
#include "tool-call.h" #include "tool-call.h"
@ -64,40 +65,30 @@ inline std::string format_chat(const struct llama_model * model, const std::stri
for (size_t i = 0; i < messages.size(); ++i) { for (size_t i = 0; i < messages.size(); ++i) {
const auto & curr_msg = messages[i]; const auto & curr_msg = messages[i];
llama_chat_msg msg; std::string role = json_value(curr_msg, "role", std::string(""));
msg.role = json_value(curr_msg, "role", std::string(""));
msg.tool = json_value(curr_msg, "tool", std::string("")); std::string content;
if (curr_msg.contains("content")) { if (curr_msg.contains("content")) {
if (curr_msg["content"].is_string()) { if (curr_msg["content"].is_string()) {
msg.content = curr_msg["content"].get<std::string>(); content = curr_msg["content"].get<std::string>();
} else if (curr_msg["content"].is_array()) { } else if (curr_msg["content"].is_array()) {
for (const auto & part : curr_msg["content"]) { for (const auto & part : curr_msg["content"]) {
if (part.contains("text")) { if (part.contains("text")) {
msg.content += "\n" + part["text"].get<std::string>(); content += "\n" + part["text"].get<std::string>();
} }
} }
} else if (!(curr_msg.is_null() && curr_msg.contains("tool_calls"))) { } else {
throw std::runtime_error("Invalid 'content' type (ref: https://github.com/ggerganov/llama.cpp/issues/8367): " + curr_msg.dump()); throw std::runtime_error("Invalid 'content' type (ref: https://github.com/ggerganov/llama.cpp/issues/8367)");
} }
} else { } else {
throw std::runtime_error("Missing 'content' (ref: https://github.com/ggerganov/llama.cpp/issues/8367)"); throw std::runtime_error("Missing 'content' (ref: https://github.com/ggerganov/llama.cpp/issues/8367)");
} }
if (curr_msg.contains("tool_calls") && curr_msg["tool_calls"].is_array()) {
for (const auto & tool_call : curr_msg["tool_calls"]) { chat.push_back({role, content});
if (json_value(tool_call, "type", std::string("")) == "function"
&& tool_call.contains("function") && tool_call["function"].is_object()) {
msg.tool_calls.push_back({
json_value(tool_call["function"], "name", std::string("")),
json_value(tool_call["function"], "arguments", std::string(""))
});
}
}
}
chat.emplace_back(std::move(msg));
} }
const auto formatted_chat = llama_chat_apply_template(model, tmpl, chat, true, use_jinja, tools.is_null() ? nullptr : tools.dump().c_str()); const auto formatted_chat = llama_chat_apply_template(model, tmpl, chat, true);
LOG_DBG("formatted_chat: '%s'\n", formatted_chat.c_str()); LOG_DBG("formatted_chat: '%s'\n", formatted_chat.c_str());
return formatted_chat; return formatted_chat;
@ -315,38 +306,12 @@ static bool server_sent_event(httplib::DataSink & sink, const char * event, cons
// OAI utils // OAI utils
// //
static std::string _llama_token_to_piece(const struct llama_model * model, llama_token token, bool special) {
std::string piece;
piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n'
const int n_chars = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special);
if (n_chars < 0) {
piece.resize(-n_chars);
int check = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special);
GGML_ASSERT(check == -n_chars);
}
else {
piece.resize(n_chars);
}
return piece;
}
std::string llama_model_meta_val_str(const struct llama_model * model, const char * key) {
int32_t tlen = llama_model_meta_val_str(model, key, nullptr, 0);
if (tlen > 0) {
std::vector<char> curr_tmpl_buf(tlen + 1, 0);
if (llama_model_meta_val_str(model, key, curr_tmpl_buf.data(), curr_tmpl_buf.size()) == tlen) {
return std::string(curr_tmpl_buf.data(), tlen);
}
}
return "";
}
static json oaicompat_completion_params_parse( static json oaicompat_completion_params_parse(
const struct llama_model * model, const struct llama_model * model,
const json & body, /* openai api json semantics */ const json & body, /* openai api json semantics */
const std::string & chat_template_src, const llama_chat_template & tmpl,
bool use_jinja) { bool use_jinja)
{
json llama_params; json llama_params;
llama_params["__oaicompat"] = true; llama_params["__oaicompat"] = true;
@ -355,16 +320,15 @@ static json oaicompat_completion_params_parse(
auto has_tools = tools.is_array() && !tools.empty(); auto has_tools = tools.is_array() && !tools.empty();
// Apply chat template to the list of messages // Apply chat template to the list of messages
auto chat_template = chat_template_src.empty() ? llama_model_meta_val_str(model, "tokenizer.chat_template") : chat_template_src; llama_params["chat_template"] = tmpl.chat_template();
llama_params["chat_template"] = chat_template;
if (use_jinja) { if (use_jinja) {
if (has_tools && chat_template.find("tools") == std::string::npos) { if (has_tools && !tmpl.supports_tools()) {
throw std::runtime_error("Chat template does not seem to support tools. Override the model template with --chat-template."); throw std::runtime_error("Chat template does not seem to support tools. Override the model template with --chat-template.");
} }
} else if (has_tools) { } else if (has_tools) {
throw std::runtime_error("Tools are only supported in --jinja mode"); throw std::runtime_error("Tools are only supported in --jinja mode");
} }
llama_params["prompt"] = format_chat(model, chat_template, body.at("messages"), tools, use_jinja);
// Handle "stop" field // Handle "stop" field
if (body.contains("stop") && body.at("stop").is_string()) { if (body.contains("stop") && body.at("stop").is_string()) {
@ -399,26 +363,40 @@ static json oaicompat_completion_params_parse(
} else if (!response_type.empty() && response_type != "text") { } else if (!response_type.empty() && response_type != "text") {
throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type); throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type);
} }
} else if (use_jinja && tool_choice != "none" && has_tools) { }
bool parallel_tool_calls = json_value(body, "parallel_tool_calls", false);
if (use_jinja) {
bool allow_content = tool_choice != "required"; bool allow_content = tool_choice != "required";
if (tool_choice != "none" && has_tools) {
bool parallel_tool_calls = json_value(body, "parallel_tool_calls", false);
llama_params["parse_tool_calls"] = true;
llama_params["parallel_tool_calls"] = parallel_tool_calls;
auto handler = llama_tool_call_handler_init(chat_template, allow_content, parallel_tool_calls, tools); auto handler = llama_tool_call_handler_init(tmpl, allow_content, parallel_tool_calls, tools);
for (const auto & stop : handler.additional_stop_words) { for (const auto & stop : handler.additional_stop_words) {
llama_params["stop"].push_back(stop); llama_params["stop"].push_back(stop);
} }
if (!handler.grammar_trigger_words.empty()) { if (!handler.grammar_trigger_words.empty()) {
auto triggers = json::array(); auto triggers = json::array();
for (const auto & word : handler.grammar_trigger_words) { for (const auto & word : handler.grammar_trigger_words) {
triggers.push_back(word); triggers.push_back(word);
}
llama_params["grammar_trigger_words"] = triggers;
}
if (handler.updated_tools.is_null()) {
tools = handler.updated_tools;
}
if (!handler.grammar.empty()) {
if (llama_params.contains("grammar")) {
throw std::runtime_error("Cannot use custom grammar constraints with tools.");
}
llama_params["grammar"] = handler.grammar;
} }
llama_params["grammar_trigger_words"] = triggers;
} }
llama_params["prompt"] = tmpl.apply(body.at("messages"), tools, /* add_generation_prompt= */ true);
llama_params["grammar"] = handler.grammar; } else {
llama_params["parse_tool_calls"] = true; llama_params["prompt"] = format_chat(model, tmpl.chat_template(), body.at("messages"), tools, /* use_jinja= */ false);
llama_params["parallel_tool_calls"] = parallel_tool_calls;
} }
// Handle "n" field // Handle "n" field
@ -458,7 +436,7 @@ static json oaicompat_completion_params_parse(
return llama_params; return llama_params;
} }
static json format_final_response_oaicompat(const json & request, const json & result, const std::string & completion_id, bool streaming = false, bool verbose = false) { static json format_final_response_oaicompat(const json & request, const json & result, const std::string & completion_id, const llama_chat_template & tmpl, bool streaming = false, bool verbose = false) {
bool stopped_word = result.count("stopped_word") != 0; bool stopped_word = result.count("stopped_word") != 0;
bool stopped_eos = json_value(result, "stopped_eos", false); bool stopped_eos = json_value(result, "stopped_eos", false);
int num_tokens_predicted = json_value(result, "tokens_predicted", 0); int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
@ -474,9 +452,8 @@ static json format_final_response_oaicompat(const json & request, const json & r
auto tools = json_value(request, "tools", json::array()); auto tools = json_value(request, "tools", json::array());
json tool_calls; json tool_calls;
json message_content; json message_content;
printf("# CONTENT: %s\n\n", content.c_str());
if (json_value(request, "parse_tool_calls", false) if (json_value(request, "parse_tool_calls", false)
&& !(parsed_tool_calls = parse_tool_calls(tools, chat_template, content)).tool_calls.empty()) { && !(parsed_tool_calls = parse_tool_calls(tmpl.tool_call_style(), tools, content)).tool_calls.empty()) {
finish_reason = "tool"; finish_reason = "tool";
if (!parsed_tool_calls.content.empty()) { if (!parsed_tool_calls.content.empty()) {
message_content = parsed_tool_calls.content; message_content = parsed_tool_calls.content;
@ -514,7 +491,6 @@ static json format_final_response_oaicompat(const json & request, const json & r
}}, }},
{"id", completion_id} {"id", completion_id}
}; };
printf("# RES: %s\n\n", res.dump(2).c_str());
// extra fields for debugging purposes // extra fields for debugging purposes
if (verbose) { if (verbose) {

View File

@ -377,19 +377,9 @@ extern "C" {
} llama_sampler_chain_params; } llama_sampler_chain_params;
// used in chat template // used in chat template
typedef struct llama_chat_message_tool_call {
const char * name;
const char * arguments;
} llama_chat_message_tool_call;
typedef struct llama_chat_message { typedef struct llama_chat_message {
const char * role; const char * role;
const char * content; const char * content;
const char * tool;
const llama_chat_message_tool_call * tool_calls;
uint32_t n_tool_calls;
} llama_chat_message; } llama_chat_message;
// lora adapter // lora adapter
@ -986,11 +976,7 @@ extern "C" {
size_t n_msg, size_t n_msg,
bool add_ass, bool add_ass,
char * buf, char * buf,
int32_t length, int32_t length);
bool use_jinja,
const char * tools,
const char * bos_token,
const char * eos_token);
// //
// Sampling API // Sampling API

View File

@ -2,8 +2,6 @@
#include "llama-vocab.h" #include "llama-vocab.h"
#include "llama-sampling.h" #include "llama-sampling.h"
#include "minja.hpp"
#include "unicode.h" #include "unicode.h"
#include "ggml.h" #include "ggml.h"
@ -21004,95 +21002,7 @@ int32_t llama_detokenize(
static int32_t llama_chat_apply_template_internal( static int32_t llama_chat_apply_template_internal(
const std::string & tmpl, const std::string & tmpl,
const std::vector<const llama_chat_message *> & chat, const std::vector<const llama_chat_message *> & chat,
std::string & dest, bool add_ass, std::string & dest, bool add_ass) {
bool use_jinja,
const std::string & tools,
const std::string & bos_token, const std::string & eos_token) {
if (use_jinja) {
auto system_not_supported = tmpl.find("System role not supported") != std::string::npos;
// Meta-Llama-3.1-8B-Instruct's template expects arguments to be an object.
// Most other templates (and OpenAI's API) expect the arguments object to be stringified.
auto tool_call_args_must_be_objects = tmpl.find("tool_call.arguments | items") != std::string::npos;
auto messages = json::array();
std::string pending_system;
auto flush_sys = [&]() {
if (!pending_system.empty()) {
messages.push_back({
{"role", "user"},
{"content", pending_system},
});
pending_system.clear();
}
};
for (const auto * msg : chat) {
std::string role(msg->role);
std::string content(msg->content);
if (system_not_supported) {
if (role == "system") {
if (!pending_system.empty()) pending_system += "\n";
pending_system += content;
continue;
} else {
if (role == "user") {
if (!pending_system.empty()) {
content = pending_system + (content.empty() ? "" : "\n" + content);
pending_system.clear();
}
} else {
flush_sys();
}
}
}
auto message = json({
{"role", role},
{"content", content},
});
if (msg->tool) message["tool"] = msg->tool;
if (msg->n_tool_calls) {
auto tool_calls = json::array();
for (uint32_t i = 0; i < msg->n_tool_calls; i++) {
auto args = msg->tool_calls[i].arguments;
tool_calls.push_back(json({
{"type", "function"},
{"function", {
{"name", msg->tool_calls[i].name},
{"arguments", tool_call_args_must_be_objects ? json::parse(args) : args},
}}
}));
}
messages["tool_calls"] = tool_calls;
}
messages.push_back(message);
}
flush_sys();
auto context = minja::Context::make(json({
{"messages", messages},
{"add_generation_prompt", add_ass},
{"bos_token", bos_token},
{"eos_token", eos_token},
}));
if (!tools.empty()) {
auto tools_val = minja::Value(json::parse(tools));
context->set("tools", tools_val);
}
auto tmpl_root = minja::Parser::parse(tmpl, {
/* .trim_blocks = */ true,
/* .lstrip_blocks = */ true,
/* .keep_trailing_newline = */ false,
});
try {
dest = tmpl_root->render(context);
return dest.size();
} catch (const std::runtime_error & err) {
LLAMA_LOG_ERROR("Error in jinja template: %s\n", err.what());
return -1;
}
}
// Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527 // Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527
std::stringstream ss; std::stringstream ss;
@ -21360,11 +21270,7 @@ int32_t llama_chat_apply_template(
size_t n_msg, size_t n_msg,
bool add_ass, bool add_ass,
char * buf, char * buf,
int32_t length, int32_t length) {
bool use_jinja,
const char * tools,
const char * bos_token,
const char * eos_token) {
std::string curr_tmpl(tmpl == nullptr ? "" : tmpl); std::string curr_tmpl(tmpl == nullptr ? "" : tmpl);
if (tmpl == nullptr) { if (tmpl == nullptr) {
GGML_ASSERT(model != nullptr); GGML_ASSERT(model != nullptr);
@ -21379,16 +21285,6 @@ int32_t llama_chat_apply_template(
curr_tmpl = std::string(model_template.data(), model_template.size()); curr_tmpl = std::string(model_template.data(), model_template.size());
} }
} }
std::string curr_bos_token(bos_token ? bos_token : "");
std::string curr_eos_token(eos_token ? eos_token : "");
if (bos_token == nullptr) {
GGML_ASSERT(model != nullptr);
curr_bos_token = llama_token_to_piece(model, llama_token_bos(model), true);
}
if (eos_token == nullptr) {
GGML_ASSERT(model != nullptr);
curr_eos_token = llama_token_to_piece(model, llama_token_eos(model), true);
}
// format the chat to string // format the chat to string
std::vector<const llama_chat_message *> chat_vec; std::vector<const llama_chat_message *> chat_vec;
@ -21398,7 +21294,7 @@ int32_t llama_chat_apply_template(
} }
std::string formatted_chat; std::string formatted_chat;
int32_t res = llama_chat_apply_template_internal(curr_tmpl, chat_vec, formatted_chat, add_ass, use_jinja, tools == nullptr ? "" : tools, curr_bos_token, curr_eos_token); int32_t res = llama_chat_apply_template_internal(curr_tmpl, chat_vec, formatted_chat, add_ass);
if (res < 0) { if (res < 0) {
return res; return res;
} }

View File

@ -20,9 +20,9 @@ static void assert_equals(const std::string & expected, const std::string & actu
cmake -B build -DLLAMA_CURL=1 -DCMAKE_BUILD_TYPE=Release && cmake --build build -t test-tool-call -j && ./build/bin/test-tool-call cmake -B build -DLLAMA_CURL=1 -DCMAKE_BUILD_TYPE=Release && cmake --build build -t test-tool-call -j && ./build/bin/test-tool-call
*/ */
static void test_parse_tool_call(const json & tools, const std::string & chat_template, const std::string & input, const std::string & expected_content, const json & expected_tool_calls) { static void test_parse_tool_call(llama_tool_call_style style, const json & tools, const std::string & input, const std::string & expected_content, const json & expected_tool_calls) {
std::cout << "# Testing: " << input << std::endl << std::flush; std::cout << "# Testing: " << input << std::endl << std::flush;
auto result = parse_tool_calls(tools, chat_template, input); auto result = parse_tool_calls(style, tools, input);
assert_equals(expected_content, result.content); assert_equals(expected_content, result.content);
auto tool_calls = json::array(); auto tool_calls = json::array();
for (const auto & tc : result.tool_calls) { for (const auto & tc : result.tool_calls) {
@ -59,8 +59,7 @@ int main() {
{"tools", tools} {"tools", tools}
}; };
std::string hermes_2_pro_like_tmpl = "Hermes 2 Pro template should have <tool_call> inside it"; test_parse_tool_call(llama_tool_call_style::Hermes2Pro, tools,
test_parse_tool_call(tools, hermes_2_pro_like_tmpl,
"<tool_call>{\"name\": \"foo\", \"arguments\": {\"bar\": 1}}</tool_call>", "<tool_call>{\"name\": \"foo\", \"arguments\": {\"bar\": 1}}</tool_call>",
"", "",
json {{ json {{
@ -72,8 +71,7 @@ int main() {
}} }}
}}); }});
std::string functionary_v3_like_tmpl = "Functionary 3.2 template should have <|start_header_id|> and then some >>>all inside it"; test_parse_tool_call(llama_tool_call_style::FunctionaryV3Llama3, tools,
test_parse_tool_call(tools, functionary_v3_like_tmpl,
">>>ipython\n{\"code\": \"print('Hello, world!')\"}", ">>>ipython\n{\"code\": \"print('Hello, world!')\"}",
"", "",
json {{ json {{
@ -84,7 +82,7 @@ int main() {
}).dump()} }).dump()}
}} }}
}}); }});
test_parse_tool_call(tools, functionary_v3_like_tmpl, test_parse_tool_call(llama_tool_call_style::FunctionaryV3Llama3, tools,
">>>test\n{ } \n ", ">>>test\n{ } \n ",
"", "",
json {{ json {{
@ -94,8 +92,7 @@ int main() {
}} }}
}}); }});
std::string functionary_v3_llama_3_1_like_tmpl = "Functionary 3.2 template for llama 3.1 should have <|start_header_id|> and then some <function=foo>{...}</function> inside it"; test_parse_tool_call(llama_tool_call_style::FunctionaryV3Llama31, tools,
test_parse_tool_call(tools, functionary_v3_llama_3_1_like_tmpl,
"Hell<function=foo>{\"arg1\": 1}</function>o, world<function=bar>{\"arg2\": 2}</function>!", "Hell<function=foo>{\"arg1\": 1}</function>o, world<function=bar>{\"arg2\": 2}</function>!",
"Hello, world!", "Hello, world!",
json { json {
@ -116,7 +113,7 @@ int main() {
}} }}
}, },
}); });
test_parse_tool_call(tools, functionary_v3_llama_3_1_like_tmpl, test_parse_tool_call(llama_tool_call_style::FunctionaryV3Llama31, tools,
"<function=test>{ } </function> ", "<function=test>{ } </function> ",
" ", " ",
json {{ json {{
@ -126,8 +123,7 @@ int main() {
}} }}
}}); }});
std::string llama_3_1_like_tmpl = "Llama 3.1 template should have <|start_header_id|> and <|python_tag|> inside it"; test_parse_tool_call(llama_tool_call_style::Llama31, tools,
test_parse_tool_call(tools, llama_3_1_like_tmpl,
"<|python_tag|>this could be anything", "<|python_tag|>this could be anything",
"", "",
json {{ json {{
@ -138,7 +134,7 @@ int main() {
}).dump()} }).dump()}
}} }}
}}); }});
test_parse_tool_call(tools, llama_3_1_like_tmpl, test_parse_tool_call(llama_tool_call_style::Llama31, tools,
"I'm thinking<|python_tag|>", "I'm thinking<|python_tag|>",
"I'm thinking", "I'm thinking",
json {{ json {{
@ -147,7 +143,7 @@ int main() {
{"arguments", (json {{"code", ""}}).dump()} {"arguments", (json {{"code", ""}}).dump()}
}} }}
}}); }});
test_parse_tool_call(tools, llama_3_1_like_tmpl, test_parse_tool_call(llama_tool_call_style::Llama31, tools,
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}", "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}",
"", "",
json {{ json {{
@ -158,7 +154,7 @@ int main() {
}).dump()} }).dump()}
}} }}
}}); }});
test_parse_tool_call(tools, llama_3_1_like_tmpl, test_parse_tool_call(llama_tool_call_style::Llama31, tools,
"{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", "{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}",
"{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", json::array()); "{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", json::array());