diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index 3fb2865ca..fe8fff2af 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -54,8 +54,7 @@ add_library(${TARGET} STATIC arg.cpp arg.h base64.hpp - chat-template.cpp - chat-template.h + chat-template.hpp common.cpp common.h console.cpp diff --git a/common/chat-template.cpp b/common/chat-template.cpp deleted file mode 100644 index 514c0baf2..000000000 --- a/common/chat-template.cpp +++ /dev/null @@ -1,156 +0,0 @@ -#include "chat-template.h" -#include "llama.h" - -using json = nlohmann::ordered_json; - -static std::string _llama_token_to_piece(const struct llama_model * model, llama_token token, bool special) { - std::string piece; - piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n' - const int n_chars = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special); - if (n_chars < 0) { - piece.resize(-n_chars); - int check = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special); - GGML_ASSERT(check == -n_chars); - } - else { - piece.resize(n_chars); - } - - return piece; -} - -static std::string llama_model_meta_val_str(const struct llama_model * model, const char * key) { - int32_t tlen = llama_model_meta_val_str(model, key, nullptr, 0); - if (tlen > 0) { - std::vector curr_tmpl_buf(tlen + 1, 0); - if (llama_model_meta_val_str(model, key, curr_tmpl_buf.data(), curr_tmpl_buf.size()) == tlen) { - return std::string(curr_tmpl_buf.data(), tlen); - } - } - return ""; -} - -llama_chat_template::llama_chat_template(const std::string & chat_template, const std::string & bos_token, const std::string & eos_token) - : _chat_template(chat_template), _bos_token(bos_token), _eos_token(eos_token) { - - _supports_tools = chat_template.find("tools") != std::string::npos; - _requires_object_arguments = - chat_template.find("tool_call.arguments | items") != std::string::npos - || chat_template.find("tool_call.arguments | tojson") != std::string::npos; - _supports_system_role = chat_template.find("System role not supported") == std::string::npos; - - if (chat_template.find("") != std::string::npos) { - _tool_call_style = Hermes2Pro; - } else if (chat_template.find(">>>all") != std::string::npos) { - _tool_call_style = FunctionaryV3Llama3; - } else if (chat_template.find("<|start_header_id|>") != std::string::npos - && chat_template.find("ipython<|end_header_id|>") != std::string::npos) { - if (chat_template.find("<|python_tag|>") != std::string::npos) { - _tool_call_style = Llama31; - } else { - _tool_call_style = Llama32; - } - } else if (chat_template.find("<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>") != std::string::npos) { - _tool_call_style = CommandRPlus; - } else { - _tool_call_style = UnknownToolCallStyle; - } - _template_root = minja::Parser::parse(_chat_template, { - /* .trim_blocks = */ true, - /* .lstrip_blocks = */ true, - /* .keep_trailing_newline = */ false, - }); -} - -llama_chat_template llama_chat_template::from_model( - const struct llama_model * model, - const char * chat_template_override) -{ - // TODO: handle "chatml"? - std::string chat_template = chat_template_override - ? chat_template_override - : llama_model_meta_val_str(model, "tokenizer.chat_template"); - auto bos_token = _llama_token_to_piece(model, llama_token_bos(model), true); - auto eos_token = _llama_token_to_piece(model, llama_token_eos(model), true); - return llama_chat_template(chat_template, bos_token, eos_token); -} - -std::string llama_chat_template::apply( - const json & messages, - const json & tools, - bool add_generation_prompt, - const json & extra_context) const -{ - auto actual_messages = messages; - - // First, "fix" messages so they have a chance to be rendered correctly by the template - - if (_requires_object_arguments || !_supports_system_role) { - std::string pending_system; - auto flush_sys = [&]() { - if (!pending_system.empty()) { - actual_messages.push_back({ - {"role", "user"}, - {"content", pending_system}, - }); - pending_system.clear(); - } - }; - for (auto & message : actual_messages) { - if (!message.contains("role") || !message.contains("content")) { - throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump()); - } - std::string role = message.at("role"); - - if (!message["content"].is_null() && !_supports_system_role) { - std::string content = message.at("content"); - if (role == "system") { - if (!pending_system.empty()) pending_system += "\n"; - pending_system += content; - continue; - } else { - if (role == "user") { - if (!pending_system.empty()) { - message["content"] = pending_system + (content.empty() ? "" : "\n" + content); - pending_system.clear(); - } - } else { - flush_sys(); - } - } - } - if (_requires_object_arguments && message.contains("tool_calls")) { - for (auto & tool_call : message.at("tool_calls")) { - if (tool_call["type"] == "function") { - auto & function = tool_call.at("function"); - std::string arguments = function.at("arguments"); - function["arguments"] = json::parse(arguments); - } - } - } - } - flush_sys(); - } - - auto context = minja::Context::make(json({ - {"messages", actual_messages}, - {"add_generation_prompt", add_generation_prompt}, - {"bos_token", _bos_token}, - {"eos_token", _eos_token}, - })); - - if (!tools.is_null()) { - auto tools_val = minja::Value(tools); - context->set("tools", tools_val); - } - if (!extra_context.is_null()) { - for (auto & kv : extra_context.items()) { - minja::Value val(kv.value()); - context->set(kv.key(), val); - } - } - - return _template_root->render(context); -} diff --git a/common/chat-template.h b/common/chat-template.h deleted file mode 100644 index 128d3bea9..000000000 --- a/common/chat-template.h +++ /dev/null @@ -1,53 +0,0 @@ -#pragma once - -#include "minja.hpp" -#include -#include -#include - -using json = nlohmann::ordered_json; - - -enum llama_tool_call_style { - UnknownToolCallStyle, - Llama31, - Llama32, - FunctionaryV3Llama3, - FunctionaryV3Llama31, - Hermes2Pro, - CommandRPlus, -}; - -class llama_chat_template { - public: - - private: - llama_tool_call_style _tool_call_style = UnknownToolCallStyle; - bool _supports_tools = true; - // Meta-Llama-3.1-8B-Instruct's template expects arguments to be an object. - // Most other templates (and OpenAI's API) expect the arguments object to be stringified. - bool _requires_object_arguments = false; - bool _supports_system_role = true; - std::string _chat_template; - std::string _bos_token; - std::string _eos_token; - std::unique_ptr _template_root; - - public: - llama_chat_template(const std::string & chat_template, const std::string & bos_token, const std::string & eos_token); - - static llama_chat_template from_model( - const struct llama_model * model, - const char * chat_template_override = nullptr); - - llama_tool_call_style tool_call_style() const { return _tool_call_style; } - - const std::string & chat_template() const { return _chat_template; } - bool supports_tools() const { return _supports_tools; } - - std::string apply( - const nlohmann::ordered_json & messages, - const nlohmann::ordered_json & tools, - bool add_generation_prompt, - const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const; -}; diff --git a/common/chat-template.hpp b/common/chat-template.hpp new file mode 100644 index 000000000..47ec0d402 --- /dev/null +++ b/common/chat-template.hpp @@ -0,0 +1,133 @@ +/* + Copyright 2024 Google LLC + + Use of this source code is governed by an MIT-style + license that can be found in the LICENSE file or at + https://opensource.org/licenses/MIT. +*/ +// SPDX-License-Identifier: MIT +#pragma once + +#include "minja.hpp" +#include +#include +#include + +using json = nlohmann::ordered_json; + +namespace minja { + +class chat_template { + public: + + private: + bool _supports_tools = true; + // Meta-Llama-3.1-8B-Instruct's template expects arguments to be an object. + // Most other templates (and OpenAI's API) expect the arguments object to be stringified. + bool _requires_object_arguments = false; + bool _supports_system_role = true; + std::string _source; + std::string _bos_token; + std::string _eos_token; + std::shared_ptr _template_root; + + public: + chat_template(const std::string & source, const std::string & bos_token, const std::string & eos_token) + : _source(source), _bos_token(bos_token), _eos_token(eos_token) + { + _supports_tools = source.find("tools") != std::string::npos; + _requires_object_arguments = + source.find("tool_call.arguments | items") != std::string::npos + || source.find("tool_call.arguments | tojson") != std::string::npos; + _supports_system_role = source.find("System role not supported") == std::string::npos; + + _template_root = minja::Parser::parse(_source, { + /* .trim_blocks = */ true, + /* .lstrip_blocks = */ true, + /* .keep_trailing_newline = */ false, + }); + } + + const std::string & source() const { return _source; } + bool supports_tools() const { return _supports_tools; } + + std::string apply( + const nlohmann::ordered_json & messages, + const nlohmann::ordered_json & tools, + bool add_generation_prompt, + const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const + { + auto actual_messages = messages; + + // First, "fix" messages so they have a chance to be rendered correctly by the template + + if (_requires_object_arguments || !_supports_system_role) { + std::string pending_system; + auto flush_sys = [&]() { + if (!pending_system.empty()) { + actual_messages.push_back({ + {"role", "user"}, + {"content", pending_system}, + }); + pending_system.clear(); + } + }; + for (auto & message : actual_messages) { + if (!message.contains("role") || !message.contains("content")) { + throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump()); + } + std::string role = message.at("role"); + + if (!message["content"].is_null() && !_supports_system_role) { + std::string content = message.at("content"); + if (role == "system") { + if (!pending_system.empty()) pending_system += "\n"; + pending_system += content; + continue; + } else { + if (role == "user") { + if (!pending_system.empty()) { + message["content"] = pending_system + (content.empty() ? "" : "\n" + content); + pending_system.clear(); + } + } else { + flush_sys(); + } + } + } + if (_requires_object_arguments && message.contains("tool_calls")) { + for (auto & tool_call : message.at("tool_calls")) { + if (tool_call["type"] == "function") { + auto & function = tool_call.at("function"); + std::string arguments = function.at("arguments"); + function["arguments"] = json::parse(arguments); + } + } + } + } + flush_sys(); + } + + auto context = minja::Context::make(json({ + {"messages", actual_messages}, + {"add_generation_prompt", add_generation_prompt}, + {"bos_token", _bos_token}, + {"eos_token", _eos_token}, + })); + + if (!tools.is_null()) { + auto tools_val = minja::Value(tools); + context->set("tools", tools_val); + } + if (!extra_context.is_null()) { + for (auto & kv : extra_context.items()) { + minja::Value val(kv.value()); + context->set(kv.key(), val); + } + } + + return _template_root->render(context); + } +}; + +} // namespace minja diff --git a/common/common.cpp b/common/common.cpp index 78263da85..909aa1970 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -9,7 +9,7 @@ #include "json.hpp" #include "json-schema-to-grammar.h" #include "llama.h" -#include "chat-template.h" +#include "chat-template.hpp" #include #include @@ -1513,13 +1513,13 @@ std::vector llama_tokenize( return result; } -std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) { +static std::string _llama_token_to_piece(const struct llama_model * model, llama_token token, bool special) { std::string piece; piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n' - const int n_chars = llama_token_to_piece(llama_get_model(ctx), token, &piece[0], piece.size(), 0, special); + const int n_chars = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special); if (n_chars < 0) { piece.resize(-n_chars); - int check = llama_token_to_piece(llama_get_model(ctx), token, &piece[0], piece.size(), 0, special); + int check = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special); GGML_ASSERT(check == -n_chars); } else { @@ -1529,6 +1529,10 @@ std::string llama_token_to_piece(const struct llama_context * ctx, llama_token t return piece; } +std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) { + return _llama_token_to_piece(llama_get_model(ctx), token, special); +} + std::string llama_detokenize(llama_context * ctx, const std::vector & tokens, bool special) { std::string text; text.resize(std::max(text.capacity(), tokens.size())); @@ -1552,7 +1556,7 @@ std::string llama_detokenize(llama_context * ctx, const std::vector bool llama_chat_verify_template(const std::string & tmpl, bool use_jinja) { if (use_jinja) { try { - auto chat_template = llama_chat_template(tmpl, "", ""); + auto chat_template = minja::chat_template(tmpl, "", ""); chat_template.apply({{ {"role", "user"}, {"content", "test"}, @@ -1651,6 +1655,30 @@ std::string llama_chat_format_example(const struct llama_model * model, return llama_chat_apply_template(model, tmpl, msgs, true); } +static std::string _llama_model_meta_val_str(const struct llama_model * model, const char * key) { + int32_t tlen = llama_model_meta_val_str(model, key, nullptr, 0); + if (tlen > 0) { + std::vector curr_tmpl_buf(tlen + 1, 0); + if (llama_model_meta_val_str(model, key, curr_tmpl_buf.data(), curr_tmpl_buf.size()) == tlen) { + return std::string(curr_tmpl_buf.data(), tlen); + } + } + return ""; +} + +minja::chat_template llama_chat_template_from_model( + const struct llama_model * model, + const char * chat_template_override) +{ + // TODO: handle "chatml"? + std::string chat_template = chat_template_override + ? chat_template_override + : _llama_model_meta_val_str(model, "tokenizer.chat_template"); + auto bos_token = _llama_token_to_piece(model, llama_token_bos(model), true); + auto eos_token = _llama_token_to_piece(model, llama_token_eos(model), true); + return {std::move(chat_template), bos_token, eos_token}; +} + // // KV cache utils // diff --git a/common/common.h b/common/common.h index 8681899ce..3c9cc80eb 100644 --- a/common/common.h +++ b/common/common.h @@ -27,6 +27,9 @@ #define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf" +// Forward declaration +namespace minja { class chat_template; } + struct llama_lora_adapter_info { std::string path; float scale; @@ -500,6 +503,10 @@ std::string llama_chat_format_single(const struct llama_model * model, std::string llama_chat_format_example(const struct llama_model * model, const std::string & tmpl); +minja::chat_template llama_chat_template_from_model( + const struct llama_model * model, + const char * chat_template_override = nullptr); + // // KV cache utils // diff --git a/common/minja.hpp b/common/minja.hpp index 7d4f4ae54..77d0ca450 100644 --- a/common/minja.hpp +++ b/common/minja.hpp @@ -1,3 +1,11 @@ +/* + Copyright 2024 Google LLC + + Use of this source code is governed by an MIT-style + license that can be found in the LICENSE file or at + https://opensource.org/licenses/MIT. +*/ +// SPDX-License-Identifier: MIT #pragma once #include @@ -532,44 +540,44 @@ static std::string error_location_suffix(const std::string & source, size_t pos) } class Context : public std::enable_shared_from_this { - protected: - Value values_; - std::shared_ptr parent_; -public: - Context(Value && values, const std::shared_ptr & parent = nullptr) : values_(std::move(values)), parent_(parent) { - if (!values_.is_object()) throw std::runtime_error("Context values must be an object: " + values_.dump()); - } - virtual ~Context() {} + protected: + Value values_; + std::shared_ptr parent_; + public: + Context(Value && values, const std::shared_ptr & parent = nullptr) : values_(std::move(values)), parent_(parent) { + if (!values_.is_object()) throw std::runtime_error("Context values must be an object: " + values_.dump()); + } + virtual ~Context() {} - static std::shared_ptr builtins(); - static std::shared_ptr make(Value && values, const std::shared_ptr & parent = builtins()); + static std::shared_ptr builtins(); + static std::shared_ptr make(Value && values, const std::shared_ptr & parent = builtins()); - std::vector keys() { - return values_.keys(); - } - virtual Value get(const Value & key) { - if (values_.contains(key)) return values_.at(key); - if (parent_) return parent_->get(key); - return Value(); - } - virtual Value & at(const Value & key) { - if (values_.contains(key)) return values_.at(key); - if (parent_) return parent_->at(key); - throw std::runtime_error("Undefined variable: " + key.dump()); - } - virtual bool contains(const Value & key) { - if (values_.contains(key)) return true; - if (parent_) return parent_->contains(key); - return false; - } - virtual void set(const Value & key, Value & value) { - values_.set(key, value); - } + std::vector keys() { + return values_.keys(); + } + virtual Value get(const Value & key) { + if (values_.contains(key)) return values_.at(key); + if (parent_) return parent_->get(key); + return Value(); + } + virtual Value & at(const Value & key) { + if (values_.contains(key)) return values_.at(key); + if (parent_) return parent_->at(key); + throw std::runtime_error("Undefined variable: " + key.dump()); + } + virtual bool contains(const Value & key) { + if (values_.contains(key)) return true; + if (parent_) return parent_->contains(key); + return false; + } + virtual void set(const Value & key, Value & value) { + values_.set(key, value); + } }; struct Location { - std::shared_ptr source; - size_t pos; + std::shared_ptr source; + size_t pos; }; class Expression { @@ -577,8 +585,8 @@ protected: virtual Value do_evaluate(const std::shared_ptr & context) const = 0; public: struct Arguments { - std::vector> args; - std::vector>> kwargs; + std::vector> args; + std::vector>> kwargs; void expectArgs(const std::string & method_name, const std::pair & pos_count, const std::pair & kw_count) const { if (args.size() < pos_count.first || args.size() > pos_count.second || kwargs.size() < kw_count.first || kwargs.size() > kw_count.second) { @@ -600,7 +608,7 @@ public: } }; - using Parameters = std::vector>>; + using Parameters = std::vector>>; Location location; @@ -687,18 +695,18 @@ struct TextTemplateToken : public TemplateToken { }; struct ExpressionTemplateToken : public TemplateToken { - std::unique_ptr expr; - ExpressionTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::unique_ptr && e) : TemplateToken(Type::Expression, location, pre, post), expr(std::move(e)) {} + std::shared_ptr expr; + ExpressionTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr && e) : TemplateToken(Type::Expression, location, pre, post), expr(std::move(e)) {} }; struct IfTemplateToken : public TemplateToken { - std::unique_ptr condition; - IfTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::unique_ptr && c) : TemplateToken(Type::If, location, pre, post), condition(std::move(c)) {} + std::shared_ptr condition; + IfTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr && c) : TemplateToken(Type::If, location, pre, post), condition(std::move(c)) {} }; struct ElifTemplateToken : public TemplateToken { - std::unique_ptr condition; - ElifTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::unique_ptr && c) : TemplateToken(Type::Elif, location, pre, post), condition(std::move(c)) {} + std::shared_ptr condition; + ElifTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr && c) : TemplateToken(Type::Elif, location, pre, post), condition(std::move(c)) {} }; struct ElseTemplateToken : public TemplateToken { @@ -706,13 +714,13 @@ struct ElseTemplateToken : public TemplateToken { }; struct EndIfTemplateToken : public TemplateToken { - EndIfTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndIf, location, pre, post) {} + EndIfTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndIf, location, pre, post) {} }; struct MacroTemplateToken : public TemplateToken { - std::unique_ptr name; + std::shared_ptr name; Expression::Parameters params; - MacroTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::unique_ptr && n, Expression::Parameters && p) + MacroTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr && n, Expression::Parameters && p) : TemplateToken(Type::Macro, location, pre, post), name(std::move(n)), params(std::move(p)) {} }; @@ -722,11 +730,11 @@ struct EndMacroTemplateToken : public TemplateToken { struct ForTemplateToken : public TemplateToken { std::vector var_names; - std::unique_ptr iterable; - std::unique_ptr condition; + std::shared_ptr iterable; + std::shared_ptr condition; bool recursive; - ForTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::vector & vns, std::unique_ptr && iter, - std::unique_ptr && c, bool r) + ForTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::vector & vns, std::shared_ptr && iter, + std::shared_ptr && c, bool r) : TemplateToken(Type::For, location, pre, post), var_names(vns), iterable(std::move(iter)), condition(std::move(c)), recursive(r) {} }; @@ -737,8 +745,8 @@ struct EndForTemplateToken : public TemplateToken { struct SetTemplateToken : public TemplateToken { std::string ns; std::vector var_names; - std::unique_ptr value; - SetTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::string & ns, const std::vector & vns, std::unique_ptr && v) + std::shared_ptr value; + SetTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::string & ns, const std::vector & vns, std::shared_ptr && v) : TemplateToken(Type::Set, location, pre, post), ns(ns), var_names(vns), value(std::move(v)) {} }; @@ -778,9 +786,9 @@ public: }; class SequenceNode : public TemplateNode { - std::vector> children; + std::vector> children; public: - SequenceNode(const Location & location, std::vector> && c) + SequenceNode(const Location & location, std::vector> && c) : TemplateNode(location), children(std::move(c)) {} void do_render(std::ostringstream & out, const std::shared_ptr & context) const override { for (const auto& child : children) child->render(out, context); @@ -797,10 +805,11 @@ public: }; class ExpressionNode : public TemplateNode { - std::unique_ptr expr; + std::shared_ptr expr; public: - ExpressionNode(const Location & location, std::unique_ptr && e) : TemplateNode(location), expr(std::move(e)) {} + ExpressionNode(const Location & location, std::shared_ptr && e) : TemplateNode(location), expr(std::move(e)) {} void do_render(std::ostringstream & out, const std::shared_ptr & context) const override { + if (!expr) throw std::runtime_error("ExpressionNode.expr is null"); auto result = expr->evaluate(context); if (result.is_string()) { out << result.get(); @@ -813,9 +822,9 @@ public: }; class IfNode : public TemplateNode { - std::vector, std::unique_ptr>> cascade; + std::vector, std::shared_ptr>> cascade; public: - IfNode(const Location & location, std::vector, std::unique_ptr>> && c) + IfNode(const Location & location, std::vector, std::shared_ptr>> && c) : TemplateNode(location), cascade(std::move(c)) {} void do_render(std::ostringstream & out, const std::shared_ptr & context) const override { for (const auto& branch : cascade) { @@ -824,6 +833,7 @@ public: enter_branch = branch.first->evaluate(context).to_bool(); } if (enter_branch) { + if (!branch.second) throw std::runtime_error("IfNode.cascade.second is null"); branch.second->render(out, context); return; } @@ -833,18 +843,20 @@ public: class ForNode : public TemplateNode { std::vector var_names; - std::unique_ptr iterable; - std::unique_ptr condition; - std::unique_ptr body; + std::shared_ptr iterable; + std::shared_ptr condition; + std::shared_ptr body; bool recursive; - std::unique_ptr else_body; + std::shared_ptr else_body; public: - ForNode(const Location & location, std::vector && var_names, std::unique_ptr && iterable, - std::unique_ptr && condition, std::unique_ptr && body, bool recursive, std::unique_ptr && else_body) + ForNode(const Location & location, std::vector && var_names, std::shared_ptr && iterable, + std::shared_ptr && condition, std::shared_ptr && body, bool recursive, std::shared_ptr && else_body) : TemplateNode(location), var_names(var_names), iterable(std::move(iterable)), condition(std::move(condition)), body(std::move(body)), recursive(recursive), else_body(std::move(else_body)) {} void do_render(std::ostringstream & out, const std::shared_ptr & context) const override { // https://jinja.palletsprojects.com/en/3.0.x/templates/#for + if (!iterable) throw std::runtime_error("ForNode.iterable is null"); + if (!body) throw std::runtime_error("ForNode.body is null"); auto iterable_value = iterable->evaluate(context); Value::CallableType loop_function; @@ -914,12 +926,12 @@ public: }; class MacroNode : public TemplateNode { - std::unique_ptr name; + std::shared_ptr name; Expression::Parameters params; - std::unique_ptr body; + std::shared_ptr body; std::unordered_map named_param_positions; public: - MacroNode(const Location & location, std::unique_ptr && n, Expression::Parameters && p, std::unique_ptr && b) + MacroNode(const Location & location, std::shared_ptr && n, Expression::Parameters && p, std::shared_ptr && b) : TemplateNode(location), name(std::move(n)), params(std::move(p)), body(std::move(b)) { for (size_t i = 0; i < params.size(); ++i) { const auto & name = params[i].first; @@ -929,6 +941,8 @@ public: } } void do_render(std::ostringstream &, const std::shared_ptr & macro_context) const override { + if (!name) throw std::runtime_error("MacroNode.name is null"); + if (!body) throw std::runtime_error("MacroNode.body is null"); auto callable = Value::callable([&](const std::shared_ptr & context, Value::Arguments & args) { auto call_context = macro_context; std::vector param_set(params.size(), false); @@ -964,19 +978,12 @@ public: class SetNode : public TemplateNode { std::string ns; std::vector var_names; - std::unique_ptr value; - std::unique_ptr template_value; + std::shared_ptr value; public: - SetNode(const Location & location, const std::string & ns, const std::vector & vns, std::unique_ptr && v, std::unique_ptr && tv) - : TemplateNode(location), ns(ns), var_names(vns), value(std::move(v)), template_value(std::move(tv)) { - if (value && template_value) { - throw std::runtime_error("Cannot have both value and template value in set node"); - } - if (template_value && var_names.size() != 1) { - throw std::runtime_error("Destructuring assignment is only supported with a single variable name"); - } - } + SetNode(const Location & location, const std::string & ns, const std::vector & vns, std::shared_ptr && v) + : TemplateNode(location), ns(ns), var_names(vns), value(std::move(v)) {} void do_render(std::ostringstream &, const std::shared_ptr & context) const override { + if (!value) throw std::runtime_error("SetNode.value is null"); if (!ns.empty()) { if (var_names.size() != 1) { throw std::runtime_error("Namespaced set only supports a single variable name"); @@ -985,9 +992,6 @@ public: auto ns_value = context->get(ns); if (!ns_value.is_object()) throw std::runtime_error("Namespace '" + ns + "' is not an object"); ns_value.set(name, this->value->evaluate(context)); - } else if (template_value) { - Value value { template_value->render(context) }; - context->set(var_names[0], value); } else { auto val = value->evaluate(context); destructuring_assign(var_names, context, val); @@ -995,14 +999,29 @@ public: } }; -class IfExpr : public Expression { - std::unique_ptr condition; - std::unique_ptr then_expr; - std::unique_ptr else_expr; +class SetTemplateNode : public TemplateNode { + std::string name; + std::shared_ptr template_value; public: - IfExpr(const Location & location, std::unique_ptr && c, std::unique_ptr && t, std::unique_ptr && e) + SetTemplateNode(const Location & location, const std::string & name, std::shared_ptr && tv) + : TemplateNode(location), name(name), template_value(std::move(tv)) {} + void do_render(std::ostringstream &, const std::shared_ptr & context) const override { + if (!template_value) throw std::runtime_error("SetTemplateNode.template_value is null"); + Value value { template_value->render(context) }; + context->set(name, value); + } +}; + +class IfExpr : public Expression { + std::shared_ptr condition; + std::shared_ptr then_expr; + std::shared_ptr else_expr; +public: + IfExpr(const Location & location, std::shared_ptr && c, std::shared_ptr && t, std::shared_ptr && e) : Expression(location), condition(std::move(c)), then_expr(std::move(t)), else_expr(std::move(e)) {} Value do_evaluate(const std::shared_ptr & context) const override { + if (!condition) throw std::runtime_error("IfExpr.condition is null"); + if (!then_expr) throw std::runtime_error("IfExpr.then_expr is null"); if (condition->evaluate(context).to_bool()) { return then_expr->evaluate(context); } @@ -1022,13 +1041,14 @@ public: }; class ArrayExpr : public Expression { - std::vector> elements; + std::vector> elements; public: - ArrayExpr(const Location & location, std::vector> && e) + ArrayExpr(const Location & location, std::vector> && e) : Expression(location), elements(std::move(e)) {} Value do_evaluate(const std::shared_ptr & context) const override { auto result = Value::array(); for (const auto& e : elements) { + if (!e) throw std::runtime_error("Array element is null"); result.push_back(e->evaluate(context)); } return result; @@ -1036,13 +1056,15 @@ public: }; class DictExpr : public Expression { - std::vector, std::unique_ptr>> elements; + std::vector, std::shared_ptr>> elements; public: - DictExpr(const Location & location, std::vector, std::unique_ptr>> && e) + DictExpr(const Location & location, std::vector, std::shared_ptr>> && e) : Expression(location), elements(std::move(e)) {} Value do_evaluate(const std::shared_ptr & context) const override { auto result = Value::object(); for (const auto& e : elements) { + if (!e.first) throw std::runtime_error("Dict key is null"); + if (!e.second) throw std::runtime_error("Dict value is null"); result.set(e.first->evaluate(context), e.second->evaluate(context)); } return result; @@ -1051,8 +1073,8 @@ public: class SliceExpr : public Expression { public: - std::unique_ptr start, end; - SliceExpr(const Location & location, std::unique_ptr && s, std::unique_ptr && e) + std::shared_ptr start, end; + SliceExpr(const Location & location, std::shared_ptr && s, std::shared_ptr && e) : Expression(location), start(std::move(s)), end(std::move(e)) {} Value do_evaluate(const std::shared_ptr &) const override { throw std::runtime_error("SliceExpr not implemented"); @@ -1060,12 +1082,14 @@ public: }; class SubscriptExpr : public Expression { - std::unique_ptr base; - std::unique_ptr index; + std::shared_ptr base; + std::shared_ptr index; public: - SubscriptExpr(const Location & location, std::unique_ptr && b, std::unique_ptr && i) + SubscriptExpr(const Location & location, std::shared_ptr && b, std::shared_ptr && i) : Expression(location), base(std::move(b)), index(std::move(i)) {} Value do_evaluate(const std::shared_ptr & context) const override { + if (!base) throw std::runtime_error("SubscriptExpr.base is null"); + if (!index) throw std::runtime_error("SubscriptExpr.index is null"); auto target_value = base->evaluate(context); if (auto slice = dynamic_cast(index.get())) { if (!target_value.is_array()) throw std::runtime_error("Subscripting non-array"); @@ -1094,12 +1118,13 @@ class UnaryOpExpr : public Expression { public: enum class Op { Plus, Minus, LogicalNot }; private: - std::unique_ptr expr; + std::shared_ptr expr; Op op; public: - UnaryOpExpr(const Location & location, std::unique_ptr && e, Op o) + UnaryOpExpr(const Location & location, std::shared_ptr && e, Op o) : Expression(location), expr(std::move(e)), op(o) {} Value do_evaluate(const std::shared_ptr & context) const override { + if (!expr) throw std::runtime_error("UnaryOpExpr.expr is null"); auto e = expr->evaluate(context); switch (op) { case Op::Plus: return e; @@ -1114,13 +1139,15 @@ class BinaryOpExpr : public Expression { public: enum class Op { StrConcat, Add, Sub, Mul, MulMul, Div, DivDiv, Mod, Eq, Ne, Lt, Gt, Le, Ge, And, Or, In, NotIn, Is, IsNot }; private: - std::unique_ptr left; - std::unique_ptr right; + std::shared_ptr left; + std::shared_ptr right; Op op; public: - BinaryOpExpr(const Location & location, std::unique_ptr && l, std::unique_ptr && r, Op o) + BinaryOpExpr(const Location & location, std::shared_ptr && l, std::shared_ptr && r, Op o) : Expression(location), left(std::move(l)), right(std::move(r)), op(o) {} Value do_evaluate(const std::shared_ptr & context) const override { + if (!left) throw std::runtime_error("BinaryOpExpr.left is null"); + if (!right) throw std::runtime_error("BinaryOpExpr.right is null"); auto l = left->evaluate(context); auto do_eval = [&](const Value & l) -> Value { @@ -1210,13 +1237,15 @@ static std::string html_escape(const std::string & s) { } class MethodCallExpr : public Expression { - std::unique_ptr object; - std::unique_ptr method; + std::shared_ptr object; + std::shared_ptr method; Expression::Arguments args; public: - MethodCallExpr(const Location & location, std::unique_ptr && obj, std::unique_ptr && m, Expression::Arguments && a) + MethodCallExpr(const Location & location, std::shared_ptr && obj, std::shared_ptr && m, Expression::Arguments && a) : Expression(location), object(std::move(obj)), method(std::move(m)), args(std::move(a)) {} Value do_evaluate(const std::shared_ptr & context) const override { + if (!object) throw std::runtime_error("MethodCallExpr.object is null"); + if (!method) throw std::runtime_error("MethodCallExpr.method is null"); auto obj = object->evaluate(context); if (obj.is_array()) { if (method->get_name() == "append") { @@ -1279,11 +1308,12 @@ public: class CallExpr : public Expression { public: - std::unique_ptr object; + std::shared_ptr object; Expression::Arguments args; - CallExpr(const Location & location, std::unique_ptr && obj, Expression::Arguments && a) + CallExpr(const Location & location, std::shared_ptr && obj, Expression::Arguments && a) : Expression(location), object(std::move(obj)), args(std::move(a)) {} Value do_evaluate(const std::shared_ptr & context) const override { + if (!object) throw std::runtime_error("CallExpr.object is null"); auto obj = object->evaluate(context); if (!obj.is_callable()) { throw std::runtime_error("Object is not callable: " + obj.dump(2)); @@ -1294,14 +1324,15 @@ public: }; class FilterExpr : public Expression { - std::vector> parts; + std::vector> parts; public: - FilterExpr(const Location & location, std::vector> && p) + FilterExpr(const Location & location, std::vector> && p) : Expression(location), parts(std::move(p)) {} Value do_evaluate(const std::shared_ptr & context) const override { Value result; bool first = true; for (const auto& part : parts) { + if (!part) throw std::runtime_error("FilterExpr.part is null"); if (first) { first = false; result = part->evaluate(context); @@ -1322,7 +1353,7 @@ public: return result; } - void prepend(std::unique_ptr && e) { + void prepend(std::shared_ptr && e) { parts.insert(parts.begin(), std::move(e)); } }; @@ -1375,7 +1406,7 @@ private: escape = true; } else if (*it == quote) { ++it; - return nonstd_make_unique(result); + return nonstd_make_unique(std::move(result)); } else { result += *it; } @@ -1429,37 +1460,37 @@ private: } /** integer, float, bool, string */ - std::unique_ptr parseConstant() { + std::shared_ptr parseConstant() { auto start = it; consumeSpaces(); if (it == end) return nullptr; if (*it == '"' || *it == '\'') { auto str = parseString(); - if (str) return nonstd_make_unique(*str); + if (str) return std::make_shared(*str); } static std::regex prim_tok(R"(true\b|True\b|false\b|False\b|None\b)"); auto token = consumeToken(prim_tok); if (!token.empty()) { - if (token == "true" || token == "True") return nonstd_make_unique(true); - if (token == "false" || token == "False") return nonstd_make_unique(false); - if (token == "None") return nonstd_make_unique(nullptr); + if (token == "true" || token == "True") return std::make_shared(true); + if (token == "false" || token == "False") return std::make_shared(false); + if (token == "None") return std::make_shared(nullptr); throw std::runtime_error("Unknown constant token: " + token); } auto number = parseNumber(it, end); - if (!number.is_null()) return nonstd_make_unique(number); + if (!number.is_null()) return std::make_shared(number); it = start; return nullptr; } class expression_parsing_error : public std::runtime_error { - const CharIterator it; - public: - expression_parsing_error(const std::string & message, const CharIterator & it) - : std::runtime_error(message), it(it) {} - size_t get_pos(const CharIterator & begin) const { - return std::distance(begin, it); + const CharIterator it; + public: + expression_parsing_error(const std::string & message, const CharIterator & it) + : std::runtime_error(message), it(it) {} + size_t get_pos(const CharIterator & begin) const { + return std::distance(begin, it); } }; @@ -1510,7 +1541,7 @@ private: return ""; } - std::unique_ptr parseExpression(bool allow_if_expr = true) { + std::shared_ptr parseExpression(bool allow_if_expr = true) { auto left = parseLogicalOr(); if (it == end) return left; @@ -1523,19 +1554,19 @@ private: auto location = get_location(); auto if_expr = parseIfExpression(); - return nonstd_make_unique(location, std::move(if_expr.first), std::move(left), std::move(if_expr.second)); + return std::make_shared(location, std::move(if_expr.first), std::move(left), std::move(if_expr.second)); } Location get_location() const { return {template_str, (size_t) std::distance(start, it)}; } - std::pair, std::unique_ptr> parseIfExpression() { + std::pair, std::shared_ptr> parseIfExpression() { auto condition = parseLogicalOr(); if (!condition) throw std::runtime_error("Expected condition expression"); static std::regex else_tok(R"(else\b)"); - std::unique_ptr else_expr; + std::shared_ptr else_expr; if (!consumeToken(else_tok).empty()) { else_expr = parseExpression(); if (!else_expr) throw std::runtime_error("Expected 'else' expression"); @@ -1543,7 +1574,7 @@ private: return std::make_pair(std::move(condition), std::move(else_expr)); } - std::unique_ptr parseLogicalOr() { + std::shared_ptr parseLogicalOr() { auto left = parseLogicalAnd(); if (!left) throw std::runtime_error("Expected left side of 'logical or' expression"); @@ -1552,24 +1583,24 @@ private: while (!consumeToken(or_tok).empty()) { auto right = parseLogicalAnd(); if (!right) throw std::runtime_error("Expected right side of 'or' expression"); - left = nonstd_make_unique(location, std::move(left), std::move(right), BinaryOpExpr::Op::Or); + left = std::make_shared(location, std::move(left), std::move(right), BinaryOpExpr::Op::Or); } return left; } - std::unique_ptr parseLogicalNot() { + std::shared_ptr parseLogicalNot() { static std::regex not_tok(R"(not\b)"); auto location = get_location(); if (!consumeToken(not_tok).empty()) { auto sub = parseLogicalNot(); if (!sub) throw std::runtime_error("Expected expression after 'not' keyword"); - return nonstd_make_unique(location, std::move(sub), UnaryOpExpr::Op::LogicalNot); + return std::make_shared(location, std::move(sub), UnaryOpExpr::Op::LogicalNot); } return parseLogicalCompare(); } - std::unique_ptr parseLogicalAnd() { + std::shared_ptr parseLogicalAnd() { auto left = parseLogicalNot(); if (!left) throw std::runtime_error("Expected left side of 'logical and' expression"); @@ -1578,12 +1609,12 @@ private: while (!consumeToken(and_tok).empty()) { auto right = parseLogicalNot(); if (!right) throw std::runtime_error("Expected right side of 'and' expression"); - left = nonstd_make_unique(location, std::move(left), std::move(right), BinaryOpExpr::Op::And); + left = std::make_shared(location, std::move(left), std::move(right), BinaryOpExpr::Op::And); } return left; } - std::unique_ptr parseLogicalCompare() { + std::shared_ptr parseLogicalCompare() { auto left = parseStringConcat(); if (!left) throw std::runtime_error("Expected left side of 'logical compare' expression"); @@ -1598,7 +1629,7 @@ private: auto identifier = parseIdentifier(); if (!identifier) throw std::runtime_error("Expected identifier after 'is' keyword"); - return nonstd_make_unique( + return std::make_shared( left->location, std::move(left), std::move(identifier), negated ? BinaryOpExpr::Op::IsNot : BinaryOpExpr::Op::Is); @@ -1615,7 +1646,7 @@ private: else if (op_str == "in") op = BinaryOpExpr::Op::In; else if (op_str.substr(0, 3) == "not") op = BinaryOpExpr::Op::NotIn; else throw std::runtime_error("Unknown comparison operator: " + op_str); - left = nonstd_make_unique(get_location(), std::move(left), std::move(right), op); + left = std::make_shared(get_location(), std::move(left), std::move(right), op); } return left; } @@ -1688,16 +1719,16 @@ private: throw std::runtime_error("Expected closing parenthesis in call args"); } - std::unique_ptr parseIdentifier() { + std::shared_ptr parseIdentifier() { static std::regex ident_regex(R"((?!(?:not|is|and|or|del)\b)[a-zA-Z_]\w*)"); auto location = get_location(); auto ident = consumeToken(ident_regex); if (ident.empty()) return nullptr; - return nonstd_make_unique(location, ident); + return std::make_shared(location, ident); } - std::unique_ptr parseStringConcat() { + std::shared_ptr parseStringConcat() { auto left = parseMathPow(); if (!left) throw std::runtime_error("Expected left side of 'string concat' expression"); @@ -1705,24 +1736,24 @@ private: if (!consumeToken(concat_tok).empty()) { auto right = parseLogicalAnd(); if (!right) throw std::runtime_error("Expected right side of 'string concat' expression"); - left = nonstd_make_unique(get_location(), std::move(left), std::move(right), BinaryOpExpr::Op::StrConcat); + left = std::make_shared(get_location(), std::move(left), std::move(right), BinaryOpExpr::Op::StrConcat); } return left; } - std::unique_ptr parseMathPow() { + std::shared_ptr parseMathPow() { auto left = parseMathPlusMinus(); if (!left) throw std::runtime_error("Expected left side of 'math pow' expression"); while (!consumeToken("**").empty()) { auto right = parseMathPlusMinus(); if (!right) throw std::runtime_error("Expected right side of 'math pow' expression"); - left = nonstd_make_unique(get_location(), std::move(left), std::move(right), BinaryOpExpr::Op::MulMul); + left = std::make_shared(get_location(), std::move(left), std::move(right), BinaryOpExpr::Op::MulMul); } return left; } - std::unique_ptr parseMathPlusMinus() { + std::shared_ptr parseMathPlusMinus() { static std::regex plus_minus_tok(R"(\+|-(?![}%#]\}))"); auto left = parseMathMulDiv(); @@ -1732,12 +1763,12 @@ private: auto right = parseMathMulDiv(); if (!right) throw std::runtime_error("Expected right side of 'math plus/minus' expression"); auto op = op_str == "+" ? BinaryOpExpr::Op::Add : BinaryOpExpr::Op::Sub; - left = nonstd_make_unique(get_location(), std::move(left), std::move(right), op); + left = std::make_shared(get_location(), std::move(left), std::move(right), op); } return left; } - std::unique_ptr parseMathMulDiv() { + std::shared_ptr parseMathMulDiv() { auto left = parseMathUnaryPlusMinus(); if (!left) throw std::runtime_error("Expected left side of 'math mul/div' expression"); @@ -1751,7 +1782,7 @@ private: : op_str == "/" ? BinaryOpExpr::Op::Div : op_str == "//" ? BinaryOpExpr::Op::DivDiv : BinaryOpExpr::Op::Mod; - left = nonstd_make_unique(get_location(), std::move(left), std::move(right), op); + left = std::make_shared(get_location(), std::move(left), std::move(right), op); } if (!consumeToken("|").empty()) { @@ -1760,20 +1791,20 @@ private: filter->prepend(std::move(left)); return expr; } else { - std::vector> parts; + std::vector> parts; parts.emplace_back(std::move(left)); parts.emplace_back(std::move(expr)); - return nonstd_make_unique(get_location(), std::move(parts)); + return std::make_shared(get_location(), std::move(parts)); } } return left; } - std::unique_ptr call_func(const std::string & name, Expression::Arguments && args) const { - return nonstd_make_unique(get_location(), nonstd_make_unique(get_location(), name), std::move(args)); + std::shared_ptr call_func(const std::string & name, Expression::Arguments && args) const { + return std::make_shared(get_location(), std::make_shared(get_location(), name), std::move(args)); } - std::unique_ptr parseMathUnaryPlusMinus() { + std::shared_ptr parseMathUnaryPlusMinus() { static std::regex unary_plus_minus_tok(R"(\+|-(?![}%#]\}))"); auto op_str = consumeToken(unary_plus_minus_tok); auto expr = parseValueExpression(); @@ -1781,19 +1812,19 @@ private: if (!op_str.empty()) { auto op = op_str == "+" ? UnaryOpExpr::Op::Plus : UnaryOpExpr::Op::Minus; - return nonstd_make_unique(get_location(), std::move(expr), op); + return std::make_shared(get_location(), std::move(expr), op); } return expr; } - std::unique_ptr parseValueExpression() { - auto parseValue = [&]() -> std::unique_ptr { + std::shared_ptr parseValueExpression() { + auto parseValue = [&]() -> std::shared_ptr { auto location = get_location(); auto constant = parseConstant(); - if (constant) return nonstd_make_unique(location, *constant); + if (constant) return std::make_shared(location, *constant); static std::regex null_regex(R"(null\b)"); - if (!consumeToken(null_regex).empty()) return nonstd_make_unique(location, Value()); + if (!consumeToken(null_regex).empty()) return std::make_shared(location, Value()); auto identifier = parseIdentifier(); if (identifier) return identifier; @@ -1814,19 +1845,19 @@ private: while (it != end && consumeSpaces() && peekSymbols({ "[", "." })) { if (!consumeToken("[").empty()) { - std::unique_ptr index; + std::shared_ptr index; if (!consumeToken(":").empty()) { auto slice_end = parseExpression(); - index = nonstd_make_unique(slice_end->location, nullptr, std::move(slice_end)); + index = std::make_shared(slice_end->location, nullptr, std::move(slice_end)); } else { auto slice_start = parseExpression(); if (!consumeToken(":").empty()) { consumeSpaces(); if (peekSymbols({ "]" })) { - index = nonstd_make_unique(slice_start->location, std::move(slice_start), nullptr); + index = std::make_shared(slice_start->location, std::move(slice_start), nullptr); } else { auto slice_end = parseExpression(); - index = nonstd_make_unique(slice_start->location, std::move(slice_start), std::move(slice_end)); + index = std::make_shared(slice_start->location, std::move(slice_start), std::move(slice_end)); } } else { index = std::move(slice_start); @@ -1835,7 +1866,7 @@ private: if (!index) throw std::runtime_error("Empty index in subscript"); if (consumeToken("]").empty()) throw std::runtime_error("Expected closing bracket in subscript"); - value = nonstd_make_unique(value->location, std::move(value), std::move(index)); + value = std::make_shared(value->location, std::move(value), std::move(index)); } else if (!consumeToken(".").empty()) { auto identifier = parseIdentifier(); if (!identifier) throw std::runtime_error("Expected identifier in subscript"); @@ -1843,10 +1874,10 @@ private: consumeSpaces(); if (peekSymbols({ "(" })) { auto callParams = parseCallArgs(); - value = nonstd_make_unique(identifier->location, std::move(value), std::move(identifier), std::move(callParams)); + value = std::make_shared(identifier->location, std::move(value), std::move(identifier), std::move(callParams)); } else { - auto key = nonstd_make_unique(identifier->location, Value(identifier->get_name())); - value = nonstd_make_unique(identifier->location, std::move(value), std::move(key)); + auto key = std::make_shared(identifier->location, Value(identifier->get_name())); + value = std::make_shared(identifier->location, std::move(value), std::move(key)); } } consumeSpaces(); @@ -1855,12 +1886,12 @@ private: if (peekSymbols({ "(" })) { auto location = get_location(); auto callParams = parseCallArgs(); - value = nonstd_make_unique(location, std::move(value), std::move(callParams)); + value = std::make_shared(location, std::move(value), std::move(callParams)); } return value; } - std::unique_ptr parseBracedExpressionOrArray() { + std::shared_ptr parseBracedExpressionOrArray() { if (consumeToken("(").empty()) return nullptr; auto expr = parseExpression(); @@ -1870,7 +1901,7 @@ private: return expr; // Drop the parentheses } - std::vector> tuple; + std::vector> tuple; tuple.emplace_back(std::move(expr)); while (it != end) { @@ -1880,18 +1911,18 @@ private: tuple.push_back(std::move(next)); if (!consumeToken(")").empty()) { - return nonstd_make_unique(get_location(), std::move(tuple)); + return std::make_shared(get_location(), std::move(tuple)); } } throw std::runtime_error("Expected closing parenthesis"); } - std::unique_ptr parseArray() { + std::shared_ptr parseArray() { if (consumeToken("[").empty()) return nullptr; - std::vector> elements; + std::vector> elements; if (!consumeToken("]").empty()) { - return nonstd_make_unique(get_location(), std::move(elements)); + return std::make_shared(get_location(), std::move(elements)); } auto first_expr = parseExpression(); if (!first_expr) throw std::runtime_error("Expected first expression in array"); @@ -1903,7 +1934,7 @@ private: if (!expr) throw std::runtime_error("Expected expression in array"); elements.push_back(std::move(expr)); } else if (!consumeToken("]").empty()) { - return nonstd_make_unique(get_location(), std::move(elements)); + return std::make_shared(get_location(), std::move(elements)); } else { throw std::runtime_error("Expected comma or closing bracket in array"); } @@ -1911,12 +1942,12 @@ private: throw std::runtime_error("Expected closing bracket"); } - std::unique_ptr parseDictionary() { + std::shared_ptr parseDictionary() { if (consumeToken("{").empty()) return nullptr; - std::vector, std::unique_ptr>> elements; + std::vector, std::shared_ptr>> elements; if (!consumeToken("}").empty()) { - return nonstd_make_unique(get_location(), std::move(elements)); + return std::make_shared(get_location(), std::move(elements)); } auto parseKeyValuePair = [&]() { @@ -1934,7 +1965,7 @@ private: if (!consumeToken(",").empty()) { parseKeyValuePair(); } else if (!consumeToken("}").empty()) { - return nonstd_make_unique(get_location(), std::move(elements)); + return std::make_shared(get_location(), std::move(elements)); } else { throw std::runtime_error("Expected comma or closing brace in dictionary"); } @@ -2051,7 +2082,7 @@ private: auto iterable = parseExpression(/* allow_if_expr = */ false); if (!iterable) throw std::runtime_error("Expected iterable in for block"); - std::unique_ptr condition; + std::shared_ptr condition; if (!consumeToken(if_tok).empty()) { condition = parseExpression(); } @@ -2067,7 +2098,7 @@ private: std::string ns; std::vector var_names; - std::unique_ptr value; + std::shared_ptr value; if (!(group = consumeTokenGroups(namespaced_var_regex)).empty()) { ns = group[1]; var_names.push_back(group[2]); @@ -2114,17 +2145,17 @@ private: } } - std::unique_ptr parseTemplate( + std::shared_ptr parseTemplate( const TemplateTokenIterator & begin, TemplateTokenIterator & it, const TemplateTokenIterator & end, bool fully = false) const { - std::vector> children; + std::vector> children; while (it != end) { const auto start = it; const auto & token = *(it++); if (auto if_token = dynamic_cast(token.get())) { - std::vector, std::unique_ptr>> cascade; + std::vector, std::shared_ptr>> cascade; cascade.emplace_back(std::move(if_token->condition), parseTemplate(begin, it, end)); while (it != end && (*it)->type == TemplateToken::Type::Elif) { @@ -2138,17 +2169,17 @@ private: if (it == end || (*(it++))->type != TemplateToken::Type::EndIf) { throw unterminated(**start); } - children.emplace_back(nonstd_make_unique(token->location, std::move(cascade))); + children.emplace_back(std::make_shared(token->location, std::move(cascade))); } else if (auto for_token = dynamic_cast(token.get())) { auto body = parseTemplate(begin, it, end); - auto else_body = std::unique_ptr(); + auto else_body = std::shared_ptr(); if (it != end && (*it)->type == TemplateToken::Type::Else) { else_body = parseTemplate(begin, ++it, end); } if (it == end || (*(it++))->type != TemplateToken::Type::EndFor) { throw unterminated(**start); } - children.emplace_back(nonstd_make_unique(token->location, std::move(for_token->var_names), std::move(for_token->iterable), std::move(for_token->condition), std::move(body), for_token->recursive, std::move(else_body))); + children.emplace_back(std::make_shared(token->location, std::move(for_token->var_names), std::move(for_token->iterable), std::move(for_token->condition), std::move(body), for_token->recursive, std::move(else_body))); } else if (auto text_token = dynamic_cast(token.get())) { SpaceHandling pre_space = (it - 1) != begin ? (*(it - 2))->post_space : SpaceHandling::Keep; SpaceHandling post_space = it != end ? (*it)->pre_space : SpaceHandling::Keep; @@ -2173,25 +2204,28 @@ private: static std::regex r(R"(\r?\n$)"); text = std::regex_replace(text, r, ""); // Strip one trailing newline } - children.emplace_back(nonstd_make_unique(token->location, text)); + children.emplace_back(std::make_shared(token->location, text)); } else if (auto expr_token = dynamic_cast(token.get())) { - children.emplace_back(nonstd_make_unique(token->location, std::move(expr_token->expr))); + children.emplace_back(std::make_shared(token->location, std::move(expr_token->expr))); } else if (auto set_token = dynamic_cast(token.get())) { if (set_token->value) { - children.emplace_back(nonstd_make_unique(token->location, set_token->ns, set_token->var_names, std::move(set_token->value), nullptr)); + children.emplace_back(std::make_shared(token->location, set_token->ns, set_token->var_names, std::move(set_token->value))); } else { auto value_template = parseTemplate(begin, it, end); if (it == end || (*(it++))->type != TemplateToken::Type::EndSet) { throw unterminated(**start); } - children.emplace_back(nonstd_make_unique(token->location, set_token->ns, set_token->var_names, nullptr, std::move(value_template))); + if (!set_token->ns.empty()) throw std::runtime_error("Namespaced set not supported in set with template value"); + if (set_token->var_names.size() != 1) throw std::runtime_error("Structural assignment not supported in set with template value"); + auto & name = set_token->var_names[0]; + children.emplace_back(std::make_shared(token->location, name, std::move(value_template))); } } else if (auto macro_token = dynamic_cast(token.get())) { auto body = parseTemplate(begin, it, end); if (it == end || (*(it++))->type != TemplateToken::Type::EndMacro) { throw unterminated(**start); } - children.emplace_back(nonstd_make_unique(token->location, std::move(macro_token->name), std::move(macro_token->params), std::move(body))); + children.emplace_back(std::make_shared(token->location, std::move(macro_token->name), std::move(macro_token->params), std::move(body))); } else if (dynamic_cast(token.get())) { // Ignore comments } else if (dynamic_cast(token.get()) @@ -2210,17 +2244,17 @@ private: throw unexpected(**it); } if (children.empty()) { - return nonstd_make_unique(Location { template_str, 0 }, std::string()); + return std::make_shared(Location { template_str, 0 }, std::string()); } else if (children.size() == 1) { return std::move(children[0]); } else { - return nonstd_make_unique(children[0]->location(), std::move(children)); + return std::make_shared(children[0]->location(), std::move(children)); } } public: - static std::unique_ptr parse(const std::string& template_str, const Options & options) { + static std::shared_ptr parse(const std::string& template_str, const Options & options) { Parser parser(std::make_shared(template_str), options); auto tokens = parser.tokenize(); TemplateTokenIterator begin = tokens.begin(); diff --git a/common/tool-call.cpp b/common/tool-call.cpp index 55d5cae59..1c713a3a1 100644 --- a/common/tool-call.cpp +++ b/common/tool-call.cpp @@ -12,6 +12,29 @@ using json = nlohmann::ordered_json; +llama_tool_call_style llama_tool_call_style_detect(const minja::chat_template & chat_template) { + const auto & src = chat_template.source(); + + if (src.find("") != std::string::npos) { + return Hermes2Pro; + } else if (src.find(">>>all") != std::string::npos) { + return FunctionaryV3Llama3; + } else if (src.find("<|start_header_id|>") != std::string::npos + && src.find("ipython<|end_header_id|>") != std::string::npos) { + if (src.find("<|python_tag|>") != std::string::npos) { + return Llama31; + } else { + return Llama32; + } + } else if (src.find("<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>") != std::string::npos) { + return CommandRPlus; + } else { + return UnknownToolCallStyle; + } +} + static bool parse_json(std::string::const_iterator & it, const std::string::const_iterator & end, json & out) { // // https://json.nlohmann.me/features/parsing/sax_interface/ struct json_error_locator : public nlohmann::json_sax { @@ -207,7 +230,8 @@ llama_tool_calls parse_tool_calls(llama_tool_call_style style, const json & tool } llama_tool_call_handler llama_tool_call_handler_init( - const llama_chat_template & tmpl, + llama_tool_call_style style, + const minja::chat_template & tmpl, bool allow_content, bool parallel_tool_calls, const nlohmann::ordered_json & messages, @@ -215,18 +239,18 @@ llama_tool_call_handler llama_tool_call_handler_init( { llama_tool_call_handler handler; - switch (tmpl.tool_call_style()) { + switch (style) { case llama_tool_call_style::Llama31: case llama_tool_call_style::Llama32: { static auto builtin_tools = json {"wolfram_alpha", "brave_search"}; - auto uses_python_tag = tmpl.tool_call_style() == llama_tool_call_style::Llama31; + auto uses_python_tag = style == llama_tool_call_style::Llama31; // Technically we should only trigger on `"\n{\"name\": \"" + name + "\""` for each tool name, // but Llama-3.2-3B (and 1B) struggles to output valid tool calls so we're "guiding" it strongly as soon // as it seems to be outputting some JSON. // TODO: make this conditional on a very small model (e.g. 1B / 3B). - auto eagerly_match_any_json = tmpl.tool_call_style() == llama_tool_call_style::Llama32; + auto eagerly_match_any_json = style == llama_tool_call_style::Llama32; handler.grammar = build_grammar([&](const llama_grammar_builder & builder) { std::vector tool_rules; diff --git a/common/tool-call.h b/common/tool-call.h index 27ec089af..dc505ba2d 100644 --- a/common/tool-call.h +++ b/common/tool-call.h @@ -2,10 +2,20 @@ #include "ggml.h" #include "common.h" +#include "chat-template.hpp" // Change JSON_ASSERT from assert() to GGML_ASSERT: #define JSON_ASSERT GGML_ASSERT #include "json.hpp" -#include "chat-template.h" + +enum llama_tool_call_style { + UnknownToolCallStyle, + Llama31, + Llama32, + FunctionaryV3Llama3, + FunctionaryV3Llama31, + Hermes2Pro, + CommandRPlus, +}; struct llama_tool_call { std::string name; @@ -24,10 +34,13 @@ struct llama_tool_call_handler { std::vector additional_stop_words; }; +llama_tool_call_style llama_tool_call_style_detect(const minja::chat_template & chat_template); + llama_tool_calls parse_tool_calls(llama_tool_call_style style, const nlohmann::ordered_json & tools, const std::string& input); llama_tool_call_handler llama_tool_call_handler_init( - const llama_chat_template & tmpl, + llama_tool_call_style style, + const minja::chat_template & tmpl, bool allow_content, bool parallel_tool_calls, const nlohmann::ordered_json & messages, diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 10913e7d8..61b900a08 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -663,7 +663,7 @@ struct server_context { llama_chat_message chat[] = {{"user", "test"}}; if (use_jinja) { - auto chat_template = llama_chat_template::from_model(model); + auto chat_template = llama_chat_template_from_model(model); try { chat_template.apply({{ {"role", "user"}, @@ -2875,11 +2875,12 @@ int main(int argc, char ** argv) { return; } - auto chat_template = llama_chat_template::from_model(ctx_server.model, params.chat_template.empty() ? nullptr : params.chat_template.c_str()); + static auto chat_template = llama_chat_template_from_model(ctx_server.model, params.chat_template.empty() ? nullptr : params.chat_template.c_str()); + static auto tool_call_style = llama_tool_call_style_detect(chat_template); json data; try { - data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), chat_template, params.use_jinja); + data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), chat_template, tool_call_style, params.use_jinja); } catch (const std::exception & e) { res_error(res, format_error_response(e.what(), ERROR_TYPE_NOT_SUPPORTED)); return; @@ -2897,7 +2898,7 @@ int main(int argc, char ** argv) { ctx_server.receive_cmpl_results(task_ids, [&](const std::vector & results) { // multitask is never support in chat completion, there is only one result try { - json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, chat_template, /*.streaming =*/ false, verbose); + json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, tool_call_style, /*.streaming =*/ false, verbose); res_ok(res, result_oai); } catch (const std::runtime_error & e) { res_error(res, format_error_response(e.what(), ERROR_TYPE_SERVER)); diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index a19e7ce99..aff2a9554 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -14,7 +14,6 @@ // Change JSON_ASSERT from assert() to GGML_ASSERT: #define JSON_ASSERT GGML_ASSERT -#include "chat-template.h" #include "json.hpp" #include "minja.hpp" #include "tool-call.h" @@ -309,7 +308,8 @@ static bool server_sent_event(httplib::DataSink & sink, const char * event, cons static json oaicompat_completion_params_parse( const struct llama_model * model, const json & body, /* openai api json semantics */ - const llama_chat_template & tmpl, + const minja::chat_template & tmpl, + llama_tool_call_style tool_call_style, bool use_jinja) { json llama_params; @@ -320,7 +320,7 @@ static json oaicompat_completion_params_parse( auto has_tools = tools.is_array() && !tools.empty(); // Apply chat template to the list of messages - llama_params["chat_template"] = tmpl.chat_template(); + llama_params["chat_template"] = tmpl.source(); if (use_jinja) { if (has_tools && !tmpl.supports_tools()) { @@ -372,7 +372,7 @@ static json oaicompat_completion_params_parse( llama_params["parse_tool_calls"] = true; llama_params["parallel_tool_calls"] = parallel_tool_calls; - auto handler = llama_tool_call_handler_init(tmpl, allow_content, parallel_tool_calls, body.at("messages"), tools); + auto handler = llama_tool_call_handler_init(tool_call_style, tmpl, allow_content, parallel_tool_calls, body.at("messages"), tools); llama_params["prompt"] = handler.prompt; for (const auto & stop : handler.additional_stop_words) { @@ -395,7 +395,7 @@ static json oaicompat_completion_params_parse( llama_params["prompt"] = tmpl.apply(body.at("messages"), tools, /* add_generation_prompt= */ true); } } else { - llama_params["prompt"] = format_chat(model, tmpl.chat_template(), body.at("messages")); + llama_params["prompt"] = format_chat(model, tmpl.source(), body.at("messages")); } // Handle "n" field @@ -435,7 +435,7 @@ static json oaicompat_completion_params_parse( return llama_params; } -static json format_final_response_oaicompat(const json & request, const json & result, const std::string & completion_id, const llama_chat_template & tmpl, bool streaming = false, bool verbose = false) { +static json format_final_response_oaicompat(const json & request, const json & result, const std::string & completion_id, llama_tool_call_style tool_call_style, bool streaming = false, bool verbose = false) { bool stopped_word = result.count("stopped_word") != 0; bool stopped_eos = json_value(result, "stopped_eos", false); int num_tokens_predicted = json_value(result, "tokens_predicted", 0); @@ -452,7 +452,7 @@ static json format_final_response_oaicompat(const json & request, const json & r json tool_calls; json message_content; if (json_value(request, "parse_tool_calls", false) - && !(parsed_tool_calls = parse_tool_calls(tmpl.tool_call_style(), tools, content)).tool_calls.empty()) { + && !(parsed_tool_calls = parse_tool_calls(tool_call_style, tools, content)).tool_calls.empty()) { finish_reason = "tool_calls"; if (!parsed_tool_calls.content.empty()) { message_content = parsed_tool_calls.content; diff --git a/fetch_templates_and_goldens.py b/fetch_templates_and_goldens.py new file mode 100644 index 000000000..7eb83003d --- /dev/null +++ b/fetch_templates_and_goldens.py @@ -0,0 +1,148 @@ +#!/usr/bin/env uv run +# /// script +# requires-python = ">=3.10" +# dependencies = [ +# "jinja2", +# "huggingface_hub", +# ] +# /// +''' + Fetches the Jinja2 templates of specified models and generates prompt goldens for predefined chat contexts. + Outputs lines of arguments for a C++ test binary. + All files are written to the specified output folder. + + Usage: + python ./update_jinja_goldens.py output_folder context1.json context2.json ... model_id1 model_id2 ... + + Example: + python ./update_jinja_goldens.py ./test_files "microsoft/Phi-3-medium-4k-instruct" "Qwen/Qwen2-7B-Instruct" +''' + +import logging +import datetime +import glob +import os +from huggingface_hub import hf_hub_download +import json +import jinja2 +import jinja2.ext +import re +import argparse +import shutil + +logging.basicConfig(level=logging.INFO, format='%(message)s') +logger = logging.getLogger(__name__) + +def raise_exception(message: str): + raise ValueError(message) + +def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False): + return json.dumps(x, ensure_ascii=ensure_ascii, indent=indent, separators=separators, sort_keys=sort_keys) + +TEST_DATE = os.environ.get('TEST_DATE', '2024-07-26') + +def strftime_now(format): + now = datetime.datetime.strptime(TEST_DATE, "%Y-%m-%d") + return now.strftime(format) + +def handle_chat_template(output_folder, model_id, variant, template_src): + model_name = model_id.replace("/", "-") + base_name = f'{model_name}-{variant}' if variant else model_name + template_file = os.path.join(output_folder, f'{base_name}.jinja') + + with open(template_file, 'w') as f: + f.write(template_src) + + env = jinja2.Environment( + trim_blocks=True, + lstrip_blocks=True, + extensions=[jinja2.ext.loopcontrols] + ) + env.filters['safe'] = lambda x: x + env.filters['tojson'] = tojson + env.globals['raise_exception'] = raise_exception + env.globals['strftime_now'] = strftime_now + + template_handles_tools = 'tools' in template_src + template_hates_the_system = 'System role not supported' in template_src + + template = env.from_string(template_src) + + context_files = glob.glob(os.path.join(output_folder, '*.json')) + for context_file in context_files: + context_name = os.path.basename(context_file).replace(".json", "") + with open(context_file, 'r') as f: + context = json.load(f) + + if not template_handles_tools and 'tools' in context: + continue + + if template_hates_the_system and any(m['role'] == 'system' for m in context['messages']): + continue + + output_file = os.path.join(output_folder, f'{base_name}-{context_name}.txt') + + render_context = json.loads(json.dumps(context)) + + if 'tool_call.arguments | items' in template_src or 'tool_call.arguments | tojson' in template_src: + for message in render_context['messages']: + if 'tool_calls' in message: + for tool_call in message['tool_calls']: + if tool_call.get('type') == 'function': + arguments = tool_call['function']['arguments'] + tool_call['function']['arguments'] = json.loads(arguments) + + try: + output = template.render(**render_context) + except Exception as e1: + for message in context["messages"]: + if message.get("content") is None: + message["content"] = "" + + try: + output = template.render(**render_context) + except Exception as e2: + logger.info(f" ERROR: {e2} (after first error: {e1})") + output = f"ERROR: {e2}" + + with open(output_file, 'w') as f: + f.write(output) + + # Output the line of arguments for the C++ test binary + print(f"{template_file} {context_file} {output_file}") + +def main(): + parser = argparse.ArgumentParser(description="Generate chat templates and output test arguments.") + parser.add_argument("output_folder", help="Folder to store all output files") + parser.add_argument("model_ids", nargs="+", help="List of model IDs to process") + args = parser.parse_args() + + output_folder = args.output_folder + if not os.path.isdir(output_folder): + os.makedirs(output_folder) + + # Copy context files to the output folder + for context_file in glob.glob('tests/chat/contexts/*.json'): + shutil.copy(context_file, output_folder) + + for model_id in args.model_ids: + try: + with open(hf_hub_download(repo_id=model_id, filename="tokenizer_config.json")) as f: + config_str = f.read() + + try: + config = json.loads(config_str) + except json.JSONDecodeError: + config = json.loads(re.sub(r'\}([\n\s]*\}[\n\s]*\],[\n\s]*"clean_up_tokenization_spaces")', r'\1', config_str)) + + chat_template = config['chat_template'] + if isinstance(chat_template, str): + handle_chat_template(output_folder, model_id, None, chat_template) + else: + for ct in chat_template: + handle_chat_template(output_folder, model_id, ct['name'], ct['template']) + except Exception as e: + logger.error(f"Error processing model {model_id}: {e}") + +if __name__ == '__main__': + main() diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index 64fb5b3c4..999681152 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -7,7 +7,7 @@ #include "llama.h" #include "common.h" -#include "chat-template.h" +#include "chat-template.hpp" #include #include #include @@ -73,7 +73,7 @@ static void test_jinja_templates() { return "tests/chat/goldens/" + golden_name + ".txt"; }; auto fail_with_golden_instructions = [&]() { - throw std::runtime_error("To fetch templates and generate golden files, run `python tests/update_jinja_goldens.py`"); + throw std::runtime_error("To fetch templates and generate golden files, run `python update_templates_and_goldens.py`"); }; if (jinja_template_files.empty()) { std::cerr << "No Jinja templates found in tests/chat/templates" << std::endl; @@ -89,7 +89,7 @@ static void test_jinja_templates() { for (const auto & ctx_file : context_files) { auto ctx = json::parse(read_file(ctx_file)); - llama_chat_template tmpl( + minja::chat_template tmpl( tmpl_str, ctx.at("bos_token"), ctx.at("eos_token")); @@ -127,20 +127,6 @@ static void test_jinja_templates() { } } -void test_tool_call_style(const std::string & template_file, llama_tool_call_style expected) { - auto tmpl = llama_chat_template(read_file(template_file), "", ""); - std::cout << "# Testing tool call style of: " << template_file << std::endl << std::flush; - assert_equals(expected, tmpl.tool_call_style()); -} - -void test_tool_call_styles() { - test_tool_call_style("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja", FunctionaryV3Llama31); - test_tool_call_style("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja", FunctionaryV3Llama3); - test_tool_call_style("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja", Llama31); - test_tool_call_style("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja", Llama32); - test_tool_call_style("tests/chat/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja", CommandRPlus); -} - static void test_legacy_templates() { struct test_template { std::string name; @@ -353,7 +339,6 @@ int main(void) { if (getenv("LLAMA_SKIP_TESTS_SLOW_ON_EMULATOR")) { fprintf(stderr, "\033[33mWARNING: Skipping slow tests on emulator.\n\033[0m"); } else { - test_tool_call_styles(); test_jinja_templates(); } diff --git a/tests/test-tool-call.cpp b/tests/test-tool-call.cpp index ad34faaa9..5899b9ada 100644 --- a/tests/test-tool-call.cpp +++ b/tests/test-tool-call.cpp @@ -9,7 +9,8 @@ using json = nlohmann::ordered_json; -static void assert_equals(const std::string & expected, const std::string & actual) { +template +static void assert_equals(const T & expected, const T & actual) { if (expected != actual) { std::cerr << "Expected: " << expected << std::endl; std::cerr << "Actual: " << actual << std::endl; @@ -242,7 +243,22 @@ static void test_parsing() { "{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", json::array()); } -static std::string get_message_prompt_delta(const llama_chat_template & tmpl, const std::vector & end_tokens, const json & user_message, const json & delta_message, const json & tools) { +void test_tool_call_style(const std::string & template_file, llama_tool_call_style expected) { + const minja::chat_template tmpl(read_file(template_file), "", ""); + auto tool_call_style = llama_tool_call_style_detect(tmpl); + std::cout << "# Testing tool call style of: " << template_file << std::endl << std::flush; + assert_equals(expected, tool_call_style); +} + +void test_tool_call_style_detection() { + test_tool_call_style("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja", FunctionaryV3Llama31); + test_tool_call_style("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja", FunctionaryV3Llama3); + test_tool_call_style("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja", Llama31); + test_tool_call_style("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja", Llama32); + test_tool_call_style("tests/chat/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja", CommandRPlus); +} + +static std::string get_message_prompt_delta(const minja::chat_template & tmpl, const std::vector & end_tokens, const json & user_message, const json & delta_message, const json & tools) { auto prefix = tmpl.apply(json::array({user_message}), tools, /* add_generation_prompt= */ true, json::object()); auto full = tmpl.apply(json::array({user_message, delta_message}), tools, /* add_generation_prompt= */ false, json::object()); @@ -267,7 +283,8 @@ static std::string get_message_prompt_delta(const llama_chat_template & tmpl, co static void test_template(const std::string & template_file, const char * bos_token, const char * eos_token, const std::vector & end_tokens, const json & tool_calling_message, const json & tools) { std::cout << "# Testing template: " << template_file << std::endl << std::flush; - const llama_chat_template & tmpl = llama_chat_template(read_file(template_file), bos_token, eos_token); + const minja::chat_template tmpl(read_file(template_file), bos_token, eos_token); + auto tool_call_style = llama_tool_call_style_detect(tmpl); auto & tool_calls = tool_calling_message.at("tool_calls"); // Format the message: apply the template to 1 user message w/ add_generation_prompt=true, then w/ the extra message w/ add_generation_prompt=false, @@ -277,7 +294,7 @@ static void test_template(const std::string & template_file, const char * bos_to {"content", "Hello, world!"} }; - auto handler = llama_tool_call_handler_init(tmpl, /* allow_content= */ true, /* parallel_tool_calls= */ true, {user_message, tool_calling_message}, tools); + auto handler = llama_tool_call_handler_init(tool_call_style, tmpl, /* allow_content= */ true, /* parallel_tool_calls= */ true, {user_message, tool_calling_message}, tools); auto grammar = build_grammar(handler.grammar); if (!grammar) { throw std::runtime_error("Failed to build grammar"); @@ -285,7 +302,7 @@ static void test_template(const std::string & template_file, const char * bos_to auto full_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, tool_calling_message, tools); std::cout << "Full delta:\n```\n" << full_delta << "\n```" << std::endl; - test_parse_tool_call(tmpl.tool_call_style(), tools, full_delta, "", tool_calls); + test_parse_tool_call(tool_call_style, tools, full_delta, "", tool_calls); auto content_less_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, { {"role", "assistant"}, @@ -319,6 +336,7 @@ static void test_grammars() { int main() { test_grammars(); test_parsing(); + test_tool_call_style_detection(); std::cout << "[tool-call] All tests passed!" << std::endl; return 0;