tool-call: factor chat template away from legacy API

This commit is contained in:
ochafik 2024-09-26 17:19:29 +01:00
parent d7ec84f78c
commit cf7bece6a7
15 changed files with 431 additions and 399 deletions

View File

@ -934,6 +934,7 @@ OBJ_LLAMA = \
OBJ_COMMON = \
common/common.o \
common/chat-template.o \
common/arg.o \
common/log.o \
common/console.o \
@ -1170,6 +1171,8 @@ $(LIB_LLAMA_S): \
common/common.o: \
common/common.cpp \
common/common.h \
common/chat-template.cpp \
common/chat-template.h \
common/console.h \
common/sampling.h \
common/json.hpp \
@ -1465,6 +1468,7 @@ llama-server: \
examples/server/prompt-formats.js.hpp \
examples/server/json-schema-to-grammar.mjs.hpp \
examples/server/loading.html.hpp \
common/chat-template.h \
common/json.hpp \
common/stb_image.h \
$(OBJ_ALL)

View File

@ -54,6 +54,8 @@ add_library(${TARGET} STATIC
arg.cpp
arg.h
base64.hpp
chat-template.cpp
chat-template.h
common.cpp
common.h
console.cpp

118
common/chat-template.cpp Normal file
View File

@ -0,0 +1,118 @@
#include "chat-template.h"
#include "minja.hpp"
#include "llama.h"
using json = nlohmann::ordered_json;
static std::string _llama_token_to_piece(const struct llama_model * model, llama_token token, bool special) {
std::string piece;
piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n'
const int n_chars = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special);
if (n_chars < 0) {
piece.resize(-n_chars);
int check = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special);
GGML_ASSERT(check == -n_chars);
}
else {
piece.resize(n_chars);
}
return piece;
}
static std::string llama_model_meta_val_str(const struct llama_model * model, const char * key) {
int32_t tlen = llama_model_meta_val_str(model, key, nullptr, 0);
if (tlen > 0) {
std::vector<char> curr_tmpl_buf(tlen + 1, 0);
if (llama_model_meta_val_str(model, key, curr_tmpl_buf.data(), curr_tmpl_buf.size()) == tlen) {
return std::string(curr_tmpl_buf.data(), tlen);
}
}
return "";
}
llama_chat_template llama_chat_template::from_model(
const struct llama_model * model,
const std::string & chat_template_override)
{
// TODO: handle "chatml"?
auto chat_template = chat_template_override.empty()
? llama_model_meta_val_str(model, "tokenizer.chat_template")
: chat_template_override;
auto bos_token = _llama_token_to_piece(model, llama_token_bos(model), true);
auto eos_token = _llama_token_to_piece(model, llama_token_eos(model), true);
return llama_chat_template(chat_template, bos_token, eos_token);
}
std::string llama_chat_template::apply(
const json & messages,
const json & tools,
bool add_generation_prompt) const
{
auto actual_messages = messages;
// First, "fix" messages so they have a chance to be rendered correctly by the template
if (_requires_object_arguments || !_supports_system_role) {
std::string pending_system;
auto flush_sys = [&]() {
if (!pending_system.empty()) {
actual_messages.push_back({
{"role", "user"},
{"content", pending_system},
});
pending_system.clear();
}
};
for (auto & message : actual_messages) {
if (!message.contains("role") || !message.contains("content")) {
throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump());
}
std::string role = message.at("role");
std::string content = message.at("content");
if (!_supports_system_role) {
if (role == "system") {
if (!pending_system.empty()) pending_system += "\n";
pending_system += content;
continue;
} else {
if (role == "user") {
if (!pending_system.empty()) {
message["content"] = pending_system + (content.empty() ? "" : "\n" + content);
pending_system.clear();
}
} else {
flush_sys();
}
}
}
if (_requires_object_arguments && message.contains("tool_calls")) {
for (auto & tool_call : message.at("tool_calls")) {
std::string arguments = tool_call.at("arguments");
tool_call["arguments"] = json::parse(arguments);
}
}
}
flush_sys();
}
auto context = minja::Context::make(json({
{"messages", actual_messages},
{"add_generation_prompt", add_generation_prompt},
{"bos_token", _bos_token},
{"eos_token", _eos_token},
}));
if (!tools.is_null() && !tools.empty()) {
auto tools_val = minja::Value(tools);
context->set("tools", tools_val);
}
auto tmpl_root = minja::Parser::parse(_chat_template, {
/* .trim_blocks = */ true,
/* .lstrip_blocks = */ true,
/* .keep_trailing_newline = */ false,
});
return tmpl_root->render(context);
}

64
common/chat-template.h Normal file
View File

