mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-13 04:00:16 +00:00
tool-call
: factor chat template away from legacy API
This commit is contained in:
parent
d7ec84f78c
commit
cf7bece6a7
4
Makefile
4
Makefile
@ -934,6 +934,7 @@ OBJ_LLAMA = \
|
||||
|
||||
OBJ_COMMON = \
|
||||
common/common.o \
|
||||
common/chat-template.o \
|
||||
common/arg.o \
|
||||
common/log.o \
|
||||
common/console.o \
|
||||
@ -1170,6 +1171,8 @@ $(LIB_LLAMA_S): \
|
||||
common/common.o: \
|
||||
common/common.cpp \
|
||||
common/common.h \
|
||||
common/chat-template.cpp \
|
||||
common/chat-template.h \
|
||||
common/console.h \
|
||||
common/sampling.h \
|
||||
common/json.hpp \
|
||||
@ -1465,6 +1468,7 @@ llama-server: \
|
||||
examples/server/prompt-formats.js.hpp \
|
||||
examples/server/json-schema-to-grammar.mjs.hpp \
|
||||
examples/server/loading.html.hpp \
|
||||
common/chat-template.h \
|
||||
common/json.hpp \
|
||||
common/stb_image.h \
|
||||
$(OBJ_ALL)
|
||||
|
@ -54,6 +54,8 @@ add_library(${TARGET} STATIC
|
||||
arg.cpp
|
||||
arg.h
|
||||
base64.hpp
|
||||
chat-template.cpp
|
||||
chat-template.h
|
||||
common.cpp
|
||||
common.h
|
||||
console.cpp
|
||||
|
118
common/chat-template.cpp
Normal file
118
common/chat-template.cpp
Normal file
@ -0,0 +1,118 @@
|
||||
#include "chat-template.h"
|
||||
#include "minja.hpp"
|
||||
#include "llama.h"
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
static std::string _llama_token_to_piece(const struct llama_model * model, llama_token token, bool special) {
|
||||
std::string piece;
|
||||
piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n'
|
||||
const int n_chars = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special);
|
||||
if (n_chars < 0) {
|
||||
piece.resize(-n_chars);
|
||||
int check = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special);
|
||||
GGML_ASSERT(check == -n_chars);
|
||||
}
|
||||
else {
|
||||
piece.resize(n_chars);
|
||||
}
|
||||
|
||||
return piece;
|
||||
}
|
||||
|
||||
static std::string llama_model_meta_val_str(const struct llama_model * model, const char * key) {
|
||||
int32_t tlen = llama_model_meta_val_str(model, key, nullptr, 0);
|
||||
if (tlen > 0) {
|
||||
std::vector<char> curr_tmpl_buf(tlen + 1, 0);
|
||||
if (llama_model_meta_val_str(model, key, curr_tmpl_buf.data(), curr_tmpl_buf.size()) == tlen) {
|
||||
return std::string(curr_tmpl_buf.data(), tlen);
|
||||
}
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
llama_chat_template llama_chat_template::from_model(
|
||||
const struct llama_model * model,
|
||||
const std::string & chat_template_override)
|
||||
{
|
||||
// TODO: handle "chatml"?
|
||||
auto chat_template = chat_template_override.empty()
|
||||
? llama_model_meta_val_str(model, "tokenizer.chat_template")
|
||||
: chat_template_override;
|
||||
auto bos_token = _llama_token_to_piece(model, llama_token_bos(model), true);
|
||||
auto eos_token = _llama_token_to_piece(model, llama_token_eos(model), true);
|
||||
return llama_chat_template(chat_template, bos_token, eos_token);
|
||||
}
|
||||
|
||||
std::string llama_chat_template::apply(
|
||||
const json & messages,
|
||||
const json & tools,
|
||||
bool add_generation_prompt) const
|
||||
{
|
||||
auto actual_messages = messages;
|
||||
|
||||
// First, "fix" messages so they have a chance to be rendered correctly by the template
|
||||
|
||||
if (_requires_object_arguments || !_supports_system_role) {
|
||||
std::string pending_system;
|
||||
auto flush_sys = [&]() {
|
||||
if (!pending_system.empty()) {
|
||||
actual_messages.push_back({
|
||||
{"role", "user"},
|
||||
{"content", pending_system},
|
||||
});
|
||||
pending_system.clear();
|
||||
}
|
||||
};
|
||||
for (auto & message : actual_messages) {
|
||||
if (!message.contains("role") || !message.contains("content")) {
|
||||
throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump());
|
||||
}
|
||||
std::string role = message.at("role");
|
||||
std::string content = message.at("content");
|
||||
|
||||
if (!_supports_system_role) {
|
||||
if (role == "system") {
|
||||
if (!pending_system.empty()) pending_system += "\n";
|
||||
pending_system += content;
|
||||
continue;
|
||||
} else {
|
||||
if (role == "user") {
|
||||
if (!pending_system.empty()) {
|
||||
message["content"] = pending_system + (content.empty() ? "" : "\n" + content);
|
||||
pending_system.clear();
|
||||
}
|
||||
} else {
|
||||
flush_sys();
|
||||
}
|
||||
}
|
||||
}
|
||||
if (_requires_object_arguments && message.contains("tool_calls")) {
|
||||
for (auto & tool_call : message.at("tool_calls")) {
|
||||
std::string arguments = tool_call.at("arguments");
|
||||
tool_call["arguments"] = json::parse(arguments);
|
||||
}
|
||||
}
|
||||
}
|
||||
flush_sys();
|
||||
}
|
||||
|
||||
auto context = minja::Context::make(json({
|
||||
{"messages", actual_messages},
|
||||
{"add_generation_prompt", add_generation_prompt},
|
||||
{"bos_token", _bos_token},
|
||||
{"eos_token", _eos_token},
|
||||
}));
|
||||
|
||||
if (!tools.is_null() && !tools.empty()) {
|
||||
auto tools_val = minja::Value(tools);
|
||||
context->set("tools", tools_val);
|
||||
}
|
||||
|
||||
auto tmpl_root = minja::Parser::parse(_chat_template, {
|
||||
/* .trim_blocks = */ true,
|
||||
/* .lstrip_blocks = */ true,
|
||||
/* .keep_trailing_newline = */ false,
|
||||
});
|
||||
return tmpl_root->render(context);
|
||||
}
|
64
common/chat-template.h
Normal file
64
common/chat-template.h
Normal file
@ -0,0 +1,64 @@
|
||||
#pragma once
|
||||
|
||||
#include <json.hpp>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
enum llama_tool_call_style {
|
||||
Unknown,
|
||||
Llama31,
|
||||
FunctionaryV3Llama3,
|
||||
FunctionaryV3Llama31,
|
||||
Hermes2Pro,
|
||||
};
|
||||
|
||||
class llama_chat_template {
|
||||
public:
|
||||
|
||||
private:
|
||||
llama_tool_call_style _tool_call_style = Unknown;
|
||||
bool _supports_tools = true;
|
||||
// Meta-Llama-3.1-8B-Instruct's template expects arguments to be an object.
|
||||
// Most other templates (and OpenAI's API) expect the arguments object to be stringified.
|
||||
bool _requires_object_arguments = false;
|
||||
bool _supports_system_role = true;
|
||||
std::string _chat_template;
|
||||
std::string _bos_token;
|
||||
std::string _eos_token;
|
||||
public:
|
||||
llama_chat_template(const std::string & chat_template, const std::string & bos_token, const std::string & eos_token)
|
||||
: _chat_template(chat_template), _bos_token(bos_token), _eos_token(eos_token) {
|
||||
|
||||
_supports_tools = chat_template.find("tools") != std::string::npos;
|
||||
_requires_object_arguments = chat_template.find("tool_call.arguments | items") != std::string::npos;
|
||||
_supports_system_role = chat_template.find("System role not supported") == std::string::npos;
|
||||
|
||||
if (chat_template.find("<tool_call>") != std::string::npos) {
|
||||
_tool_call_style = Hermes2Pro;
|
||||
} else if (chat_template.find(">>>all") != std::string::npos) {
|
||||
_tool_call_style = FunctionaryV3Llama3;
|
||||
} else if (chat_template.find("<|start_header_id|>") != std::string::npos) {
|
||||
if (chat_template.find("<function=") != std::string::npos) {
|
||||
_tool_call_style = FunctionaryV3Llama31;
|
||||
} else if (chat_template.find("<|python_tag|>") != std::string::npos) {
|
||||
_tool_call_style = Llama31;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static llama_chat_template from_model(
|
||||
const struct llama_model * model,
|
||||
const std::string & chat_template_override);
|
||||
|
||||
llama_tool_call_style tool_call_style() const { return _tool_call_style; }
|
||||
|
||||
const std::string & chat_template() const { return _chat_template; }
|
||||
bool supports_tools() const { return _supports_tools; }
|
||||
|
||||
std::string apply(
|
||||
const nlohmann::ordered_json & messages,
|
||||
const nlohmann::ordered_json & tools,
|
||||
bool add_generation_prompt) const;
|
||||
};
|
@ -9,6 +9,7 @@
|
||||
#include "json.hpp"
|
||||
#include "json-schema-to-grammar.h"
|
||||
#include "llama.h"
|
||||
#include "chat-template.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cinttypes>
|
||||
@ -1511,6 +1512,20 @@ std::string llama_detokenize(llama_context * ctx, const std::vector<llama_token>
|
||||
//
|
||||
|
||||
bool llama_chat_verify_template(const std::string & tmpl, bool use_jinja) {
|
||||
if (use_jinja) {
|
||||
try {
|
||||
auto chat_template = llama_chat_template(tmpl, "<s>", "</s>");
|
||||
chat_template.apply({{
|
||||
{"role", "user"},
|
||||
{"content", "test"},
|
||||
}}, json(), true);
|
||||
return true;
|
||||
} catch (const std::exception & e) {
|
||||
LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
llama_chat_message chat[] = {{"user", "test"}};
|
||||
int res = llama_chat_apply_template(
|
||||
nullptr,
|
||||
@ -1519,22 +1534,14 @@ bool llama_chat_verify_template(const std::string & tmpl, bool use_jinja) {
|
||||
1,
|
||||
/* add_ass= */ true,
|
||||
/* buffer= */ nullptr,
|
||||
/* length= */ 0,
|
||||
use_jinja,
|
||||
/* tools= */ nullptr,
|
||||
"<s>",
|
||||
"</s>");
|
||||
/* length= */ 0);
|
||||
return res >= 0;
|
||||
}
|
||||
|
||||
std::string llama_chat_apply_template(const struct llama_model * model,
|
||||
const std::string & tmpl,
|
||||
const std::vector<llama_chat_msg> & msgs,
|
||||
bool add_ass,
|
||||
bool use_jinja,
|
||||
const char * tools,
|
||||
const char * bos_token,
|
||||
const char * eos_token) {
|
||||
bool add_ass) {
|
||||
int alloc_size = 0;
|
||||
bool fallback = false; // indicate if we must fallback to default chatml
|
||||
std::vector<llama_chat_message> chat;
|
||||
@ -1547,7 +1554,7 @@ std::string llama_chat_apply_template(const struct llama_model * model,
|
||||
std::vector<char> buf(alloc_size);
|
||||
|
||||
// run the first time to get the total output length
|
||||
int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size(), use_jinja, tools, bos_token, eos_token);
|
||||
int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
||||
|
||||
// error: chat template is not supported
|
||||
if (res < 0) {
|
||||
@ -1557,7 +1564,7 @@ std::string llama_chat_apply_template(const struct llama_model * model,
|
||||
throw std::runtime_error("this custom template is not supported");
|
||||
} else {
|
||||
// If the built-in template is not supported, we default to chatml
|
||||
res = llama_chat_apply_template(nullptr, "chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size(), use_jinja, tools, bos_token, eos_token);
|
||||
res = llama_chat_apply_template(nullptr, "chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
||||
fallback = true;
|
||||
}
|
||||
}
|
||||
@ -1568,7 +1575,7 @@ std::string llama_chat_apply_template(const struct llama_model * model,
|
||||
res = llama_chat_apply_template(
|
||||
fallback ? nullptr : model,
|
||||
fallback ? "chatml" : ptr_tmpl,
|
||||
chat.data(), chat.size(), add_ass, buf.data(), buf.size(), use_jinja, tools, bos_token, eos_token);
|
||||
chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
||||
}
|
||||
|
||||
std::string formatted_chat(buf.data(), res);
|
||||
@ -1579,13 +1586,9 @@ std::string llama_chat_format_single(const struct llama_model * model,
|
||||
const std::string & tmpl,
|
||||
const std::vector<llama_chat_msg> & past_msg,
|
||||
const llama_chat_msg & new_msg,
|
||||
bool add_ass,
|
||||
bool use_jinja,
|
||||
const char * tools,
|
||||
const char * bos_token,
|
||||
const char * eos_token) {
|
||||
bool add_ass) {
|
||||
std::ostringstream ss;
|
||||
auto fmt_past_msg = past_msg.empty() ? "" : llama_chat_apply_template(model, tmpl, past_msg, false, use_jinja, tools, bos_token, eos_token);
|
||||
auto fmt_past_msg = past_msg.empty() ? "" : llama_chat_apply_template(model, tmpl, past_msg, false);
|
||||
std::vector<llama_chat_msg> chat_new(past_msg);
|
||||
// if the past_msg ends with a newline, we must preserve it in the formatted version
|
||||
if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
|
||||
@ -1593,7 +1596,7 @@ std::string llama_chat_format_single(const struct llama_model * model,
|
||||
};
|
||||
// format chat with new_msg
|
||||
chat_new.push_back(new_msg);
|
||||
auto fmt_new_msg = llama_chat_apply_template(model, tmpl, chat_new, add_ass, use_jinja, tools, bos_token, eos_token);
|
||||
auto fmt_new_msg = llama_chat_apply_template(model, tmpl, chat_new, add_ass);
|
||||
// get the diff part
|
||||
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
|
||||
return ss.str();
|
||||
|
@ -471,21 +471,14 @@ std::string llama_detokenize(
|
||||
// Chat template utils
|
||||
//
|
||||
|
||||
struct llama_chat_msg_tool_call {
|
||||
std::string name;
|
||||
std::string arguments;
|
||||
};
|
||||
|
||||
// same as llama_chat_message, but uses std::string and std::vector
|
||||
struct llama_chat_msg {
|
||||
std::string role;
|
||||
std::string content;
|
||||
std::string tool;
|
||||
std::vector<struct llama_chat_msg_tool_call> tool_calls;
|
||||
};
|
||||
|
||||
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
|
||||
bool llama_chat_verify_template(const std::string & tmpl, bool use_jinja = false);
|
||||
// Check if the template is supported or not. Returns true if it's valid
|
||||
bool llama_chat_verify_template(const std::string & tmpl, bool use_jinja);
|
||||
|
||||
// CPP wrapper for llama_chat_apply_template
|
||||
// If the built-in template is not supported, we default to chatml
|
||||
@ -493,22 +486,14 @@ bool llama_chat_verify_template(const std::string & tmpl, bool use_jinja = false
|
||||
std::string llama_chat_apply_template(const struct llama_model * model,
|
||||
const std::string & tmpl,
|
||||
const std::vector<llama_chat_msg> & chat,
|
||||
bool add_ass,
|
||||
bool use_jinja = false,
|
||||
const char * tools = nullptr,
|
||||
const char * bos_token = nullptr,
|
||||
const char * eos_token = nullptr);
|
||||
bool add_ass);
|
||||
|
||||
// Format single message, while taking into account the position of that message in chat history
|
||||
std::string llama_chat_format_single(const struct llama_model * model,
|
||||
const std::string & tmpl,
|
||||
const std::vector<llama_chat_msg> & past_msg,
|
||||
const llama_chat_msg & new_msg,
|
||||
bool add_ass,
|
||||
bool use_jinja = false,
|
||||
const char * tools = nullptr,
|
||||
const char * bos_token = nullptr,
|
||||
const char * eos_token = nullptr);
|
||||
bool add_ass);
|
||||
|
||||
// Returns an example of formatted chat
|
||||
std::string llama_chat_format_example(const struct llama_model * model,
|
||||
|
@ -12,27 +12,6 @@
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3.llama3.txt
|
||||
static bool needs_functionary_v3_tool_call(const std::string & chat_template) {
|
||||
return chat_template.find("<|start_header_id|>") != std::string::npos
|
||||
&& chat_template.find(">>>all") != std::string::npos;
|
||||
}
|
||||
|
||||
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
|
||||
static bool needs_functionary_v3_llama_3_1_tool_call(const std::string & chat_template) {
|
||||
return chat_template.find("<|start_header_id|>") != std::string::npos
|
||||
&& chat_template.find("<function=") != std::string::npos;
|
||||
}
|
||||
|
||||
static bool needs_llama_3_1_tool_call(const std::string & chat_template) {
|
||||
return chat_template.find("<|start_header_id|>") != std::string::npos
|
||||
&& chat_template.find("<|python_tag|>") != std::string::npos;
|
||||
}
|
||||
|
||||
static bool needs_hermes_pro_tool_call(const std::string & chat_template) {
|
||||
return chat_template.find("<tool_call>") != std::string::npos;
|
||||
}
|
||||
|
||||
static bool parse_json(std::string::const_iterator & it, const std::string::const_iterator & end, json & out) {
|
||||
// // https://json.nlohmann.me/features/parsing/sax_interface/
|
||||
struct json_error_locator : public nlohmann::json_sax<json> {
|
||||
@ -209,137 +188,145 @@ static llama_tool_calls parse_functionary_v3_tool_calls(const std::string& input
|
||||
return parse_functionary_tool_calls(input, function_regex, close_regex);
|
||||
}
|
||||
|
||||
llama_tool_calls parse_tool_calls(const json & tools, const std::string & chat_template, const std::string& input) {
|
||||
if (needs_hermes_pro_tool_call(chat_template)) {
|
||||
return parse_hermes_tool_calls(input);
|
||||
} else if (needs_functionary_v3_tool_call(chat_template)) {
|
||||
return parse_functionary_v3_tool_calls(input);
|
||||
} else if (needs_functionary_v3_llama_3_1_tool_call(chat_template)) {
|
||||
return parse_functionary_v3_llama_3_1_tool_calls(input);
|
||||
} else if (needs_llama_3_1_tool_call(chat_template)) {
|
||||
return parse_llama_3_1_tool_calls(tools, input);
|
||||
} else {
|
||||
throw std::runtime_error("Unsupported chat template for tool calls");
|
||||
llama_tool_calls parse_tool_calls(llama_tool_call_style style, const json & tools, const std::string& input) {
|
||||
switch (style) {
|
||||
case llama_tool_call_style::Llama31:
|
||||
return parse_llama_3_1_tool_calls(tools, input);
|
||||
case llama_tool_call_style::FunctionaryV3Llama3:
|
||||
return parse_functionary_v3_tool_calls(input);
|
||||
case llama_tool_call_style::FunctionaryV3Llama31:
|
||||
return parse_functionary_v3_llama_3_1_tool_calls(input);
|
||||
case llama_tool_call_style::Hermes2Pro:
|
||||
return parse_hermes_tool_calls(input);
|
||||
default:
|
||||
throw std::runtime_error("Unsupported tool call style");
|
||||
}
|
||||
}
|
||||
|
||||
llama_tool_call_handler llama_tool_call_handler_init(
|
||||
const std::string & chat_template,
|
||||
const llama_chat_template & tmpl,
|
||||
bool allow_content,
|
||||
bool parallel_tool_calls,
|
||||
const nlohmann::ordered_json & tools)
|
||||
{
|
||||
llama_tool_call_handler handler;
|
||||
|
||||
if (needs_functionary_v3_tool_call(chat_template)) {
|
||||
// MeetKaiFunctionary_3_2
|
||||
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
|
||||
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
|
||||
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
for (size_t i = 0, n = tools.size(); i < n; i++) {
|
||||
auto & tool = tools[i];
|
||||
const auto & function = tool["function"];
|
||||
std::string name = function["name"];
|
||||
auto parameters = function["parameters"];
|
||||
auto tool_rule = builder.add_rule(name + "-call", "\">>>" + name + "\\n\" " + builder.add_schema(name + "-args", parameters));
|
||||
tool_rules.push_back(tool_rule);
|
||||
switch (tmpl.tool_call_style()) {
|
||||
case llama_tool_call_style::Llama31: {
|
||||
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
|
||||
static std::vector<std::string> builtin_tools {"wolfram_alpha", "brave_search"};
|
||||
std::vector<std::string> tool_rules;
|
||||
|
||||
for (const auto & tool : tools) {
|
||||
const auto & function = tool["function"];
|
||||
std::string name = function["name"];
|
||||
auto parameters = function["parameters"];
|
||||
builder.resolve_refs(parameters);
|
||||
if (name == "ipython" || std::find(builtin_tools.begin(), builtin_tools.end(), name) != builtin_tools.end()) {
|
||||
tool_rules.push_back(builder.add_rule("ipython-call", "\"<|python_tag|>\" .*"));
|
||||
if (allow_content) {
|
||||
handler.grammar_trigger_words.push_back("<|python_tag|>");
|
||||
}
|
||||
} else {
|
||||
//"<|start_header_id|>assistant<|end_header_id|>\n\n{\"name\": \"" + name + "\", " +
|
||||
tool_rules.push_back(
|
||||
builder.add_rule(
|
||||
name + "-call",
|
||||
"\"\\n{\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " +
|
||||
builder.add_schema(name + "-args", parameters) +
|
||||
" \"}\""));
|
||||
if (allow_content) {
|
||||
handler.grammar_trigger_words.push_back("\n{\"" + name + "\"");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
builder.add_rule("root", join(tool_rules.begin(), tool_rules.end(), " | "));
|
||||
});
|
||||
handler.additional_stop_words.push_back("<|eom_id|>");
|
||||
break;
|
||||
}
|
||||
case llama_tool_call_style::FunctionaryV3Llama3: {
|
||||
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
|
||||
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
|
||||
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
for (size_t i = 0, n = tools.size(); i < n; i++) {
|
||||
auto & tool = tools[i];
|
||||
const auto & function = tool["function"];
|
||||
std::string name = function["name"];
|
||||
auto parameters = function["parameters"];
|
||||
auto tool_rule = builder.add_rule(name + "-call", "\">>>" + name + "\\n\" " + builder.add_schema(name + "-args", parameters));
|
||||
tool_rules.push_back(tool_rule);
|
||||
if (allow_content) {
|
||||
handler.grammar_trigger_words.push_back(">>>" + name + "\n");
|
||||
}
|
||||
}
|
||||
auto tool_call = builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " space";
|
||||
builder.add_rule("root", parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
|
||||
});
|
||||
// handler.parser = parse_functionary_3_2_tool_calls;
|
||||
break;
|
||||
}
|
||||
case llama_tool_call_style::FunctionaryV3Llama31: {
|
||||
// ./tests/chat/templates/meetkai-functionary-medium-v3.1.jinja
|
||||
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
|
||||
// TODO: handle tool {type: code_interpreter} as python
|
||||
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
for (size_t i = 0, n = tools.size(); i < n; i++) {
|
||||
auto & tool = tools[i];
|
||||
const auto & function = tool["function"];
|
||||
std::string name = function["name"];
|
||||
auto parameters = function["parameters"];
|
||||
if (name == "python") {
|
||||
tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));
|
||||
if (allow_content) {
|
||||
handler.grammar_trigger_words.push_back("<|python_tag|>");
|
||||
}
|
||||
} else {
|
||||
tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\""));
|
||||
}
|
||||
}
|
||||
auto tool_call = builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " space";
|
||||
builder.add_rule("root", parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
|
||||
if (allow_content) {
|
||||
handler.grammar_trigger_words.push_back(">>>" + name + "\n");
|
||||
handler.grammar_trigger_words.push_back("<function=");
|
||||
}
|
||||
}
|
||||
auto tool_call = builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " space";
|
||||
builder.add_rule("root", parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
|
||||
});
|
||||
// handler.parser = parse_functionary_3_2_tool_calls;
|
||||
} else if (needs_functionary_v3_llama_3_1_tool_call(chat_template)) {
|
||||
// ./tests/chat/templates/meetkai-functionary-medium-v3.1.jinja
|
||||
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
|
||||
// TODO: handle tool {type: code_interpreter} as python
|
||||
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
for (size_t i = 0, n = tools.size(); i < n; i++) {
|
||||
auto & tool = tools[i];
|
||||
const auto & function = tool["function"];
|
||||
std::string name = function["name"];
|
||||
auto parameters = function["parameters"];
|
||||
if (name == "python") {
|
||||
tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));
|
||||
if (allow_content) {
|
||||
handler.grammar_trigger_words.push_back("<|python_tag|>");
|
||||
}
|
||||
} else {
|
||||
tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\""));
|
||||
});
|
||||
// handler.parser = parse_functionary_3_2_tool_calls;
|
||||
break;
|
||||
}
|
||||
case llama_tool_call_style::Hermes2Pro: {
|
||||
// NousResearchHermesPro_2
|
||||
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
|
||||
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
for (const auto & tool : tools) {
|
||||
const auto & function = tool["function"];
|
||||
std::string name = function["name"];
|
||||
auto parameters = function["parameters"];
|
||||
builder.resolve_refs(parameters);
|
||||
tool_rules.push_back(builder.add_schema(name + "-call", {
|
||||
{"type", "object"},
|
||||
{"properties", json {
|
||||
{"name", json {{"const", name}}},
|
||||
{"arguments", parameters},
|
||||
}},
|
||||
{"required", json::array({"name", "arguments"})},
|
||||
}));
|
||||
}
|
||||
}
|
||||
auto tool_call = builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " space";
|
||||
builder.add_rule("root", parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
|
||||
if (allow_content) {
|
||||
handler.grammar_trigger_words.push_back("<function=");
|
||||
}
|
||||
});
|
||||
// handler.parser = parse_functionary_3_2_tool_calls;
|
||||
} else if (needs_hermes_pro_tool_call(chat_template)) {
|
||||
// NousResearchHermesPro_2
|
||||
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
|
||||
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
for (const auto & tool : tools) {
|
||||
const auto & function = tool["function"];
|
||||
std::string name = function["name"];
|
||||
auto parameters = function["parameters"];
|
||||
builder.resolve_refs(parameters);
|
||||
tool_rules.push_back(builder.add_schema(name + "-call", {
|
||||
{"type", "object"},
|
||||
{"properties", json {
|
||||
{"name", json {{"const", name}}},
|
||||
{"arguments", parameters},
|
||||
}},
|
||||
{"required", json::array({"name", "arguments"})},
|
||||
}));
|
||||
}
|
||||
|
||||
auto tool_call = "\"<tool_call>\" " + builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " \"</tool_call>\" space";
|
||||
builder.add_rule("root", parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
|
||||
if (allow_content) {
|
||||
handler.grammar_trigger_words.push_back("<tool_call>");
|
||||
}
|
||||
});
|
||||
} else if (needs_llama_3_1_tool_call(chat_template)) {
|
||||
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
|
||||
static std::vector<std::string> builtin_tools {"wolfram_alpha", "brave_search"};
|
||||
std::vector<std::string> tool_rules;
|
||||
|
||||
for (const auto & tool : tools) {
|
||||
const auto & function = tool["function"];
|
||||
std::string name = function["name"];
|
||||
auto parameters = function["parameters"];
|
||||
builder.resolve_refs(parameters);
|
||||
if (name == "ipython" || std::find(builtin_tools.begin(), builtin_tools.end(), name) != builtin_tools.end()) {
|
||||
tool_rules.push_back(builder.add_rule("ipython-call", "\"<|python_tag|>\" .*"));
|
||||
if (allow_content) {
|
||||
handler.grammar_trigger_words.push_back("<|python_tag|>");
|
||||
}
|
||||
} else {
|
||||
//"<|start_header_id|>assistant<|end_header_id|>\n\n{\"name\": \"" + name + "\", " +
|
||||
tool_rules.push_back(
|
||||
builder.add_rule(
|
||||
name + "-call",
|
||||
"\"\\n{\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " +
|
||||
builder.add_schema(name + "-args", parameters) +
|
||||
" \"}\""));
|
||||
if (allow_content) {
|
||||
handler.grammar_trigger_words.push_back("\n{\"" + name + "\"");
|
||||
}
|
||||
auto tool_call = "\"<tool_call>\" " + builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " \"</tool_call>\" space";
|
||||
builder.add_rule("root", parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
|
||||
if (allow_content) {
|
||||
handler.grammar_trigger_words.push_back("<tool_call>");
|
||||
}
|
||||
}
|
||||
|
||||
builder.add_rule("root", join(tool_rules.begin(), tool_rules.end(), " | "));
|
||||
});
|
||||
handler.additional_stop_words.push_back("<|eom_id|>");
|
||||
} else {
|
||||
// TODO: generic thoughtful schema.
|
||||
throw std::runtime_error("Unsupported tool call style!");
|
||||
});
|
||||
break;
|
||||
}
|
||||
default:
|
||||
throw std::runtime_error("Unsupported tool call style");
|
||||
}
|
||||
return handler;
|
||||
}
|
||||
|
@ -5,22 +5,29 @@
|
||||
// Change JSON_ASSERT from assert() to GGML_ASSERT:
|
||||
#define JSON_ASSERT GGML_ASSERT
|
||||
#include "json.hpp"
|
||||
#include "chat-template.h"
|
||||
|
||||
struct llama_tool_call {
|
||||
std::string name;
|
||||
std::string arguments;
|
||||
};
|
||||
|
||||
struct llama_tool_calls {
|
||||
std::string content;
|
||||
std::vector<llama_chat_msg_tool_call> tool_calls;
|
||||
std::vector<llama_tool_call> tool_calls;
|
||||
};
|
||||
|
||||
struct llama_tool_call_handler {
|
||||
std::string grammar;
|
||||
std::vector<std::string> grammar_trigger_words;
|
||||
std::vector<std::string> additional_stop_words;
|
||||
nlohmann::ordered_json updated_tools;
|
||||
};
|
||||
|
||||
llama_tool_calls parse_tool_calls(const nlohmann::ordered_json & tools, const std::string & chat_template, const std::string& input);
|
||||
llama_tool_calls parse_tool_calls(llama_tool_call_style style, const nlohmann::ordered_json & tools, const std::string& input);
|
||||
|
||||
llama_tool_call_handler llama_tool_call_handler_init(
|
||||
const std::string & chat_template,
|
||||
const llama_chat_template & tmpl,
|
||||
bool allow_content,
|
||||
bool parallel_tool_calls,
|
||||
const nlohmann::ordered_json & tools);
|
||||
|
@ -571,6 +571,12 @@ Given a ChatML-formatted json description in `messages`, it returns the predicte
|
||||
```shell
|
||||
llama-server --jinja -hfr lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF -hff Meta-Llama-3.1-8B-Instruct-Q5_K_M.gguf -fa
|
||||
|
||||
# https://huggingface.co/meetkai/functionary-medium-v3.2
|
||||
llama-server --jinja -hfr bartowski/functionary-medium-v3.2-GGUF -hff functionary-medium-v3.2-IQ4_XS.gguf -fa
|
||||
|
||||
# https://huggingface.co/meetkai/functionary-medium-v3.1
|
||||
llama-server --jinja -hfr meetkai/functionary-medium-v3.1-GGUF -hff functionary-medium-llama-3.1.Q4_0.gguf -fa
|
||||
|
||||
curl http://localhost:8080/v1/chat/completions \
|
||||
-d '{
|
||||
"model": "gpt-3.5-turbo",
|
||||
|
@ -662,7 +662,7 @@ struct server_context {
|
||||
bool validate_model_chat_template(bool use_jinja) const {
|
||||
llama_chat_message chat[] = {{"user", "test"}};
|
||||
|
||||
const int res = llama_chat_apply_template(model, nullptr, chat, 1, true, nullptr, 0, use_jinja, nullptr, nullptr, nullptr);
|
||||
const int res = llama_chat_apply_template(model, nullptr, chat, 1, true, nullptr, 0);
|
||||
|
||||
return res > 0;
|
||||
}
|
||||
@ -2860,9 +2860,11 @@ int main(int argc, char ** argv) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto chat_template = llama_chat_template::from_model(ctx_server.model, params.chat_template);
|
||||
|
||||
json data;
|
||||
try {
|
||||
data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template, params.use_jinja);
|
||||
data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), chat_template, params.use_jinja);
|
||||
} catch (const std::runtime_error & e) {
|
||||
res_error(res, format_error_response(e.what(), ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
@ -2880,7 +2882,7 @@ int main(int argc, char ** argv) {
|
||||
ctx_server.receive_cmpl_results(task_ids, [&](const std::vector<server_task_result> & results) {
|
||||
// multitask is never support in chat completion, there is only one result
|
||||
try {
|
||||
json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, /*.streaming =*/ false, verbose);
|
||||
json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, chat_template, /*.streaming =*/ false, verbose);
|
||||
res_ok(res, result_oai);
|
||||
} catch (const std::runtime_error & e) {
|
||||
res_error(res, format_error_response(e.what(), ERROR_TYPE_SERVER));
|
||||
|
@ -23,19 +23,19 @@ Feature: llama.cpp server
|
||||
And a model test
|
||||
And <n_predict> max tokens to predict
|
||||
And a user prompt write a hello world in python
|
||||
And a tool choice <tool_choice>
|
||||
And a tool choice required
|
||||
And tools <tools>
|
||||
And an OAI compatible chat completions request with no api error
|
||||
Then tool <tool_name> is called with arguments <tool_arguments>
|
||||
|
||||
Examples: Prompts
|
||||
| template_name | n_predict | tool_name | tool_arguments | tool_choice | tools |
|
||||
| meetkai-functionary-medium-v3.1 | 128 | test | {} | required | [{"type":"function", "function": {"name": "test", "description": "", "parameters": {"type": "object", "properties": {}}}}] |
|
||||
| meetkai-functionary-medium-v3.1 | 128 | ipython | {"code": "Yes, you can."} | required | [{"type":"function", "function": {"name": "ipython", "description": "", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": ""}}, "required": ["code"]}}}] |
|
||||
| meetkai-functionary-medium-v3.2 | 128 | test | {} | required | [{"type":"function", "function": {"name": "test", "description": "", "parameters": {"type": "object", "properties": {}}}}] |
|
||||
| meetkai-functionary-medium-v3.2 | 128 | ipython | {"code": "Yes,"} | required | [{"type":"function", "function": {"name": "ipython", "description": "", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": ""}}, "required": ["code"]}}}] |
|
||||
| meta-llama-Meta-Llama-3.1-8B-Instruct | 64 | test | {} | required | [{"type":"function", "function": {"name": "test", "description": "", "parameters": {"type": "object", "properties": {}}}}] |
|
||||
| meta-llama-Meta-Llama-3.1-8B-Instruct | 16 | ipython | {"code": "it and "} | required | [{"type":"function", "function": {"name": "ipython", "description": "", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": ""}}, "required": ["code"]}}}] |
|
||||
| template_name | n_predict | tool_name | tool_arguments | tools |
|
||||
| meetkai-functionary-medium-v3.1 | 128 | test | {} | [{"type":"function", "function": {"name": "test", "description": "", "parameters": {"type": "object", "properties": {}}}}] |
|
||||
| meetkai-functionary-medium-v3.1 | 128 | ipython | {"code": "I'm sorry,"} | [{"type":"function", "function": {"name": "ipython", "description": "", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": ""}}, "required": ["code"]}}}] |
|
||||
| meetkai-functionary-medium-v3.2 | 128 | test | {} | [{"type":"function", "function": {"name": "test", "description": "", "parameters": {"type": "object", "properties": {}}}}] |
|
||||
| meetkai-functionary-medium-v3.2 | 128 | ipython | {"code": "Yes,"} | [{"type":"function", "function": {"name": "ipython", "description": "", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": ""}}, "required": ["code"]}}}] |
|
||||
| meta-llama-Meta-Llama-3.1-8B-Instruct | 64 | test | {} | [{"type":"function", "function": {"name": "test", "description": "", "parameters": {"type": "object", "properties": {}}}}] |
|
||||
| meta-llama-Meta-Llama-3.1-8B-Instruct | 16 | ipython | {"code": ". A"} | [{"type":"function", "function": {"name": "ipython", "description": "", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": ""}}, "required": ["code"]}}}] |
|
||||
|
||||
|
||||
Scenario: OAI Compatibility w/ no tool
|
||||
|
@ -14,6 +14,7 @@
|
||||
|
||||
// Change JSON_ASSERT from assert() to GGML_ASSERT:
|
||||
#define JSON_ASSERT GGML_ASSERT
|
||||
#include "chat-template.h"
|
||||
#include "json.hpp"
|
||||
#include "minja.hpp"
|
||||
#include "tool-call.h"
|
||||
@ -64,40 +65,30 @@ inline std::string format_chat(const struct llama_model * model, const std::stri
|
||||
for (size_t i = 0; i < messages.size(); ++i) {
|
||||
const auto & curr_msg = messages[i];
|
||||
|
||||
llama_chat_msg msg;
|
||||
msg.role = json_value(curr_msg, "role", std::string(""));
|
||||
msg.tool = json_value(curr_msg, "tool", std::string(""));
|
||||
std::string role = json_value(curr_msg, "role", std::string(""));
|
||||
|
||||
std::string content;
|
||||
|
||||
if (curr_msg.contains("content")) {
|
||||
if (curr_msg["content"].is_string()) {
|
||||
msg.content = curr_msg["content"].get<std::string>();
|
||||
content = curr_msg["content"].get<std::string>();
|
||||
} else if (curr_msg["content"].is_array()) {
|
||||
for (const auto & part : curr_msg["content"]) {
|
||||
if (part.contains("text")) {
|
||||
msg.content += "\n" + part["text"].get<std::string>();
|
||||
content += "\n" + part["text"].get<std::string>();
|
||||
}
|
||||
}
|
||||
} else if (!(curr_msg.is_null() && curr_msg.contains("tool_calls"))) {
|
||||
throw std::runtime_error("Invalid 'content' type (ref: https://github.com/ggerganov/llama.cpp/issues/8367): " + curr_msg.dump());
|
||||
} else {
|
||||
throw std::runtime_error("Invalid 'content' type (ref: https://github.com/ggerganov/llama.cpp/issues/8367)");
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error("Missing 'content' (ref: https://github.com/ggerganov/llama.cpp/issues/8367)");
|
||||
}
|
||||
if (curr_msg.contains("tool_calls") && curr_msg["tool_calls"].is_array()) {
|
||||
for (const auto & tool_call : curr_msg["tool_calls"]) {
|
||||
if (json_value(tool_call, "type", std::string("")) == "function"
|
||||
&& tool_call.contains("function") && tool_call["function"].is_object()) {
|
||||
msg.tool_calls.push_back({
|
||||
json_value(tool_call["function"], "name", std::string("")),
|
||||
json_value(tool_call["function"], "arguments", std::string(""))
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
chat.emplace_back(std::move(msg));
|
||||
|
||||
chat.push_back({role, content});
|
||||
}
|
||||
|
||||
const auto formatted_chat = llama_chat_apply_template(model, tmpl, chat, true, use_jinja, tools.is_null() ? nullptr : tools.dump().c_str());
|
||||
const auto formatted_chat = llama_chat_apply_template(model, tmpl, chat, true);
|
||||
LOG_DBG("formatted_chat: '%s'\n", formatted_chat.c_str());
|
||||
|
||||
return formatted_chat;
|
||||
@ -315,38 +306,12 @@ static bool server_sent_event(httplib::DataSink & sink, const char * event, cons
|
||||
// OAI utils
|
||||
//
|
||||
|
||||
static std::string _llama_token_to_piece(const struct llama_model * model, llama_token token, bool special) {
|
||||
std::string piece;
|
||||
piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n'
|
||||
const int n_chars = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special);
|
||||
if (n_chars < 0) {
|
||||
piece.resize(-n_chars);
|
||||
int check = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special);
|
||||
GGML_ASSERT(check == -n_chars);
|
||||
}
|
||||
else {
|
||||
piece.resize(n_chars);
|
||||
}
|
||||
|
||||
return piece;
|
||||
}
|
||||
|
||||
std::string llama_model_meta_val_str(const struct llama_model * model, const char * key) {
|
||||
int32_t tlen = llama_model_meta_val_str(model, key, nullptr, 0);
|
||||
if (tlen > 0) {
|
||||
std::vector<char> curr_tmpl_buf(tlen + 1, 0);
|
||||
if (llama_model_meta_val_str(model, key, curr_tmpl_buf.data(), curr_tmpl_buf.size()) == tlen) {
|
||||
return std::string(curr_tmpl_buf.data(), tlen);
|
||||
}
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
static json oaicompat_completion_params_parse(
|
||||
const struct llama_model * model,
|
||||
const json & body, /* openai api json semantics */
|
||||
const std::string & chat_template_src,
|
||||
bool use_jinja) {
|
||||
const llama_chat_template & tmpl,
|
||||
bool use_jinja)
|
||||
{
|
||||
json llama_params;
|
||||
|
||||
llama_params["__oaicompat"] = true;
|
||||
@ -355,16 +320,15 @@ static json oaicompat_completion_params_parse(
|
||||
auto has_tools = tools.is_array() && !tools.empty();
|
||||
|
||||
// Apply chat template to the list of messages
|
||||
auto chat_template = chat_template_src.empty() ? llama_model_meta_val_str(model, "tokenizer.chat_template") : chat_template_src;
|
||||
llama_params["chat_template"] = chat_template;
|
||||
llama_params["chat_template"] = tmpl.chat_template();
|
||||
|
||||
if (use_jinja) {
|
||||
if (has_tools && chat_template.find("tools") == std::string::npos) {
|
||||
if (has_tools && !tmpl.supports_tools()) {
|
||||
throw std::runtime_error("Chat template does not seem to support tools. Override the model template with --chat-template.");
|
||||
}
|
||||
} else if (has_tools) {
|
||||
throw std::runtime_error("Tools are only supported in --jinja mode");
|
||||
}
|
||||
llama_params["prompt"] = format_chat(model, chat_template, body.at("messages"), tools, use_jinja);
|
||||
|
||||
// Handle "stop" field
|
||||
if (body.contains("stop") && body.at("stop").is_string()) {
|
||||
@ -399,26 +363,40 @@ static json oaicompat_completion_params_parse(
|
||||
} else if (!response_type.empty() && response_type != "text") {
|
||||
throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type);
|
||||
}
|
||||
} else if (use_jinja && tool_choice != "none" && has_tools) {
|
||||
bool parallel_tool_calls = json_value(body, "parallel_tool_calls", false);
|
||||
}
|
||||
|
||||
if (use_jinja) {
|
||||
bool allow_content = tool_choice != "required";
|
||||
if (tool_choice != "none" && has_tools) {
|
||||
bool parallel_tool_calls = json_value(body, "parallel_tool_calls", false);
|
||||
llama_params["parse_tool_calls"] = true;
|
||||
llama_params["parallel_tool_calls"] = parallel_tool_calls;
|
||||
|
||||
auto handler = llama_tool_call_handler_init(chat_template, allow_content, parallel_tool_calls, tools);
|
||||
auto handler = llama_tool_call_handler_init(tmpl, allow_content, parallel_tool_calls, tools);
|
||||
|
||||
for (const auto & stop : handler.additional_stop_words) {
|
||||
llama_params["stop"].push_back(stop);
|
||||
}
|
||||
if (!handler.grammar_trigger_words.empty()) {
|
||||
auto triggers = json::array();
|
||||
for (const auto & word : handler.grammar_trigger_words) {
|
||||
triggers.push_back(word);
|
||||
for (const auto & stop : handler.additional_stop_words) {
|
||||
llama_params["stop"].push_back(stop);
|
||||
}
|
||||
if (!handler.grammar_trigger_words.empty()) {
|
||||
auto triggers = json::array();
|
||||
for (const auto & word : handler.grammar_trigger_words) {
|
||||
triggers.push_back(word);
|
||||
}
|
||||
llama_params["grammar_trigger_words"] = triggers;
|
||||
}
|
||||
if (handler.updated_tools.is_null()) {
|
||||
tools = handler.updated_tools;
|
||||
}
|
||||
if (!handler.grammar.empty()) {
|
||||
if (llama_params.contains("grammar")) {
|
||||
throw std::runtime_error("Cannot use custom grammar constraints with tools.");
|
||||
}
|
||||
llama_params["grammar"] = handler.grammar;
|
||||
}
|
||||
llama_params["grammar_trigger_words"] = triggers;
|
||||
}
|
||||
|
||||
llama_params["grammar"] = handler.grammar;
|
||||
llama_params["parse_tool_calls"] = true;
|
||||
llama_params["parallel_tool_calls"] = parallel_tool_calls;
|
||||
llama_params["prompt"] = tmpl.apply(body.at("messages"), tools, /* add_generation_prompt= */ true);
|
||||
} else {
|
||||
llama_params["prompt"] = format_chat(model, tmpl.chat_template(), body.at("messages"), tools, /* use_jinja= */ false);
|
||||
}
|
||||
|
||||
// Handle "n" field
|
||||
@ -458,7 +436,7 @@ static json oaicompat_completion_params_parse(
|
||||
return llama_params;
|
||||
}
|
||||
|
||||
static json format_final_response_oaicompat(const json & request, const json & result, const std::string & completion_id, bool streaming = false, bool verbose = false) {
|
||||
static json format_final_response_oaicompat(const json & request, const json & result, const std::string & completion_id, const llama_chat_template & tmpl, bool streaming = false, bool verbose = false) {
|
||||
bool stopped_word = result.count("stopped_word") != 0;
|
||||
bool stopped_eos = json_value(result, "stopped_eos", false);
|
||||
int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
|
||||
@ -474,9 +452,8 @@ static json format_final_response_oaicompat(const json & request, const json & r
|
||||
auto tools = json_value(request, "tools", json::array());
|
||||
json tool_calls;
|
||||
json message_content;
|
||||
printf("# CONTENT: %s\n\n", content.c_str());
|
||||
if (json_value(request, "parse_tool_calls", false)
|
||||
&& !(parsed_tool_calls = parse_tool_calls(tools, chat_template, content)).tool_calls.empty()) {
|
||||
&& !(parsed_tool_calls = parse_tool_calls(tmpl.tool_call_style(), tools, content)).tool_calls.empty()) {
|
||||
finish_reason = "tool";
|
||||
if (!parsed_tool_calls.content.empty()) {
|
||||
message_content = parsed_tool_calls.content;
|
||||
@ -514,7 +491,6 @@ static json format_final_response_oaicompat(const json & request, const json & r
|
||||
}},
|
||||
{"id", completion_id}
|
||||
};
|
||||
printf("# RES: %s\n\n", res.dump(2).c_str());
|
||||
|
||||
// extra fields for debugging purposes
|
||||
if (verbose) {
|
||||
|
@ -377,19 +377,9 @@ extern "C" {
|
||||
} llama_sampler_chain_params;
|
||||
|
||||
// used in chat template
|
||||
|
||||
typedef struct llama_chat_message_tool_call {
|
||||
const char * name;
|
||||
const char * arguments;
|
||||
} llama_chat_message_tool_call;
|
||||
|
||||
typedef struct llama_chat_message {
|
||||
const char * role;
|
||||
const char * content;
|
||||
const char * tool;
|
||||
|
||||
const llama_chat_message_tool_call * tool_calls;
|
||||
uint32_t n_tool_calls;
|
||||
} llama_chat_message;
|
||||
|
||||
// lora adapter
|
||||
@ -986,11 +976,7 @@ extern "C" {
|
||||
size_t n_msg,
|
||||
bool add_ass,
|
||||
char * buf,
|
||||
int32_t length,
|
||||
bool use_jinja,
|
||||
const char * tools,
|
||||
const char * bos_token,
|
||||
const char * eos_token);
|
||||
int32_t length);
|
||||
|
||||
//
|
||||
// Sampling API
|
||||
|
110
src/llama.cpp
110
src/llama.cpp
@ -2,8 +2,6 @@
|
||||
#include "llama-vocab.h"
|
||||
#include "llama-sampling.h"
|
||||
|
||||
#include "minja.hpp"
|
||||
|
||||
#include "unicode.h"
|
||||
|
||||
#include "ggml.h"
|
||||
@ -21004,95 +21002,7 @@ int32_t llama_detokenize(
|
||||
static int32_t llama_chat_apply_template_internal(
|
||||
const std::string & tmpl,
|
||||
const std::vector<const llama_chat_message *> & chat,
|
||||
std::string & dest, bool add_ass,
|
||||
bool use_jinja,
|
||||
const std::string & tools,
|
||||
const std::string & bos_token, const std::string & eos_token) {
|
||||
|
||||
if (use_jinja) {
|
||||
auto system_not_supported = tmpl.find("System role not supported") != std::string::npos;
|
||||
|
||||
// Meta-Llama-3.1-8B-Instruct's template expects arguments to be an object.
|
||||
// Most other templates (and OpenAI's API) expect the arguments object to be stringified.
|
||||
auto tool_call_args_must_be_objects = tmpl.find("tool_call.arguments | items") != std::string::npos;
|
||||
|
||||
auto messages = json::array();
|
||||
|
||||
std::string pending_system;
|
||||
auto flush_sys = [&]() {
|
||||
if (!pending_system.empty()) {
|
||||
messages.push_back({
|
||||
{"role", "user"},
|
||||
{"content", pending_system},
|
||||
});
|
||||
pending_system.clear();
|
||||
}
|
||||
};
|
||||
for (const auto * msg : chat) {
|
||||
std::string role(msg->role);
|
||||
std::string content(msg->content);
|
||||
if (system_not_supported) {
|
||||
if (role == "system") {
|
||||
if (!pending_system.empty()) pending_system += "\n";
|
||||
pending_system += content;
|
||||
continue;
|
||||
} else {
|
||||
if (role == "user") {
|
||||
if (!pending_system.empty()) {
|
||||
content = pending_system + (content.empty() ? "" : "\n" + content);
|
||||
pending_system.clear();
|
||||
}
|
||||
} else {
|
||||
flush_sys();
|
||||
}
|
||||
}
|
||||
}
|
||||
auto message = json({
|
||||
{"role", role},
|
||||
{"content", content},
|
||||
});
|
||||
if (msg->tool) message["tool"] = msg->tool;
|
||||
if (msg->n_tool_calls) {
|
||||
auto tool_calls = json::array();
|
||||
for (uint32_t i = 0; i < msg->n_tool_calls; i++) {
|
||||
auto args = msg->tool_calls[i].arguments;
|
||||
tool_calls.push_back(json({
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", msg->tool_calls[i].name},
|
||||
{"arguments", tool_call_args_must_be_objects ? json::parse(args) : args},
|
||||
}}
|
||||
}));
|
||||
}
|
||||
messages["tool_calls"] = tool_calls;
|
||||
}
|
||||
messages.push_back(message);
|
||||
}
|
||||
flush_sys();
|
||||
|
||||
auto context = minja::Context::make(json({
|
||||
{"messages", messages},
|
||||
{"add_generation_prompt", add_ass},
|
||||
{"bos_token", bos_token},
|
||||
{"eos_token", eos_token},
|
||||
}));
|
||||
if (!tools.empty()) {
|
||||
auto tools_val = minja::Value(json::parse(tools));
|
||||
context->set("tools", tools_val);
|
||||
}
|
||||
auto tmpl_root = minja::Parser::parse(tmpl, {
|
||||
/* .trim_blocks = */ true,
|
||||
/* .lstrip_blocks = */ true,
|
||||
/* .keep_trailing_newline = */ false,
|
||||
});
|
||||
try {
|
||||
dest = tmpl_root->render(context);
|
||||
return dest.size();
|
||||
} catch (const std::runtime_error & err) {
|
||||
LLAMA_LOG_ERROR("Error in jinja template: %s\n", err.what());
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
std::string & dest, bool add_ass) {
|
||||
|
||||
// Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527
|
||||
std::stringstream ss;
|
||||
@ -21360,11 +21270,7 @@ int32_t llama_chat_apply_template(
|
||||
size_t n_msg,
|
||||
bool add_ass,
|
||||
char * buf,
|
||||
int32_t length,
|
||||
bool use_jinja,
|
||||
const char * tools,
|
||||
const char * bos_token,
|
||||
const char * eos_token) {
|
||||
int32_t length) {
|
||||
std::string curr_tmpl(tmpl == nullptr ? "" : tmpl);
|
||||
if (tmpl == nullptr) {
|
||||
GGML_ASSERT(model != nullptr);
|
||||
@ -21379,16 +21285,6 @@ int32_t llama_chat_apply_template(
|
||||
curr_tmpl = std::string(model_template.data(), model_template.size());
|
||||
}
|
||||
}
|
||||
std::string curr_bos_token(bos_token ? bos_token : "");
|
||||
std::string curr_eos_token(eos_token ? eos_token : "");
|
||||
if (bos_token == nullptr) {
|
||||
GGML_ASSERT(model != nullptr);
|
||||
curr_bos_token = llama_token_to_piece(model, llama_token_bos(model), true);
|
||||
}
|
||||
if (eos_token == nullptr) {
|
||||
GGML_ASSERT(model != nullptr);
|
||||
curr_eos_token = llama_token_to_piece(model, llama_token_eos(model), true);
|
||||
}
|
||||
|
||||
// format the chat to string
|
||||
std::vector<const llama_chat_message *> chat_vec;
|
||||
@ -21398,7 +21294,7 @@ int32_t llama_chat_apply_template(
|
||||
}
|
||||
|
||||
std::string formatted_chat;
|
||||
int32_t res = llama_chat_apply_template_internal(curr_tmpl, chat_vec, formatted_chat, add_ass, use_jinja, tools == nullptr ? "" : tools, curr_bos_token, curr_eos_token);
|
||||
int32_t res = llama_chat_apply_template_internal(curr_tmpl, chat_vec, formatted_chat, add_ass);
|
||||
if (res < 0) {
|
||||
return res;
|
||||
}
|
||||
|
@ -20,9 +20,9 @@ static void assert_equals(const std::string & expected, const std::string & actu
|
||||
cmake -B build -DLLAMA_CURL=1 -DCMAKE_BUILD_TYPE=Release && cmake --build build -t test-tool-call -j && ./build/bin/test-tool-call
|
||||
*/
|
||||
|
||||
static void test_parse_tool_call(const json & tools, const std::string & chat_template, const std::string & input, const std::string & expected_content, const json & expected_tool_calls) {
|
||||
static void test_parse_tool_call(llama_tool_call_style style, const json & tools, const std::string & input, const std::string & expected_content, const json & expected_tool_calls) {
|
||||
std::cout << "# Testing: " << input << std::endl << std::flush;
|
||||
auto result = parse_tool_calls(tools, chat_template, input);
|
||||
auto result = parse_tool_calls(style, tools, input);
|
||||
assert_equals(expected_content, result.content);
|
||||
auto tool_calls = json::array();
|
||||
for (const auto & tc : result.tool_calls) {
|
||||
@ -59,8 +59,7 @@ int main() {
|
||||
{"tools", tools}
|
||||
};
|
||||
|
||||
std::string hermes_2_pro_like_tmpl = "Hermes 2 Pro template should have <tool_call> inside it";
|
||||
test_parse_tool_call(tools, hermes_2_pro_like_tmpl,
|
||||
test_parse_tool_call(llama_tool_call_style::Hermes2Pro, tools,
|
||||
"<tool_call>{\"name\": \"foo\", \"arguments\": {\"bar\": 1}}</tool_call>",
|
||||
"",
|
||||
json {{
|
||||
@ -72,8 +71,7 @@ int main() {
|
||||
}}
|
||||
}});
|
||||
|
||||
std::string functionary_v3_like_tmpl = "Functionary 3.2 template should have <|start_header_id|> and then some >>>all inside it";
|
||||
test_parse_tool_call(tools, functionary_v3_like_tmpl,
|
||||
test_parse_tool_call(llama_tool_call_style::FunctionaryV3Llama3, tools,
|
||||
">>>ipython\n{\"code\": \"print('Hello, world!')\"}",
|
||||
"",
|
||||
json {{
|
||||
@ -84,7 +82,7 @@ int main() {
|
||||
}).dump()}
|
||||
}}
|
||||
}});
|
||||
test_parse_tool_call(tools, functionary_v3_like_tmpl,
|
||||
test_parse_tool_call(llama_tool_call_style::FunctionaryV3Llama3, tools,
|
||||
">>>test\n{ } \n ",
|
||||
"",
|
||||
json {{
|
||||
@ -94,8 +92,7 @@ int main() {
|
||||
}}
|
||||
}});
|
||||
|
||||
std::string functionary_v3_llama_3_1_like_tmpl = "Functionary 3.2 template for llama 3.1 should have <|start_header_id|> and then some <function=foo>{...}</function> inside it";
|
||||
test_parse_tool_call(tools, functionary_v3_llama_3_1_like_tmpl,
|
||||
test_parse_tool_call(llama_tool_call_style::FunctionaryV3Llama31, tools,
|
||||
"Hell<function=foo>{\"arg1\": 1}</function>o, world<function=bar>{\"arg2\": 2}</function>!",
|
||||
"Hello, world!",
|
||||
json {
|
||||
@ -116,7 +113,7 @@ int main() {
|
||||
}}
|
||||
},
|
||||
});
|
||||
test_parse_tool_call(tools, functionary_v3_llama_3_1_like_tmpl,
|
||||
test_parse_tool_call(llama_tool_call_style::FunctionaryV3Llama31, tools,
|
||||
"<function=test>{ } </function> ",
|
||||
" ",
|
||||
json {{
|
||||
@ -126,8 +123,7 @@ int main() {
|
||||
}}
|
||||
}});
|
||||
|
||||
std::string llama_3_1_like_tmpl = "Llama 3.1 template should have <|start_header_id|> and <|python_tag|> inside it";
|
||||
test_parse_tool_call(tools, llama_3_1_like_tmpl,
|
||||
test_parse_tool_call(llama_tool_call_style::Llama31, tools,
|
||||
"<|python_tag|>this could be anything",
|
||||
"",
|
||||
json {{
|
||||
@ -138,7 +134,7 @@ int main() {
|
||||
}).dump()}
|
||||
}}
|
||||
}});
|
||||
test_parse_tool_call(tools, llama_3_1_like_tmpl,
|
||||
test_parse_tool_call(llama_tool_call_style::Llama31, tools,
|
||||
"I'm thinking<|python_tag|>",
|
||||
"I'm thinking",
|
||||
json {{
|
||||
@ -147,7 +143,7 @@ int main() {
|
||||
{"arguments", (json {{"code", ""}}).dump()}
|
||||
}}
|
||||
}});
|
||||
test_parse_tool_call(tools, llama_3_1_like_tmpl,
|
||||
test_parse_tool_call(llama_tool_call_style::Llama31, tools,
|
||||
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}",
|
||||
"",
|
||||
json {{
|
||||
@ -158,7 +154,7 @@ int main() {
|
||||
}).dump()}
|
||||
}}
|
||||
}});
|
||||
test_parse_tool_call(tools, llama_3_1_like_tmpl,
|
||||
test_parse_tool_call(llama_tool_call_style::Llama31, tools,
|
||||
"{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}",
|
||||
"{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", json::array());
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user