tool-call: behaviour-based detection of template features

This commit is contained in:
ochafik 2024-10-31 13:45:10 +00:00
parent e8d9d711f6
commit c395d4804f

View File

@ -32,22 +32,45 @@ class chat_template {
std::string _eos_token;
std::shared_ptr<minja::TemplateNode> _template_root;
bool renders_needles(
const std::vector<std::string> & needles,
const nlohmann::ordered_json & messages,
const nlohmann::ordered_json & tools,
bool add_generation_prompt,
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const
{
try {
auto prompt = apply(messages, tools, add_generation_prompt, extra_context);
for (const auto & needle : needles) {
if (prompt.find(needle) == std::string::npos) {
return false;
}
}
return true;
} catch (const std::exception & e) {
return false;
}
}
public:
chat_template(const std::string & source, const std::string & bos_token, const std::string & eos_token)
: _source(source), _bos_token(bos_token), _eos_token(eos_token)
{
_supports_tools = source.find("tools") != std::string::npos;
_requires_object_arguments =
source.find("tool_call.arguments | items") != std::string::npos
|| source.find("tool_call.arguments | tojson") != std::string::npos;
_supports_system_role = source.find("System role not supported") == std::string::npos;
_supports_parallel_tool_calls = source.find("tool_call_id") != std::string::npos;
_template_root = minja::Parser::parse(_source, {
/* .trim_blocks = */ true,
/* .lstrip_blocks = */ true,
/* .keep_trailing_newline = */ false,
});
_supports_tools = source.find("tools") != std::string::npos;
_requires_object_arguments =
source.find("tool_call.arguments | items") != std::string::npos
|| source.find("tool_call.arguments | tojson") != std::string::npos;
_supports_parallel_tool_calls = source.find("tool_call_id") != std::string::npos;
_supports_system_role = renders_needles({"<System Needle>"}, {
{{"role", "system"}, {"content", "<System Needle>"}},
{{"role", "user"}, {"content", "Hey"}}
}, {}, false);
}
const std::string & source() const { return _source; }