@ -0,0 +1,64 @@
#pragma once
#include <json.hpp>
#include <string>
#include <vector>
using json = nlohmann::ordered_json;
enum llama_tool_call_style {
Unknown,
Llama31,
FunctionaryV3Llama3,
FunctionaryV3Llama31,
Hermes2Pro,
};
class llama_chat_template {
public:
private:
llama_tool_call_style _tool_call_style = Unknown;
bool _supports_tools = true;
// Meta-Llama-3.1-8B-Instruct's template expects arguments to be an object.
// Most other templates (and OpenAI's API) expect the arguments object to be stringified.
bool _requires_object_arguments = false;
bool _supports_system_role = true;
std::string _chat_template;
std::string _bos_token;
std::string _eos_token;
public:
llama_chat_template(const std::string & chat_template, const std::string & bos_token, const std::string & eos_token)
: _chat_template(chat_template), _bos_token(bos_token), _eos_token(eos_token) {
_supports_tools = chat_template.find("tools") != std::string::npos;
_requires_object_arguments = chat_template.find("tool_call.arguments | items") != std::string::npos;
_supports_system_role = chat_template.find("System role not supported") == std::string::npos;
if (chat_template.find("<tool_call>") != std::string::npos) {
_tool_call_style = Hermes2Pro;
} else if (chat_template.find(">>>all") != std::string::npos) {
_tool_call_style = FunctionaryV3Llama3;
} else if (chat_template.find("<|start_header_id|>") != std::string::npos) {
if (chat_template.find("<function=") != std::string::npos) {
_tool_call_style = FunctionaryV3Llama31;
} else if (chat_template.find("<|python_tag|>") != std::string::npos) {
_tool_call_style = Llama31;
}
}
}
static llama_chat_template from_model(
const struct llama_model * model,
const std::string & chat_template_override);
llama_tool_call_style tool_call_style() const { return _tool_call_style; }
const std::string & chat_template() const { return _chat_template; }
bool supports_tools() const { return _supports_tools; }
std::string apply(
const nlohmann::ordered_json & messages,
const nlohmann::ordered_json & tools,
bool add_generation_prompt) const;
};

View File

