tool-call: prepare possible externalization of minja + factor tool call style out of template

This commit is contained in:
Olivier Chafik 2024-10-01 23:12:24 +01:00
parent d9451fd647
commit c36a196f53
14 changed files with 626 additions and 445 deletions

View File

@ -54,8 +54,7 @@ add_library(${TARGET} STATIC
arg.cpp arg.cpp
arg.h arg.h
base64.hpp base64.hpp
chat-template.cpp chat-template.hpp
chat-template.h
common.cpp common.cpp
common.h common.h
console.cpp console.cpp

View File

@ -1,156 +0,0 @@
#include "chat-template.h"
#include "llama.h"
using json = nlohmann::ordered_json;
static std::string _llama_token_to_piece(const struct llama_model * model, llama_token token, bool special) {
std::string piece;
piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n'
const int n_chars = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special);
if (n_chars < 0) {
piece.resize(-n_chars);
int check = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special);
GGML_ASSERT(check == -n_chars);
}
else {
piece.resize(n_chars);
}
return piece;
}
static std::string llama_model_meta_val_str(const struct llama_model * model, const char * key) {
int32_t tlen = llama_model_meta_val_str(model, key, nullptr, 0);
if (tlen > 0) {
std::vector<char> curr_tmpl_buf(tlen + 1, 0);
if (llama_model_meta_val_str(model, key, curr_tmpl_buf.data(), curr_tmpl_buf.size()) == tlen) {
return std::string(curr_tmpl_buf.data(), tlen);
}
}
return "";
}
llama_chat_template::llama_chat_template(const std::string & chat_template, const std::string & bos_token, const std::string & eos_token)
: _chat_template(chat_template), _bos_token(bos_token), _eos_token(eos_token) {
_supports_tools = chat_template.find("tools") != std::string::npos;
_requires_object_arguments =
chat_template.find("tool_call.arguments | items") != std::string::npos
|| chat_template.find("tool_call.arguments | tojson") != std::string::npos;
_supports_system_role = chat_template.find("System role not supported") == std::string::npos;
if (chat_template.find("<tool_call>") != std::string::npos) {
_tool_call_style = Hermes2Pro;
} else if (chat_template.find(">>>all") != std::string::npos) {
_tool_call_style = FunctionaryV3Llama3;
} else if (chat_template.find("<|start_header_id|>") != std::string::npos
&& chat_template.find("<function=") != std::string::npos) {
_tool_call_style = FunctionaryV3Llama31;
} else if (chat_template.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
if (chat_template.find("<|python_tag|>") != std::string::npos) {
_tool_call_style = Llama31;
} else {
_tool_call_style = Llama32;
}
} else if (chat_template.find("<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>") != std::string::npos) {
_tool_call_style = CommandRPlus;
} else {
_tool_call_style = UnknownToolCallStyle;
}
_template_root = minja::Parser::parse(_chat_template, {
/* .trim_blocks = */ true,
/* .lstrip_blocks = */ true,
/* .keep_trailing_newline = */ false,
});
}
llama_chat_template llama_chat_template::from_model(
const struct llama_model * model,
const char * chat_template_override)
{
// TODO: handle "chatml"?
std::string chat_template = chat_template_override
? chat_template_override
: llama_model_meta_val_str(model, "tokenizer.chat_template");
auto bos_token = _llama_token_to_piece(model, llama_token_bos(model), true);
auto eos_token = _llama_token_to_piece(model, llama_token_eos(model), true);
return llama_chat_template(chat_template, bos_token, eos_token);
}
std::string llama_chat_template::apply(
const json & messages,
const json & tools,
bool add_generation_prompt,
const json & extra_context) const
{
auto actual_messages = messages;
// First, "fix" messages so they have a chance to be rendered correctly by the template
if (_requires_object_arguments || !_supports_system_role) {
std::string pending_system;
auto flush_sys = [&]() {
if (!pending_system.empty()) {
actual_messages.push_back({
{"role", "user"},
{"content", pending_system},
});
pending_system.clear();
}
};
for (auto & message : actual_messages) {
if (!message.contains("role") || !message.contains("content")) {
throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump());
}
std::string role = message.at("role");
if (!message["content"].is_null() && !_supports_system_role) {
std::string content = message.at("content");
if (role == "system") {
if (!pending_system.empty()) pending_system += "\n";
pending_system += content;
continue;
} else {
if (role == "user") {
if (!pending_system.empty()) {
message["content"] = pending_system + (content.empty() ? "" : "\n" + content);
pending_system.clear();
}
} else {
flush_sys();
}
}
}
if (_requires_object_arguments && message.contains("tool_calls")) {
for (auto & tool_call : message.at("tool_calls")) {
if (tool_call["type"] == "function") {
auto & function = tool_call.at("function");
std::string arguments = function.at("arguments");
function["arguments"] = json::parse(arguments);
}
}
}
}
flush_sys();
}
auto context = minja::Context::make(json({
{"messages", actual_messages},
{"add_generation_prompt", add_generation_prompt},
{"bos_token", _bos_token},
{"eos_token", _eos_token},
}));
if (!tools.is_null()) {
auto tools_val = minja::Value(tools);
context->set("tools", tools_val);
}
if (!extra_context.is_null()) {
for (auto & kv : extra_context.items()) {
minja::Value val(kv.value());
context->set(kv.key(), val);
}
}
return _template_root->render(context);
}

View File

