mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-13 12:10:18 +00:00
tool-call
: behaviour-based detection of template features
This commit is contained in:
parent
e8d9d711f6
commit
c395d4804f
@ -32,22 +32,45 @@ class chat_template {
|
||||
std::string _eos_token;
|
||||
std::shared_ptr<minja::TemplateNode> _template_root;
|
||||
|
||||
bool renders_needles(
|
||||
const std::vector<std::string> & needles,
|
||||
const nlohmann::ordered_json & messages,
|
||||
const nlohmann::ordered_json & tools,
|
||||
bool add_generation_prompt,
|
||||
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const
|
||||
{
|
||||
try {
|
||||
auto prompt = apply(messages, tools, add_generation_prompt, extra_context);
|
||||
for (const auto & needle : needles) {
|
||||
if (prompt.find(needle) == std::string::npos) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
} catch (const std::exception & e) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
chat_template(const std::string & source, const std::string & bos_token, const std::string & eos_token)
|
||||
: _source(source), _bos_token(bos_token), _eos_token(eos_token)
|
||||
{
|
||||
_supports_tools = source.find("tools") != std::string::npos;
|
||||
_requires_object_arguments =
|
||||
source.find("tool_call.arguments | items") != std::string::npos
|
||||
|| source.find("tool_call.arguments | tojson") != std::string::npos;
|
||||
_supports_system_role = source.find("System role not supported") == std::string::npos;
|
||||
_supports_parallel_tool_calls = source.find("tool_call_id") != std::string::npos;
|
||||
|
||||
_template_root = minja::Parser::parse(_source, {
|
||||
/* .trim_blocks = */ true,
|
||||
/* .lstrip_blocks = */ true,
|
||||
/* .keep_trailing_newline = */ false,
|
||||
});
|
||||
_supports_tools = source.find("tools") != std::string::npos;
|
||||
_requires_object_arguments =
|
||||
source.find("tool_call.arguments | items") != std::string::npos
|
||||
|| source.find("tool_call.arguments | tojson") != std::string::npos;
|
||||
_supports_parallel_tool_calls = source.find("tool_call_id") != std::string::npos;
|
||||
|
||||
_supports_system_role = renders_needles({"<System Needle>"}, {
|
||||
{{"role", "system"}, {"content", "<System Needle>"}},
|
||||
{{"role", "user"}, {"content", "Hey"}}
|
||||
}, {}, false);
|
||||
}
|
||||
|
||||
const std::string & source() const { return _source; }
|
||||
|
Loading…
Reference in New Issue
Block a user