@ -9,6 +9,7 @@
#include "json.hpp"
#include "json-schema-to-grammar.h"
#include "llama.h"
#include "chat-template.h"
#include <algorithm>
#include <cinttypes>
@ -1511,6 +1512,20 @@ std::string llama_detokenize(llama_context * ctx, const std::vector<llama_token>
//
bool llama_chat_verify_template(const std::string & tmpl, bool use_jinja) {
if (use_jinja) {
try {
auto chat_template = llama_chat_template(tmpl, "<s>", "</s>");
chat_template.apply({{
{"role", "user"},
{"content", "test"},
}}, json(), true);
return true;
} catch (const std::exception & e) {
LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what());
return false;
}
}
llama_chat_message chat[] = {{"user", "test"}};
int res = llama_chat_apply_template(
nullptr,
@ -1519,22 +1534,14 @@ bool llama_chat_verify_template(const std::string & tmpl, bool use_jinja) {
1,
/* add_ass= */ true,
/* buffer= */ nullptr,
/* length= */ 0,
use_jinja,
/* tools= */ nullptr,
"<s>",
"</s>");
/* length= */ 0);
return res >= 0;
}
std::string llama_chat_apply_template(const struct llama_model * model,
const std::string & tmpl,
const std::vector<llama_chat_msg> & msgs,
bool add_ass,
bool use_jinja,
const char * tools,
const char * bos_token,
const char * eos_token) {
bool add_ass) {
int alloc_size = 0;
bool fallback = false; // indicate if we must fallback to default chatml
std::vector<llama_chat_message> chat;
@ -1547,7 +1554,7 @@ std::string llama_chat_apply_template(const struct llama_model * model,
std::vector<char> buf(alloc_size);
// run the first time to get the total output length
int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size(), use_jinja, tools, bos_token, eos_token);
int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
// error: chat template is not supported
if (res < 0) {
@ -1557,7 +1564,7 @@ std::string llama_chat_apply_template(const struct llama_model * model,
throw std::runtime_error("this custom template is not supported");
} else {
// If the built-in template is not supported, we default to chatml
res = llama_chat_apply_template(nullptr, "chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size(), use_jinja, tools, bos_token, eos_token);
res = llama_chat_apply_template(nullptr, "chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size());
fallback = true;
}
}
@ -1568,7 +1575,7 @@ std::string llama_chat_apply_template(const struct llama_model * model,
res = llama_chat_apply_template(
fallback ? nullptr : model,
fallback ? "chatml" : ptr_tmpl,
chat.data(), chat.size(), add_ass, buf.data(), buf.size(), use_jinja, tools, bos_token, eos_token);
chat.data(), chat.size(), add_ass, buf.data(), buf.size());
}
std::string formatted_chat(buf.data(), res);
@ -1579,13 +1586,9 @@ std::string llama_chat_format_single(const struct llama_model * model,
const std::string & tmpl,
const std::vector<llama_chat_msg> & past_msg,
const llama_chat_msg & new_msg,
bool add_ass,
bool use_jinja,
const char * tools,
const char * bos_token,
const char * eos_token) {
bool add_ass) {
std::ostringstream ss;
auto fmt_past_msg = past_msg.empty() ? "" : llama_chat_apply_template(model, tmpl, past_msg, false, use_jinja, tools, bos_token, eos_token);
auto fmt_past_msg = past_msg.empty() ? "" : llama_chat_apply_template(model, tmpl, past_msg, false);
std::vector<llama_chat_msg> chat_new(past_msg);
// if the past_msg ends with a newline, we must preserve it in the formatted version
if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
@ -1593,7 +1596,7 @@ std::string llama_chat_format_single(const struct llama_model * model,
};
// format chat with new_msg
chat_new.push_back(new_msg);
auto fmt_new_msg = llama_chat_apply_template(model, tmpl, chat_new, add_ass, use_jinja, tools, bos_token, eos_token);
auto fmt_new_msg = llama_chat_apply_template(model, tmpl, chat_new, add_ass);
// get the diff part
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
return ss.str();

View File

@ -471,21 +471,14 @@ std::string llama_detokenize(
// Chat template utils
//
struct llama_chat_msg_tool_call {
std::string name;
std::string arguments;
};
// same as llama_chat_message, but uses std::string and std::vector
struct llama_chat_msg {
std::string role;
std::string content;
std::string tool;
std::vector<struct llama_chat_msg_tool_call> tool_calls;
};
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
bool llama_chat_verify_template(const std::string & tmpl, bool use_jinja = false);
// Check if the template is supported or not. Returns true if it's valid
bool llama_chat_verify_template(const std::string & tmpl, bool use_jinja);
// CPP wrapper for llama_chat_apply_template
// If the built-in template is not supported, we default to chatml
@ -493,22 +486,14 @@ bool llama_chat_verify_template(const std::string & tmpl, bool use_jinja = false
std::string llama_chat_apply_template(const struct llama_model * model,
const std::string & tmpl,
const std::vector<llama_chat_msg> & chat,
bool add_ass,
bool use_jinja = false,
const char * tools = nullptr,
const char * bos_token = nullptr,
const char * eos_token = nullptr);
bool add_ass);
// Format single message, while taking into account the position of that message in chat history
std::string llama_chat_format_single(const struct llama_model * model,
const std::string & tmpl,
const std::vector<llama_chat_msg> & past_msg,
const llama_chat_msg & new_msg,
bool add_ass,
bool use_jinja = false,
const char * tools = nullptr,
const char * bos_token = nullptr,
const char * eos_token = nullptr);
bool add_ass);
// Returns an example of formatted chat
std::string llama_chat_format_example(const struct llama_model * model,

View File

@ -12,27 +12,6 @@
using json = nlohmann::ordered_json;
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3.llama3.txt
static bool needs_functionary_v3_tool_call(const std::string & chat_template) {
return chat_template.find("<|start_header_id|>") != std::string::npos
&& chat_template.find(">>>all") != std::string::npos;
}
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
static bool needs_functionary_v3_llama_3_1_tool_call(const std::string & chat_template) {
return chat_template.find("<|start_header_id|>") != std::string::npos
&& chat_template.find("<function=") != std::string::npos;
}
static bool needs_llama_3_1_tool_call(const std::string & chat_template) {
return chat_template.find("<|start_header_id|>") != std::string::npos
&& chat_template.find("<|python_tag|>") != std::string::npos;
}
static bool needs_hermes_pro_tool_call(const std::string & chat_template) {
return chat_template.find("<tool_call>") != std::string::npos;
}
static bool parse_json(std::string::const_iterator & it, const std::string::const_iterator & end, json & out) {
// // https://json.nlohmann.me/features/parsing/sax_interface/
struct json_error_locator : public nlohmann::json_sax<json> {
@ -209,103 +188,31 @@ static llama_tool_calls parse_functionary_v3_tool_calls(const std::string& input
return parse_functionary_tool_calls(input, function_regex, close_regex);
}
llama_tool_calls parse_tool_calls(const json & tools, const std::string & chat_template, const std::string& input) {
if (needs_hermes_pro_tool_call(chat_template)) {
return parse_hermes_tool_calls(input);
} else if (needs_functionary_v3_tool_call(chat_template)) {
return parse_functionary_v3_tool_calls(input);
} else if (needs_functionary_v3_llama_3_1_tool_call(chat_template)) {
return parse_functionary_v3_llama_3_1_tool_calls(input);
} else if (needs_llama_3_1_tool_call(chat_template)) {
llama_tool_calls parse_tool_calls(llama_tool_call_style style, const json & tools, const std::string& input) {
switch (style) {
case llama_tool_call_style::Llama31:
return parse_llama_3_1_tool_calls(tools, input);
} else {
throw std::runtime_error("Unsupported chat template for tool calls");
case llama_tool_call_style::FunctionaryV3Llama3:
return parse_functionary_v3_tool_calls(input);
case llama_tool_call_style::FunctionaryV3Llama31:
return parse_functionary_v3_llama_3_1_tool_calls(input);
case llama_tool_call_style::Hermes2Pro:
return parse_hermes_tool_calls(input);
default:
throw std::runtime_error("Unsupported tool call style");
}
}
llama_tool_call_handler llama_tool_call_handler_init(
const std::string & chat_template,
const llama_chat_template & tmpl,
bool allow_content,
bool parallel_tool_calls,
const nlohmann::ordered_json & tools)
{
llama_tool_call_handler handler;
if (needs_functionary_v3_tool_call(chat_template)) {
// MeetKaiFunctionary_3_2
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
std::vector<std::string> tool_rules;
for (size_t i = 0, n = tools.size(); i < n; i++) {
auto & tool = tools[i];
const auto & function = tool["function"];
std::string name = function["name"];
auto parameters = function["parameters"];
auto tool_rule = builder.add_rule(name + "-call", "\">>>" + name + "\\n\" " + builder.add_schema(name + "-args", parameters));
tool_rules.push_back(tool_rule);
if (allow_content) {
handler.grammar_trigger_words.push_back(">>>" + name + "\n");
}
}
auto tool_call = builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " space";
builder.add_rule("root", parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
});
// handler.parser = parse_functionary_3_2_tool_calls;
} else if (needs_functionary_v3_llama_3_1_tool_call(chat_template)) {
// ./tests/chat/templates/meetkai-functionary-medium-v3.1.jinja
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
// TODO: handle tool {type: code_interpreter} as python
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
std::vector<std::string> tool_rules;
for (size_t i = 0, n = tools.size(); i < n; i++) {
auto & tool = tools[i];
const auto & function = tool["function"];
std::string name = function["name"];
auto parameters = function["parameters"];
if (name == "python") {
tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));
if (allow_content) {
handler.grammar_trigger_words.push_back("<|python_tag|>");
}
} else {
tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\""));
}
}
auto tool_call = builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " space";
builder.add_rule("root", parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
if (allow_content) {
handler.grammar_trigger_words.push_back("<function=");
}
});
// handler.parser = parse_functionary_3_2_tool_calls;
} else if (needs_hermes_pro_tool_call(chat_template)) {
// NousResearchHermesPro_2
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
std::vector<std::string> tool_rules;
for (const auto & tool : tools) {
const auto & function = tool["function"];
std::string name = function["name"];
auto parameters = function["parameters"];
builder.resolve_refs(parameters);
tool_rules.push_back(builder.add_schema(name + "-call", {
{"type", "object"},
{"properties", json {
{"name", json {{"const", name}}},
{"arguments", parameters},
}},
{"required", json::array({"name", "arguments"})},
}));
}
auto tool_call = "\"<tool_call>\" " + builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " \"</tool_call>\" space";
builder.add_rule("root", parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
if (allow_content) {
handler.grammar_trigger_words.push_back("<tool_call>");
}
});
} else if (needs_llama_3_1_tool_call(chat_template)) {
switch (tmpl.tool_call_style()) {
case llama_tool_call_style::Llama31: {
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
static std::vector<std::string> builtin_tools {"wolfram_alpha", "brave_search"};
std::vector<std::string> tool_rules;
@ -337,9 +244,89 @@ llama_tool_call_handler llama_tool_call_handler_init(
builder.add_rule("root", join(tool_rules.begin(), tool_rules.end(), " | "));
});
handler.additional_stop_words.push_back("<|eom_id|>");
break;
}
case llama_tool_call_style::FunctionaryV3Llama3: {
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
std::vector<std::string> tool_rules;
for (size_t i = 0, n = tools.size(); i < n; i++) {
auto & tool = tools[i];
const auto & function = tool["function"];
std::string name = function["name"];
auto parameters = function["parameters"];
auto tool_rule = builder.add_rule(name + "-call", "\">>>" + name + "\\n\" " + builder.add_schema(name + "-args", parameters));
tool_rules.push_back(tool_rule);
if (allow_content) {
handler.grammar_trigger_words.push_back(">>>" + name + "\n");
}
}
auto tool_call = builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " space";
builder.add_rule("root", parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
});
// handler.parser = parse_functionary_3_2_tool_calls;
break;
}
case llama_tool_call_style::FunctionaryV3Llama31: {
// ./tests/chat/templates/meetkai-functionary-medium-v3.1.jinja
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
// TODO: handle tool {type: code_interpreter} as python
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
std::vector<std::string> tool_rules;
for (size_t i = 0, n = tools.size(); i < n; i++) {
auto & tool = tools[i];
const auto & function = tool["function"];
std::string name = function["name"];
auto parameters = function["parameters"];
if (name == "python") {
tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));
if (allow_content) {
handler.grammar_trigger_words.push_back("<|python_tag|>");
}
} else {
// TODO: generic thoughtful schema.
throw std::runtime_error("Unsupported tool call style!");
tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\""));
}
}
auto tool_call = builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " space";
builder.add_rule("root", parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
if (allow_content) {
handler.grammar_trigger_words.push_back("<function=");
}
});
// handler.parser = parse_functionary_3_2_tool_calls;
break;
}
case llama_tool_call_style::Hermes2Pro: {
// NousResearchHermesPro_2
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
std::vector<std::string> tool_rules;
for (const auto & tool : tools) {
const auto & function = tool["function"];
std::string name = function["name"];
auto parameters = function["parameters"];
builder.resolve_refs(parameters);
tool_rules.push_back(builder.add_schema(name + "-call", {
{"type", "object"},
{"properties", json {
{"name", json {{"const", name}}},
{"arguments", parameters},
}},
{"required", json::array({"name", "arguments"})},
}));
}
auto tool_call = "\"<tool_call>\" " + builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " \"</tool_call>\" space";
builder.add_rule("root", parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
if (allow_content) {
handler.grammar_trigger_words.push_back("<tool_call>");
}
});
break;
}
default:
throw std::runtime_error("Unsupported tool call style");
}
return handler;
}