@ -1,53 +0,0 @@
#pragma once
#include "minja.hpp"
#include <json.hpp>
#include <string>
#include <vector>
using json = nlohmann::ordered_json;
enum llama_tool_call_style {
UnknownToolCallStyle,
Llama31,
Llama32,
FunctionaryV3Llama3,
FunctionaryV3Llama31,
Hermes2Pro,
CommandRPlus,
};
class llama_chat_template {
public:
private:
llama_tool_call_style _tool_call_style = UnknownToolCallStyle;
bool _supports_tools = true;
// Meta-Llama-3.1-8B-Instruct's template expects arguments to be an object.
// Most other templates (and OpenAI's API) expect the arguments object to be stringified.
bool _requires_object_arguments = false;
bool _supports_system_role = true;
std::string _chat_template;
std::string _bos_token;
std::string _eos_token;
std::unique_ptr<minja::TemplateNode> _template_root;
public:
llama_chat_template(const std::string & chat_template, const std::string & bos_token, const std::string & eos_token);
static llama_chat_template from_model(
const struct llama_model * model,
const char * chat_template_override = nullptr);
llama_tool_call_style tool_call_style() const { return _tool_call_style; }
const std::string & chat_template() const { return _chat_template; }
bool supports_tools() const { return _supports_tools; }
std::string apply(
const nlohmann::ordered_json & messages,
const nlohmann::ordered_json & tools,
bool add_generation_prompt,
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const;
};

133
common/chat-template.hpp Normal file
View File

@ -0,0 +1,133 @@
/*
Copyright 2024 Google LLC
Use of this source code is governed by an MIT-style
license that can be found in the LICENSE file or at
https://opensource.org/licenses/MIT.
*/
// SPDX-License-Identifier: MIT
#pragma once
#include "minja.hpp"
#include <json.hpp>
#include <string>
#include <vector>
using json = nlohmann::ordered_json;
namespace minja {
class chat_template {
public:
private:
bool _supports_tools = true;
// Meta-Llama-3.1-8B-Instruct's template expects arguments to be an object.
// Most other templates (and OpenAI's API) expect the arguments object to be stringified.
bool _requires_object_arguments = false;
bool _supports_system_role = true;
std::string _source;
std::string _bos_token;
std::string _eos_token;
std::shared_ptr<minja::TemplateNode> _template_root;
public:
chat_template(const std::string & source, const std::string & bos_token, const std::string & eos_token)
: _source(source), _bos_token(bos_token), _eos_token(eos_token)
{
_supports_tools = source.find("tools") != std::string::npos;
_requires_object_arguments =
source.find("tool_call.arguments | items") != std::string::npos
|| source.find("tool_call.arguments | tojson") != std::string::npos;
_supports_system_role = source.find("System role not supported") == std::string::npos;
_template_root = minja::Parser::parse(_source, {
/* .trim_blocks = */ true,
/* .lstrip_blocks = */ true,
/* .keep_trailing_newline = */ false,
});
}
const std::string & source() const { return _source; }
bool supports_tools() const { return _supports_tools; }
std::string apply(
const nlohmann::ordered_json & messages,
const nlohmann::ordered_json & tools,
bool add_generation_prompt,
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const
{
auto actual_messages = messages;
// First, "fix" messages so they have a chance to be rendered correctly by the template
if (_requires_object_arguments || !_supports_system_role) {
std::string pending_system;
auto flush_sys = [&]() {
if (!pending_system.empty()) {
actual_messages.push_back({
{"role", "user"},
{"content", pending_system},
});
pending_system.clear();
}
};
for (auto & message : actual_messages) {
if (!message.contains("role") || !message.contains("content")) {
throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump());
}
std::string role = message.at("role");
if (!message["content"].is_null() && !_supports_system_role) {
std::string content = message.at("content");
if (role == "system") {
if (!pending_system.empty()) pending_system += "\n";
pending_system += content;
continue;
} else {
if (role == "user") {
if (!pending_system.empty()) {
message["content"] = pending_system + (content.empty() ? "" : "\n" + content);
pending_system.clear();
}
} else {
flush_sys();
}
}
}
if (_requires_object_arguments && message.contains("tool_calls")) {
for (auto & tool_call : message.at("tool_calls")) {
if (tool_call["type"] == "function") {
auto & function = tool_call.at("function");
std::string arguments = function.at("arguments");
function["arguments"] = json::parse(arguments);
}
}
}
}
flush_sys();
}
auto context = minja::Context::make(json({
{"messages", actual_messages},
{"add_generation_prompt", add_generation_prompt},
{"bos_token", _bos_token},
{"eos_token", _eos_token},
}));
if (!tools.is_null()) {
auto tools_val = minja::Value(tools);
context->set("tools", tools_val);
}
if (!extra_context.is_null()) {
for (auto & kv : extra_context.items()) {
minja::Value val(kv.value());
context->set(kv.key(), val);
}
}
return _template_root->render(context);
}
};
} // namespace minja

View File

@ -9,7 +9,7 @@
#include "json.hpp" #include "json.hpp"
#include "json-schema-to-grammar.h" #include "json-schema-to-grammar.h"
#include "llama.h" #include "llama.h"
#include "chat-template.h" #include "chat-template.hpp"
#include <algorithm> #include <algorithm>
#include <cinttypes> #include <cinttypes>
@ -1513,13 +1513,13 @@ std::vector<llama_token> llama_tokenize(
return result; return result;
} }
std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) { static std::string _llama_token_to_piece(const struct llama_model * model, llama_token token, bool special) {
std::string piece; std::string piece;
piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n' piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n'
const int n_chars = llama_token_to_piece(llama_get_model(ctx), token, &piece[0], piece.size(), 0, special); const int n_chars = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special);
if (n_chars < 0) { if (n_chars < 0) {
piece.resize(-n_chars); piece.resize(-n_chars);
int check = llama_token_to_piece(llama_get_model(ctx), token, &piece[0], piece.size(), 0, special); int check = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special);
GGML_ASSERT(check == -n_chars); GGML_ASSERT(check == -n_chars);
} }
else { else {
@ -1529,6 +1529,10 @@ std::string llama_token_to_piece(const struct llama_context * ctx, llama_token t
return piece; return piece;
} }
std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) {
return _llama_token_to_piece(llama_get_model(ctx), token, special);
}
std::string llama_detokenize(llama_context * ctx, const std::vector<llama_token> & tokens, bool special) { std::string llama_detokenize(llama_context * ctx, const std::vector<llama_token> & tokens, bool special) {
std::string text; std::string text;
text.resize(std::max(text.capacity(), tokens.size())); text.resize(std::max(text.capacity(), tokens.size()));
@ -1552,7 +1556,7 @@ std::string llama_detokenize(llama_context * ctx, const std::vector<llama_token>
bool llama_chat_verify_template(const std::string & tmpl, bool use_jinja) { bool llama_chat_verify_template(const std::string & tmpl, bool use_jinja) {
if (use_jinja) { if (use_jinja) {
try { try {
auto chat_template = llama_chat_template(tmpl, "<s>", "</s>"); auto chat_template = minja::chat_template(tmpl, "<s>", "</s>");
chat_template.apply({{ chat_template.apply({{
{"role", "user"}, {"role", "user"},
{"content", "test"}, {"content", "test"},
@ -1651,6 +1655,30 @@ std::string llama_chat_format_example(const struct llama_model * model,
return llama_chat_apply_template(model, tmpl, msgs, true); return llama_chat_apply_template(model, tmpl, msgs, true);
} }
static std::string _llama_model_meta_val_str(const struct llama_model * model, const char * key) {
int32_t tlen = llama_model_meta_val_str(model, key, nullptr, 0);
if (tlen > 0) {
std::vector<char> curr_tmpl_buf(tlen + 1, 0);
if (llama_model_meta_val_str(model, key, curr_tmpl_buf.data(), curr_tmpl_buf.size()) == tlen) {
return std::string(curr_tmpl_buf.data(), tlen);
}
}
return "";
}
minja::chat_template llama_chat_template_from_model(
const struct llama_model * model,
const char * chat_template_override)
{
// TODO: handle "chatml"?
std::string chat_template = chat_template_override
? chat_template_override
: _llama_model_meta_val_str(model, "tokenizer.chat_template");
auto bos_token = _llama_token_to_piece(model, llama_token_bos(model), true);
auto eos_token = _llama_token_to_piece(model, llama_token_eos(model), true);
return {std::move(chat_template), bos_token, eos_token};
}
// //
// KV cache utils // KV cache utils
// //

View File

@ -27,6 +27,9 @@
#define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf" #define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf"
// Forward declaration
namespace minja { class chat_template; }
struct llama_lora_adapter_info { struct llama_lora_adapter_info {
std::string path; std::string path;
float scale; float scale;
@ -500,6 +503,10 @@ std::string llama_chat_format_single(const struct llama_model * model,
std::string llama_chat_format_example(const struct llama_model * model, std::string llama_chat_format_example(const struct llama_model * model,
const std::string & tmpl); const std::string & tmpl);
minja::chat_template llama_chat_template_from_model(
const struct llama_model * model,
const char * chat_template_override = nullptr);
// //
// KV cache utils // KV cache utils
// //

View File

@ -1,3 +1,11 @@
/*
Copyright 2024 Google LLC
Use of this source code is governed by an MIT-style
license that can be found in the LICENSE file or at
https://opensource.org/licenses/MIT.
*/
// SPDX-License-Identifier: MIT
#pragma once #pragma once
#include <iostream> #include <iostream>
@ -532,44 +540,44 @@ static std::string error_location_suffix(const std::string & source, size_t pos)
} }
class Context : public std::enable_shared_from_this<Context> { class Context : public std::enable_shared_from_this<Context> {
protected: protected:
Value values_; Value values_;
std::shared_ptr<Context> parent_; std::shared_ptr<Context> parent_;
public: public:
Context(Value && values, const std::shared_ptr<Context> & parent = nullptr) : values_(std::move(values)), parent_(parent) { Context(Value && values, const std::shared_ptr<Context> & parent = nullptr) : values_(std::move(values)), parent_(parent) {
if (!values_.is_object()) throw std::runtime_error("Context values must be an object: " + values_.dump()); if (!values_.is_object()) throw std::runtime_error("Context values must be an object: " + values_.dump());
} }
virtual ~Context() {} virtual ~Context() {}
static std::shared_ptr<Context> builtins(); static std::shared_ptr<Context> builtins();
static std::shared_ptr<Context> make(Value && values, const std::shared_ptr<Context> & parent = builtins()); static std::shared_ptr<Context> make(Value && values, const std::shared_ptr<Context> & parent = builtins());
std::vector<Value> keys() { std::vector<Value> keys() {
return values_.keys(); return values_.keys();
} }
virtual Value get(const Value & key) { virtual Value get(const Value & key) {
if (values_.contains(key)) return values_.at(key); if (values_.contains(key)) return values_.at(key);
if (parent_) return parent_->get(key); if (parent_) return parent_->get(key);
return Value(); return Value();
} }
virtual Value & at(const Value & key) { virtual Value & at(const Value & key) {
if (values_.contains(key)) return values_.at(key); if (values_.contains(key)) return values_.at(key);
if (parent_) return parent_->at(key); if (parent_) return parent_->at(key);
throw std::runtime_error("Undefined variable: " + key.dump()); throw std::runtime_error("Undefined variable: " + key.dump());
} }
virtual bool contains(const Value & key) { virtual bool contains(const Value & key) {
if (values_.contains(key)) return true; if (values_.contains(key)) return true;
if (parent_) return parent_->contains(key); if (parent_) return parent_->contains(key);
return false; return false;
} }
virtual void set(const Value & key, Value & value) { virtual void set(const Value & key, Value & value) {
values_.set(key, value); values_.set(key, value);
} }
}; };
struct Location { struct Location {
std::shared_ptr<std::string> source; std::shared_ptr<std::string> source;
size_t pos; size_t pos;
}; };
class Expression { class Expression {
@ -577,8 +585,8 @@ protected:
virtual Value do_evaluate(const std::shared_ptr<Context> & context) const = 0; virtual Value do_evaluate(const std::shared_ptr<Context> & context) const = 0;
public: public:
struct Arguments { struct Arguments {
std::vector<std::unique_ptr<Expression>> args; std::vector<std::shared_ptr<Expression>> args;
std::vector<std::pair<std::string, std::unique_ptr<Expression>>> kwargs; std::vector<std::pair<std::string, std::shared_ptr<Expression>>> kwargs;
void expectArgs(const std::string & method_name, const std::pair<size_t, size_t> & pos_count, const std::pair<size_t, size_t> & kw_count) const { void expectArgs(const std::string & method_name, const std::pair<size_t, size_t> & pos_count, const std::pair<size_t, size_t> & kw_count) const {
if (args.size() < pos_count.first || args.size() > pos_count.second || kwargs.size() < kw_count.first || kwargs.size() > kw_count.second) { if (args.size() < pos_count.first || args.size() > pos_count.second || kwargs.size() < kw_count.first || kwargs.size() > kw_count.second) {
@ -600,7 +608,7 @@ public:
} }
}; };
using Parameters = std::vector<std::pair<std::string, std::unique_ptr<Expression>>>; using Parameters = std::vector<std::pair<std::string, std::shared_ptr<Expression>>>;
Location location; Location location;
@ -687,18 +695,18 @@ struct TextTemplateToken : public TemplateToken {
}; };
struct ExpressionTemplateToken : public TemplateToken { struct ExpressionTemplateToken : public TemplateToken {
std::unique_ptr<Expression> expr; std::shared_ptr<Expression> expr;
ExpressionTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::unique_ptr<Expression> && e) : TemplateToken(Type::Expression, location, pre, post), expr(std::move(e)) {} ExpressionTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr<Expression> && e) : TemplateToken(Type::Expression, location, pre, post), expr(std::move(e)) {}
}; };
struct IfTemplateToken : public TemplateToken { struct IfTemplateToken : public TemplateToken {
std::unique_ptr<Expression> condition; std::shared_ptr<Expression> condition;
IfTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::unique_ptr<Expression> && c) : TemplateToken(Type::If, location, pre, post), condition(std::move(c)) {} IfTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr<Expression> && c) : TemplateToken(Type::If, location, pre, post), condition(std::move(c)) {}
}; };
struct ElifTemplateToken : public TemplateToken { struct ElifTemplateToken : public TemplateToken {
std::unique_ptr<Expression> condition; std::shared_ptr<Expression> condition;
ElifTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::unique_ptr<Expression> && c) : TemplateToken(Type::Elif, location, pre, post), condition(std::move(c)) {} ElifTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr<Expression> && c) : TemplateToken(Type::Elif, location, pre, post), condition(std::move(c)) {}
}; };
struct ElseTemplateToken : public TemplateToken { struct ElseTemplateToken : public TemplateToken {
@ -706,13 +714,13 @@ struct ElseTemplateToken : public TemplateToken {
}; };
struct EndIfTemplateToken : public TemplateToken { struct EndIfTemplateToken : public TemplateToken {
EndIfTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndIf, location, pre, post) {} EndIfTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndIf, location, pre, post) {}
}; };
struct MacroTemplateToken : public TemplateToken { struct MacroTemplateToken : public TemplateToken {
std::unique_ptr<VariableExpr> name; std::shared_ptr<VariableExpr> name;
Expression::Parameters params; Expression::Parameters params;
MacroTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::unique_ptr<VariableExpr> && n, Expression::Parameters && p) MacroTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr<VariableExpr> && n, Expression::Parameters && p)
: TemplateToken(Type::Macro, location, pre, post), name(std::move(n)), params(std::move(p)) {} : TemplateToken(Type::Macro, location, pre, post), name(std::move(n)), params(std::move(p)) {}
}; };
@ -722,11 +730,11 @@ struct EndMacroTemplateToken : public TemplateToken {
struct ForTemplateToken : public TemplateToken { struct ForTemplateToken : public TemplateToken {
std::vector<std::string> var_names; std::vector<std::string> var_names;
std::unique_ptr<Expression> iterable; std::shared_ptr<Expression> iterable;
std::unique_ptr<Expression> condition; std::shared_ptr<Expression> condition;
bool recursive; bool recursive;
ForTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::vector<std::string> & vns, std::unique_ptr<Expression> && iter, ForTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::vector<std::string> & vns, std::shared_ptr<Expression> && iter,
std::unique_ptr<Expression> && c, bool r) std::shared_ptr<Expression> && c, bool r)
: TemplateToken(Type::For, location, pre, post), var_names(vns), iterable(std::move(iter)), condition(std::move(c)), recursive(r) {} : TemplateToken(Type::For, location, pre, post), var_names(vns), iterable(std::move(iter)), condition(std::move(c)), recursive(r) {}
}; };
@ -737,8 +745,8 @@ struct EndForTemplateToken : public TemplateToken {
struct SetTemplateToken : public TemplateToken { struct SetTemplateToken : public TemplateToken {
std::string ns; std::string ns;
std::vector<std::string> var_names; std::vector<std::string> var_names;
std::unique_ptr<Expression> value; std::shared_ptr<Expression> value;
SetTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::string & ns, const std::vector<std::string> & vns, std::unique_ptr<Expression> && v) SetTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::string & ns, const std::vector<std::string> & vns, std::shared_ptr<Expression> && v)
: TemplateToken(Type::Set, location, pre, post), ns(ns), var_names(vns), value(std::move(v)) {} : TemplateToken(Type::Set, location, pre, post), ns(ns), var_names(vns), value(std::move(v)) {}
}; };
@ -778,9 +786,9 @@ public:
}; };
class SequenceNode : public TemplateNode { class SequenceNode : public TemplateNode {
std::vector<std::unique_ptr<TemplateNode>> children; std::vector<std::shared_ptr<TemplateNode>> children;
public: public:
SequenceNode(const Location & location, std::vector<std::unique_ptr<TemplateNode>> && c) SequenceNode(const Location & location, std::vector<std::shared_ptr<TemplateNode>> && c)
: TemplateNode(location), children(std::move(c)) {} : TemplateNode(location), children(std::move(c)) {}
void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override { void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override {
for (const auto& child : children) child->render(out, context); for (const auto& child : children) child->render(out, context);
@ -797,10 +805,11 @@ public:
}; };
class ExpressionNode : public TemplateNode { class ExpressionNode : public TemplateNode {
std::unique_ptr<Expression> expr; std::shared_ptr<Expression> expr;
public: public:
ExpressionNode(const Location & location, std::unique_ptr<Expression> && e) : TemplateNode(location), expr(std::move(e)) {} ExpressionNode(const Location & location, std::shared_ptr<Expression> && e) : TemplateNode(location), expr(std::move(e)) {}
void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override { void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override {
if (!expr) throw std::runtime_error("ExpressionNode.expr is null");
auto result = expr->evaluate(context); auto result = expr->evaluate(context);
if (result.is_string()) { if (result.is_string()) {
out << result.get<std::string>(); out << result.get<std::string>();
@ -813,9 +822,9 @@ public:
}; };
class IfNode : public TemplateNode { class IfNode : public TemplateNode {
std::vector<std::pair<std::unique_ptr<Expression>, std::unique_ptr<TemplateNode>>> cascade; std::vector<std::pair<std::shared_ptr<Expression>, std::shared_ptr<TemplateNode>>> cascade;
public: public:
IfNode(const Location & location, std::vector<std::pair<std::unique_ptr<Expression>, std::unique_ptr<TemplateNode>>> && c) IfNode(const Location & location, std::vector<std::pair<std::shared_ptr<Expression>, std::shared_ptr<TemplateNode>>> && c)
: TemplateNode(location), cascade(std::move(c)) {} : TemplateNode(location), cascade(std::move(c)) {}
void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override { void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override {
for (const auto& branch : cascade) { for (const auto& branch : cascade) {
@ -824,6 +833,7 @@ public:
enter_branch = branch.first->evaluate(context).to_bool(); enter_branch = branch.first->evaluate(context).to_bool();
} }
if (enter_branch) { if (enter_branch) {
if (!branch.second) throw std::runtime_error("IfNode.cascade.second is null");
branch.second->render(out, context); branch.second->render(out, context);
return; return;
} }
@ -833,18 +843,20 @@ public:
class ForNode : public TemplateNode { class ForNode : public TemplateNode {
std::vector<std::string> var_names; std::vector<std::string> var_names;
std::unique_ptr<Expression> iterable; std::shared_ptr<Expression> iterable;
std::unique_ptr<Expression> condition; std::shared_ptr<Expression> condition;
std::unique_ptr<TemplateNode> body; std::shared_ptr<TemplateNode> body;
bool recursive; bool recursive;
std::unique_ptr<TemplateNode> else_body; std::shared_ptr<TemplateNode> else_body;
public: public:
ForNode(const Location & location, std::vector<std::string> && var_names, std::unique_ptr<Expression> && iterable, ForNode(const Location & location, std::vector<std::string> && var_names, std::shared_ptr<Expression> && iterable,
std::unique_ptr<Expression> && condition, std::unique_ptr<TemplateNode> && body, bool recursive, std::unique_ptr<TemplateNode> && else_body) std::shared_ptr<Expression> && condition, std::shared_ptr<TemplateNode> && body, bool recursive, std::shared_ptr<TemplateNode> && else_body)
: TemplateNode(location), var_names(var_names), iterable(std::move(iterable)), condition(std::move(condition)), body(std::move(body)), recursive(recursive), else_body(std::move(else_body)) {} : TemplateNode(location), var_names(var_names), iterable(std::move(iterable)), condition(std::move(condition)), body(std::move(body)), recursive(recursive), else_body(std::move(else_body)) {}
void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override { void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override {
// https://jinja.palletsprojects.com/en/3.0.x/templates/#for // https://jinja.palletsprojects.com/en/3.0.x/templates/#for
if (!iterable) throw std::runtime_error("ForNode.iterable is null");
if (!body) throw std::runtime_error("ForNode.body is null");
auto iterable_value = iterable->evaluate(context); auto iterable_value = iterable->evaluate(context);
Value::CallableType loop_function; Value::CallableType loop_function;
@ -914,12 +926,12 @@ public:
}; };
class MacroNode : public TemplateNode { class MacroNode : public TemplateNode {
std::unique_ptr<VariableExpr> name; std::shared_ptr<VariableExpr> name;
Expression::Parameters params; Expression::Parameters params;
std::unique_ptr<TemplateNode> body; std::shared_ptr<TemplateNode> body;
std::unordered_map<std::string, size_t> named_param_positions; std::unordered_map<std::string, size_t> named_param_positions;
public: public:
MacroNode(const Location & location, std::unique_ptr<VariableExpr> && n, Expression::Parameters && p, std::unique_ptr<TemplateNode> && b) MacroNode(const Location & location, std::shared_ptr<VariableExpr> && n, Expression::Parameters && p, std::shared_ptr<TemplateNode> && b)
: TemplateNode(location), name(std::move(n)), params(std::move(p)), body(std::move(b)) { : TemplateNode(location), name(std::move(n)), params(std::move(p)), body(std::move(b)) {
for (size_t i = 0; i < params.size(); ++i) { for (size_t i = 0; i < params.size(); ++i) {
const auto & name = params[i].first; const auto & name = params[i].first;
@ -929,6 +941,8 @@ public:
} }
} }
void do_render(std::ostringstream &, const std::shared_ptr<Context> & macro_context) const override { void do_render(std::ostringstream &, const std::shared_ptr<Context> & macro_context) const override {
if (!name) throw std::runtime_error("MacroNode.name is null");
if (!body) throw std::runtime_error("MacroNode.body is null");
auto callable = Value::callable([&](const std::shared_ptr<Context> & context, Value::Arguments & args) { auto callable = Value::callable([&](const std::shared_ptr<Context> & context, Value::Arguments & args) {
auto call_context = macro_context; auto call_context = macro_context;
std::vector<bool> param_set(params.size(), false); std::vector<bool> param_set(params.size(), false);
@ -964,19 +978,12 @@ public:
class SetNode : public TemplateNode { class SetNode : public TemplateNode {
std::string ns; std::string ns;
std::vector<std::string> var_names; std::vector<std::string> var_names;
std::unique_ptr<Expression> value; std::shared_ptr<Expression> value;
std::unique_ptr<TemplateNode> template_value;
public: public:
SetNode(const Location & location, const std::string & ns, const std::vector<std::string> & vns, std::unique_ptr<Expression> && v, std::unique_ptr<TemplateNode> && tv) SetNode(const Location & location, const std::string & ns, const std::vector<std::string> & vns, std::shared_ptr<Expression> && v)
: TemplateNode(location), ns(ns), var_names(vns), value(std::move(v)), template_value(std::move(tv)) { : TemplateNode(location), ns(ns), var_names(vns), value(std::move(v)) {}
if (value && template_value) {
throw std::runtime_error("Cannot have both value and template value in set node");
}
if (template_value && var_names.size() != 1) {
throw std::runtime_error("Destructuring assignment is only supported with a single variable name");
}
}
void do_render(std::ostringstream &, const std::shared_ptr<Context> & context) const override { void do_render(std::ostringstream &, const std::shared_ptr<Context> & context) const override {
if (!value) throw std::runtime_error("SetNode.value is null");
if (!ns.empty()) { if (!ns.empty()) {
if (var_names.size() != 1) { if (var_names.size() != 1) {
throw std::runtime_error("Namespaced set only supports a single variable name"); throw std::runtime_error("Namespaced set only supports a single variable name");
@ -985,9 +992,6 @@ public:
auto ns_value = context->get(ns); auto ns_value = context->get(ns);
if (!ns_value.is_object()) throw std::runtime_error("Namespace '" + ns + "' is not an object"); if (!ns_value.is_object()) throw std::runtime_error("Namespace '" + ns + "' is not an object");
ns_value.set(name, this->value->evaluate(context)); ns_value.set(name, this->value->evaluate(context));
} else if (template_value) {
Value value { template_value->render(context) };
context->set(var_names[0], value);
} else { } else {
auto val = value->evaluate(context); auto val = value->evaluate(context);
destructuring_assign(var_names, context, val); destructuring_assign(var_names, context, val);
@ -995,14 +999,29 @@ public:
} }
}; };
class IfExpr : public Expression { class SetTemplateNode : public TemplateNode {
std::unique_ptr<Expression> condition; std::string name;
std::unique_ptr<Expression> then_expr; std::shared_ptr<TemplateNode> template_value;
std::unique_ptr<Expression> else_expr;
public: public:
IfExpr(const Location & location, std::unique_ptr<Expression> && c, std::unique_ptr<Expression> && t, std::unique_ptr<Expression> && e) SetTemplateNode(const Location & location, const std::string & name, std::shared_ptr<TemplateNode> && tv)
: TemplateNode(location), name(name), template_value(std::move(tv)) {}
void do_render(std::ostringstream &, const std::shared_ptr<Context> & context) const override {
if (!template_value) throw std::runtime_error("SetTemplateNode.template_value is null");
Value value { template_value->render(context) };
context->set(name, value);
}
};
class IfExpr : public Expression {
std::shared_ptr<Expression> condition;
std::shared_ptr<Expression> then_expr;
std::shared_ptr<Expression> else_expr;
public:
IfExpr(const Location & location, std::shared_ptr<Expression> && c, std::shared_ptr<Expression> && t, std::shared_ptr<Expression> && e)
: Expression(location), condition(std::move(c)), then_expr(std::move(t)), else_expr(std::move(e)) {} : Expression(location), condition(std::move(c)), then_expr(std::move(t)), else_expr(std::move(e)) {}
Value do_evaluate(const std::shared_ptr<Context> & context) const override { Value do_evaluate(const std::shared_ptr<Context> & context) const override {
if (!condition) throw std::runtime_error("IfExpr.condition is null");
if (!then_expr) throw std::runtime_error("IfExpr.then_expr is null");
if (condition->evaluate(context).to_bool()) { if (condition->evaluate(context).to_bool()) {
return then_expr->evaluate(context); return then_expr->evaluate(context);
} }
@ -1022,13 +1041,14 @@ public:
}; };
class ArrayExpr : public Expression { class ArrayExpr : public Expression {
std::vector<std::unique_ptr<Expression>> elements; std::vector<std::shared_ptr<Expression>> elements;
public: public:
ArrayExpr(const Location & location, std::vector<std::unique_ptr<Expression>> && e) ArrayExpr(const Location & location, std::vector<std::shared_ptr<Expression>> && e)
: Expression(location), elements(std::move(e)) {} : Expression(location), elements(std::move(e)) {}
Value do_evaluate(const std::shared_ptr<Context> & context) const override { Value do_evaluate(const std::shared_ptr<Context> & context) const override {
auto result = Value::array(); auto result = Value::array();
for (const auto& e : elements) { for (const auto& e : elements) {
if (!e) throw std::runtime_error("Array element is null");
result.push_back(e->evaluate(context)); result.push_back(e->evaluate(context));
} }
return result; return result;
@ -1036,13 +1056,15 @@ public:
}; };
class DictExpr : public Expression { class DictExpr : public Expression {
std::vector<std::pair<std::unique_ptr<Expression>, std::unique_ptr<Expression>>> elements; std::vector<std::pair<std::shared_ptr<Expression>, std::shared_ptr<Expression>>> elements;
public: public:
DictExpr(const Location & location, std::vector<std::pair<std::unique_ptr<Expression>, std::unique_ptr<Expression>>> && e) DictExpr(const Location & location, std::vector<std::pair<std::shared_ptr<Expression>, std::shared_ptr<Expression>>> && e)
: Expression(location), elements(std::move(e)) {} : Expression(location), elements(std::move(e)) {}
Value do_evaluate(const std::shared_ptr<Context> & context) const override { Value do_evaluate(const std::shared_ptr<Context> & context) const override {
auto result = Value::object(); auto result = Value::object();
for (const auto& e : elements) { for (const auto& e : elements) {
if (!e.first) throw std::runtime_error("Dict key is null");
if (!e.second) throw std::runtime_error("Dict value is null");
result.set(e.first->evaluate(context), e.second->evaluate(context)); result.set(e.first->evaluate(context), e.second->evaluate(context));
} }
return result; return result;
@ -1051,8 +1073,8 @@ public:
class SliceExpr : public Expression { class SliceExpr : public Expression {
public: public:
std::unique_ptr<Expression> start, end; std::shared_ptr<Expression> start, end;
SliceExpr(const Location & location, std::unique_ptr<Expression> && s, std::unique_ptr<Expression> && e) SliceExpr(const Location & location, std::shared_ptr<Expression> && s, std::shared_ptr<Expression> && e)
: Expression(location), start(std::move(s)), end(std::move(e)) {} : Expression(location), start(std::move(s)), end(std::move(e)) {}
Value do_evaluate(const std::shared_ptr<Context> &) const override { Value do_evaluate(const std::shared_ptr<Context> &) const override {
throw std::runtime_error("SliceExpr not implemented"); throw std::runtime_error("SliceExpr not implemented");
@ -1060,12 +1082,14 @@ public:
}; };
class SubscriptExpr : public Expression { class SubscriptExpr : public Expression {
std::unique_ptr<Expression> base; std::shared_ptr<Expression> base;
std::unique_ptr<Expression> index; std::shared_ptr<Expression> index;
public: public:
SubscriptExpr(const Location & location, std::unique_ptr<Expression> && b, std::unique_ptr<Expression> && i) SubscriptExpr(const Location & location, std::shared_ptr<Expression> && b, std::shared_ptr<Expression> && i)
: Expression(location), base(std::move(b)), index(std::move(i)) {} : Expression(location), base(std::move(b)), index(std::move(i)) {}
Value do_evaluate(const std::shared_ptr<Context> & context) const override { Value do_evaluate(const std::shared_ptr<Context> & context) const override {
if (!base) throw std::runtime_error("SubscriptExpr.base is null");
if (!index) throw std::runtime_error("SubscriptExpr.index is null");
auto target_value = base->evaluate(context); auto target_value = base->evaluate(context);
if (auto slice = dynamic_cast<SliceExpr*>(index.get())) { if (auto slice = dynamic_cast<SliceExpr*>(index.get())) {
if (!target_value.is_array()) throw std::runtime_error("Subscripting non-array"); if (!target_value.is_array()) throw std::runtime_error("Subscripting non-array");
@ -1094,12 +1118,13 @@ class UnaryOpExpr : public Expression {
public: public:
enum class Op { Plus, Minus, LogicalNot }; enum class Op { Plus, Minus, LogicalNot };
private: private:
std::unique_ptr<Expression> expr; std::shared_ptr<Expression> expr;
Op op; Op op;
public: public:
UnaryOpExpr(const Location & location, std::unique_ptr<Expression> && e, Op o) UnaryOpExpr(const Location & location, std::shared_ptr<Expression> && e, Op o)
: Expression(location), expr(std::move(e)), op(o) {} : Expression(location), expr(std::move(e)), op(o) {}
Value do_evaluate(const std::shared_ptr<Context> & context) const override { Value do_evaluate(const std::shared_ptr<Context> & context) const override {
if (!expr) throw std::runtime_error("UnaryOpExpr.expr is null");
auto e = expr->evaluate(context); auto e = expr->evaluate(context);
switch (op) { switch (op) {
case Op::Plus: return e; case Op::Plus: return e;
@ -1114,13 +1139,15 @@ class BinaryOpExpr : public Expression {
public: public:
enum class Op { StrConcat, Add, Sub, Mul, MulMul, Div, DivDiv, Mod, Eq, Ne, Lt, Gt, Le, Ge, And, Or, In, NotIn, Is, IsNot }; enum class Op { StrConcat, Add, Sub, Mul, MulMul, Div, DivDiv, Mod, Eq, Ne, Lt, Gt, Le, Ge, And, Or, In, NotIn, Is, IsNot };
private: private:
std::unique_ptr<Expression> left; std::shared_ptr<Expression> left;
std::unique_ptr<Expression> right; std::shared_ptr<Expression> right;
Op op; Op op;
public: public:
BinaryOpExpr(const Location & location, std::unique_ptr<Expression> && l, std::unique_ptr<Expression> && r, Op o) BinaryOpExpr(const Location & location, std::shared_ptr<Expression> && l, std::shared_ptr<Expression> && r, Op o)
: Expression(location), left(std::move(l)), right(std::move(r)), op(o) {} : Expression(location), left(std::move(l)), right(std::move(r)), op(o) {}
Value do_evaluate(const std::shared_ptr<Context> & context) const override { Value do_evaluate(const std::shared_ptr<Context> & context) const override {
if (!left) throw std::runtime_error("BinaryOpExpr.left is null");
if (!right) throw std::runtime_error("BinaryOpExpr.right is null");
auto l = left->evaluate(context); auto l = left->evaluate(context);
auto do_eval = [&](const Value & l) -> Value { auto do_eval = [&](const Value & l) -> Value {
@ -1210,13 +1237,15 @@ static std::string html_escape(const std::string & s) {
} }
class MethodCallExpr : public Expression { class MethodCallExpr : public Expression {
std::unique_ptr<Expression> object; std::shared_ptr<Expression> object;
std::unique_ptr<VariableExpr> method; std::shared_ptr<VariableExpr> method;
Expression::Arguments args; Expression::Arguments args;
public: public:
MethodCallExpr(const Location & location, std::unique_ptr<Expression> && obj, std::unique_ptr<VariableExpr> && m, Expression::Arguments && a) MethodCallExpr(const Location & location, std::shared_ptr<Expression> && obj, std::shared_ptr<VariableExpr> && m, Expression::Arguments && a)
: Expression(location), object(std::move(obj)), method(std::move(m)), args(std::move(a)) {} : Expression(location), object(std::move(obj)), method(std::move(m)), args(std::move(a)) {}
Value do_evaluate(const std::shared_ptr<Context> & context) const override { Value do_evaluate(const std::shared_ptr<Context> & context) const override {
if (!object) throw std::runtime_error("MethodCallExpr.object is null");
if (!method) throw std::runtime_error("MethodCallExpr.method is null");
auto obj = object->evaluate(context); auto obj = object->evaluate(context);
if (obj.is_array()) { if (obj.is_array()) {
if (method->get_name() == "append") { if (method->get_name() == "append") {
@ -1279,11 +1308,12 @@ public:
class CallExpr : public Expression { class CallExpr : public Expression {
public: public:
std::unique_ptr<Expression> object; std::shared_ptr<Expression> object;
Expression::Arguments args; Expression::Arguments args;
CallExpr(const Location & location, std::unique_ptr<Expression> && obj, Expression::Arguments && a) CallExpr(const Location & location, std::shared_ptr<Expression> && obj, Expression::Arguments && a)
: Expression(location), object(std::move(obj)), args(std::move(a)) {} : Expression(location), object(std::move(obj)), args(std::move(a)) {}
Value do_evaluate(const std::shared_ptr<Context> & context) const override { Value do_evaluate(const std::shared_ptr<Context> & context) const override {
if (!object) throw std::runtime_error("CallExpr.object is null");
auto obj = object->evaluate(context); auto obj = object->evaluate(context);
if (!obj.is_callable()) { if (!obj.is_callable()) {
throw std::runtime_error("Object is not callable: " + obj.dump(2)); throw std::runtime_error("Object is not callable: " + obj.dump(2));
@ -1294,14 +1324,15 @@ public:
}; };
class FilterExpr : public Expression { class FilterExpr : public Expression {
std::vector<std::unique_ptr<Expression>> parts; std::vector<std::shared_ptr<Expression>> parts;
public: public:
FilterExpr(const Location & location, std::vector<std::unique_ptr<Expression>> && p) FilterExpr(const Location & location, std::vector<std::shared_ptr<Expression>> && p)
: Expression(location), parts(std::move(p)) {} : Expression(location), parts(std::move(p)) {}
Value do_evaluate(const std::shared_ptr<Context> & context) const override { Value do_evaluate(const std::shared_ptr<Context> & context) const override {
Value result; Value result;
bool first = true; bool first = true;
for (const auto& part : parts) { for (const auto& part : parts) {
if (!part) throw std::runtime_error("FilterExpr.part is null");
if (first) { if (first) {
first = false; first = false;
result = part->evaluate(context); result = part->evaluate(context);
@ -1322,7 +1353,7 @@ public:
return result; return result;
} }
void prepend(std::unique_ptr<Expression> && e) { void prepend(std::shared_ptr<Expression> && e) {
parts.insert(parts.begin(), std::move(e)); parts.insert(parts.begin(), std::move(e));
} }
}; };
@ -1375,7 +1406,7 @@ private:
escape = true; escape = true;
} else if (*it == quote) { } else if (*it == quote) {
++it; ++it;
return nonstd_make_unique<std::string>(result); return nonstd_make_unique<std::string>(std::move(result));
} else { } else {
result += *it; result += *it;
} }
@ -1429,37 +1460,37 @@ private:
} }
/** integer, float, bool, string */ /** integer, float, bool, string */
std::unique_ptr<Value> parseConstant() { std::shared_ptr<Value> parseConstant() {
auto start = it; auto start = it;
consumeSpaces(); consumeSpaces();
if (it == end) return nullptr; if (it == end) return nullptr;
if (*it == '"' || *it == '\'') { if (*it == '"' || *it == '\'') {
auto str = parseString(); auto str = parseString();
if (str) return nonstd_make_unique<Value>(*str); if (str) return std::make_shared<Value>(*str);
} }
static std::regex prim_tok(R"(true\b|True\b|false\b|False\b|None\b)"); static std::regex prim_tok(R"(true\b|True\b|false\b|False\b|None\b)");
auto token = consumeToken(prim_tok); auto token = consumeToken(prim_tok);
if (!token.empty()) { if (!token.empty()) {
if (token == "true" || token == "True") return nonstd_make_unique<Value>(true); if (token == "true" || token == "True") return std::make_shared<Value>(true);
if (token == "false" || token == "False") return nonstd_make_unique<Value>(false); if (token == "false" || token == "False") return std::make_shared<Value>(false);
if (token == "None") return nonstd_make_unique<Value>(nullptr); if (token == "None") return std::make_shared<Value>(nullptr);
throw std::runtime_error("Unknown constant token: " + token); throw std::runtime_error("Unknown constant token: " + token);
} }
auto number = parseNumber(it, end); auto number = parseNumber(it, end);
if (!number.is_null()) return nonstd_make_unique<Value>(number); if (!number.is_null()) return std::make_shared<Value>(number);
it = start; it = start;
return nullptr; return nullptr;
} }
class expression_parsing_error : public std::runtime_error { class expression_parsing_error : public std::runtime_error {
const CharIterator it; const CharIterator it;
public: public:
expression_parsing_error(const std::string & message, const CharIterator & it) expression_parsing_error(const std::string & message, const CharIterator & it)
: std::runtime_error(message), it(it) {} : std::runtime_error(message), it(it) {}
size_t get_pos(const CharIterator & begin) const { size_t get_pos(const CharIterator & begin) const {
return std::distance(begin, it); return std::distance(begin, it);
} }
}; };
@ -1510,7 +1541,7 @@ private:
return ""; return "";
} }
std::unique_ptr<Expression> parseExpression(bool allow_if_expr = true) { std::shared_ptr<Expression> parseExpression(bool allow_if_expr = true) {
auto left = parseLogicalOr(); auto left = parseLogicalOr();
if (it == end) return left; if (it == end) return left;
@ -1523,19 +1554,19 @@ private:
auto location = get_location(); auto location = get_location();
auto if_expr = parseIfExpression(); auto if_expr = parseIfExpression();
return nonstd_make_unique<IfExpr>(location, std::move(if_expr.first), std::move(left), std::move(if_expr.second)); return std::make_shared<IfExpr>(location, std::move(if_expr.first), std::move(left), std::move(if_expr.second));
} }
Location get_location() const { Location get_location() const {
return {template_str, (size_t) std::distance(start, it)}; return {template_str, (size_t) std::distance(start, it)};
} }
std::pair<std::unique_ptr<Expression>, std::unique_ptr<Expression>> parseIfExpression() { std::pair<std::shared_ptr<Expression>, std::shared_ptr<Expression>> parseIfExpression() {
auto condition = parseLogicalOr(); auto condition = parseLogicalOr();
if (!condition) throw std::runtime_error("Expected condition expression"); if (!condition) throw std::runtime_error("Expected condition expression");
static std::regex else_tok(R"(else\b)"); static std::regex else_tok(R"(else\b)");
std::unique_ptr<Expression> else_expr; std::shared_ptr<Expression> else_expr;
if (!consumeToken(else_tok).empty()) { if (!consumeToken(else_tok).empty()) {
else_expr = parseExpression(); else_expr = parseExpression();
if (!else_expr) throw std::runtime_error("Expected 'else' expression"); if (!else_expr) throw std::runtime_error("Expected 'else' expression");
@ -1543,7 +1574,7 @@ private:
return std::make_pair(std::move(condition), std::move(else_expr)); return std::make_pair(std::move(condition), std::move(else_expr));
} }
std::unique_ptr<Expression> parseLogicalOr() { std::shared_ptr<Expression> parseLogicalOr() {
auto left = parseLogicalAnd(); auto left = parseLogicalAnd();
if (!left) throw std::runtime_error("Expected left side of 'logical or' expression"); if (!left) throw std::runtime_error("Expected left side of 'logical or' expression");
@ -1552,24 +1583,24 @@ private:
while (!consumeToken(or_tok).empty()) { while (!consumeToken(or_tok).empty()) {
auto right = parseLogicalAnd(); auto right = parseLogicalAnd();
if (!right) throw std::runtime_error("Expected right side of 'or' expression"); if (!right) throw std::runtime_error("Expected right side of 'or' expression");
left = nonstd_make_unique<BinaryOpExpr>(location, std::move(left), std::move(right), BinaryOpExpr::Op::Or); left = std::make_shared<BinaryOpExpr>(location, std::move(left), std::move(right), BinaryOpExpr::Op::Or);
} }
return left; return left;
} }
std::unique_ptr<Expression> parseLogicalNot() { std::shared_ptr<Expression> parseLogicalNot() {
static std::regex not_tok(R"(not\b)"); static std::regex not_tok(R"(not\b)");
auto location = get_location(); auto location = get_location();
if (!consumeToken(not_tok).empty()) { if (!consumeToken(not_tok).empty()) {
auto sub = parseLogicalNot(); auto sub = parseLogicalNot();
if (!sub) throw std::runtime_error("Expected expression after 'not' keyword"); if (!sub) throw std::runtime_error("Expected expression after 'not' keyword");
return nonstd_make_unique<UnaryOpExpr>(location, std::move(sub), UnaryOpExpr::Op::LogicalNot); return std::make_shared<UnaryOpExpr>(location, std::move(sub), UnaryOpExpr::Op::LogicalNot);
} }
return parseLogicalCompare(); return parseLogicalCompare();
} }
std::unique_ptr<Expression> parseLogicalAnd() { std::shared_ptr<Expression> parseLogicalAnd() {
auto left = parseLogicalNot(); auto left = parseLogicalNot();
if (!left) throw std::runtime_error("Expected left side of 'logical and' expression"); if (!left) throw std::runtime_error("Expected left side of 'logical and' expression");
@ -1578,12 +1609,12 @@ private:
while (!consumeToken(and_tok).empty()) { while (!consumeToken(and_tok).empty()) {
auto right = parseLogicalNot(); auto right = parseLogicalNot();
if (!right) throw std::runtime_error("Expected right side of 'and' expression"); if (!right) throw std::runtime_error("Expected right side of 'and' expression");
left = nonstd_make_unique<BinaryOpExpr>(location, std::move(left), std::move(right), BinaryOpExpr::Op::And); left = std::make_shared<BinaryOpExpr>(location, std::move(left), std::move(right), BinaryOpExpr::Op::And);
} }
return left; return left;
} }
std::unique_ptr<Expression> parseLogicalCompare() { std::shared_ptr<Expression> parseLogicalCompare() {
auto left = parseStringConcat(); auto left = parseStringConcat();
if (!left) throw std::runtime_error("Expected left side of 'logical compare' expression"); if (!left) throw std::runtime_error("Expected left side of 'logical compare' expression");
@ -1598,7 +1629,7 @@ private:
auto identifier = parseIdentifier(); auto identifier = parseIdentifier();
if (!identifier) throw std::runtime_error("Expected identifier after 'is' keyword"); if (!identifier) throw std::runtime_error("Expected identifier after 'is' keyword");
return nonstd_make_unique<BinaryOpExpr>( return std::make_shared<BinaryOpExpr>(
left->location, left->location,
std::move(left), std::move(identifier), std::move(left), std::move(identifier),
negated ? BinaryOpExpr::Op::IsNot : BinaryOpExpr::Op::Is); negated ? BinaryOpExpr::Op::IsNot : BinaryOpExpr::Op::Is);
@ -1615,7 +1646,7 @@ private:
else if (op_str == "in") op = BinaryOpExpr::Op::In; else if (op_str == "in") op = BinaryOpExpr::Op::In;
else if (op_str.substr(0, 3) == "not") op = BinaryOpExpr::Op::NotIn; else if (op_str.substr(0, 3) == "not") op = BinaryOpExpr::Op::NotIn;
else throw std::runtime_error("Unknown comparison operator: " + op_str); else throw std::runtime_error("Unknown comparison operator: " + op_str);
left = nonstd_make_unique<BinaryOpExpr>(get_location(), std::move(left), std::move(right), op); left = std::make_shared<BinaryOpExpr>(get_location(), std::move(left), std::move(right), op);
} }
return left; return left;
} }
@ -1688,16 +1719,16 @@ private:
throw std::runtime_error("Expected closing parenthesis in call args"); throw std::runtime_error("Expected closing parenthesis in call args");
} }
std::unique_ptr<VariableExpr> parseIdentifier() { std::shared_ptr<VariableExpr> parseIdentifier() {
static std::regex ident_regex(R"((?!(?:not|is|and|or|del)\b)[a-zA-Z_]\w*)"); static std::regex ident_regex(R"((?!(?:not|is|and|or|del)\b)[a-zA-Z_]\w*)");
auto location = get_location(); auto location = get_location();
auto ident = consumeToken(ident_regex); auto ident = consumeToken(ident_regex);
if (ident.empty()) if (ident.empty())
return nullptr; return nullptr;
return nonstd_make_unique<VariableExpr>(location, ident); return std::make_shared<VariableExpr>(location, ident);
} }
std::unique_ptr<Expression> parseStringConcat() { std::shared_ptr<Expression> parseStringConcat() {
auto left = parseMathPow(); auto left = parseMathPow();
if (!left) throw std::runtime_error("Expected left side of 'string concat' expression"); if (!left) throw std::runtime_error("Expected left side of 'string concat' expression");
@ -1705,24 +1736,24 @@ private:
if (!consumeToken(concat_tok).empty()) { if (!consumeToken(concat_tok).empty()) {
auto right = parseLogicalAnd(); auto right = parseLogicalAnd();
if (!right) throw std::runtime_error("Expected right side of 'string concat' expression"); if (!right) throw std::runtime_error("Expected right side of 'string concat' expression");
left = nonstd_make_unique<BinaryOpExpr>(get_location(), std::move(left), std::move(right), BinaryOpExpr::Op::StrConcat); left = std::make_shared<BinaryOpExpr>(get_location(), std::move(left), std::move(right), BinaryOpExpr::Op::StrConcat);
} }
return left; return left;
} }
std::unique_ptr<Expression> parseMathPow() { std::shared_ptr<Expression> parseMathPow() {
auto left = parseMathPlusMinus(); auto left = parseMathPlusMinus();
if (!left) throw std::runtime_error("Expected left side of 'math pow' expression"); if (!left) throw std::runtime_error("Expected left side of 'math pow' expression");
while (!consumeToken("**").empty()) { while (!consumeToken("**").empty()) {
auto right = parseMathPlusMinus(); auto right = parseMathPlusMinus();
if (!right) throw std::runtime_error("Expected right side of 'math pow' expression"); if (!right) throw std::runtime_error("Expected right side of 'math pow' expression");
left = nonstd_make_unique<BinaryOpExpr>(get_location(), std::move(left), std::move(right), BinaryOpExpr::Op::MulMul); left = std::make_shared<BinaryOpExpr>(get_location(), std::move(left), std::move(right), BinaryOpExpr::Op::MulMul);
} }
return left; return left;
} }
std::unique_ptr<Expression> parseMathPlusMinus() { std::shared_ptr<Expression> parseMathPlusMinus() {
static std::regex plus_minus_tok(R"(\+|-(?![}%#]\}))"); static std::regex plus_minus_tok(R"(\+|-(?![}%#]\}))");
auto left = parseMathMulDiv(); auto left = parseMathMulDiv();
@ -1732,12 +1763,12 @@ private:
auto right = parseMathMulDiv(); auto right = parseMathMulDiv();
if (!right) throw std::runtime_error("Expected right side of 'math plus/minus' expression"); if (!right) throw std::runtime_error("Expected right side of 'math plus/minus' expression");
auto op = op_str == "+" ? BinaryOpExpr::Op::Add : BinaryOpExpr::Op::Sub; auto op = op_str == "+" ? BinaryOpExpr::Op::Add : BinaryOpExpr::Op::Sub;
left = nonstd_make_unique<BinaryOpExpr>(get_location(), std::move(left), std::move(right), op); left = std::make_shared<BinaryOpExpr>(get_location(), std::move(left), std::move(right), op);
} }
return left; return left;
} }
std::unique_ptr<Expression> parseMathMulDiv() { std::shared_ptr<Expression> parseMathMulDiv() {
auto left = parseMathUnaryPlusMinus(); auto left = parseMathUnaryPlusMinus();
if (!left) throw std::runtime_error("Expected left side of 'math mul/div' expression"); if (!left) throw std::runtime_error("Expected left side of 'math mul/div' expression");
@ -1751,7 +1782,7 @@ private:
: op_str == "/" ? BinaryOpExpr::Op::Div : op_str == "/" ? BinaryOpExpr::Op::Div
: op_str == "//" ? BinaryOpExpr::Op::DivDiv : op_str == "//" ? BinaryOpExpr::Op::DivDiv
: BinaryOpExpr::Op::Mod; : BinaryOpExpr::Op::Mod;
left = nonstd_make_unique<BinaryOpExpr>(get_location(), std::move(left), std::move(right), op); left = std::make_shared<BinaryOpExpr>(get_location(), std::move(left), std::move(right), op);
} }
if (!consumeToken("|").empty()) { if (!consumeToken("|").empty()) {
@ -1760,20 +1791,20 @@ private:
filter->prepend(std::move(left)); filter->prepend(std::move(left));
return expr; return expr;
} else { } else {
std::vector<std::unique_ptr<Expression>> parts; std::vector<std::shared_ptr<Expression>> parts;
parts.emplace_back(std::move(left)); parts.emplace_back(std::move(left));
parts.emplace_back(std::move(expr)); parts.emplace_back(std::move(expr));
return nonstd_make_unique<FilterExpr>(get_location(), std::move(parts)); return std::make_shared<FilterExpr>(get_location(), std::move(parts));
} }
} }
return left; return left;
} }
std::unique_ptr<Expression> call_func(const std::string & name, Expression::Arguments && args) const { std::shared_ptr<Expression> call_func(const std::string & name, Expression::Arguments && args) const {
return nonstd_make_unique<CallExpr>(get_location(), nonstd_make_unique<VariableExpr>(get_location(), name), std::move(args)); return std::make_shared<CallExpr>(get_location(), std::make_shared<VariableExpr>(get_location(), name), std::move(args));
} }
std::unique_ptr<Expression> parseMathUnaryPlusMinus() { std::shared_ptr<Expression> parseMathUnaryPlusMinus() {
static std::regex unary_plus_minus_tok(R"(\+|-(?![}%#]\}))"); static std::regex unary_plus_minus_tok(R"(\+|-(?![}%#]\}))");
auto op_str = consumeToken(unary_plus_minus_tok); auto op_str = consumeToken(unary_plus_minus_tok);
auto expr = parseValueExpression(); auto expr = parseValueExpression();
@ -1781,19 +1812,19 @@ private:
if (!op_str.empty()) { if (!op_str.empty()) {
auto op = op_str == "+" ? UnaryOpExpr::Op::Plus : UnaryOpExpr::Op::Minus; auto op = op_str == "+" ? UnaryOpExpr::Op::Plus : UnaryOpExpr::Op::Minus;
return nonstd_make_unique<UnaryOpExpr>(get_location(), std::move(expr), op); return std::make_shared<UnaryOpExpr>(get_location(), std::move(expr), op);
} }
return expr; return expr;
} }
std::unique_ptr<Expression> parseValueExpression() { std::shared_ptr<Expression> parseValueExpression() {
auto parseValue = [&]() -> std::unique_ptr<Expression> { auto parseValue = [&]() -> std::shared_ptr<Expression> {
auto location = get_location(); auto location = get_location();
auto constant = parseConstant(); auto constant = parseConstant();
if (constant) return nonstd_make_unique<LiteralExpr>(location, *constant); if (constant) return std::make_shared<LiteralExpr>(location, *constant);
static std::regex null_regex(R"(null\b)"); static std::regex null_regex(R"(null\b)");
if (!consumeToken(null_regex).empty()) return nonstd_make_unique<LiteralExpr>(location, Value()); if (!consumeToken(null_regex).empty()) return std::make_shared<LiteralExpr>(location, Value());
auto identifier = parseIdentifier(); auto identifier = parseIdentifier();
if (identifier) return identifier; if (identifier) return identifier;
@ -1814,19 +1845,19 @@ private:
while (it != end && consumeSpaces() && peekSymbols({ "[", "." })) { while (it != end && consumeSpaces() && peekSymbols({ "[", "." })) {
if (!consumeToken("[").empty()) { if (!consumeToken("[").empty()) {
std::unique_ptr<Expression> index; std::shared_ptr<Expression> index;
if (!consumeToken(":").empty()) { if (!consumeToken(":").empty()) {
auto slice_end = parseExpression(); auto slice_end = parseExpression();
index = nonstd_make_unique<SliceExpr>(slice_end->location, nullptr, std::move(slice_end)); index = std::make_shared<SliceExpr>(slice_end->location, nullptr, std::move(slice_end));
} else { } else {
auto slice_start = parseExpression(); auto slice_start = parseExpression();
if (!consumeToken(":").empty()) { if (!consumeToken(":").empty()) {
consumeSpaces(); consumeSpaces();
if (peekSymbols({ "]" })) { if (peekSymbols({ "]" })) {
index = nonstd_make_unique<SliceExpr>(slice_start->location, std::move(slice_start), nullptr); index = std::make_shared<SliceExpr>(slice_start->location, std::move(slice_start), nullptr);
} else { } else {
auto slice_end = parseExpression(); auto slice_end = parseExpression();
index = nonstd_make_unique<SliceExpr>(slice_start->location, std::move(slice_start), std::move(slice_end)); index = std::make_shared<SliceExpr>(slice_start->location, std::move(slice_start), std::move(slice_end));
} }
} else { } else {
index = std::move(slice_start); index = std::move(slice_start);
@ -1835,7 +1866,7 @@ private:
if (!index) throw std::runtime_error("Empty index in subscript"); if (!index) throw std::runtime_error("Empty index in subscript");
if (consumeToken("]").empty()) throw std::runtime_error("Expected closing bracket in subscript"); if (consumeToken("]").empty()) throw std::runtime_error("Expected closing bracket in subscript");
value = nonstd_make_unique<SubscriptExpr>(value->location, std::move(value), std::move(index)); value = std::make_shared<SubscriptExpr>(value->location, std::move(value), std::move(index));
} else if (!consumeToken(".").empty()) { } else if (!consumeToken(".").empty()) {
auto identifier = parseIdentifier(); auto identifier = parseIdentifier();
if (!identifier) throw std::runtime_error("Expected identifier in subscript"); if (!identifier) throw std::runtime_error("Expected identifier in subscript");
@ -1843,10 +1874,10 @@ private:
consumeSpaces(); consumeSpaces();
if (peekSymbols({ "(" })) { if (peekSymbols({ "(" })) {
auto callParams = parseCallArgs(); auto callParams = parseCallArgs();
value = nonstd_make_unique<MethodCallExpr>(identifier->location, std::move(value), std::move(identifier), std::move(callParams)); value = std::make_shared<MethodCallExpr>(identifier->location, std::move(value), std::move(identifier), std::move(callParams));
} else { } else {
auto key = nonstd_make_unique<LiteralExpr>(identifier->location, Value(identifier->get_name())); auto key = std::make_shared<LiteralExpr>(identifier->location, Value(identifier->get_name()));
value = nonstd_make_unique<SubscriptExpr>(identifier->location, std::move(value), std::move(key)); value = std::make_shared<SubscriptExpr>(identifier->location, std::move(value), std::move(key));
} }
} }
consumeSpaces(); consumeSpaces();
@ -1855,12 +1886,12 @@ private:
if (peekSymbols({ "(" })) { if (peekSymbols({ "(" })) {
auto location = get_location(); auto location = get_location();
auto callParams = parseCallArgs(); auto callParams = parseCallArgs();
value = nonstd_make_unique<CallExpr>(location, std::move(value), std::move(callParams)); value = std::make_shared<CallExpr>(location, std::move(value), std::move(callParams));
} }
return value; return value;
} }
std::unique_ptr<Expression> parseBracedExpressionOrArray() { std::shared_ptr<Expression> parseBracedExpressionOrArray() {
if (consumeToken("(").empty()) return nullptr; if (consumeToken("(").empty()) return nullptr;
auto expr = parseExpression(); auto expr = parseExpression();
@ -1870,7 +1901,7 @@ private:
return expr; // Drop the parentheses return expr; // Drop the parentheses
} }
std::vector<std::unique_ptr<Expression>> tuple; std::vector<std::shared_ptr<Expression>> tuple;
tuple.emplace_back(std::move(expr)); tuple.emplace_back(std::move(expr));
while (it != end) { while (it != end) {
@ -1880,18 +1911,18 @@ private:
tuple.push_back(std::move(next)); tuple.push_back(std::move(next));
if (!consumeToken(")").empty()) { if (!consumeToken(")").empty()) {
return nonstd_make_unique<ArrayExpr>(get_location(), std::move(tuple)); return std::make_shared<ArrayExpr>(get_location(), std::move(tuple));
} }
} }
throw std::runtime_error("Expected closing parenthesis"); throw std::runtime_error("Expected closing parenthesis");
} }
std::unique_ptr<Expression> parseArray() { std::shared_ptr<Expression> parseArray() {
if (consumeToken("[").empty()) return nullptr; if (consumeToken("[").empty()) return nullptr;
std::vector<std::unique_ptr<Expression>> elements; std::vector<std::shared_ptr<Expression>> elements;
if (!consumeToken("]").empty()) { if (!consumeToken("]").empty()) {
return nonstd_make_unique<ArrayExpr>(get_location(), std::move(elements)); return std::make_shared<ArrayExpr>(get_location(), std::move(elements));
} }
auto first_expr = parseExpression(); auto first_expr = parseExpression();
if (!first_expr) throw std::runtime_error("Expected first expression in array"); if (!first_expr) throw std::runtime_error("Expected first expression in array");
@ -1903,7 +1934,7 @@ private:
if (!expr) throw std::runtime_error("Expected expression in array"); if (!expr) throw std::runtime_error("Expected expression in array");
elements.push_back(std::move(expr)); elements.push_back(std::move(expr));
} else if (!consumeToken("]").empty()) { } else if (!consumeToken("]").empty()) {
return nonstd_make_unique<ArrayExpr>(get_location(), std::move(elements)); return std::make_shared<ArrayExpr>(get_location(), std::move(elements));
} else { } else {
throw std::runtime_error("Expected comma or closing bracket in array"); throw std::runtime_error("Expected comma or closing bracket in array");
} }
@ -1911,12 +1942,12 @@ private:
throw std::runtime_error("Expected closing bracket"); throw std::runtime_error("Expected closing bracket");
} }
std::unique_ptr<Expression> parseDictionary() { std::shared_ptr<Expression> parseDictionary() {
if (consumeToken("{").empty()) return nullptr; if (consumeToken("{").empty()) return nullptr;
std::vector<std::pair<std::unique_ptr<Expression>, std::unique_ptr<Expression>>> elements; std::vector<std::pair<std::shared_ptr<Expression>, std::shared_ptr<Expression>>> elements;
if (!consumeToken("}").empty()) { if (!consumeToken("}").empty()) {
return nonstd_make_unique<DictExpr>(get_location(), std::move(elements)); return std::make_shared<DictExpr>(get_location(), std::move(elements));
} }
auto parseKeyValuePair = [&]() { auto parseKeyValuePair = [&]() {
@ -1934,7 +1965,7 @@ private:
if (!consumeToken(",").empty()) { if (!consumeToken(",").empty()) {
parseKeyValuePair(); parseKeyValuePair();
} else if (!consumeToken("}").empty()) { } else if (!consumeToken("}").empty()) {
return nonstd_make_unique<DictExpr>(get_location(), std::move(elements)); return std::make_shared<DictExpr>(get_location(), std::move(elements));
} else { } else {
throw std::runtime_error("Expected comma or closing brace in dictionary"); throw std::runtime_error("Expected comma or closing brace in dictionary");
} }
@ -2051,7 +2082,7 @@ private:
auto iterable = parseExpression(/* allow_if_expr = */ false); auto iterable = parseExpression(/* allow_if_expr = */ false);
if (!iterable) throw std::runtime_error("Expected iterable in for block"); if (!iterable) throw std::runtime_error("Expected iterable in for block");
std::unique_ptr<Expression> condition; std::shared_ptr<Expression> condition;
if (!consumeToken(if_tok).empty()) { if (!consumeToken(if_tok).empty()) {
condition = parseExpression(); condition = parseExpression();
} }
@ -2067,7 +2098,7 @@ private:
std::string ns; std::string ns;
std::vector<std::string> var_names; std::vector<std::string> var_names;
std::unique_ptr<Expression> value; std::shared_ptr<Expression> value;
if (!(group = consumeTokenGroups(namespaced_var_regex)).empty()) { if (!(group = consumeTokenGroups(namespaced_var_regex)).empty()) {
ns = group[1]; ns = group[1];
var_names.push_back(group[2]); var_names.push_back(group[2]);
@ -2114,17 +2145,17 @@ private:
} }
} }
std::unique_ptr<TemplateNode> parseTemplate( std::shared_ptr<TemplateNode> parseTemplate(
const TemplateTokenIterator & begin, const TemplateTokenIterator & begin,
TemplateTokenIterator & it, TemplateTokenIterator & it,
const TemplateTokenIterator & end, const TemplateTokenIterator & end,
bool fully = false) const { bool fully = false) const {
std::vector<std::unique_ptr<TemplateNode>> children; std::vector<std::shared_ptr<TemplateNode>> children;
while (it != end) { while (it != end) {
const auto start = it; const auto start = it;
const auto & token = *(it++); const auto & token = *(it++);
if (auto if_token = dynamic_cast<IfTemplateToken*>(token.get())) { if (auto if_token = dynamic_cast<IfTemplateToken*>(token.get())) {
std::vector<std::pair<std::unique_ptr<Expression>, std::unique_ptr<TemplateNode>>> cascade; std::vector<std::pair<std::shared_ptr<Expression>, std::shared_ptr<TemplateNode>>> cascade;
cascade.emplace_back(std::move(if_token->condition), parseTemplate(begin, it, end)); cascade.emplace_back(std::move(if_token->condition), parseTemplate(begin, it, end));
while (it != end && (*it)->type == TemplateToken::Type::Elif) { while (it != end && (*it)->type == TemplateToken::Type::Elif) {
@ -2138,17 +2169,17 @@ private:
if (it == end || (*(it++))->type != TemplateToken::Type::EndIf) { if (it == end || (*(it++))->type != TemplateToken::Type::EndIf) {
throw unterminated(**start); throw unterminated(**start);
} }
children.emplace_back(nonstd_make_unique<IfNode>(token->location, std::move(cascade))); children.emplace_back(std::make_shared<IfNode>(token->location, std::move(cascade)));
} else if (auto for_token = dynamic_cast<ForTemplateToken*>(token.get())) { } else if (auto for_token = dynamic_cast<ForTemplateToken*>(token.get())) {
auto body = parseTemplate(begin, it, end); auto body = parseTemplate(begin, it, end);
auto else_body = std::unique_ptr<TemplateNode>(); auto else_body = std::shared_ptr<TemplateNode>();
if (it != end && (*it)->type == TemplateToken::Type::Else) { if (it != end && (*it)->type == TemplateToken::Type::Else) {
else_body = parseTemplate(begin, ++it, end); else_body = parseTemplate(begin, ++it, end);
} }
if (it == end || (*(it++))->type != TemplateToken::Type::EndFor) { if (it == end || (*(it++))->type != TemplateToken::Type::EndFor) {
throw unterminated(**start); throw unterminated(**start);
} }
children.emplace_back(nonstd_make_unique<ForNode>(token->location, std::move(for_token->var_names), std::move(for_token->iterable), std::move(for_token->condition), std::move(body), for_token->recursive, std::move(else_body))); children.emplace_back(std::make_shared<ForNode>(token->location, std::move(for_token->var_names), std::move(for_token->iterable), std::move(for_token->condition), std::move(body), for_token->recursive, std::move(else_body)));
} else if (auto text_token = dynamic_cast<TextTemplateToken*>(token.get())) { } else if (auto text_token = dynamic_cast<TextTemplateToken*>(token.get())) {
SpaceHandling pre_space = (it - 1) != begin ? (*(it - 2))->post_space : SpaceHandling::Keep; SpaceHandling pre_space = (it - 1) != begin ? (*(it - 2))->post_space : SpaceHandling::Keep;
SpaceHandling post_space = it != end ? (*it)->pre_space : SpaceHandling::Keep; SpaceHandling post_space = it != end ? (*it)->pre_space : SpaceHandling::Keep;
@ -2173,25 +2204,28 @@ private:
static std::regex r(R"(\r?\n$)"); static std::regex r(R"(\r?\n$)");
text = std::regex_replace(text, r, ""); // Strip one trailing newline text = std::regex_replace(text, r, ""); // Strip one trailing newline
} }
children.emplace_back(nonstd_make_unique<TextNode>(token->location, text)); children.emplace_back(std::make_shared<TextNode>(token->location, text));
} else if (auto expr_token = dynamic_cast<ExpressionTemplateToken*>(token.get())) { } else if (auto expr_token = dynamic_cast<ExpressionTemplateToken*>(token.get())) {
children.emplace_back(nonstd_make_unique<ExpressionNode>(token->location, std::move(expr_token->expr))); children.emplace_back(std::make_shared<ExpressionNode>(token->location, std::move(expr_token->expr)));
} else if (auto set_token = dynamic_cast<SetTemplateToken*>(token.get())) { } else if (auto set_token = dynamic_cast<SetTemplateToken*>(token.get())) {
if (set_token->value) { if (set_token->value) {
children.emplace_back(nonstd_make_unique<SetNode>(token->location, set_token->ns, set_token->var_names, std::move(set_token->value), nullptr)); children.emplace_back(std::make_shared<SetNode>(token->location, set_token->ns, set_token->var_names, std::move(set_token->value)));
} else { } else {
auto value_template = parseTemplate(begin, it, end); auto value_template = parseTemplate(begin, it, end);
if (it == end || (*(it++))->type != TemplateToken::Type::EndSet) { if (it == end || (*(it++))->type != TemplateToken::Type::EndSet) {
throw unterminated(**start); throw unterminated(**start);
} }
children.emplace_back(nonstd_make_unique<SetNode>(token->location, set_token->ns, set_token->var_names, nullptr, std::move(value_template))); if (!set_token->ns.empty()) throw std::runtime_error("Namespaced set not supported in set with template value");
if (set_token->var_names.size() != 1) throw std::runtime_error("Structural assignment not supported in set with template value");
auto & name = set_token->var_names[0];
children.emplace_back(std::make_shared<SetTemplateNode>(token->location, name, std::move(value_template)));
} }
} else if (auto macro_token = dynamic_cast<MacroTemplateToken*>(token.get())) { } else if (auto macro_token = dynamic_cast<MacroTemplateToken*>(token.get())) {
auto body = parseTemplate(begin, it, end); auto body = parseTemplate(begin, it, end);
if (it == end || (*(it++))->type != TemplateToken::Type::EndMacro) { if (it == end || (*(it++))->type != TemplateToken::Type::EndMacro) {
throw unterminated(**start); throw unterminated(**start);
} }
children.emplace_back(nonstd_make_unique<MacroNode>(token->location, std::move(macro_token->name), std::move(macro_token->params), std::move(body))); children.emplace_back(std::make_shared<MacroNode>(token->location, std::move(macro_token->name), std::move(macro_token->params), std::move(body)));
} else if (dynamic_cast<CommentTemplateToken*>(token.get())) { } else if (dynamic_cast<CommentTemplateToken*>(token.get())) {
// Ignore comments // Ignore comments
} else if (dynamic_cast<EndForTemplateToken*>(token.get()) } else if (dynamic_cast<EndForTemplateToken*>(token.get())
@ -2210,17 +2244,17 @@ private:
throw unexpected(**it); throw unexpected(**it);
} }
if (children.empty()) { if (children.empty()) {
return nonstd_make_unique<TextNode>(Location { template_str, 0 }, std::string()); return std::make_shared<TextNode>(Location { template_str, 0 }, std::string());
} else if (children.size() == 1) { } else if (children.size() == 1) {
return std::move(children[0]); return std::move(children[0]);
} else { } else {
return nonstd_make_unique<SequenceNode>(children[0]->location(), std::move(children)); return std::make_shared<SequenceNode>(children[0]->location(), std::move(children));
} }
} }
public: public:
static std::unique_ptr<TemplateNode> parse(const std::string& template_str, const Options & options) { static std::shared_ptr<TemplateNode> parse(const std::string& template_str, const Options & options) {
Parser parser(std::make_shared<std::string>(template_str), options); Parser parser(std::make_shared<std::string>(template_str), options);
auto tokens = parser.tokenize(); auto tokens = parser.tokenize();
TemplateTokenIterator begin = tokens.begin(); TemplateTokenIterator begin = tokens.begin();

View File

@ -12,6 +12,29 @@
using json = nlohmann::ordered_json; using json = nlohmann::ordered_json;
llama_tool_call_style llama_tool_call_style_detect(const minja::chat_template & chat_template) {
const auto & src = chat_template.source();
if (src.find("<tool_call>") != std::string::npos) {
return Hermes2Pro;
} else if (src.find(">>>all") != std::string::npos) {
return FunctionaryV3Llama3;
} else if (src.find("<|start_header_id|>") != std::string::npos
&& src.find("<function=") != std::string::npos) {
return FunctionaryV3Llama31;
} else if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
if (src.find("<|python_tag|>") != std::string::npos) {
return Llama31;
} else {
return Llama32;
}
} else if (src.find("<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>") != std::string::npos) {
return CommandRPlus;
} else {
return UnknownToolCallStyle;
}
}
static bool parse_json(std::string::const_iterator & it, const std::string::const_iterator & end, json & out) { static bool parse_json(std::string::const_iterator & it, const std::string::const_iterator & end, json & out) {
// // https://json.nlohmann.me/features/parsing/sax_interface/ // // https://json.nlohmann.me/features/parsing/sax_interface/
struct json_error_locator : public nlohmann::json_sax<json> { struct json_error_locator : public nlohmann::json_sax<json> {
@ -207,7 +230,8 @@ llama_tool_calls parse_tool_calls(llama_tool_call_style style, const json & tool
} }
llama_tool_call_handler llama_tool_call_handler_init( llama_tool_call_handler llama_tool_call_handler_init(
const llama_chat_template & tmpl, llama_tool_call_style style,
const minja::chat_template & tmpl,
bool allow_content, bool allow_content,
bool parallel_tool_calls, bool parallel_tool_calls,
const nlohmann::ordered_json & messages, const nlohmann::ordered_json & messages,
@ -215,18 +239,18 @@ llama_tool_call_handler llama_tool_call_handler_init(
{ {
llama_tool_call_handler handler; llama_tool_call_handler handler;
switch (tmpl.tool_call_style()) { switch (style) {
case llama_tool_call_style::Llama31: case llama_tool_call_style::Llama31:
case llama_tool_call_style::Llama32: { case llama_tool_call_style::Llama32: {
static auto builtin_tools = json {"wolfram_alpha", "brave_search"}; static auto builtin_tools = json {"wolfram_alpha", "brave_search"};
auto uses_python_tag = tmpl.tool_call_style() == llama_tool_call_style::Llama31; auto uses_python_tag = style == llama_tool_call_style::Llama31;
// Technically we should only trigger on `"\n{\"name\": \"" + name + "\""` for each tool name, // Technically we should only trigger on `"\n{\"name\": \"" + name + "\""` for each tool name,
// but Llama-3.2-3B (and 1B) struggles to output valid tool calls so we're "guiding" it strongly as soon // but Llama-3.2-3B (and 1B) struggles to output valid tool calls so we're "guiding" it strongly as soon
// as it seems to be outputting some JSON. // as it seems to be outputting some JSON.
// TODO: make this conditional on a very small model (e.g. 1B / 3B). // TODO: make this conditional on a very small model (e.g. 1B / 3B).
auto eagerly_match_any_json = tmpl.tool_call_style() == llama_tool_call_style::Llama32; auto eagerly_match_any_json = style == llama_tool_call_style::Llama32;
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) { handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
std::vector<std::string> tool_rules; std::vector<std::string> tool_rules;

View File

@ -2,10 +2,20 @@
#include "ggml.h" #include "ggml.h"
#include "common.h" #include "common.h"
#include "chat-template.hpp"
// Change JSON_ASSERT from assert() to GGML_ASSERT: // Change JSON_ASSERT from assert() to GGML_ASSERT:
#define JSON_ASSERT GGML_ASSERT #define JSON_ASSERT GGML_ASSERT
#include "json.hpp" #include "json.hpp"
#include "chat-template.h"
enum llama_tool_call_style {
UnknownToolCallStyle,
Llama31,
Llama32,
FunctionaryV3Llama3,
FunctionaryV3Llama31,
Hermes2Pro,
CommandRPlus,
};
struct llama_tool_call { struct llama_tool_call {
std::string name; std::string name;
@ -24,10 +34,13 @@ struct llama_tool_call_handler {
std::vector<std::string> additional_stop_words; std::vector<std::string> additional_stop_words;
}; };
llama_tool_call_style llama_tool_call_style_detect(const minja::chat_template & chat_template);
llama_tool_calls parse_tool_calls(llama_tool_call_style style, const nlohmann::ordered_json & tools, const std::string& input); llama_tool_calls parse_tool_calls(llama_tool_call_style style, const nlohmann::ordered_json & tools, const std::string& input);
llama_tool_call_handler llama_tool_call_handler_init( llama_tool_call_handler llama_tool_call_handler_init(
const llama_chat_template & tmpl, llama_tool_call_style style,
const minja::chat_template & tmpl,
bool allow_content, bool allow_content,
bool parallel_tool_calls, bool parallel_tool_calls,
const nlohmann::ordered_json & messages, const nlohmann::ordered_json & messages,

View File

@ -663,7 +663,7 @@ struct server_context {
llama_chat_message chat[] = {{"user", "test"}}; llama_chat_message chat[] = {{"user", "test"}};
if (use_jinja) { if (use_jinja) {
auto chat_template = llama_chat_template::from_model(model); auto chat_template = llama_chat_template_from_model(model);
try { try {
chat_template.apply({{ chat_template.apply({{
{"role", "user"}, {"role", "user"},
@ -2875,11 +2875,12 @@ int main(int argc, char ** argv) {
return; return;
} }
auto chat_template = llama_chat_template::from_model(ctx_server.model, params.chat_template.empty() ? nullptr : params.chat_template.c_str()); static auto chat_template = llama_chat_template_from_model(ctx_server.model, params.chat_template.empty() ? nullptr : params.chat_template.c_str());
static auto tool_call_style = llama_tool_call_style_detect(chat_template);
json data; json data;
try { try {
data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), chat_template, params.use_jinja); data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), chat_template, tool_call_style, params.use_jinja);
} catch (const std::exception & e) { } catch (const std::exception & e) {
res_error(res, format_error_response(e.what(), ERROR_TYPE_NOT_SUPPORTED)); res_error(res, format_error_response(e.what(), ERROR_TYPE_NOT_SUPPORTED));
return; return;
@ -2897,7 +2898,7 @@ int main(int argc, char ** argv) {
ctx_server.receive_cmpl_results(task_ids, [&](const std::vector<server_task_result> & results) { ctx_server.receive_cmpl_results(task_ids, [&](const std::vector<server_task_result> & results) {
// multitask is never support in chat completion, there is only one result // multitask is never support in chat completion, there is only one result
try { try {
json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, chat_template, /*.streaming =*/ false, verbose); json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, tool_call_style, /*.streaming =*/ false, verbose);
res_ok(res, result_oai); res_ok(res, result_oai);
} catch (const std::runtime_error & e) { } catch (const std::runtime_error & e) {
res_error(res, format_error_response(e.what(), ERROR_TYPE_SERVER)); res_error(res, format_error_response(e.what(), ERROR_TYPE_SERVER));

View File

@ -14,7 +14,6 @@
// Change JSON_ASSERT from assert() to GGML_ASSERT: // Change JSON_ASSERT from assert() to GGML_ASSERT:
#define JSON_ASSERT GGML_ASSERT #define JSON_ASSERT GGML_ASSERT
#include "chat-template.h"
#include "json.hpp" #include "json.hpp"
#include "minja.hpp" #include "minja.hpp"
#include "tool-call.h" #include "tool-call.h"
@ -309,7 +308,8 @@ static bool server_sent_event(httplib::DataSink & sink, const char * event, cons
static json oaicompat_completion_params_parse( static json oaicompat_completion_params_parse(
const struct llama_model * model, const struct llama_model * model,
const json & body, /* openai api json semantics */ const json & body, /* openai api json semantics */
const llama_chat_template & tmpl, const minja::chat_template & tmpl,
llama_tool_call_style tool_call_style,
bool use_jinja) bool use_jinja)
{ {
json llama_params; json llama_params;
@ -320,7 +320,7 @@ static json oaicompat_completion_params_parse(
auto has_tools = tools.is_array() && !tools.empty(); auto has_tools = tools.is_array() && !tools.empty();
// Apply chat template to the list of messages // Apply chat template to the list of messages
llama_params["chat_template"] = tmpl.chat_template(); llama_params["chat_template"] = tmpl.source();
if (use_jinja) { if (use_jinja) {
if (has_tools && !tmpl.supports_tools()) { if (has_tools && !tmpl.supports_tools()) {
@ -372,7 +372,7 @@ static json oaicompat_completion_params_parse(
llama_params["parse_tool_calls"] = true; llama_params["parse_tool_calls"] = true;
llama_params["parallel_tool_calls"] = parallel_tool_calls; llama_params["parallel_tool_calls"] = parallel_tool_calls;
auto handler = llama_tool_call_handler_init(tmpl, allow_content, parallel_tool_calls, body.at("messages"), tools); auto handler = llama_tool_call_handler_init(tool_call_style, tmpl, allow_content, parallel_tool_calls, body.at("messages"), tools);
llama_params["prompt"] = handler.prompt; llama_params["prompt"] = handler.prompt;
for (const auto & stop : handler.additional_stop_words) { for (const auto & stop : handler.additional_stop_words) {
@ -395,7 +395,7 @@ static json oaicompat_completion_params_parse(
llama_params["prompt"] = tmpl.apply(body.at("messages"), tools, /* add_generation_prompt= */ true); llama_params["prompt"] = tmpl.apply(body.at("messages"), tools, /* add_generation_prompt= */ true);
} }
} else { } else {
llama_params["prompt"] = format_chat(model, tmpl.chat_template(), body.at("messages")); llama_params["prompt"] = format_chat(model, tmpl.source(), body.at("messages"));
} }
// Handle "n" field // Handle "n" field
@ -435,7 +435,7 @@ static json oaicompat_completion_params_parse(
return llama_params; return llama_params;
} }
static json format_final_response_oaicompat(const json & request, const json & result, const std::string & completion_id, const llama_chat_template & tmpl, bool streaming = false, bool verbose = false) { static json format_final_response_oaicompat(const json & request, const json & result, const std::string & completion_id, llama_tool_call_style tool_call_style, bool streaming = false, bool verbose = false) {
bool stopped_word = result.count("stopped_word") != 0; bool stopped_word = result.count("stopped_word") != 0;
bool stopped_eos = json_value(result, "stopped_eos", false); bool stopped_eos = json_value(result, "stopped_eos", false);
int num_tokens_predicted = json_value(result, "tokens_predicted", 0); int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
@ -452,7 +452,7 @@ static json format_final_response_oaicompat(const json & request, const json & r
json tool_calls; json tool_calls;
json message_content; json message_content;
if (json_value(request, "parse_tool_calls", false) if (json_value(request, "parse_tool_calls", false)
&& !(parsed_tool_calls = parse_tool_calls(tmpl.tool_call_style(), tools, content)).tool_calls.empty()) { && !(parsed_tool_calls = parse_tool_calls(tool_call_style, tools, content)).tool_calls.empty()) {
finish_reason = "tool_calls"; finish_reason = "tool_calls";
if (!parsed_tool_calls.content.empty()) { if (!parsed_tool_calls.content.empty()) {
message_content = parsed_tool_calls.content; message_content = parsed_tool_calls.content;

View File

@ -0,0 +1,148 @@
#!/usr/bin/env uv run
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "jinja2",
# "huggingface_hub",
# ]
# ///
'''
Fetches the Jinja2 templates of specified models and generates prompt goldens for predefined chat contexts.
Outputs lines of arguments for a C++ test binary.
All files are written to the specified output folder.
Usage:
python ./update_jinja_goldens.py output_folder context1.json context2.json ... model_id1 model_id2 ...
Example:
python ./update_jinja_goldens.py ./test_files "microsoft/Phi-3-medium-4k-instruct" "Qwen/Qwen2-7B-Instruct"
'''
import logging
import datetime
import glob
import os
from huggingface_hub import hf_hub_download
import json
import jinja2
import jinja2.ext
import re
import argparse
import shutil
logging.basicConfig(level=logging.INFO, format='%(message)s')
logger = logging.getLogger(__name__)
def raise_exception(message: str):
raise ValueError(message)
def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False):
return json.dumps(x, ensure_ascii=ensure_ascii, indent=indent, separators=separators, sort_keys=sort_keys)
TEST_DATE = os.environ.get('TEST_DATE', '2024-07-26')
def strftime_now(format):
now = datetime.datetime.strptime(TEST_DATE, "%Y-%m-%d")
return now.strftime(format)
def handle_chat_template(output_folder, model_id, variant, template_src):
model_name = model_id.replace("/", "-")
base_name = f'{model_name}-{variant}' if variant else model_name
template_file = os.path.join(output_folder, f'{base_name}.jinja')
with open(template_file, 'w') as f:
f.write(template_src)
env = jinja2.Environment(
trim_blocks=True,
lstrip_blocks=True,
extensions=[jinja2.ext.loopcontrols]
)
env.filters['safe'] = lambda x: x
env.filters['tojson'] = tojson
env.globals['raise_exception'] = raise_exception
env.globals['strftime_now'] = strftime_now
template_handles_tools = 'tools' in template_src
template_hates_the_system = 'System role not supported' in template_src
template = env.from_string(template_src)
context_files = glob.glob(os.path.join(output_folder, '*.json'))
for context_file in context_files:
context_name = os.path.basename(context_file).replace(".json", "")
with open(context_file, 'r') as f:
context = json.load(f)
if not template_handles_tools and 'tools' in context:
continue
if template_hates_the_system and any(m['role'] == 'system' for m in context['messages']):
continue
output_file = os.path.join(output_folder, f'{base_name}-{context_name}.txt')
render_context = json.loads(json.dumps(context))
if 'tool_call.arguments | items' in template_src or 'tool_call.arguments | tojson' in template_src:
for message in render_context['messages']:
if 'tool_calls' in message:
for tool_call in message['tool_calls']:
if tool_call.get('type') == 'function':
arguments = tool_call['function']['arguments']
tool_call['function']['arguments'] = json.loads(arguments)
try:
output = template.render(**render_context)
except Exception as e1:
for message in context["messages"]:
if message.get("content") is None:
message["content"] = ""
try:
output = template.render(**render_context)
except Exception as e2:
logger.info(f" ERROR: {e2} (after first error: {e1})")
output = f"ERROR: {e2}"
with open(output_file, 'w') as f:
f.write(output)
# Output the line of arguments for the C++ test binary
print(f"{template_file} {context_file} {output_file}")
def main():
parser = argparse.ArgumentParser(description="Generate chat templates and output test arguments.")
parser.add_argument("output_folder", help="Folder to store all output files")
parser.add_argument("model_ids", nargs="+", help="List of model IDs to process")
args = parser.parse_args()
output_folder = args.output_folder
if not os.path.isdir(output_folder):
os.makedirs(output_folder)
# Copy context files to the output folder
for context_file in glob.glob('tests/chat/contexts/*.json'):
shutil.copy(context_file, output_folder)
for model_id in args.model_ids:
try:
with open(hf_hub_download(repo_id=model_id, filename="tokenizer_config.json")) as f:
config_str = f.read()
try:
config = json.loads(config_str)
except json.JSONDecodeError:
config = json.loads(re.sub(r'\}([\n\s]*\}[\n\s]*\],[\n\s]*"clean_up_tokenization_spaces")', r'\1', config_str))
chat_template = config['chat_template']
if isinstance(chat_template, str):
handle_chat_template(output_folder, model_id, None, chat_template)
else:
for ct in chat_template:
handle_chat_template(output_folder, model_id, ct['name'], ct['template'])
except Exception as e:
logger.error(f"Error processing model {model_id}: {e}")
if __name__ == '__main__':
main()

View File

@ -7,7 +7,7 @@
#include "llama.h" #include "llama.h"
#include "common.h" #include "common.h"
#include "chat-template.h" #include "chat-template.hpp"
#include <iostream> #include <iostream>
#include <fstream> #include <fstream>
#include <iostream> #include <iostream>
@ -73,7 +73,7 @@ static void test_jinja_templates() {
return "tests/chat/goldens/" + golden_name + ".txt"; return "tests/chat/goldens/" + golden_name + ".txt";
}; };
auto fail_with_golden_instructions = [&]() { auto fail_with_golden_instructions = [&]() {
throw std::runtime_error("To fetch templates and generate golden files, run `python tests/update_jinja_goldens.py`"); throw std::runtime_error("To fetch templates and generate golden files, run `python update_templates_and_goldens.py`");
}; };
if (jinja_template_files.empty()) { if (jinja_template_files.empty()) {
std::cerr << "No Jinja templates found in tests/chat/templates" << std::endl; std::cerr << "No Jinja templates found in tests/chat/templates" << std::endl;
@ -89,7 +89,7 @@ static void test_jinja_templates() {
for (const auto & ctx_file : context_files) { for (const auto & ctx_file : context_files) {
auto ctx = json::parse(read_file(ctx_file)); auto ctx = json::parse(read_file(ctx_file));
llama_chat_template tmpl( minja::chat_template tmpl(
tmpl_str, tmpl_str,
ctx.at("bos_token"), ctx.at("bos_token"),
ctx.at("eos_token")); ctx.at("eos_token"));
@ -127,20 +127,6 @@ static void test_jinja_templates() {
} }
} }
void test_tool_call_style(const std::string & template_file, llama_tool_call_style expected) {
auto tmpl = llama_chat_template(read_file(template_file), "<s>", "</s>");
std::cout << "# Testing tool call style of: " << template_file << std::endl << std::flush;
assert_equals(expected, tmpl.tool_call_style());
}
void test_tool_call_styles() {
test_tool_call_style("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja", FunctionaryV3Llama31);
test_tool_call_style("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja", FunctionaryV3Llama3);
test_tool_call_style("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja", Llama31);
test_tool_call_style("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja", Llama32);
test_tool_call_style("tests/chat/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja", CommandRPlus);
}
static void test_legacy_templates() { static void test_legacy_templates() {
struct test_template { struct test_template {
std::string name; std::string name;
@ -353,7 +339,6 @@ int main(void) {
if (getenv("LLAMA_SKIP_TESTS_SLOW_ON_EMULATOR")) { if (getenv("LLAMA_SKIP_TESTS_SLOW_ON_EMULATOR")) {
fprintf(stderr, "\033[33mWARNING: Skipping slow tests on emulator.\n\033[0m"); fprintf(stderr, "\033[33mWARNING: Skipping slow tests on emulator.\n\033[0m");
} else { } else {
test_tool_call_styles();
test_jinja_templates(); test_jinja_templates();
} }

View File

@ -9,7 +9,8 @@
using json = nlohmann::ordered_json; using json = nlohmann::ordered_json;
static void assert_equals(const std::string & expected, const std::string & actual) { template <class T>
static void assert_equals(const T & expected, const T & actual) {
if (expected != actual) { if (expected != actual) {
std::cerr << "Expected: " << expected << std::endl; std::cerr << "Expected: " << expected << std::endl;
std::cerr << "Actual: " << actual << std::endl; std::cerr << "Actual: " << actual << std::endl;
@ -242,7 +243,22 @@ static void test_parsing() {
"{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", json::array()); "{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", json::array());
} }
static std::string get_message_prompt_delta(const llama_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools) { void test_tool_call_style(const std::string & template_file, llama_tool_call_style expected) {
const minja::chat_template tmpl(read_file(template_file), "<s>", "</s>");
auto tool_call_style = llama_tool_call_style_detect(tmpl);
std::cout << "# Testing tool call style of: " << template_file << std::endl << std::flush;
assert_equals(expected, tool_call_style);
}
void test_tool_call_style_detection() {
test_tool_call_style("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja", FunctionaryV3Llama31);
test_tool_call_style("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja", FunctionaryV3Llama3);
test_tool_call_style("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja", Llama31);
test_tool_call_style("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja", Llama32);
test_tool_call_style("tests/chat/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja", CommandRPlus);
}
static std::string get_message_prompt_delta(const minja::chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools) {
auto prefix = tmpl.apply(json::array({user_message}), tools, /* add_generation_prompt= */ true, json::object()); auto prefix = tmpl.apply(json::array({user_message}), tools, /* add_generation_prompt= */ true, json::object());
auto full = tmpl.apply(json::array({user_message, delta_message}), tools, /* add_generation_prompt= */ false, json::object()); auto full = tmpl.apply(json::array({user_message, delta_message}), tools, /* add_generation_prompt= */ false, json::object());
@ -267,7 +283,8 @@ static std::string get_message_prompt_delta(const llama_chat_template & tmpl, co
static void test_template(const std::string & template_file, const char * bos_token, const char * eos_token, const std::vector<std::string> & end_tokens, const json & tool_calling_message, const json & tools) { static void test_template(const std::string & template_file, const char * bos_token, const char * eos_token, const std::vector<std::string> & end_tokens, const json & tool_calling_message, const json & tools) {
std::cout << "# Testing template: " << template_file << std::endl << std::flush; std::cout << "# Testing template: " << template_file << std::endl << std::flush;
const llama_chat_template & tmpl = llama_chat_template(read_file(template_file), bos_token, eos_token); const minja::chat_template tmpl(read_file(template_file), bos_token, eos_token);
auto tool_call_style = llama_tool_call_style_detect(tmpl);
auto & tool_calls = tool_calling_message.at("tool_calls"); auto & tool_calls = tool_calling_message.at("tool_calls");
// Format the message: apply the template to 1 user message w/ add_generation_prompt=true, then w/ the extra message w/ add_generation_prompt=false, // Format the message: apply the template to 1 user message w/ add_generation_prompt=true, then w/ the extra message w/ add_generation_prompt=false,
@ -277,7 +294,7 @@ static void test_template(const std::string & template_file, const char * bos_to
{"content", "Hello, world!"} {"content", "Hello, world!"}
}; };
auto handler = llama_tool_call_handler_init(tmpl, /* allow_content= */ true, /* parallel_tool_calls= */ true, {user_message, tool_calling_message}, tools); auto handler = llama_tool_call_handler_init(tool_call_style, tmpl, /* allow_content= */ true, /* parallel_tool_calls= */ true, {user_message, tool_calling_message}, tools);
auto grammar = build_grammar(handler.grammar); auto grammar = build_grammar(handler.grammar);
if (!grammar) { if (!grammar) {
throw std::runtime_error("Failed to build grammar"); throw std::runtime_error("Failed to build grammar");
@ -285,7 +302,7 @@ static void test_template(const std::string & template_file, const char * bos_to
auto full_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, tool_calling_message, tools); auto full_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, tool_calling_message, tools);
std::cout << "Full delta:\n```\n" << full_delta << "\n```" << std::endl; std::cout << "Full delta:\n```\n" << full_delta << "\n```" << std::endl;
test_parse_tool_call(tmpl.tool_call_style(), tools, full_delta, "", tool_calls); test_parse_tool_call(tool_call_style, tools, full_delta, "", tool_calls);
auto content_less_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, { auto content_less_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, {
{"role", "assistant"}, {"role", "assistant"},
@ -319,6 +336,7 @@ static void test_grammars() {
int main() { int main() {
test_grammars(); test_grammars();
test_parsing(); test_parsing();
test_tool_call_style_detection();
std::cout << "[tool-call] All tests passed!" << std::endl; std::cout << "[tool-call] All tests passed!" << std::endl;
return 0; return 0;