mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-13 04:00:16 +00:00
tool-call
: factor chat template away from legacy API
This commit is contained in:
parent
d7ec84f78c
commit
cf7bece6a7
4
Makefile
4
Makefile
@ -934,6 +934,7 @@ OBJ_LLAMA = \
|
|||||||
|
|
||||||
OBJ_COMMON = \
|
OBJ_COMMON = \
|
||||||
common/common.o \
|
common/common.o \
|
||||||
|
common/chat-template.o \
|
||||||
common/arg.o \
|
common/arg.o \
|
||||||
common/log.o \
|
common/log.o \
|
||||||
common/console.o \
|
common/console.o \
|
||||||
@ -1170,6 +1171,8 @@ $(LIB_LLAMA_S): \
|
|||||||
common/common.o: \
|
common/common.o: \
|
||||||
common/common.cpp \
|
common/common.cpp \
|
||||||
common/common.h \
|
common/common.h \
|
||||||
|
common/chat-template.cpp \
|
||||||
|
common/chat-template.h \
|
||||||
common/console.h \
|
common/console.h \
|
||||||
common/sampling.h \
|
common/sampling.h \
|
||||||
common/json.hpp \
|
common/json.hpp \
|
||||||
@ -1465,6 +1468,7 @@ llama-server: \
|
|||||||
examples/server/prompt-formats.js.hpp \
|
examples/server/prompt-formats.js.hpp \
|
||||||
examples/server/json-schema-to-grammar.mjs.hpp \
|
examples/server/json-schema-to-grammar.mjs.hpp \
|
||||||
examples/server/loading.html.hpp \
|
examples/server/loading.html.hpp \
|
||||||
|
common/chat-template.h \
|
||||||
common/json.hpp \
|
common/json.hpp \
|
||||||
common/stb_image.h \
|
common/stb_image.h \
|
||||||
$(OBJ_ALL)
|
$(OBJ_ALL)
|
||||||
|
@ -54,6 +54,8 @@ add_library(${TARGET} STATIC
|
|||||||
arg.cpp
|
arg.cpp
|
||||||
arg.h
|
arg.h
|
||||||
base64.hpp
|
base64.hpp
|
||||||
|
chat-template.cpp
|
||||||
|
chat-template.h
|
||||||
common.cpp
|
common.cpp
|
||||||
common.h
|
common.h
|
||||||
console.cpp
|
console.cpp
|
||||||
|
118
common/chat-template.cpp
Normal file
118
common/chat-template.cpp
Normal file
@ -0,0 +1,118 @@
|
|||||||
|
#include "chat-template.h"
|
||||||
|
#include "minja.hpp"
|
||||||
|
#include "llama.h"
|
||||||
|
|
||||||
|
using json = nlohmann::ordered_json;
|
||||||
|
|
||||||
|
static std::string _llama_token_to_piece(const struct llama_model * model, llama_token token, bool special) {
|
||||||
|
std::string piece;
|
||||||
|
piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n'
|
||||||
|
const int n_chars = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special);
|
||||||
|
if (n_chars < 0) {
|
||||||
|
piece.resize(-n_chars);
|
||||||
|
int check = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special);
|
||||||
|
GGML_ASSERT(check == -n_chars);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
piece.resize(n_chars);
|
||||||
|
}
|
||||||
|
|
||||||
|
return piece;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::string llama_model_meta_val_str(const struct llama_model * model, const char * key) {
|
||||||
|
int32_t tlen = llama_model_meta_val_str(model, key, nullptr, 0);
|
||||||
|
if (tlen > 0) {
|
||||||
|
std::vector<char> curr_tmpl_buf(tlen + 1, 0);
|
||||||
|
if (llama_model_meta_val_str(model, key, curr_tmpl_buf.data(), curr_tmpl_buf.size()) == tlen) {
|
||||||
|
return std::string(curr_tmpl_buf.data(), tlen);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_chat_template llama_chat_template::from_model(
|
||||||
|
const struct llama_model * model,
|
||||||
|
const std::string & chat_template_override)
|
||||||
|
{
|
||||||
|
// TODO: handle "chatml"?
|
||||||
|
auto chat_template = chat_template_override.empty()
|
||||||
|
? llama_model_meta_val_str(model, "tokenizer.chat_template")
|
||||||
|
: chat_template_override;
|
||||||
|
auto bos_token = _llama_token_to_piece(model, llama_token_bos(model), true);
|
||||||
|
auto eos_token = _llama_token_to_piece(model, llama_token_eos(model), true);
|
||||||
|
return llama_chat_template(chat_template, bos_token, eos_token);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string llama_chat_template::apply(
|
||||||
|
const json & messages,
|
||||||
|
const json & tools,
|
||||||
|
bool add_generation_prompt) const
|
||||||
|
{
|
||||||
|
auto actual_messages = messages;
|
||||||
|
|
||||||
|
// First, "fix" messages so they have a chance to be rendered correctly by the template
|
||||||
|
|
||||||
|
if (_requires_object_arguments || !_supports_system_role) {
|
||||||
|
std::string pending_system;
|
||||||
|
auto flush_sys = [&]() {
|
||||||
|
if (!pending_system.empty()) {
|
||||||
|
actual_messages.push_back({
|
||||||
|
{"role", "user"},
|
||||||
|
{"content", pending_system},
|
||||||
|
});
|
||||||
|
pending_system.clear();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
for (auto & message : actual_messages) {
|
||||||
|
if (!message.contains("role") || !message.contains("content")) {
|
||||||
|
throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump());
|
||||||
|
}
|
||||||
|
std::string role = message.at("role");
|
||||||
|
std::string content = message.at("content");
|
||||||
|
|
||||||
|
if (!_supports_system_role) {
|
||||||
|
if (role == "system") {
|
||||||
|
if (!pending_system.empty()) pending_system += "\n";
|
||||||
|
pending_system += content;
|
||||||
|
continue;
|
||||||
|
} else {
|
||||||
|
if (role == "user") {
|
||||||
|
if (!pending_system.empty()) {
|
||||||
|
message["content"] = pending_system + (content.empty() ? "" : "\n" + content);
|
||||||
|
pending_system.clear();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
flush_sys();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (_requires_object_arguments && message.contains("tool_calls")) {
|
||||||
|
for (auto & tool_call : message.at("tool_calls")) {
|
||||||
|
std::string arguments = tool_call.at("arguments");
|
||||||
|
tool_call["arguments"] = json::parse(arguments);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
flush_sys();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto context = minja::Context::make(json({
|
||||||
|
{"messages", actual_messages},
|
||||||
|
{"add_generation_prompt", add_generation_prompt},
|
||||||
|
{"bos_token", _bos_token},
|
||||||
|
{"eos_token", _eos_token},
|
||||||
|
}));
|
||||||
|
|
||||||
|
if (!tools.is_null() && !tools.empty()) {
|
||||||
|
auto tools_val = minja::Value(tools);
|
||||||
|
context->set("tools", tools_val);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto tmpl_root = minja::Parser::parse(_chat_template, {
|
||||||
|
/* .trim_blocks = */ true,
|
||||||
|
/* .lstrip_blocks = */ true,
|
||||||
|
/* .keep_trailing_newline = */ false,
|
||||||
|
});
|
||||||
|
return tmpl_root->render(context);
|
||||||
|
}
|
64
common/chat-template.h
Normal file
64
common/chat-template.h
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <json.hpp>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
using json = nlohmann::ordered_json;
|
||||||
|
|
||||||
|
enum llama_tool_call_style {
|
||||||
|
Unknown,
|
||||||
|
Llama31,
|
||||||
|
FunctionaryV3Llama3,
|
||||||
|
FunctionaryV3Llama31,
|
||||||
|
Hermes2Pro,
|
||||||
|
};
|
||||||
|
|
||||||
|
class llama_chat_template {
|
||||||
|
public:
|
||||||
|
|
||||||
|
private:
|
||||||
|
llama_tool_call_style _tool_call_style = Unknown;
|
||||||
|
bool _supports_tools = true;
|
||||||
|
// Meta-Llama-3.1-8B-Instruct's template expects arguments to be an object.
|
||||||
|
// Most other templates (and OpenAI's API) expect the arguments object to be stringified.
|
||||||
|
bool _requires_object_arguments = false;
|
||||||
|
bool _supports_system_role = true;
|
||||||
|
std::string _chat_template;
|
||||||
|
std::string _bos_token;
|
||||||
|
std::string _eos_token;
|
||||||
|
public:
|
||||||
|
llama_chat_template(const std::string & chat_template, const std::string & bos_token, const std::string & eos_token)
|
||||||
|
: _chat_template(chat_template), _bos_token(bos_token), _eos_token(eos_token) {
|
||||||
|
|
||||||
|
_supports_tools = chat_template.find("tools") != std::string::npos;
|
||||||
|
_requires_object_arguments = chat_template.find("tool_call.arguments | items") != std::string::npos;
|
||||||
|
_supports_system_role = chat_template.find("System role not supported") == std::string::npos;
|
||||||
|
|
||||||
|
if (chat_template.find("<tool_call>") != std::string::npos) {
|
||||||
|
_tool_call_style = Hermes2Pro;
|
||||||
|
} else if (chat_template.find(">>>all") != std::string::npos) {
|
||||||
|
_tool_call_style = FunctionaryV3Llama3;
|
||||||
|
} else if (chat_template.find("<|start_header_id|>") != std::string::npos) {
|
||||||
|
if (chat_template.find("<function=") != std::string::npos) {
|
||||||
|
_tool_call_style = FunctionaryV3Llama31;
|
||||||
|
} else if (chat_template.find("<|python_tag|>") != std::string::npos) {
|
||||||
|
_tool_call_style = Llama31;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static llama_chat_template from_model(
|
||||||
|
const struct llama_model * model,
|
||||||
|
const std::string & chat_template_override);
|
||||||
|
|
||||||
|
llama_tool_call_style tool_call_style() const { return _tool_call_style; }
|
||||||
|
|
||||||
|
const std::string & chat_template() const { return _chat_template; }
|
||||||
|
bool supports_tools() const { return _supports_tools; }
|
||||||
|
|
||||||
|
std::string apply(
|
||||||
|
const nlohmann::ordered_json & messages,
|
||||||
|
const nlohmann::ordered_json & tools,
|
||||||
|
bool add_generation_prompt) const;
|
||||||
|
};
|
@ -9,6 +9,7 @@
|
|||||||
#include "json.hpp"
|
#include "json.hpp"
|
||||||
#include "json-schema-to-grammar.h"
|
#include "json-schema-to-grammar.h"
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
|
#include "chat-template.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cinttypes>
|
#include <cinttypes>
|
||||||
@ -1511,6 +1512,20 @@ std::string llama_detokenize(llama_context * ctx, const std::vector<llama_token>
|
|||||||
//
|
//
|
||||||
|
|
||||||
bool llama_chat_verify_template(const std::string & tmpl, bool use_jinja) {
|
bool llama_chat_verify_template(const std::string & tmpl, bool use_jinja) {
|
||||||
|
if (use_jinja) {
|
||||||
|
try {
|
||||||
|
auto chat_template = llama_chat_template(tmpl, "<s>", "</s>");
|
||||||
|
chat_template.apply({{
|
||||||
|
{"role", "user"},
|
||||||
|
{"content", "test"},
|
||||||
|
}}, json(), true);
|
||||||
|
return true;
|
||||||
|
} catch (const std::exception & e) {
|
||||||
|
LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
llama_chat_message chat[] = {{"user", "test"}};
|
llama_chat_message chat[] = {{"user", "test"}};
|
||||||
int res = llama_chat_apply_template(
|
int res = llama_chat_apply_template(
|
||||||
nullptr,
|
nullptr,
|
||||||
@ -1519,22 +1534,14 @@ bool llama_chat_verify_template(const std::string & tmpl, bool use_jinja) {
|
|||||||
1,
|
1,
|
||||||
/* add_ass= */ true,
|
/* add_ass= */ true,
|
||||||
/* buffer= */ nullptr,
|
/* buffer= */ nullptr,
|
||||||
/* length= */ 0,
|
/* length= */ 0);
|
||||||
use_jinja,
|
|
||||||
/* tools= */ nullptr,
|
|
||||||
"<s>",
|
|
||||||
"</s>");
|
|
||||||
return res >= 0;
|
return res >= 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string llama_chat_apply_template(const struct llama_model * model,
|
std::string llama_chat_apply_template(const struct llama_model * model,
|
||||||
const std::string & tmpl,
|
const std::string & tmpl,
|
||||||
const std::vector<llama_chat_msg> & msgs,
|
const std::vector<llama_chat_msg> & msgs,
|
||||||
bool add_ass,
|
bool add_ass) {
|
||||||
bool use_jinja,
|
|
||||||
const char * tools,
|
|
||||||
const char * bos_token,
|
|
||||||
const char * eos_token) {
|
|
||||||
int alloc_size = 0;
|
int alloc_size = 0;
|
||||||
bool fallback = false; // indicate if we must fallback to default chatml
|
bool fallback = false; // indicate if we must fallback to default chatml
|
||||||
std::vector<llama_chat_message> chat;
|
std::vector<llama_chat_message> chat;
|
||||||
@ -1547,7 +1554,7 @@ std::string llama_chat_apply_template(const struct llama_model * model,
|
|||||||
std::vector<char> buf(alloc_size);
|
std::vector<char> buf(alloc_size);
|
||||||
|
|
||||||
// run the first time to get the total output length
|
// run the first time to get the total output length
|
||||||
int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size(), use_jinja, tools, bos_token, eos_token);
|
int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
||||||
|
|
||||||
// error: chat template is not supported
|
// error: chat template is not supported
|
||||||
if (res < 0) {
|
if (res < 0) {
|
||||||
@ -1557,7 +1564,7 @@ std::string llama_chat_apply_template(const struct llama_model * model,
|
|||||||
throw std::runtime_error("this custom template is not supported");
|
throw std::runtime_error("this custom template is not supported");
|
||||||
} else {
|
} else {
|
||||||
// If the built-in template is not supported, we default to chatml
|
// If the built-in template is not supported, we default to chatml
|
||||||
res = llama_chat_apply_template(nullptr, "chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size(), use_jinja, tools, bos_token, eos_token);
|
res = llama_chat_apply_template(nullptr, "chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
||||||
fallback = true;
|
fallback = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1568,7 +1575,7 @@ std::string llama_chat_apply_template(const struct llama_model * model,
|
|||||||
res = llama_chat_apply_template(
|
res = llama_chat_apply_template(
|
||||||
fallback ? nullptr : model,
|
fallback ? nullptr : model,
|
||||||
fallback ? "chatml" : ptr_tmpl,
|
fallback ? "chatml" : ptr_tmpl,
|
||||||
chat.data(), chat.size(), add_ass, buf.data(), buf.size(), use_jinja, tools, bos_token, eos_token);
|
chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string formatted_chat(buf.data(), res);
|
std::string formatted_chat(buf.data(), res);
|
||||||
@ -1579,13 +1586,9 @@ std::string llama_chat_format_single(const struct llama_model * model,
|
|||||||
const std::string & tmpl,
|
const std::string & tmpl,
|
||||||
const std::vector<llama_chat_msg> & past_msg,
|
const std::vector<llama_chat_msg> & past_msg,
|
||||||
const llama_chat_msg & new_msg,
|
const llama_chat_msg & new_msg,
|
||||||
bool add_ass,
|
bool add_ass) {
|
||||||
bool use_jinja,
|
|
||||||
const char * tools,
|
|
||||||
const char * bos_token,
|
|
||||||
const char * eos_token) {
|
|
||||||
std::ostringstream ss;
|
std::ostringstream ss;
|
||||||
auto fmt_past_msg = past_msg.empty() ? "" : llama_chat_apply_template(model, tmpl, past_msg, false, use_jinja, tools, bos_token, eos_token);
|
auto fmt_past_msg = past_msg.empty() ? "" : llama_chat_apply_template(model, tmpl, past_msg, false);
|
||||||
std::vector<llama_chat_msg> chat_new(past_msg);
|
std::vector<llama_chat_msg> chat_new(past_msg);
|
||||||
// if the past_msg ends with a newline, we must preserve it in the formatted version
|
// if the past_msg ends with a newline, we must preserve it in the formatted version
|
||||||
if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
|
if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
|
||||||
@ -1593,7 +1596,7 @@ std::string llama_chat_format_single(const struct llama_model * model,
|
|||||||
};
|
};
|
||||||
// format chat with new_msg
|
// format chat with new_msg
|
||||||
chat_new.push_back(new_msg);
|
chat_new.push_back(new_msg);
|
||||||
auto fmt_new_msg = llama_chat_apply_template(model, tmpl, chat_new, add_ass, use_jinja, tools, bos_token, eos_token);
|
auto fmt_new_msg = llama_chat_apply_template(model, tmpl, chat_new, add_ass);
|
||||||
// get the diff part
|
// get the diff part
|
||||||
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
|
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
|
||||||
return ss.str();
|
return ss.str();
|
||||||
|
@ -471,21 +471,14 @@ std::string llama_detokenize(
|
|||||||
// Chat template utils
|
// Chat template utils
|
||||||
//
|
//
|
||||||
|
|
||||||
struct llama_chat_msg_tool_call {
|
|
||||||
std::string name;
|
|
||||||
std::string arguments;
|
|
||||||
};
|
|
||||||
|
|
||||||
// same as llama_chat_message, but uses std::string and std::vector
|
// same as llama_chat_message, but uses std::string and std::vector
|
||||||
struct llama_chat_msg {
|
struct llama_chat_msg {
|
||||||
std::string role;
|
std::string role;
|
||||||
std::string content;
|
std::string content;
|
||||||
std::string tool;
|
|
||||||
std::vector<struct llama_chat_msg_tool_call> tool_calls;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
|
// Check if the template is supported or not. Returns true if it's valid
|
||||||
bool llama_chat_verify_template(const std::string & tmpl, bool use_jinja = false);
|
bool llama_chat_verify_template(const std::string & tmpl, bool use_jinja);
|
||||||
|
|
||||||
// CPP wrapper for llama_chat_apply_template
|
// CPP wrapper for llama_chat_apply_template
|
||||||
// If the built-in template is not supported, we default to chatml
|
// If the built-in template is not supported, we default to chatml
|
||||||
@ -493,22 +486,14 @@ bool llama_chat_verify_template(const std::string & tmpl, bool use_jinja = false
|
|||||||
std::string llama_chat_apply_template(const struct llama_model * model,
|
std::string llama_chat_apply_template(const struct llama_model * model,
|
||||||
const std::string & tmpl,
|
const std::string & tmpl,
|
||||||
const std::vector<llama_chat_msg> & chat,
|
const std::vector<llama_chat_msg> & chat,
|
||||||
bool add_ass,
|
bool add_ass);
|
||||||
bool use_jinja = false,
|
|
||||||
const char * tools = nullptr,
|
|
||||||
const char * bos_token = nullptr,
|
|
||||||
const char * eos_token = nullptr);
|
|
||||||
|
|
||||||
// Format single message, while taking into account the position of that message in chat history
|
// Format single message, while taking into account the position of that message in chat history
|
||||||
std::string llama_chat_format_single(const struct llama_model * model,
|
std::string llama_chat_format_single(const struct llama_model * model,
|
||||||
const std::string & tmpl,
|
const std::string & tmpl,
|
||||||
const std::vector<llama_chat_msg> & past_msg,
|
const std::vector<llama_chat_msg> & past_msg,
|
||||||
const llama_chat_msg & new_msg,
|
const llama_chat_msg & new_msg,
|
||||||
bool add_ass,
|
bool add_ass);
|
||||||
bool use_jinja = false,
|
|
||||||
const char * tools = nullptr,
|
|
||||||
const char * bos_token = nullptr,
|
|
||||||
const char * eos_token = nullptr);
|
|
||||||
|
|
||||||
// Returns an example of formatted chat
|
// Returns an example of formatted chat
|
||||||
std::string llama_chat_format_example(const struct llama_model * model,
|
std::string llama_chat_format_example(const struct llama_model * model,
|
||||||
|
@ -12,27 +12,6 @@
|
|||||||
|
|
||||||
using json = nlohmann::ordered_json;
|
using json = nlohmann::ordered_json;
|
||||||
|
|
||||||
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3.llama3.txt
|
|
||||||
static bool needs_functionary_v3_tool_call(const std::string & chat_template) {
|
|
||||||
return chat_template.find("<|start_header_id|>") != std::string::npos
|
|
||||||
&& chat_template.find(">>>all") != std::string::npos;
|
|
||||||
}
|
|
||||||
|
|
||||||
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
|
|
||||||
static bool needs_functionary_v3_llama_3_1_tool_call(const std::string & chat_template) {
|
|
||||||
return chat_template.find("<|start_header_id|>") != std::string::npos
|
|
||||||
&& chat_template.find("<function=") != std::string::npos;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool needs_llama_3_1_tool_call(const std::string & chat_template) {
|
|
||||||
return chat_template.find("<|start_header_id|>") != std::string::npos
|
|
||||||
&& chat_template.find("<|python_tag|>") != std::string::npos;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool needs_hermes_pro_tool_call(const std::string & chat_template) {
|
|
||||||
return chat_template.find("<tool_call>") != std::string::npos;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool parse_json(std::string::const_iterator & it, const std::string::const_iterator & end, json & out) {
|
static bool parse_json(std::string::const_iterator & it, const std::string::const_iterator & end, json & out) {
|
||||||
// // https://json.nlohmann.me/features/parsing/sax_interface/
|
// // https://json.nlohmann.me/features/parsing/sax_interface/
|
||||||
struct json_error_locator : public nlohmann::json_sax<json> {
|
struct json_error_locator : public nlohmann::json_sax<json> {
|
||||||
@ -209,137 +188,145 @@ static llama_tool_calls parse_functionary_v3_tool_calls(const std::string& input
|
|||||||
return parse_functionary_tool_calls(input, function_regex, close_regex);
|
return parse_functionary_tool_calls(input, function_regex, close_regex);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_tool_calls parse_tool_calls(const json & tools, const std::string & chat_template, const std::string& input) {
|
llama_tool_calls parse_tool_calls(llama_tool_call_style style, const json & tools, const std::string& input) {
|
||||||
if (needs_hermes_pro_tool_call(chat_template)) {
|
switch (style) {
|
||||||
return parse_hermes_tool_calls(input);
|
case llama_tool_call_style::Llama31:
|
||||||
} else if (needs_functionary_v3_tool_call(chat_template)) {
|
return parse_llama_3_1_tool_calls(tools, input);
|
||||||
return parse_functionary_v3_tool_calls(input);
|
case llama_tool_call_style::FunctionaryV3Llama3:
|
||||||
} else if (needs_functionary_v3_llama_3_1_tool_call(chat_template)) {
|
return parse_functionary_v3_tool_calls(input);
|
||||||
return parse_functionary_v3_llama_3_1_tool_calls(input);
|
case llama_tool_call_style::FunctionaryV3Llama31:
|
||||||
} else if (needs_llama_3_1_tool_call(chat_template)) {
|
return parse_functionary_v3_llama_3_1_tool_calls(input);
|
||||||
return parse_llama_3_1_tool_calls(tools, input);
|
case llama_tool_call_style::Hermes2Pro:
|
||||||
} else {
|
return parse_hermes_tool_calls(input);
|
||||||
throw std::runtime_error("Unsupported chat template for tool calls");
|
default:
|
||||||
|
throw std::runtime_error("Unsupported tool call style");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_tool_call_handler llama_tool_call_handler_init(
|
llama_tool_call_handler llama_tool_call_handler_init(
|
||||||
const std::string & chat_template,
|
const llama_chat_template & tmpl,
|
||||||
bool allow_content,
|
bool allow_content,
|
||||||
bool parallel_tool_calls,
|
bool parallel_tool_calls,
|
||||||
const nlohmann::ordered_json & tools)
|
const nlohmann::ordered_json & tools)
|
||||||
{
|
{
|
||||||
llama_tool_call_handler handler;
|
llama_tool_call_handler handler;
|
||||||
|
|
||||||
if (needs_functionary_v3_tool_call(chat_template)) {
|
switch (tmpl.tool_call_style()) {
|
||||||
// MeetKaiFunctionary_3_2
|
case llama_tool_call_style::Llama31: {
|
||||||
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
|
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
|
||||||
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
|
static std::vector<std::string> builtin_tools {"wolfram_alpha", "brave_search"};
|
||||||
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
|
std::vector<std::string> tool_rules;
|
||||||
std::vector<std::string> tool_rules;
|
|
||||||
for (size_t i = 0, n = tools.size(); i < n; i++) {
|
for (const auto & tool : tools) {
|
||||||
auto & tool = tools[i];
|
const auto & function = tool["function"];
|
||||||
const auto & function = tool["function"];
|
std::string name = function["name"];
|
||||||
std::string name = function["name"];
|
auto parameters = function["parameters"];
|
||||||
auto parameters = function["parameters"];
|
builder.resolve_refs(parameters);
|
||||||
auto tool_rule = builder.add_rule(name + "-call", "\">>>" + name + "\\n\" " + builder.add_schema(name + "-args", parameters));
|
if (name == "ipython" || std::find(builtin_tools.begin(), builtin_tools.end(), name) != builtin_tools.end()) {
|
||||||
tool_rules.push_back(tool_rule);
|
tool_rules.push_back(builder.add_rule("ipython-call", "\"<|python_tag|>\" .*"));
|
||||||
|
if (allow_content) {
|
||||||
|
handler.grammar_trigger_words.push_back("<|python_tag|>");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
//"<|start_header_id|>assistant<|end_header_id|>\n\n{\"name\": \"" + name + "\", " +
|
||||||
|
tool_rules.push_back(
|
||||||
|
builder.add_rule(
|
||||||
|
name + "-call",
|
||||||
|
"\"\\n{\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " +
|
||||||
|
builder.add_schema(name + "-args", parameters) +
|
||||||
|
" \"}\""));
|
||||||
|
if (allow_content) {
|
||||||
|
handler.grammar_trigger_words.push_back("\n{\"" + name + "\"");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
builder.add_rule("root", join(tool_rules.begin(), tool_rules.end(), " | "));
|
||||||
|
});
|
||||||
|
handler.additional_stop_words.push_back("<|eom_id|>");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case llama_tool_call_style::FunctionaryV3Llama3: {
|
||||||
|
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
|
||||||
|
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
|
||||||
|
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
|
||||||
|
std::vector<std::string> tool_rules;
|
||||||
|
for (size_t i = 0, n = tools.size(); i < n; i++) {
|
||||||
|
auto & tool = tools[i];
|
||||||
|
const auto & function = tool["function"];
|
||||||
|
std::string name = function["name"];
|
||||||
|
auto parameters = function["parameters"];
|
||||||
|
auto tool_rule = builder.add_rule(name + "-call", "\">>>" + name + "\\n\" " + builder.add_schema(name + "-args", parameters));
|
||||||
|
tool_rules.push_back(tool_rule);
|
||||||
|
if (allow_content) {
|
||||||
|
handler.grammar_trigger_words.push_back(">>>" + name + "\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto tool_call = builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " space";
|
||||||
|
builder.add_rule("root", parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
|
||||||
|
});
|
||||||
|
// handler.parser = parse_functionary_3_2_tool_calls;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case llama_tool_call_style::FunctionaryV3Llama31: {
|
||||||
|
// ./tests/chat/templates/meetkai-functionary-medium-v3.1.jinja
|
||||||
|
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
|
||||||
|
// TODO: handle tool {type: code_interpreter} as python
|
||||||
|
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
|
||||||
|
std::vector<std::string> tool_rules;
|
||||||
|
for (size_t i = 0, n = tools.size(); i < n; i++) {
|
||||||
|
auto & tool = tools[i];
|
||||||
|
const auto & function = tool["function"];
|
||||||
|
std::string name = function["name"];
|
||||||
|
auto parameters = function["parameters"];
|
||||||
|
if (name == "python") {
|
||||||
|
tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));
|
||||||
|
if (allow_content) {
|
||||||
|
handler.grammar_trigger_words.push_back("<|python_tag|>");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\""));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto tool_call = builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " space";
|
||||||
|
builder.add_rule("root", parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
|
||||||
if (allow_content) {
|
if (allow_content) {
|
||||||
handler.grammar_trigger_words.push_back(">>>" + name + "\n");
|
handler.grammar_trigger_words.push_back("<function=");
|
||||||
}
|
}
|
||||||
}
|
});
|
||||||
auto tool_call = builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " space";
|
// handler.parser = parse_functionary_3_2_tool_calls;
|
||||||
builder.add_rule("root", parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
|
break;
|
||||||
});
|
}
|
||||||
// handler.parser = parse_functionary_3_2_tool_calls;
|
case llama_tool_call_style::Hermes2Pro: {
|
||||||
} else if (needs_functionary_v3_llama_3_1_tool_call(chat_template)) {
|
// NousResearchHermesPro_2
|
||||||
// ./tests/chat/templates/meetkai-functionary-medium-v3.1.jinja
|
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
|
||||||
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
|
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
|
||||||
// TODO: handle tool {type: code_interpreter} as python
|
std::vector<std::string> tool_rules;
|
||||||
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
|
for (const auto & tool : tools) {
|
||||||
std::vector<std::string> tool_rules;
|
const auto & function = tool["function"];
|
||||||
for (size_t i = 0, n = tools.size(); i < n; i++) {
|
std::string name = function["name"];
|
||||||
auto & tool = tools[i];
|
auto parameters = function["parameters"];
|
||||||
const auto & function = tool["function"];
|
builder.resolve_refs(parameters);
|
||||||
std::string name = function["name"];
|
tool_rules.push_back(builder.add_schema(name + "-call", {
|
||||||
auto parameters = function["parameters"];
|
{"type", "object"},
|
||||||
if (name == "python") {
|
{"properties", json {
|
||||||
tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));
|
{"name", json {{"const", name}}},
|
||||||
if (allow_content) {
|
{"arguments", parameters},
|
||||||
handler.grammar_trigger_words.push_back("<|python_tag|>");
|
}},
|
||||||
}
|
{"required", json::array({"name", "arguments"})},
|
||||||
} else {
|
}));
|
||||||
tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\""));
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
auto tool_call = builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " space";
|
|
||||||
builder.add_rule("root", parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
|
|
||||||
if (allow_content) {
|
|
||||||
handler.grammar_trigger_words.push_back("<function=");
|
|
||||||
}
|
|
||||||
});
|
|
||||||
// handler.parser = parse_functionary_3_2_tool_calls;
|
|
||||||
} else if (needs_hermes_pro_tool_call(chat_template)) {
|
|
||||||
// NousResearchHermesPro_2
|
|
||||||
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
|
|
||||||
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
|
|
||||||
std::vector<std::string> tool_rules;
|
|
||||||
for (const auto & tool : tools) {
|
|
||||||
const auto & function = tool["function"];
|
|
||||||
std::string name = function["name"];
|
|
||||||
auto parameters = function["parameters"];
|
|
||||||
builder.resolve_refs(parameters);
|
|
||||||
tool_rules.push_back(builder.add_schema(name + "-call", {
|
|
||||||
{"type", "object"},
|
|
||||||
{"properties", json {
|
|
||||||
{"name", json {{"const", name}}},
|
|
||||||
{"arguments", parameters},
|
|
||||||
}},
|
|
||||||
{"required", json::array({"name", "arguments"})},
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
|
|
||||||
auto tool_call = "\"<tool_call>\" " + builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " \"</tool_call>\" space";
|
auto tool_call = "\"<tool_call>\" " + builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " \"</tool_call>\" space";
|
||||||
builder.add_rule("root", parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
|
builder.add_rule("root", parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
|
||||||
if (allow_content) {
|
if (allow_content) {
|
||||||
handler.grammar_trigger_words.push_back("<tool_call>");
|
handler.grammar_trigger_words.push_back("<tool_call>");
|
||||||
}
|
|
||||||
});
|
|
||||||
} else if (needs_llama_3_1_tool_call(chat_template)) {
|
|
||||||
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
|
|
||||||
static std::vector<std::string> builtin_tools {"wolfram_alpha", "brave_search"};
|
|
||||||
std::vector<std::string> tool_rules;
|
|
||||||
|
|
||||||
for (const auto & tool : tools) {
|
|
||||||
const auto & function = tool["function"];
|
|
||||||
std::string name = function["name"];
|
|
||||||
auto parameters = function["parameters"];
|
|
||||||
builder.resolve_refs(parameters);
|
|
||||||
if (name == "ipython" || std::find(builtin_tools.begin(), builtin_tools.end(), name) != builtin_tools.end()) {
|
|
||||||
tool_rules.push_back(builder.add_rule("ipython-call", "\"<|python_tag|>\" .*"));
|
|
||||||
if (allow_content) {
|
|
||||||
handler.grammar_trigger_words.push_back("<|python_tag|>");
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
//"<|start_header_id|>assistant<|end_header_id|>\n\n{\"name\": \"" + name + "\", " +
|
|
||||||
tool_rules.push_back(
|
|
||||||
builder.add_rule(
|
|
||||||
name + "-call",
|
|
||||||
"\"\\n{\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " +
|
|
||||||
builder.add_schema(name + "-args", parameters) +
|
|
||||||
" \"}\""));
|
|
||||||
if (allow_content) {
|
|
||||||
handler.grammar_trigger_words.push_back("\n{\"" + name + "\"");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
});
|
||||||
|
break;
|
||||||
builder.add_rule("root", join(tool_rules.begin(), tool_rules.end(), " | "));
|
}
|
||||||
});
|
default:
|
||||||
handler.additional_stop_words.push_back("<|eom_id|>");
|
throw std::runtime_error("Unsupported tool call style");
|
||||||
} else {
|
|
||||||
// TODO: generic thoughtful schema.
|
|
||||||
throw std::runtime_error("Unsupported tool call style!");
|
|
||||||
}
|
}
|
||||||
return handler;
|
return handler;
|
||||||
}
|
}
|
||||||
|
@ -5,22 +5,29 @@
|
|||||||
// Change JSON_ASSERT from assert() to GGML_ASSERT:
|
// Change JSON_ASSERT from assert() to GGML_ASSERT:
|
||||||
#define JSON_ASSERT GGML_ASSERT
|
#define JSON_ASSERT GGML_ASSERT
|
||||||
#include "json.hpp"
|
#include "json.hpp"
|
||||||
|
#include "chat-template.h"
|
||||||
|
|
||||||
|
struct llama_tool_call {
|
||||||
|
std::string name;
|
||||||
|
std::string arguments;
|
||||||
|
};
|
||||||
|
|
||||||
struct llama_tool_calls {
|
struct llama_tool_calls {
|
||||||
std::string content;
|
std::string content;
|
||||||
std::vector<llama_chat_msg_tool_call> tool_calls;
|
std::vector<llama_tool_call> tool_calls;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_tool_call_handler {
|
struct llama_tool_call_handler {
|
||||||
std::string grammar;
|
std::string grammar;
|
||||||
std::vector<std::string> grammar_trigger_words;
|
std::vector<std::string> grammar_trigger_words;
|
||||||
std::vector<std::string> additional_stop_words;
|
std::vector<std::string> additional_stop_words;
|
||||||
|
nlohmann::ordered_json updated_tools;
|
||||||
};
|
};
|
||||||
|
|
||||||
llama_tool_calls parse_tool_calls(const nlohmann::ordered_json & tools, const std::string & chat_template, const std::string& input);
|
llama_tool_calls parse_tool_calls(llama_tool_call_style style, const nlohmann::ordered_json & tools, const std::string& input);
|
||||||
|
|
||||||
llama_tool_call_handler llama_tool_call_handler_init(
|
llama_tool_call_handler llama_tool_call_handler_init(
|
||||||
const std::string & chat_template,
|
const llama_chat_template & tmpl,
|
||||||
bool allow_content,
|
bool allow_content,
|
||||||
bool parallel_tool_calls,
|
bool parallel_tool_calls,
|
||||||
const nlohmann::ordered_json & tools);
|
const nlohmann::ordered_json & tools);
|
||||||
|
@ -571,6 +571,12 @@ Given a ChatML-formatted json description in `messages`, it returns the predicte
|
|||||||
```shell
|
```shell
|
||||||
llama-server --jinja -hfr lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF -hff Meta-Llama-3.1-8B-Instruct-Q5_K_M.gguf -fa
|
llama-server --jinja -hfr lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF -hff Meta-Llama-3.1-8B-Instruct-Q5_K_M.gguf -fa
|
||||||
|
|
||||||
|
# https://huggingface.co/meetkai/functionary-medium-v3.2
|
||||||
|
llama-server --jinja -hfr bartowski/functionary-medium-v3.2-GGUF -hff functionary-medium-v3.2-IQ4_XS.gguf -fa
|
||||||
|
|
||||||
|
# https://huggingface.co/meetkai/functionary-medium-v3.1
|
||||||
|
llama-server --jinja -hfr meetkai/functionary-medium-v3.1-GGUF -hff functionary-medium-llama-3.1.Q4_0.gguf -fa
|
||||||
|
|
||||||
curl http://localhost:8080/v1/chat/completions \
|
curl http://localhost:8080/v1/chat/completions \
|
||||||
-d '{
|
-d '{
|
||||||
"model": "gpt-3.5-turbo",
|
"model": "gpt-3.5-turbo",
|
||||||
|
@ -662,7 +662,7 @@ struct server_context {
|
|||||||
bool validate_model_chat_template(bool use_jinja) const {
|
bool validate_model_chat_template(bool use_jinja) const {
|
||||||
llama_chat_message chat[] = {{"user", "test"}};
|
llama_chat_message chat[] = {{"user", "test"}};
|
||||||
|
|
||||||
const int res = llama_chat_apply_template(model, nullptr, chat, 1, true, nullptr, 0, use_jinja, nullptr, nullptr, nullptr);
|
const int res = llama_chat_apply_template(model, nullptr, chat, 1, true, nullptr, 0);
|
||||||
|
|
||||||
return res > 0;
|
return res > 0;
|
||||||
}
|
}
|
||||||
@ -2860,9 +2860,11 @@ int main(int argc, char ** argv) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto chat_template = llama_chat_template::from_model(ctx_server.model, params.chat_template);
|
||||||
|
|
||||||
json data;
|
json data;
|
||||||
try {
|
try {
|
||||||
data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template, params.use_jinja);
|
data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), chat_template, params.use_jinja);
|
||||||
} catch (const std::runtime_error & e) {
|
} catch (const std::runtime_error & e) {
|
||||||
res_error(res, format_error_response(e.what(), ERROR_TYPE_NOT_SUPPORTED));
|
res_error(res, format_error_response(e.what(), ERROR_TYPE_NOT_SUPPORTED));
|
||||||
return;
|
return;
|
||||||
@ -2880,7 +2882,7 @@ int main(int argc, char ** argv) {
|
|||||||
ctx_server.receive_cmpl_results(task_ids, [&](const std::vector<server_task_result> & results) {
|
ctx_server.receive_cmpl_results(task_ids, [&](const std::vector<server_task_result> & results) {
|
||||||
// multitask is never support in chat completion, there is only one result
|
// multitask is never support in chat completion, there is only one result
|
||||||
try {
|
try {
|
||||||
json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, /*.streaming =*/ false, verbose);
|
json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, chat_template, /*.streaming =*/ false, verbose);
|
||||||
res_ok(res, result_oai);
|
res_ok(res, result_oai);
|
||||||
} catch (const std::runtime_error & e) {
|
} catch (const std::runtime_error & e) {
|
||||||
res_error(res, format_error_response(e.what(), ERROR_TYPE_SERVER));
|
res_error(res, format_error_response(e.what(), ERROR_TYPE_SERVER));
|
||||||
|
@ -23,19 +23,19 @@ Feature: llama.cpp server
|
|||||||
And a model test
|
And a model test
|
||||||
And <n_predict> max tokens to predict
|
And <n_predict> max tokens to predict
|
||||||
And a user prompt write a hello world in python
|
And a user prompt write a hello world in python
|
||||||
And a tool choice <tool_choice>
|
And a tool choice required
|
||||||
And tools <tools>
|
And tools <tools>
|
||||||
And an OAI compatible chat completions request with no api error
|
And an OAI compatible chat completions request with no api error
|
||||||
Then tool <tool_name> is called with arguments <tool_arguments>
|
Then tool <tool_name> is called with arguments <tool_arguments>
|
||||||
|
|
||||||
Examples: Prompts
|
Examples: Prompts
|
||||||
| template_name | n_predict | tool_name | tool_arguments | tool_choice | tools |
|
| template_name | n_predict | tool_name | tool_arguments | tools |
|
||||||
| meetkai-functionary-medium-v3.1 | 128 | test | {} | required | [{"type":"function", "function": {"name": "test", "description": "", "parameters": {"type": "object", "properties": {}}}}] |
|
| meetkai-functionary-medium-v3.1 | 128 | test | {} | [{"type":"function", "function": {"name": "test", "description": "", "parameters": {"type": "object", "properties": {}}}}] |
|
||||||
| meetkai-functionary-medium-v3.1 | 128 | ipython | {"code": "Yes, you can."} | required | [{"type":"function", "function": {"name": "ipython", "description": "", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": ""}}, "required": ["code"]}}}] |
|
| meetkai-functionary-medium-v3.1 | 128 | ipython | {"code": "I'm sorry,"} | [{"type":"function", "function": {"name": "ipython", "description": "", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": ""}}, "required": ["code"]}}}] |
|
||||||
| meetkai-functionary-medium-v3.2 | 128 | test | {} | required | [{"type":"function", "function": {"name": "test", "description": "", "parameters": {"type": "object", "properties": {}}}}] |
|
| meetkai-functionary-medium-v3.2 | 128 | test | {} | [{"type":"function", "function": {"name": "test", "description": "", "parameters": {"type": "object", "properties": {}}}}] |
|
||||||
| meetkai-functionary-medium-v3.2 | 128 | ipython | {"code": "Yes,"} | required | [{"type":"function", "function": {"name": "ipython", "description": "", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": ""}}, "required": ["code"]}}}] |
|
| meetkai-functionary-medium-v3.2 | 128 | ipython | {"code": "Yes,"} | [{"type":"function", "function": {"name": "ipython", "description": "", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": ""}}, "required": ["code"]}}}] |
|
||||||
| meta-llama-Meta-Llama-3.1-8B-Instruct | 64 | test | {} | required | [{"type":"function", "function": {"name": "test", "description": "", "parameters": {"type": "object", "properties": {}}}}] |
|
| meta-llama-Meta-Llama-3.1-8B-Instruct | 64 | test | {} | [{"type":"function", "function": {"name": "test", "description": "", "parameters": {"type": "object", "properties": {}}}}] |
|
||||||
| meta-llama-Meta-Llama-3.1-8B-Instruct | 16 | ipython | {"code": "it and "} | required | [{"type":"function", "function": {"name": "ipython", "description": "", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": ""}}, "required": ["code"]}}}] |
|
| meta-llama-Meta-Llama-3.1-8B-Instruct | 16 | ipython | {"code": ". A"} | [{"type":"function", "function": {"name": "ipython", "description": "", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": ""}}, "required": ["code"]}}}] |
|
||||||
|
|
||||||
|
|
||||||
Scenario: OAI Compatibility w/ no tool
|
Scenario: OAI Compatibility w/ no tool
|
||||||
|
@ -14,6 +14,7 @@
|
|||||||
|
|
||||||
// Change JSON_ASSERT from assert() to GGML_ASSERT:
|
// Change JSON_ASSERT from assert() to GGML_ASSERT:
|
||||||
#define JSON_ASSERT GGML_ASSERT
|
#define JSON_ASSERT GGML_ASSERT
|
||||||
|
#include "chat-template.h"
|
||||||
#include "json.hpp"
|
#include "json.hpp"
|
||||||
#include "minja.hpp"
|
#include "minja.hpp"
|
||||||
#include "tool-call.h"
|
#include "tool-call.h"
|
||||||
@ -64,40 +65,30 @@ inline std::string format_chat(const struct llama_model * model, const std::stri
|
|||||||
for (size_t i = 0; i < messages.size(); ++i) {
|
for (size_t i = 0; i < messages.size(); ++i) {
|
||||||
const auto & curr_msg = messages[i];
|
const auto & curr_msg = messages[i];
|
||||||
|
|
||||||
llama_chat_msg msg;
|
std::string role = json_value(curr_msg, "role", std::string(""));
|
||||||
msg.role = json_value(curr_msg, "role", std::string(""));
|
|
||||||
msg.tool = json_value(curr_msg, "tool", std::string(""));
|
std::string content;
|
||||||
|
|
||||||
if (curr_msg.contains("content")) {
|
if (curr_msg.contains("content")) {
|
||||||
if (curr_msg["content"].is_string()) {
|
if (curr_msg["content"].is_string()) {
|
||||||
msg.content = curr_msg["content"].get<std::string>();
|
content = curr_msg["content"].get<std::string>();
|
||||||
} else if (curr_msg["content"].is_array()) {
|
} else if (curr_msg["content"].is_array()) {
|
||||||
for (const auto & part : curr_msg["content"]) {
|
for (const auto & part : curr_msg["content"]) {
|
||||||
if (part.contains("text")) {
|
if (part.contains("text")) {
|
||||||
msg.content += "\n" + part["text"].get<std::string>();
|
content += "\n" + part["text"].get<std::string>();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if (!(curr_msg.is_null() && curr_msg.contains("tool_calls"))) {
|
} else {
|
||||||
throw std::runtime_error("Invalid 'content' type (ref: https://github.com/ggerganov/llama.cpp/issues/8367): " + curr_msg.dump());
|
throw std::runtime_error("Invalid 'content' type (ref: https://github.com/ggerganov/llama.cpp/issues/8367)");
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
throw std::runtime_error("Missing 'content' (ref: https://github.com/ggerganov/llama.cpp/issues/8367)");
|
throw std::runtime_error("Missing 'content' (ref: https://github.com/ggerganov/llama.cpp/issues/8367)");
|
||||||
}
|
}
|
||||||
if (curr_msg.contains("tool_calls") && curr_msg["tool_calls"].is_array()) {
|
|
||||||
for (const auto & tool_call : curr_msg["tool_calls"]) {
|
chat.push_back({role, content});
|
||||||
if (json_value(tool_call, "type", std::string("")) == "function"
|
|
||||||
&& tool_call.contains("function") && tool_call["function"].is_object()) {
|
|
||||||
msg.tool_calls.push_back({
|
|
||||||
json_value(tool_call["function"], "name", std::string("")),
|
|
||||||
json_value(tool_call["function"], "arguments", std::string(""))
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
chat.emplace_back(std::move(msg));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto formatted_chat = llama_chat_apply_template(model, tmpl, chat, true, use_jinja, tools.is_null() ? nullptr : tools.dump().c_str());
|
const auto formatted_chat = llama_chat_apply_template(model, tmpl, chat, true);
|
||||||
LOG_DBG("formatted_chat: '%s'\n", formatted_chat.c_str());
|
LOG_DBG("formatted_chat: '%s'\n", formatted_chat.c_str());
|
||||||
|
|
||||||
return formatted_chat;
|
return formatted_chat;
|
||||||
@ -315,38 +306,12 @@ static bool server_sent_event(httplib::DataSink & sink, const char * event, cons
|
|||||||
// OAI utils
|
// OAI utils
|
||||||
//
|
//
|
||||||
|
|
||||||
static std::string _llama_token_to_piece(const struct llama_model * model, llama_token token, bool special) {
|
|
||||||
std::string piece;
|
|
||||||
piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n'
|
|
||||||
const int n_chars = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special);
|
|
||||||
if (n_chars < 0) {
|
|
||||||
piece.resize(-n_chars);
|
|
||||||
int check = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special);
|
|
||||||
GGML_ASSERT(check == -n_chars);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
piece.resize(n_chars);
|
|
||||||
}
|
|
||||||
|
|
||||||
return piece;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string llama_model_meta_val_str(const struct llama_model * model, const char * key) {
|
|
||||||
int32_t tlen = llama_model_meta_val_str(model, key, nullptr, 0);
|
|
||||||
if (tlen > 0) {
|
|
||||||
std::vector<char> curr_tmpl_buf(tlen + 1, 0);
|
|
||||||
if (llama_model_meta_val_str(model, key, curr_tmpl_buf.data(), curr_tmpl_buf.size()) == tlen) {
|
|
||||||
return std::string(curr_tmpl_buf.data(), tlen);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return "";
|
|
||||||
}
|
|
||||||
|
|
||||||
static json oaicompat_completion_params_parse(
|
static json oaicompat_completion_params_parse(
|
||||||
const struct llama_model * model,
|
const struct llama_model * model,
|
||||||
const json & body, /* openai api json semantics */
|
const json & body, /* openai api json semantics */
|
||||||
const std::string & chat_template_src,
|
const llama_chat_template & tmpl,
|
||||||
bool use_jinja) {
|
bool use_jinja)
|
||||||
|
{
|
||||||
json llama_params;
|
json llama_params;
|
||||||
|
|
||||||
llama_params["__oaicompat"] = true;
|
llama_params["__oaicompat"] = true;
|
||||||
@ -355,16 +320,15 @@ static json oaicompat_completion_params_parse(
|
|||||||
auto has_tools = tools.is_array() && !tools.empty();
|
auto has_tools = tools.is_array() && !tools.empty();
|
||||||
|
|
||||||
// Apply chat template to the list of messages
|
// Apply chat template to the list of messages
|
||||||
auto chat_template = chat_template_src.empty() ? llama_model_meta_val_str(model, "tokenizer.chat_template") : chat_template_src;
|
llama_params["chat_template"] = tmpl.chat_template();
|
||||||
llama_params["chat_template"] = chat_template;
|
|
||||||
if (use_jinja) {
|
if (use_jinja) {
|
||||||
if (has_tools && chat_template.find("tools") == std::string::npos) {
|
if (has_tools && !tmpl.supports_tools()) {
|
||||||
throw std::runtime_error("Chat template does not seem to support tools. Override the model template with --chat-template.");
|
throw std::runtime_error("Chat template does not seem to support tools. Override the model template with --chat-template.");
|
||||||
}
|
}
|
||||||
} else if (has_tools) {
|
} else if (has_tools) {
|
||||||
throw std::runtime_error("Tools are only supported in --jinja mode");
|
throw std::runtime_error("Tools are only supported in --jinja mode");
|
||||||
}
|
}
|
||||||
llama_params["prompt"] = format_chat(model, chat_template, body.at("messages"), tools, use_jinja);
|
|
||||||
|
|
||||||
// Handle "stop" field
|
// Handle "stop" field
|
||||||
if (body.contains("stop") && body.at("stop").is_string()) {
|
if (body.contains("stop") && body.at("stop").is_string()) {
|
||||||
@ -399,26 +363,40 @@ static json oaicompat_completion_params_parse(
|
|||||||
} else if (!response_type.empty() && response_type != "text") {
|
} else if (!response_type.empty() && response_type != "text") {
|
||||||
throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type);
|
throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type);
|
||||||
}
|
}
|
||||||
} else if (use_jinja && tool_choice != "none" && has_tools) {
|
}
|
||||||
bool parallel_tool_calls = json_value(body, "parallel_tool_calls", false);
|
|
||||||
|
if (use_jinja) {
|
||||||
bool allow_content = tool_choice != "required";
|
bool allow_content = tool_choice != "required";
|
||||||
|
if (tool_choice != "none" && has_tools) {
|
||||||
|
bool parallel_tool_calls = json_value(body, "parallel_tool_calls", false);
|
||||||
|
llama_params["parse_tool_calls"] = true;
|
||||||
|
llama_params["parallel_tool_calls"] = parallel_tool_calls;
|
||||||
|
|
||||||
auto handler = llama_tool_call_handler_init(chat_template, allow_content, parallel_tool_calls, tools);
|
auto handler = llama_tool_call_handler_init(tmpl, allow_content, parallel_tool_calls, tools);
|
||||||
|
|
||||||
for (const auto & stop : handler.additional_stop_words) {
|
for (const auto & stop : handler.additional_stop_words) {
|
||||||
llama_params["stop"].push_back(stop);
|
llama_params["stop"].push_back(stop);
|
||||||
}
|
}
|
||||||
if (!handler.grammar_trigger_words.empty()) {
|
if (!handler.grammar_trigger_words.empty()) {
|
||||||
auto triggers = json::array();
|
auto triggers = json::array();
|
||||||
for (const auto & word : handler.grammar_trigger_words) {
|
for (const auto & word : handler.grammar_trigger_words) {
|
||||||
triggers.push_back(word);
|
triggers.push_back(word);
|
||||||
|
}
|
||||||
|
llama_params["grammar_trigger_words"] = triggers;
|
||||||
|
}
|
||||||
|
if (handler.updated_tools.is_null()) {
|
||||||
|
tools = handler.updated_tools;
|
||||||
|
}
|
||||||
|
if (!handler.grammar.empty()) {
|
||||||
|
if (llama_params.contains("grammar")) {
|
||||||
|
throw std::runtime_error("Cannot use custom grammar constraints with tools.");
|
||||||
|
}
|
||||||
|
llama_params["grammar"] = handler.grammar;
|
||||||
}
|
}
|
||||||
llama_params["grammar_trigger_words"] = triggers;
|
|
||||||
}
|
}
|
||||||
|
llama_params["prompt"] = tmpl.apply(body.at("messages"), tools, /* add_generation_prompt= */ true);
|
||||||
llama_params["grammar"] = handler.grammar;
|
} else {
|
||||||
llama_params["parse_tool_calls"] = true;
|
llama_params["prompt"] = format_chat(model, tmpl.chat_template(), body.at("messages"), tools, /* use_jinja= */ false);
|
||||||
llama_params["parallel_tool_calls"] = parallel_tool_calls;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle "n" field
|
// Handle "n" field
|
||||||
@ -458,7 +436,7 @@ static json oaicompat_completion_params_parse(
|
|||||||
return llama_params;
|
return llama_params;
|
||||||
}
|
}
|
||||||
|
|
||||||
static json format_final_response_oaicompat(const json & request, const json & result, const std::string & completion_id, bool streaming = false, bool verbose = false) {
|
static json format_final_response_oaicompat(const json & request, const json & result, const std::string & completion_id, const llama_chat_template & tmpl, bool streaming = false, bool verbose = false) {
|
||||||
bool stopped_word = result.count("stopped_word") != 0;
|
bool stopped_word = result.count("stopped_word") != 0;
|
||||||
bool stopped_eos = json_value(result, "stopped_eos", false);
|
bool stopped_eos = json_value(result, "stopped_eos", false);
|
||||||
int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
|
int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
|
||||||
@ -474,9 +452,8 @@ static json format_final_response_oaicompat(const json & request, const json & r
|
|||||||
auto tools = json_value(request, "tools", json::array());
|
auto tools = json_value(request, "tools", json::array());
|
||||||
json tool_calls;
|
json tool_calls;
|
||||||
json message_content;
|
json message_content;
|
||||||
printf("# CONTENT: %s\n\n", content.c_str());
|
|
||||||
if (json_value(request, "parse_tool_calls", false)
|
if (json_value(request, "parse_tool_calls", false)
|
||||||
&& !(parsed_tool_calls = parse_tool_calls(tools, chat_template, content)).tool_calls.empty()) {
|
&& !(parsed_tool_calls = parse_tool_calls(tmpl.tool_call_style(), tools, content)).tool_calls.empty()) {
|
||||||
finish_reason = "tool";
|
finish_reason = "tool";
|
||||||
if (!parsed_tool_calls.content.empty()) {
|
if (!parsed_tool_calls.content.empty()) {
|
||||||
message_content = parsed_tool_calls.content;
|
message_content = parsed_tool_calls.content;
|
||||||
@ -514,7 +491,6 @@ static json format_final_response_oaicompat(const json & request, const json & r
|
|||||||
}},
|
}},
|
||||||
{"id", completion_id}
|
{"id", completion_id}
|
||||||
};
|
};
|
||||||
printf("# RES: %s\n\n", res.dump(2).c_str());
|
|
||||||
|
|
||||||
// extra fields for debugging purposes
|
// extra fields for debugging purposes
|
||||||
if (verbose) {
|
if (verbose) {
|
||||||
|
@ -377,19 +377,9 @@ extern "C" {
|
|||||||
} llama_sampler_chain_params;
|
} llama_sampler_chain_params;
|
||||||
|
|
||||||
// used in chat template
|
// used in chat template
|
||||||
|
|
||||||
typedef struct llama_chat_message_tool_call {
|
|
||||||
const char * name;
|
|
||||||
const char * arguments;
|
|
||||||
} llama_chat_message_tool_call;
|
|
||||||
|
|
||||||
typedef struct llama_chat_message {
|
typedef struct llama_chat_message {
|
||||||
const char * role;
|
const char * role;
|
||||||
const char * content;
|
const char * content;
|
||||||
const char * tool;
|
|
||||||
|
|
||||||
const llama_chat_message_tool_call * tool_calls;
|
|
||||||
uint32_t n_tool_calls;
|
|
||||||
} llama_chat_message;
|
} llama_chat_message;
|
||||||
|
|
||||||
// lora adapter
|
// lora adapter
|
||||||
@ -986,11 +976,7 @@ extern "C" {
|
|||||||
size_t n_msg,
|
size_t n_msg,
|
||||||
bool add_ass,
|
bool add_ass,
|
||||||
char * buf,
|
char * buf,
|
||||||
int32_t length,
|
int32_t length);
|
||||||
bool use_jinja,
|
|
||||||
const char * tools,
|
|
||||||
const char * bos_token,
|
|
||||||
const char * eos_token);
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// Sampling API
|
// Sampling API
|
||||||
|
110
src/llama.cpp
110
src/llama.cpp
@ -2,8 +2,6 @@
|
|||||||
#include "llama-vocab.h"
|
#include "llama-vocab.h"
|
||||||
#include "llama-sampling.h"
|
#include "llama-sampling.h"
|
||||||
|
|
||||||
#include "minja.hpp"
|
|
||||||
|
|
||||||
#include "unicode.h"
|
#include "unicode.h"
|
||||||
|
|
||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
@ -21004,95 +21002,7 @@ int32_t llama_detokenize(
|
|||||||
static int32_t llama_chat_apply_template_internal(
|
static int32_t llama_chat_apply_template_internal(
|
||||||
const std::string & tmpl,
|
const std::string & tmpl,
|
||||||
const std::vector<const llama_chat_message *> & chat,
|
const std::vector<const llama_chat_message *> & chat,
|
||||||
std::string & dest, bool add_ass,
|
std::string & dest, bool add_ass) {
|
||||||
bool use_jinja,
|
|
||||||
const std::string & tools,
|
|
||||||
const std::string & bos_token, const std::string & eos_token) {
|
|
||||||
|
|
||||||
if (use_jinja) {
|
|
||||||
auto system_not_supported = tmpl.find("System role not supported") != std::string::npos;
|
|
||||||
|
|
||||||
// Meta-Llama-3.1-8B-Instruct's template expects arguments to be an object.
|
|
||||||
// Most other templates (and OpenAI's API) expect the arguments object to be stringified.
|
|
||||||
auto tool_call_args_must_be_objects = tmpl.find("tool_call.arguments | items") != std::string::npos;
|
|
||||||
|
|
||||||
auto messages = json::array();
|
|
||||||
|
|
||||||
std::string pending_system;
|
|
||||||
auto flush_sys = [&]() {
|
|
||||||
if (!pending_system.empty()) {
|
|
||||||
messages.push_back({
|
|
||||||
{"role", "user"},
|
|
||||||
{"content", pending_system},
|
|
||||||
});
|
|
||||||
pending_system.clear();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
for (const auto * msg : chat) {
|
|
||||||
std::string role(msg->role);
|
|
||||||
std::string content(msg->content);
|
|
||||||
if (system_not_supported) {
|
|
||||||
if (role == "system") {
|
|
||||||
if (!pending_system.empty()) pending_system += "\n";
|
|
||||||
pending_system += content;
|
|
||||||
continue;
|
|
||||||
} else {
|
|
||||||
if (role == "user") {
|
|
||||||
if (!pending_system.empty()) {
|
|
||||||
content = pending_system + (content.empty() ? "" : "\n" + content);
|
|
||||||
pending_system.clear();
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
flush_sys();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
auto message = json({
|
|
||||||
{"role", role},
|
|
||||||
{"content", content},
|
|
||||||
});
|
|
||||||
if (msg->tool) message["tool"] = msg->tool;
|
|
||||||
if (msg->n_tool_calls) {
|
|
||||||
auto tool_calls = json::array();
|
|
||||||
for (uint32_t i = 0; i < msg->n_tool_calls; i++) {
|
|
||||||
auto args = msg->tool_calls[i].arguments;
|
|
||||||
tool_calls.push_back(json({
|
|
||||||
{"type", "function"},
|
|
||||||
{"function", {
|
|
||||||
{"name", msg->tool_calls[i].name},
|
|
||||||
{"arguments", tool_call_args_must_be_objects ? json::parse(args) : args},
|
|
||||||
}}
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
messages["tool_calls"] = tool_calls;
|
|
||||||
}
|
|
||||||
messages.push_back(message);
|
|
||||||
}
|
|
||||||
flush_sys();
|
|
||||||
|
|
||||||
auto context = minja::Context::make(json({
|
|
||||||
{"messages", messages},
|
|
||||||
{"add_generation_prompt", add_ass},
|
|
||||||
{"bos_token", bos_token},
|
|
||||||
{"eos_token", eos_token},
|
|
||||||
}));
|
|
||||||
if (!tools.empty()) {
|
|
||||||
auto tools_val = minja::Value(json::parse(tools));
|
|
||||||
context->set("tools", tools_val);
|
|
||||||
}
|
|
||||||
auto tmpl_root = minja::Parser::parse(tmpl, {
|
|
||||||
/* .trim_blocks = */ true,
|
|
||||||
/* .lstrip_blocks = */ true,
|
|
||||||
/* .keep_trailing_newline = */ false,
|
|
||||||
});
|
|
||||||
try {
|
|
||||||
dest = tmpl_root->render(context);
|
|
||||||
return dest.size();
|
|
||||||
} catch (const std::runtime_error & err) {
|
|
||||||
LLAMA_LOG_ERROR("Error in jinja template: %s\n", err.what());
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527
|
// Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
@ -21360,11 +21270,7 @@ int32_t llama_chat_apply_template(
|
|||||||
size_t n_msg,
|
size_t n_msg,
|
||||||
bool add_ass,
|
bool add_ass,
|
||||||
char * buf,
|
char * buf,
|
||||||
int32_t length,
|
int32_t length) {
|
||||||
bool use_jinja,
|
|
||||||
const char * tools,
|
|
||||||
const char * bos_token,
|
|
||||||
const char * eos_token) {
|
|
||||||
std::string curr_tmpl(tmpl == nullptr ? "" : tmpl);
|
std::string curr_tmpl(tmpl == nullptr ? "" : tmpl);
|
||||||
if (tmpl == nullptr) {
|
if (tmpl == nullptr) {
|
||||||
GGML_ASSERT(model != nullptr);
|
GGML_ASSERT(model != nullptr);
|
||||||
@ -21379,16 +21285,6 @@ int32_t llama_chat_apply_template(
|
|||||||
curr_tmpl = std::string(model_template.data(), model_template.size());
|
curr_tmpl = std::string(model_template.data(), model_template.size());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
std::string curr_bos_token(bos_token ? bos_token : "");
|
|
||||||
std::string curr_eos_token(eos_token ? eos_token : "");
|
|
||||||
if (bos_token == nullptr) {
|
|
||||||
GGML_ASSERT(model != nullptr);
|
|
||||||
curr_bos_token = llama_token_to_piece(model, llama_token_bos(model), true);
|
|
||||||
}
|
|
||||||
if (eos_token == nullptr) {
|
|
||||||
GGML_ASSERT(model != nullptr);
|
|
||||||
curr_eos_token = llama_token_to_piece(model, llama_token_eos(model), true);
|
|
||||||
}
|
|
||||||
|
|
||||||
// format the chat to string
|
// format the chat to string
|
||||||
std::vector<const llama_chat_message *> chat_vec;
|
std::vector<const llama_chat_message *> chat_vec;
|
||||||
@ -21398,7 +21294,7 @@ int32_t llama_chat_apply_template(
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::string formatted_chat;
|
std::string formatted_chat;
|
||||||
int32_t res = llama_chat_apply_template_internal(curr_tmpl, chat_vec, formatted_chat, add_ass, use_jinja, tools == nullptr ? "" : tools, curr_bos_token, curr_eos_token);
|
int32_t res = llama_chat_apply_template_internal(curr_tmpl, chat_vec, formatted_chat, add_ass);
|
||||||
if (res < 0) {
|
if (res < 0) {
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
@ -20,9 +20,9 @@ static void assert_equals(const std::string & expected, const std::string & actu
|
|||||||
cmake -B build -DLLAMA_CURL=1 -DCMAKE_BUILD_TYPE=Release && cmake --build build -t test-tool-call -j && ./build/bin/test-tool-call
|
cmake -B build -DLLAMA_CURL=1 -DCMAKE_BUILD_TYPE=Release && cmake --build build -t test-tool-call -j && ./build/bin/test-tool-call
|
||||||
*/
|
*/
|
||||||
|
|
||||||
static void test_parse_tool_call(const json & tools, const std::string & chat_template, const std::string & input, const std::string & expected_content, const json & expected_tool_calls) {
|
static void test_parse_tool_call(llama_tool_call_style style, const json & tools, const std::string & input, const std::string & expected_content, const json & expected_tool_calls) {
|
||||||
std::cout << "# Testing: " << input << std::endl << std::flush;
|
std::cout << "# Testing: " << input << std::endl << std::flush;
|
||||||
auto result = parse_tool_calls(tools, chat_template, input);
|
auto result = parse_tool_calls(style, tools, input);
|
||||||
assert_equals(expected_content, result.content);
|
assert_equals(expected_content, result.content);
|
||||||
auto tool_calls = json::array();
|
auto tool_calls = json::array();
|
||||||
for (const auto & tc : result.tool_calls) {
|
for (const auto & tc : result.tool_calls) {
|
||||||
@ -59,8 +59,7 @@ int main() {
|
|||||||
{"tools", tools}
|
{"tools", tools}
|
||||||
};
|
};
|
||||||
|
|
||||||
std::string hermes_2_pro_like_tmpl = "Hermes 2 Pro template should have <tool_call> inside it";
|
test_parse_tool_call(llama_tool_call_style::Hermes2Pro, tools,
|
||||||
test_parse_tool_call(tools, hermes_2_pro_like_tmpl,
|
|
||||||
"<tool_call>{\"name\": \"foo\", \"arguments\": {\"bar\": 1}}</tool_call>",
|
"<tool_call>{\"name\": \"foo\", \"arguments\": {\"bar\": 1}}</tool_call>",
|
||||||
"",
|
"",
|
||||||
json {{
|
json {{
|
||||||
@ -72,8 +71,7 @@ int main() {
|
|||||||
}}
|
}}
|
||||||
}});
|
}});
|
||||||
|
|
||||||
std::string functionary_v3_like_tmpl = "Functionary 3.2 template should have <|start_header_id|> and then some >>>all inside it";
|
test_parse_tool_call(llama_tool_call_style::FunctionaryV3Llama3, tools,
|
||||||
test_parse_tool_call(tools, functionary_v3_like_tmpl,
|
|
||||||
">>>ipython\n{\"code\": \"print('Hello, world!')\"}",
|
">>>ipython\n{\"code\": \"print('Hello, world!')\"}",
|
||||||
"",
|
"",
|
||||||
json {{
|
json {{
|
||||||
@ -84,7 +82,7 @@ int main() {
|
|||||||
}).dump()}
|
}).dump()}
|
||||||
}}
|
}}
|
||||||
}});
|
}});
|
||||||
test_parse_tool_call(tools, functionary_v3_like_tmpl,
|
test_parse_tool_call(llama_tool_call_style::FunctionaryV3Llama3, tools,
|
||||||
">>>test\n{ } \n ",
|
">>>test\n{ } \n ",
|
||||||
"",
|
"",
|
||||||
json {{
|
json {{
|
||||||
@ -94,8 +92,7 @@ int main() {
|
|||||||
}}
|
}}
|
||||||
}});
|
}});
|
||||||
|
|
||||||
std::string functionary_v3_llama_3_1_like_tmpl = "Functionary 3.2 template for llama 3.1 should have <|start_header_id|> and then some <function=foo>{...}</function> inside it";
|
test_parse_tool_call(llama_tool_call_style::FunctionaryV3Llama31, tools,
|
||||||
test_parse_tool_call(tools, functionary_v3_llama_3_1_like_tmpl,
|
|
||||||
"Hell<function=foo>{\"arg1\": 1}</function>o, world<function=bar>{\"arg2\": 2}</function>!",
|
"Hell<function=foo>{\"arg1\": 1}</function>o, world<function=bar>{\"arg2\": 2}</function>!",
|
||||||
"Hello, world!",
|
"Hello, world!",
|
||||||
json {
|
json {
|
||||||
@ -116,7 +113,7 @@ int main() {
|
|||||||
}}
|
}}
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
test_parse_tool_call(tools, functionary_v3_llama_3_1_like_tmpl,
|
test_parse_tool_call(llama_tool_call_style::FunctionaryV3Llama31, tools,
|
||||||
"<function=test>{ } </function> ",
|
"<function=test>{ } </function> ",
|
||||||
" ",
|
" ",
|
||||||
json {{
|
json {{
|
||||||
@ -126,8 +123,7 @@ int main() {
|
|||||||
}}
|
}}
|
||||||
}});
|
}});
|
||||||
|
|
||||||
std::string llama_3_1_like_tmpl = "Llama 3.1 template should have <|start_header_id|> and <|python_tag|> inside it";
|
test_parse_tool_call(llama_tool_call_style::Llama31, tools,
|
||||||
test_parse_tool_call(tools, llama_3_1_like_tmpl,
|
|
||||||
"<|python_tag|>this could be anything",
|
"<|python_tag|>this could be anything",
|
||||||
"",
|
"",
|
||||||
json {{
|
json {{
|
||||||
@ -138,7 +134,7 @@ int main() {
|
|||||||
}).dump()}
|
}).dump()}
|
||||||
}}
|
}}
|
||||||
}});
|
}});
|
||||||
test_parse_tool_call(tools, llama_3_1_like_tmpl,
|
test_parse_tool_call(llama_tool_call_style::Llama31, tools,
|
||||||
"I'm thinking<|python_tag|>",
|
"I'm thinking<|python_tag|>",
|
||||||
"I'm thinking",
|
"I'm thinking",
|
||||||
json {{
|
json {{
|
||||||
@ -147,7 +143,7 @@ int main() {
|
|||||||
{"arguments", (json {{"code", ""}}).dump()}
|
{"arguments", (json {{"code", ""}}).dump()}
|
||||||
}}
|
}}
|
||||||
}});
|
}});
|
||||||
test_parse_tool_call(tools, llama_3_1_like_tmpl,
|
test_parse_tool_call(llama_tool_call_style::Llama31, tools,
|
||||||
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}",
|
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}",
|
||||||
"",
|
"",
|
||||||
json {{
|
json {{
|
||||||
@ -158,7 +154,7 @@ int main() {
|
|||||||
}).dump()}
|
}).dump()}
|
||||||
}}
|
}}
|
||||||
}});
|
}});
|
||||||
test_parse_tool_call(tools, llama_3_1_like_tmpl,
|
test_parse_tool_call(llama_tool_call_style::Llama31, tools,
|
||||||
"{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}",
|
"{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}",
|
||||||
"{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", json::array());
|
"{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", json::array());
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user