View File

@ -5,22 +5,29 @@
// Change JSON_ASSERT from assert() to GGML_ASSERT:
#define JSON_ASSERT GGML_ASSERT
#include "json.hpp"
#include "chat-template.h"
struct llama_tool_call {
std::string name;
std::string arguments;
};
struct llama_tool_calls {
std::string content;
std::vector<llama_chat_msg_tool_call> tool_calls;
std::vector<llama_tool_call> tool_calls;
};
struct llama_tool_call_handler {
std::string grammar;
std::vector<std::string> grammar_trigger_words;
std::vector<std::string> additional_stop_words;
nlohmann::ordered_json updated_tools;
};
llama_tool_calls parse_tool_calls(const nlohmann::ordered_json & tools, const std::string & chat_template, const std::string& input);
llama_tool_calls parse_tool_calls(llama_tool_call_style style, const nlohmann::ordered_json & tools, const std::string& input);
llama_tool_call_handler llama_tool_call_handler_init(
const std::string & chat_template,
const llama_chat_template & tmpl,
bool allow_content,
bool parallel_tool_calls,
const nlohmann::ordered_json & tools);

View File

@ -571,6 +571,12 @@ Given a ChatML-formatted json description in `messages`, it returns the predicte
```shell
llama-server --jinja -hfr lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF -hff Meta-Llama-3.1-8B-Instruct-Q5_K_M.gguf -fa
# https://huggingface.co/meetkai/functionary-medium-v3.2
llama-server --jinja -hfr bartowski/functionary-medium-v3.2-GGUF -hff functionary-medium-v3.2-IQ4_XS.gguf -fa
# https://huggingface.co/meetkai/functionary-medium-v3.1
llama-server --jinja -hfr meetkai/functionary-medium-v3.1-GGUF -hff functionary-medium-llama-3.1.Q4_0.gguf -fa
curl http://localhost:8080/v1/chat/completions \
-d '{
"model": "gpt-3.5-turbo",

View File

@ -662,7 +662,7 @@ struct server_context {
bool validate_model_chat_template(bool use_jinja) const {
llama_chat_message chat[] = {{"user", "test"}};
const int res = llama_chat_apply_template(model, nullptr, chat, 1, true, nullptr, 0, use_jinja, nullptr, nullptr, nullptr);
const int res = llama_chat_apply_template(model, nullptr, chat, 1, true, nullptr, 0);
return res > 0;
}
@ -2860,9 +2860,11 @@ int main(int argc, char ** argv) {
return;
}
auto chat_template = llama_chat_template::from_model(ctx_server.model, params.chat_template);
json data;
try {
data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template, params.use_jinja);
data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), chat_template, params.use_jinja);
} catch (const std::runtime_error & e) {
res_error(res, format_error_response(e.what(), ERROR_TYPE_NOT_SUPPORTED));
return;
@ -2880,7 +2882,7 @@ int main(int argc, char ** argv) {
ctx_server.receive_cmpl_results(task_ids, [&](const std::vector<server_task_result> & results) {
// multitask is never support in chat completion, there is only one result
try {
json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, /*.streaming =*/ false, verbose);
json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, chat_template, /*.streaming =*/ false, verbose);
res_ok(res, result_oai);
} catch (const std::runtime_error & e) {
res_error(res, format_error_response(e.what(), ERROR_TYPE_SERVER));

View File

@ -23,19 +23,19 @@ Feature: llama.cpp server
And a model test
And <n_predict> max tokens to predict
And a user prompt write a hello world in python
And a tool choice <tool_choice>
And a tool choice required
And tools <tools>
And an OAI compatible chat completions request with no api error
Then tool <tool_name> is called with arguments <tool_arguments>
Examples: Prompts
| template_name | n_predict | tool_name | tool_arguments | tool_choice | tools |
| meetkai-functionary-medium-v3.1 | 128 | test | {} | required | [{"type":"function", "function": {"name": "test", "description": "", "parameters": {"type": "object", "properties": {}}}}] |
| meetkai-functionary-medium-v3.1 | 128 | ipython | {"code": "Yes, you can."} | required | [{"type":"function", "function": {"name": "ipython", "description": "", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": ""}}, "required": ["code"]}}}] |
| meetkai-functionary-medium-v3.2 | 128 | test | {} | required | [{"type":"function", "function": {"name": "test", "description": "", "parameters": {"type": "object", "properties": {}}}}] |
| meetkai-functionary-medium-v3.2 | 128 | ipython | {"code": "Yes,"} | required | [{"type":"function", "function": {"name": "ipython", "description": "", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": ""}}, "required": ["code"]}}}] |
| meta-llama-Meta-Llama-3.1-8B-Instruct | 64 | test | {} | required | [{"type":"function", "function": {"name": "test", "description": "", "parameters": {"type": "object", "properties": {}}}}] |
| meta-llama-Meta-Llama-3.1-8B-Instruct | 16 | ipython | {"code": "it and "} | required | [{"type":"function", "function": {"name": "ipython", "description": "", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": ""}}, "required": ["code"]}}}] |
| template_name | n_predict | tool_name | tool_arguments | tools |
| meetkai-functionary-medium-v3.1 | 128 | test | {} | [{"type":"function", "function": {"name": "test", "description": "", "parameters": {"type": "object", "properties": {}}}}] |
| meetkai-functionary-medium-v3.1 | 128 | ipython | {"code": "I'm sorry,"} | [{"type":"function", "function": {"name": "ipython", "description": "", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": ""}}, "required": ["code"]}}}] |
| meetkai-functionary-medium-v3.2 | 128 | test | {} | [{"type":"function", "function": {"name": "test", "description": "", "parameters": {"type": "object", "properties": {}}}}] |
| meetkai-functionary-medium-v3.2 | 128 | ipython | {"code": "Yes,"} | [{"type":"function", "function": {"name": "ipython", "description": "", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": ""}}, "required": ["code"]}}}] |
| meta-llama-Meta-Llama-3.1-8B-Instruct | 64 | test | {} | [{"type":"function", "function": {"name": "test", "description": "", "parameters": {"type": "object", "properties": {}}}}] |
| meta-llama-Meta-Llama-3.1-8B-Instruct | 16 | ipython | {"code": ". A"} | [{"type":"function", "function": {"name": "ipython", "description": "", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": ""}}, "required": ["code"]}}}] |
Scenario: OAI Compatibility w/ no tool

View File

@ -14,6 +14,7 @@
// Change JSON_ASSERT from assert() to GGML_ASSERT:
#define JSON_ASSERT GGML_ASSERT
#include "chat-template.h"
#include "json.hpp"
#include "minja.hpp"
#include "tool-call.h"
@ -64,40 +65,30 @@ inline std::string format_chat(const struct llama_model * model, const std::stri
for (size_t i = 0; i < messages.size(); ++i) {
const auto & curr_msg = messages[i];
llama_chat_msg msg;
msg.role = json_value(curr_msg, "role", std::string(""));
msg.tool = json_value(curr_msg, "tool", std::string(""));
std::string role = json_value(curr_msg, "role", std::string(""));
std::string content;
if (curr_msg.contains("content")) {
if (curr_msg["content"].is_string()) {
msg.content = curr_msg["content"].get<std::string>();
content = curr_msg["content"].get<std::string>();
} else if (curr_msg["content"].is_array()) {
for (const auto & part : curr_msg["content"]) {
if (part.contains("text")) {
msg.content += "\n" + part["text"].get<std::string>();
content += "\n" + part["text"].get<std::string>();
}
}
} else if (!(curr_msg.is_null() && curr_msg.contains("tool_calls"))) {
throw std::runtime_error("Invalid 'content' type (ref: https://github.com/ggerganov/llama.cpp/issues/8367): " + curr_msg.dump());
} else {
throw std::runtime_error("Invalid 'content' type (ref: https://github.com/ggerganov/llama.cpp/issues/8367)");
}
} else {
throw std::runtime_error("Missing 'content' (ref: https://github.com/ggerganov/llama.cpp/issues/8367)");
}
if (curr_msg.contains("tool_calls") && curr_msg["tool_calls"].is_array()) {
for (const auto & tool_call : curr_msg["tool_calls"]) {
if (json_value(tool_call, "type", std::string("")) == "function"
&& tool_call.contains("function") && tool_call["function"].is_object()) {
msg.tool_calls.push_back({
json_value(tool_call["function"], "name", std::string("")),
json_value(tool_call["function"], "arguments", std::string(""))
});
}
}
}
chat.emplace_back(std::move(msg));
chat.push_back({role, content});
}
const auto formatted_chat = llama_chat_apply_template(model, tmpl, chat, true, use_jinja, tools.is_null() ? nullptr : tools.dump().c_str());
const auto formatted_chat = llama_chat_apply_template(model, tmpl, chat, true);
LOG_DBG("formatted_chat: '%s'\n", formatted_chat.c_str());
return formatted_chat;
@ -315,38 +306,12 @@ static bool server_sent_event(httplib::DataSink & sink, const char * event, cons
// OAI utils
//
static std::string _llama_token_to_piece(const struct llama_model * model, llama_token token, bool special) {
std::string piece;
piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n'
const int n_chars = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special);
if (n_chars < 0) {
piece.resize(-n_chars);
int check = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special);
GGML_ASSERT(check == -n_chars);
}
else {
piece.resize(n_chars);
}
return piece;
}
std::string llama_model_meta_val_str(const struct llama_model * model, const char * key) {
int32_t tlen = llama_model_meta_val_str(model, key, nullptr, 0);
if (tlen > 0) {
std::vector<char> curr_tmpl_buf(tlen + 1, 0);
if (llama_model_meta_val_str(model, key, curr_tmpl_buf.data(), curr_tmpl_buf.size()) == tlen) {
return std::string(curr_tmpl_buf.data(), tlen);
}
}
return "";
}
static json oaicompat_completion_params_parse(
const struct llama_model * model,
const json & body, /* openai api json semantics */
const std::string & chat_template_src,
bool use_jinja) {
const llama_chat_template & tmpl,
bool use_jinja)
{
json llama_params;
llama_params["__oaicompat"] = true;
@ -355,16 +320,15 @@ static json oaicompat_completion_params_parse(
auto has_tools = tools.is_array() && !tools.empty();
// Apply chat template to the list of messages
auto chat_template = chat_template_src.empty() ? llama_model_meta_val_str(model, "tokenizer.chat_template") : chat_template_src;
llama_params["chat_template"] = chat_template;
llama_params["chat_template"] = tmpl.chat_template();
if (use_jinja) {
if (has_tools && chat_template.find("tools") == std::string::npos) {
if (has_tools && !tmpl.supports_tools()) {
throw std::runtime_error("Chat template does not seem to support tools. Override the model template with --chat-template.");
}
} else if (has_tools) {
throw std::runtime_error("Tools are only supported in --jinja mode");
}
llama_params["prompt"] = format_chat(model, chat_template, body.at("messages"), tools, use_jinja);
// Handle "stop" field
if (body.contains("stop") && body.at("stop").is_string()) {
@ -399,11 +363,16 @@ static json oaicompat_completion_params_parse(
} else if (!response_type.empty() && response_type != "text") {
throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type);
}
} else if (use_jinja && tool_choice != "none" && has_tools) {
bool parallel_tool_calls = json_value(body, "parallel_tool_calls", false);
bool allow_content = tool_choice != "required";
}
auto handler = llama_tool_call_handler_init(chat_template, allow_content, parallel_tool_calls, tools);
if (use_jinja) {
bool allow_content = tool_choice != "required";
if (tool_choice != "none" && has_tools) {
bool parallel_tool_calls = json_value(body, "parallel_tool_calls", false);
llama_params["parse_tool_calls"] = true;
llama_params["parallel_tool_calls"] = parallel_tool_calls;
auto handler = llama_tool_call_handler_init(tmpl, allow_content, parallel_tool_calls, tools);
for (const auto & stop : handler.additional_stop_words) {
llama_params["stop"].push_back(stop);
@ -415,10 +384,19 @@ static json oaicompat_completion_params_parse(
}
llama_params["grammar_trigger_words"] = triggers;
}
if (handler.updated_tools.is_null()) {
tools = handler.updated_tools;
}
if (!handler.grammar.empty()) {
if (llama_params.contains("grammar")) {
throw std::runtime_error("Cannot use custom grammar constraints with tools.");
}
llama_params["grammar"] = handler.grammar;
llama_params["parse_tool_calls"] = true;
llama_params["parallel_tool_calls"] = parallel_tool_calls;
}
}
llama_params["prompt"] = tmpl.apply(body.at("messages"), tools, /* add_generation_prompt= */ true);
} else {
llama_params["prompt"] = format_chat(model, tmpl.chat_template(), body.at("messages"), tools, /* use_jinja= */ false);
}
// Handle "n" field
@ -458,7 +436,7 @@ static json oaicompat_completion_params_parse(
return llama_params;
}
static json format_final_response_oaicompat(const json & request, const json & result, const std::string & completion_id, bool streaming = false, bool verbose = false) {
static json format_final_response_oaicompat(const json & request, const json & result, const std::string & completion_id, const llama_chat_template & tmpl, bool streaming = false, bool verbose = false) {
bool stopped_word = result.count("stopped_word") != 0;
bool stopped_eos = json_value(result, "stopped_eos", false);
int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
@ -474,9 +452,8 @@ static json format_final_response_oaicompat(const json & request, const json & r
auto tools = json_value(request, "tools", json::array());
json tool_calls;
json message_content;
printf("# CONTENT: %s\n\n", content.c_str());
if (json_value(request, "parse_tool_calls", false)
&& !(parsed_tool_calls = parse_tool_calls(tools, chat_template, content)).tool_calls.empty()) {
&& !(parsed_tool_calls = parse_tool_calls(tmpl.tool_call_style(), tools, content)).tool_calls.empty()) {
finish_reason = "tool";
if (!parsed_tool_calls.content.empty()) {
message_content = parsed_tool_calls.content;
@ -514,7 +491,6 @@ static json format_final_response_oaicompat(const json & request, const json & r
}},
{"id", completion_id}
};
printf("# RES: %s\n\n", res.dump(2).c_str());
// extra fields for debugging purposes
if (verbose) {

View File

@ -377,19 +377,9 @@ extern "C" {
} llama_sampler_chain_params;
// used in chat template
typedef struct llama_chat_message_tool_call {
const char * name;
const char * arguments;
} llama_chat_message_tool_call;
typedef struct llama_chat_message {
const char * role;
const char * content;
const char * tool;
const llama_chat_message_tool_call * tool_calls;
uint32_t n_tool_calls;
} llama_chat_message;
// lora adapter
@ -986,11 +976,7 @@ extern "C" {
size_t n_msg,
bool add_ass,
char * buf,
int32_t length,
bool use_jinja,
const char * tools,
const char * bos_token,
const char * eos_token);
int32_t length);
//
// Sampling API

View File

@ -2,8 +2,6 @@
#include "llama-vocab.h"
#include "llama-sampling.h"
#include "minja.hpp"
#include "unicode.h"
#include "ggml.h"
@ -21004,95 +21002,7 @@ int32_t llama_detokenize(
static int32_t llama_chat_apply_template_internal(
const std::string & tmpl,
const std::vector<const llama_chat_message *> & chat,
std::string & dest, bool add_ass,
bool use_jinja,
const std::string & tools,
const std::string & bos_token, const std::string & eos_token) {
if (use_jinja) {
auto system_not_supported = tmpl.find("System role not supported") != std::string::npos;
// Meta-Llama-3.1-8B-Instruct's template expects arguments to be an object.
// Most other templates (and OpenAI's API) expect the arguments object to be stringified.
auto tool_call_args_must_be_objects = tmpl.find("tool_call.arguments | items") != std::string::npos;
auto messages = json::array();
std::string pending_system;
auto flush_sys = [&]() {
if (!pending_system.empty()) {
messages.push_back({
{"role", "user"},
{"content", pending_system},
});
pending_system.clear();
}
};
for (const auto * msg : chat) {
std::string role(msg->role);
std::string content(msg->content);
if (system_not_supported) {
if (role == "system") {
if (!pending_system.empty()) pending_system += "\n";
pending_system += content;
continue;
} else {
if (role == "user") {
if (!pending_system.empty()) {
content = pending_system + (content.empty() ? "" : "\n" + content);
pending_system.clear();
}
} else {
flush_sys();
}
}
}
auto message = json({
{"role", role},
{"content", content},
});
if (msg->tool) message["tool"] = msg->tool;
if (msg->n_tool_calls) {
auto tool_calls = json::array();
for (uint32_t i = 0; i < msg->n_tool_calls; i++) {
auto args = msg->tool_calls[i].arguments;
tool_calls.push_back(json({
{"type", "function"},
{"function", {
{"name", msg->tool_calls[i].name},
{"arguments", tool_call_args_must_be_objects ? json::parse(args) : args},
}}
}));
}
messages["tool_calls"] = tool_calls;
}
messages.push_back(message);
}
flush_sys();
auto context = minja::Context::make(json({
{"messages", messages},
{"add_generation_prompt", add_ass},
{"bos_token", bos_token},
{"eos_token", eos_token},
}));
if (!tools.empty()) {
auto tools_val = minja::Value(json::parse(tools));
context->set("tools", tools_val);
}
auto tmpl_root = minja::Parser::parse(tmpl, {
/* .trim_blocks = */ true,
/* .lstrip_blocks = */ true,
/* .keep_trailing_newline = */ false,
});
try {
dest = tmpl_root->render(context);
return dest.size();
} catch (const std::runtime_error & err) {
LLAMA_LOG_ERROR("Error in jinja template: %s\n", err.what());
return -1;
}
}
std::string & dest, bool add_ass) {
// Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527
std::stringstream ss;
@ -21360,11 +21270,7 @@ int32_t llama_chat_apply_template(
size_t n_msg,
bool add_ass,
char * buf,
int32_t length,
bool use_jinja,
const char * tools,
const char * bos_token,
const char * eos_token) {
int32_t length) {
std::string curr_tmpl(tmpl == nullptr ? "" : tmpl);
if (tmpl == nullptr) {
GGML_ASSERT(model != nullptr);
@ -21379,16 +21285,6 @@ int32_t llama_chat_apply_template(
curr_tmpl = std::string(model_template.data(), model_template.size());
}
}
std::string curr_bos_token(bos_token ? bos_token : "");
std::string curr_eos_token(eos_token ? eos_token : "");
if (bos_token == nullptr) {
GGML_ASSERT(model != nullptr);
curr_bos_token = llama_token_to_piece(model, llama_token_bos(model), true);
}
if (eos_token == nullptr) {
GGML_ASSERT(model != nullptr);
curr_eos_token = llama_token_to_piece(model, llama_token_eos(model), true);
}
// format the chat to string
std::vector<const llama_chat_message *> chat_vec;
@ -21398,7 +21294,7 @@ int32_t llama_chat_apply_template(
}
std::string formatted_chat;
int32_t res = llama_chat_apply_template_internal(curr_tmpl, chat_vec, formatted_chat, add_ass, use_jinja, tools == nullptr ? "" : tools, curr_bos_token, curr_eos_token);
int32_t res = llama_chat_apply_template_internal(curr_tmpl, chat_vec, formatted_chat, add_ass);
if (res < 0) {
return res;
}

View File

@ -20,9 +20,9 @@ static void assert_equals(const std::string & expected, const std::string & actu
cmake -B build -DLLAMA_CURL=1 -DCMAKE_BUILD_TYPE=Release && cmake --build build -t test-tool-call -j && ./build/bin/test-tool-call
*/
static void test_parse_tool_call(const json & tools, const std::string & chat_template, const std::string & input, const std::string & expected_content, const json & expected_tool_calls) {
static void test_parse_tool_call(llama_tool_call_style style, const json & tools, const std::string & input, const std::string & expected_content, const json & expected_tool_calls) {
std::cout << "# Testing: " << input << std::endl << std::flush;
auto result = parse_tool_calls(tools, chat_template, input);
auto result = parse_tool_calls(style, tools, input);
assert_equals(expected_content, result.content);
auto tool_calls = json::array();
for (const auto & tc : result.tool_calls) {
@ -59,8 +59,7 @@ int main() {
{"tools", tools}
};
std::string hermes_2_pro_like_tmpl = "Hermes 2 Pro template should have <tool_call> inside it";
test_parse_tool_call(tools, hermes_2_pro_like_tmpl,
test_parse_tool_call(llama_tool_call_style::Hermes2Pro, tools,
"<tool_call>{\"name\": \"foo\", \"arguments\": {\"bar\": 1}}</tool_call>",
"",
json {{
@ -72,8 +71,7 @@ int main() {
}}
}});
std::string functionary_v3_like_tmpl = "Functionary 3.2 template should have <|start_header_id|> and then some >>>all inside it";
test_parse_tool_call(tools, functionary_v3_like_tmpl,
test_parse_tool_call(llama_tool_call_style::FunctionaryV3Llama3, tools,
">>>ipython\n{\"code\": \"print('Hello, world!')\"}",
"",
json {{
@ -84,7 +82,7 @@ int main() {
}).dump()}
}}
}});
test_parse_tool_call(tools, functionary_v3_like_tmpl,
test_parse_tool_call(llama_tool_call_style::FunctionaryV3Llama3, tools,
">>>test\n{ } \n ",
"",
json {{
@ -94,8 +92,7 @@ int main() {
}}
}});
std::string functionary_v3_llama_3_1_like_tmpl = "Functionary 3.2 template for llama 3.1 should have <|start_header_id|> and then some <function=foo>{...}</function> inside it";
test_parse_tool_call(tools, functionary_v3_llama_3_1_like_tmpl,
test_parse_tool_call(llama_tool_call_style::FunctionaryV3Llama31, tools,
"Hell<function=foo>{\"arg1\": 1}</function>o, world<function=bar>{\"arg2\": 2}</function>!",
"Hello, world!",
json {
@ -116,7 +113,7 @@ int main() {
}}
},
});
test_parse_tool_call(tools, functionary_v3_llama_3_1_like_tmpl,
test_parse_tool_call(llama_tool_call_style::FunctionaryV3Llama31, tools,
"<function=test>{ } </function> ",
" ",
json {{
@ -126,8 +123,7 @@ int main() {
}}
}});
std::string llama_3_1_like_tmpl = "Llama 3.1 template should have <|start_header_id|> and <|python_tag|> inside it";
test_parse_tool_call(tools, llama_3_1_like_tmpl,
test_parse_tool_call(llama_tool_call_style::Llama31, tools,
"<|python_tag|>this could be anything",
"",
json {{
@ -138,7 +134,7 @@ int main() {
}).dump()}
}}
}});
test_parse_tool_call(tools, llama_3_1_like_tmpl,
test_parse_tool_call(llama_tool_call_style::Llama31, tools,
"I'm thinking<|python_tag|>",
"I'm thinking",
json {{
@ -147,7 +143,7 @@ int main() {
{"arguments", (json {{"code", ""}}).dump()}
}}
}});
test_parse_tool_call(tools, llama_3_1_like_tmpl,
test_parse_tool_call(llama_tool_call_style::Llama31, tools,
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}",
"",
json {{
@ -158,7 +154,7 @@ int main() {
}).dump()}
}}
}});
test_parse_tool_call(tools, llama_3_1_like_tmpl,
test_parse_tool_call(llama_tool_call_style::Llama31, tools,
"{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}",
"{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", json::array());