mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-13 12:10:18 +00:00
tool-call
: refactor llama_chat_template class + use in validate_model_chat_template
This commit is contained in:
parent
cf7bece6a7
commit
9cfe4d7202
@ -1,5 +1,4 @@
|
|||||||
#include "chat-template.h"
|
#include "chat-template.h"
|
||||||
#include "minja.hpp"
|
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
|
|
||||||
using json = nlohmann::ordered_json;
|
using json = nlohmann::ordered_json;
|
||||||
@ -31,14 +30,39 @@ static std::string llama_model_meta_val_str(const struct llama_model * model, co
|
|||||||
return "";
|
return "";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
llama_chat_template::llama_chat_template(const std::string & chat_template, const std::string & bos_token, const std::string & eos_token)
|
||||||
|
: _chat_template(chat_template), _bos_token(bos_token), _eos_token(eos_token) {
|
||||||
|
|
||||||
|
_supports_tools = chat_template.find("tools") != std::string::npos;
|
||||||
|
_requires_object_arguments = chat_template.find("tool_call.arguments | items") != std::string::npos;
|
||||||
|
_supports_system_role = chat_template.find("System role not supported") == std::string::npos;
|
||||||
|
|
||||||
|
if (chat_template.find("<tool_call>") != std::string::npos) {
|
||||||
|
_tool_call_style = Hermes2Pro;
|
||||||
|
} else if (chat_template.find(">>>all") != std::string::npos) {
|
||||||
|
_tool_call_style = FunctionaryV3Llama3;
|
||||||
|
} else if (chat_template.find("<|start_header_id|>") != std::string::npos) {
|
||||||
|
if (chat_template.find("<function=") != std::string::npos) {
|
||||||
|
_tool_call_style = FunctionaryV3Llama31;
|
||||||
|
} else if (chat_template.find("<|python_tag|>") != std::string::npos) {
|
||||||
|
_tool_call_style = Llama31;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_template_root = minja::Parser::parse(_chat_template, {
|
||||||
|
/* .trim_blocks = */ true,
|
||||||
|
/* .lstrip_blocks = */ true,
|
||||||
|
/* .keep_trailing_newline = */ false,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
llama_chat_template llama_chat_template::from_model(
|
llama_chat_template llama_chat_template::from_model(
|
||||||
const struct llama_model * model,
|
const struct llama_model * model,
|
||||||
const std::string & chat_template_override)
|
const char * chat_template_override)
|
||||||
{
|
{
|
||||||
// TODO: handle "chatml"?
|
// TODO: handle "chatml"?
|
||||||
auto chat_template = chat_template_override.empty()
|
std::string chat_template = chat_template_override
|
||||||
? llama_model_meta_val_str(model, "tokenizer.chat_template")
|
? chat_template_override
|
||||||
: chat_template_override;
|
: llama_model_meta_val_str(model, "tokenizer.chat_template");
|
||||||
auto bos_token = _llama_token_to_piece(model, llama_token_bos(model), true);
|
auto bos_token = _llama_token_to_piece(model, llama_token_bos(model), true);
|
||||||
auto eos_token = _llama_token_to_piece(model, llama_token_eos(model), true);
|
auto eos_token = _llama_token_to_piece(model, llama_token_eos(model), true);
|
||||||
return llama_chat_template(chat_template, bos_token, eos_token);
|
return llama_chat_template(chat_template, bos_token, eos_token);
|
||||||
@ -69,9 +93,9 @@ std::string llama_chat_template::apply(
|
|||||||
throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump());
|
throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump());
|
||||||
}
|
}
|
||||||
std::string role = message.at("role");
|
std::string role = message.at("role");
|
||||||
std::string content = message.at("content");
|
|
||||||
|
|
||||||
if (!_supports_system_role) {
|
if (!message["content"].is_null() && !_supports_system_role) {
|
||||||
|
std::string content = message.at("content");
|
||||||
if (role == "system") {
|
if (role == "system") {
|
||||||
if (!pending_system.empty()) pending_system += "\n";
|
if (!pending_system.empty()) pending_system += "\n";
|
||||||
pending_system += content;
|
pending_system += content;
|
||||||
@ -89,8 +113,11 @@ std::string llama_chat_template::apply(
|
|||||||
}
|
}
|
||||||
if (_requires_object_arguments && message.contains("tool_calls")) {
|
if (_requires_object_arguments && message.contains("tool_calls")) {
|
||||||
for (auto & tool_call : message.at("tool_calls")) {
|
for (auto & tool_call : message.at("tool_calls")) {
|
||||||
std::string arguments = tool_call.at("arguments");
|
if (tool_call["type"] == "function") {
|
||||||
tool_call["arguments"] = json::parse(arguments);
|
auto & function = tool_call.at("function");
|
||||||
|
std::string arguments = function.at("arguments");
|
||||||
|
function["arguments"] = json::parse(arguments);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -99,20 +126,11 @@ std::string llama_chat_template::apply(
|
|||||||
|
|
||||||
auto context = minja::Context::make(json({
|
auto context = minja::Context::make(json({
|
||||||
{"messages", actual_messages},
|
{"messages", actual_messages},
|
||||||
|
{"tools", tools},
|
||||||
{"add_generation_prompt", add_generation_prompt},
|
{"add_generation_prompt", add_generation_prompt},
|
||||||
{"bos_token", _bos_token},
|
{"bos_token", _bos_token},
|
||||||
{"eos_token", _eos_token},
|
{"eos_token", _eos_token},
|
||||||
}));
|
}));
|
||||||
|
|
||||||
if (!tools.is_null() && !tools.empty()) {
|
return _template_root->render(context);
|
||||||
auto tools_val = minja::Value(tools);
|
|
||||||
context->set("tools", tools_val);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto tmpl_root = minja::Parser::parse(_chat_template, {
|
|
||||||
/* .trim_blocks = */ true,
|
|
||||||
/* .lstrip_blocks = */ true,
|
|
||||||
/* .keep_trailing_newline = */ false,
|
|
||||||
});
|
|
||||||
return tmpl_root->render(context);
|
|
||||||
}
|
}
|
||||||
|
@ -1,11 +1,13 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include "minja.hpp"
|
||||||
#include <json.hpp>
|
#include <json.hpp>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
using json = nlohmann::ordered_json;
|
using json = nlohmann::ordered_json;
|
||||||
|
|
||||||
|
|
||||||
enum llama_tool_call_style {
|
enum llama_tool_call_style {
|
||||||
Unknown,
|
Unknown,
|
||||||
Llama31,
|
Llama31,
|
||||||
@ -27,30 +29,14 @@ class llama_chat_template {
|
|||||||
std::string _chat_template;
|
std::string _chat_template;
|
||||||
std::string _bos_token;
|
std::string _bos_token;
|
||||||
std::string _eos_token;
|
std::string _eos_token;
|
||||||
|
std::unique_ptr<minja::TemplateNode> _template_root;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
llama_chat_template(const std::string & chat_template, const std::string & bos_token, const std::string & eos_token)
|
llama_chat_template(const std::string & chat_template, const std::string & bos_token, const std::string & eos_token);
|
||||||
: _chat_template(chat_template), _bos_token(bos_token), _eos_token(eos_token) {
|
|
||||||
|
|
||||||
_supports_tools = chat_template.find("tools") != std::string::npos;
|
|
||||||
_requires_object_arguments = chat_template.find("tool_call.arguments | items") != std::string::npos;
|
|
||||||
_supports_system_role = chat_template.find("System role not supported") == std::string::npos;
|
|
||||||
|
|
||||||
if (chat_template.find("<tool_call>") != std::string::npos) {
|
|
||||||
_tool_call_style = Hermes2Pro;
|
|
||||||
} else if (chat_template.find(">>>all") != std::string::npos) {
|
|
||||||
_tool_call_style = FunctionaryV3Llama3;
|
|
||||||
} else if (chat_template.find("<|start_header_id|>") != std::string::npos) {
|
|
||||||
if (chat_template.find("<function=") != std::string::npos) {
|
|
||||||
_tool_call_style = FunctionaryV3Llama31;
|
|
||||||
} else if (chat_template.find("<|python_tag|>") != std::string::npos) {
|
|
||||||
_tool_call_style = Llama31;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static llama_chat_template from_model(
|
static llama_chat_template from_model(
|
||||||
const struct llama_model * model,
|
const struct llama_model * model,
|
||||||
const std::string & chat_template_override);
|
const char * chat_template_override = nullptr);
|
||||||
|
|
||||||
llama_tool_call_style tool_call_style() const { return _tool_call_style; }
|
llama_tool_call_style tool_call_style() const { return _tool_call_style; }
|
||||||
|
|
||||||
|
@ -662,9 +662,23 @@ struct server_context {
|
|||||||
bool validate_model_chat_template(bool use_jinja) const {
|
bool validate_model_chat_template(bool use_jinja) const {
|
||||||
llama_chat_message chat[] = {{"user", "test"}};
|
llama_chat_message chat[] = {{"user", "test"}};
|
||||||
|
|
||||||
const int res = llama_chat_apply_template(model, nullptr, chat, 1, true, nullptr, 0);
|
if (use_jinja) {
|
||||||
|
auto chat_template = llama_chat_template::from_model(model);
|
||||||
|
try {
|
||||||
|
chat_template.apply({{
|
||||||
|
{"role", "user"},
|
||||||
|
{"content", "test"},
|
||||||
|
}}, json(), true);
|
||||||
|
return true;
|
||||||
|
} catch (const std::exception & e) {
|
||||||
|
SRV_ERR("failed to apply template: %s\n", e.what());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const int res = llama_chat_apply_template(model, nullptr, chat, 1, true, nullptr, 0);
|
||||||
|
|
||||||
return res > 0;
|
return res > 0;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void init() {
|
void init() {
|
||||||
@ -2860,7 +2874,7 @@ int main(int argc, char ** argv) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto chat_template = llama_chat_template::from_model(ctx_server.model, params.chat_template);
|
auto chat_template = llama_chat_template::from_model(ctx_server.model, params.chat_template.empty() ? nullptr : params.chat_template.c_str());
|
||||||
|
|
||||||
json data;
|
json data;
|
||||||
try {
|
try {
|
||||||
|
Loading…
Reference in New Issue
Block a user