tool-call: prepare possible externalization of minja + factor tool call style out of template

This commit is contained in:
Olivier Chafik 2024-10-01 23:12:24 +01:00
parent d9451fd647
commit c36a196f53
14 changed files with 626 additions and 445 deletions

View File

@ -54,8 +54,7 @@ add_library(${TARGET} STATIC
arg.cpp
arg.h
base64.hpp
chat-template.cpp
chat-template.h
chat-template.hpp
common.cpp
common.h
console.cpp

View File

@ -1,156 +0,0 @@
#include "chat-template.h"
#include "llama.h"
using json = nlohmann::ordered_json;
static std::string _llama_token_to_piece(const struct llama_model * model, llama_token token, bool special) {
std::string piece;
piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n'
const int n_chars = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special);
if (n_chars < 0) {
piece.resize(-n_chars);
int check = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special);
GGML_ASSERT(check == -n_chars);
}
else {
piece.resize(n_chars);
}
return piece;
}
static std::string llama_model_meta_val_str(const struct llama_model * model, const char * key) {
int32_t tlen = llama_model_meta_val_str(model, key, nullptr, 0);
if (tlen > 0) {
std::vector<char> curr_tmpl_buf(tlen + 1, 0);
if (llama_model_meta_val_str(model, key, curr_tmpl_buf.data(), curr_tmpl_buf.size()) == tlen) {
return std::string(curr_tmpl_buf.data(), tlen);
}
}
return "";
}
llama_chat_template::llama_chat_template(const std::string & chat_template, const std::string & bos_token, const std::string & eos_token)
: _chat_template(chat_template), _bos_token(bos_token), _eos_token(eos_token) {
_supports_tools = chat_template.find("tools") != std::string::npos;
_requires_object_arguments =
chat_template.find("tool_call.arguments | items") != std::string::npos
|| chat_template.find("tool_call.arguments | tojson") != std::string::npos;
_supports_system_role = chat_template.find("System role not supported") == std::string::npos;
if (chat_template.find("<tool_call>") != std::string::npos) {
_tool_call_style = Hermes2Pro;
} else if (chat_template.find(">>>all") != std::string::npos) {
_tool_call_style = FunctionaryV3Llama3;
} else if (chat_template.find("<|start_header_id|>") != std::string::npos
&& chat_template.find("<function=") != std::string::npos) {
_tool_call_style = FunctionaryV3Llama31;
} else if (chat_template.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
if (chat_template.find("<|python_tag|>") != std::string::npos) {
_tool_call_style = Llama31;
} else {
_tool_call_style = Llama32;
}
} else if (chat_template.find("<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>") != std::string::npos) {
_tool_call_style = CommandRPlus;
} else {
_tool_call_style = UnknownToolCallStyle;
}
_template_root = minja::Parser::parse(_chat_template, {
/* .trim_blocks = */ true,
/* .lstrip_blocks = */ true,
/* .keep_trailing_newline = */ false,
});
}
llama_chat_template llama_chat_template::from_model(
const struct llama_model * model,
const char * chat_template_override)
{
// TODO: handle "chatml"?
std::string chat_template = chat_template_override
? chat_template_override
: llama_model_meta_val_str(model, "tokenizer.chat_template");
auto bos_token = _llama_token_to_piece(model, llama_token_bos(model), true);
auto eos_token = _llama_token_to_piece(model, llama_token_eos(model), true);
return llama_chat_template(chat_template, bos_token, eos_token);
}
std::string llama_chat_template::apply(
const json & messages,
const json & tools,
bool add_generation_prompt,
const json & extra_context) const
{
auto actual_messages = messages;
// First, "fix" messages so they have a chance to be rendered correctly by the template
if (_requires_object_arguments || !_supports_system_role) {
std::string pending_system;
auto flush_sys = [&]() {
if (!pending_system.empty()) {
actual_messages.push_back({
{"role", "user"},
{"content", pending_system},
});
pending_system.clear();
}
};
for (auto & message : actual_messages) {
if (!message.contains("role") || !message.contains("content")) {
throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump());
}
std::string role = message.at("role");
if (!message["content"].is_null() && !_supports_system_role) {
std::string content = message.at("content");
if (role == "system") {
if (!pending_system.empty()) pending_system += "\n";
pending_system += content;
continue;
} else {
if (role == "user") {
if (!pending_system.empty()) {
message["content"] = pending_system + (content.empty() ? "" : "\n" + content);
pending_system.clear();
}
} else {
flush_sys();
}
}
}
if (_requires_object_arguments && message.contains("tool_calls")) {
for (auto & tool_call : message.at("tool_calls")) {
if (tool_call["type"] == "function") {
auto & function = tool_call.at("function");
std::string arguments = function.at("arguments");
function["arguments"] = json::parse(arguments);
}
}
}
}
flush_sys();
}
auto context = minja::Context::make(json({
{"messages", actual_messages},
{"add_generation_prompt", add_generation_prompt},
{"bos_token", _bos_token},
{"eos_token", _eos_token},
}));
if (!tools.is_null()) {
auto tools_val = minja::Value(tools);
context->set("tools", tools_val);
}
if (!extra_context.is_null()) {
for (auto & kv : extra_context.items()) {
minja::Value val(kv.value());
context->set(kv.key(), val);
}
}
return _template_root->render(context);
}

View File

@ -1,53 +0,0 @@
#pragma once
#include "minja.hpp"
#include <json.hpp>
#include <string>
#include <vector>
using json = nlohmann::ordered_json;
enum llama_tool_call_style {
UnknownToolCallStyle,
Llama31,
Llama32,
FunctionaryV3Llama3,
FunctionaryV3Llama31,
Hermes2Pro,
CommandRPlus,
};
class llama_chat_template {
public:
private:
llama_tool_call_style _tool_call_style = UnknownToolCallStyle;
bool _supports_tools = true;
// Meta-Llama-3.1-8B-Instruct's template expects arguments to be an object.
// Most other templates (and OpenAI's API) expect the arguments object to be stringified.
bool _requires_object_arguments = false;
bool _supports_system_role = true;
std::string _chat_template;
std::string _bos_token;
std::string _eos_token;
std::unique_ptr<minja::TemplateNode> _template_root;
public:
llama_chat_template(const std::string & chat_template, const std::string & bos_token, const std::string & eos_token);
static llama_chat_template from_model(
const struct llama_model * model,
const char * chat_template_override = nullptr);
llama_tool_call_style tool_call_style() const { return _tool_call_style; }
const std::string & chat_template() const { return _chat_template; }
bool supports_tools() const { return _supports_tools; }
std::string apply(
const nlohmann::ordered_json & messages,
const nlohmann::ordered_json & tools,
bool add_generation_prompt,
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const;
};

133
common/chat-template.hpp Normal file
View File

@ -0,0 +1,133 @@
/*
Copyright 2024 Google LLC
Use of this source code is governed by an MIT-style
license that can be found in the LICENSE file or at
https://opensource.org/licenses/MIT.
*/
// SPDX-License-Identifier: MIT
#pragma once
#include "minja.hpp"
#include <json.hpp>
#include <string>
#include <vector>
using json = nlohmann::ordered_json;
namespace minja {
class chat_template {
public:
private:
bool _supports_tools = true;
// Meta-Llama-3.1-8B-Instruct's template expects arguments to be an object.
// Most other templates (and OpenAI's API) expect the arguments object to be stringified.
bool _requires_object_arguments = false;
bool _supports_system_role = true;
std::string _source;
std::string _bos_token;
std::string _eos_token;
std::shared_ptr<minja::TemplateNode> _template_root;
public:
chat_template(const std::string & source, const std::string & bos_token, const std::string & eos_token)
: _source(source), _bos_token(bos_token), _eos_token(eos_token)
{
_supports_tools = source.find("tools") != std::string::npos;
_requires_object_arguments =
source.find("tool_call.arguments | items") != std::string::npos
|| source.find("tool_call.arguments | tojson") != std::string::npos;
_supports_system_role = source.find("System role not supported") == std::string::npos;
_template_root = minja::Parser::parse(_source, {
/* .trim_blocks = */ true,
/* .lstrip_blocks = */ true,
/* .keep_trailing_newline = */ false,
});
}
const std::string & source() const { return _source; }
bool supports_tools() const { return _supports_tools; }
std::string apply(
const nlohmann::ordered_json & messages,
const nlohmann::ordered_json & tools,
bool add_generation_prompt,
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const
{
auto actual_messages = messages;
// First, "fix" messages so they have a chance to be rendered correctly by the template
if (_requires_object_arguments || !_supports_system_role) {
std::string pending_system;
auto flush_sys = [&]() {
if (!pending_system.empty()) {
actual_messages.push_back({
{"role", "user"},
{"content", pending_system},
});
pending_system.clear();
}
};
for (auto & message : actual_messages) {
if (!message.contains("role") || !message.contains("content")) {
throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump());
}
std::string role = message.at("role");
if (!message["content"].is_null() && !_supports_system_role) {
std::string content = message.at("content");
if (role == "system") {
if (!pending_system.empty()) pending_system += "\n";
pending_system += content;
continue;
} else {
if (role == "user") {
if (!pending_system.empty()) {
message["content"] = pending_system + (content.empty() ? "" : "\n" + content);
pending_system.clear();
}
} else {
flush_sys();
}
}
}
if (_requires_object_arguments && message.contains("tool_calls")) {
for (auto & tool_call : message.at("tool_calls")) {
if (tool_call["type"] == "function") {
auto & function = tool_call.at("function");
std::string arguments = function.at("arguments");
function["arguments"] = json::parse(arguments);
}
}
}
}
flush_sys();
}
auto context = minja::Context::make(json({
{"messages", actual_messages},
{"add_generation_prompt", add_generation_prompt},
{"bos_token", _bos_token},
{"eos_token", _eos_token},
}));
if (!tools.is_null()) {
auto tools_val = minja::Value(tools);
context->set("tools", tools_val);
}
if (!extra_context.is_null()) {
for (auto & kv : extra_context.items()) {
minja::Value val(kv.value());
context->set(kv.key(), val);
}
}
return _template_root->render(context);
}
};
} // namespace minja

View File

@ -9,7 +9,7 @@
#include "json.hpp"
#include "json-schema-to-grammar.h"
#include "llama.h"
#include "chat-template.h"
#include "chat-template.hpp"
#include <algorithm>
#include <cinttypes>
@ -1513,13 +1513,13 @@ std::vector<llama_token> llama_tokenize(
return result;
}
std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) {
static std::string _llama_token_to_piece(const struct llama_model * model, llama_token token, bool special) {
std::string piece;
piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n'
const int n_chars = llama_token_to_piece(llama_get_model(ctx), token, &piece[0], piece.size(), 0, special);
const int n_chars = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special);
if (n_chars < 0) {
piece.resize(-n_chars);
int check = llama_token_to_piece(llama_get_model(ctx), token, &piece[0], piece.size(), 0, special);
int check = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special);
GGML_ASSERT(check == -n_chars);
}
else {
@ -1529,6 +1529,10 @@ std::string llama_token_to_piece(const struct llama_context * ctx, llama_token t
return piece;
}
std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) {
return _llama_token_to_piece(llama_get_model(ctx), token, special);
}
std::string llama_detokenize(llama_context * ctx, const std::vector<llama_token> & tokens, bool special) {
std::string text;
text.resize(std::max(text.capacity(), tokens.size()));
@ -1552,7 +1556,7 @@ std::string llama_detokenize(llama_context * ctx, const std::vector<llama_token>
bool llama_chat_verify_template(const std::string & tmpl, bool use_jinja) {
if (use_jinja) {
try {
auto chat_template = llama_chat_template(tmpl, "<s>", "</s>");
auto chat_template = minja::chat_template(tmpl, "<s>", "</s>");
chat_template.apply({{
{"role", "user"},
{"content", "test"},
@ -1651,6 +1655,30 @@ std::string llama_chat_format_example(const struct llama_model * model,
return llama_chat_apply_template(model, tmpl, msgs, true);
}
static std::string _llama_model_meta_val_str(const struct llama_model * model, const char * key) {
int32_t tlen = llama_model_meta_val_str(model, key, nullptr, 0);
if (tlen > 0) {
std::vector<char> curr_tmpl_buf(tlen + 1, 0);
if (llama_model_meta_val_str(model, key, curr_tmpl_buf.data(), curr_tmpl_buf.size()) == tlen) {
return std::string(curr_tmpl_buf.data(), tlen);
}
}
return "";
}
minja::chat_template llama_chat_template_from_model(
const struct llama_model * model,
const char * chat_template_override)
{
// TODO: handle "chatml"?
std::string chat_template = chat_template_override
? chat_template_override
: _llama_model_meta_val_str(model, "tokenizer.chat_template");
auto bos_token = _llama_token_to_piece(model, llama_token_bos(model), true);
auto eos_token = _llama_token_to_piece(model, llama_token_eos(model), true);
return {std::move(chat_template), bos_token, eos_token};
}
//
// KV cache utils
//

View File

@ -27,6 +27,9 @@
#define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf"
// Forward declaration
namespace minja { class chat_template; }
struct llama_lora_adapter_info {
std::string path;
float scale;
@ -500,6 +503,10 @@ std::string llama_chat_format_single(const struct llama_model * model,
std::string llama_chat_format_example(const struct llama_model * model,
const std::string & tmpl);
minja::chat_template llama_chat_template_from_model(
const struct llama_model * model,
const char * chat_template_override = nullptr);
//
// KV cache utils
//

View File

@ -1,3 +1,11 @@
/*
Copyright 2024 Google LLC
Use of this source code is governed by an MIT-style
license that can be found in the LICENSE file or at
https://opensource.org/licenses/MIT.
*/
// SPDX-License-Identifier: MIT
#pragma once
#include <iostream>
@ -577,8 +585,8 @@ protected:
virtual Value do_evaluate(const std::shared_ptr<Context> & context) const = 0;
public:
struct Arguments {
std::vector<std::unique_ptr<Expression>> args;
std::vector<std::pair<std::string, std::unique_ptr<Expression>>> kwargs;
std::vector<std::shared_ptr<Expression>> args;
std::vector<std::pair<std::string, std::shared_ptr<Expression>>> kwargs;
void expectArgs(const std::string & method_name, const std::pair<size_t, size_t> & pos_count, const std::pair<size_t, size_t> & kw_count) const {
if (args.size() < pos_count.first || args.size() > pos_count.second || kwargs.size() < kw_count.first || kwargs.size() > kw_count.second) {
@ -600,7 +608,7 @@ public:
}
};
using Parameters = std::vector<std::pair<std::string, std::unique_ptr<Expression>>>;
using Parameters = std::vector<std::pair<std::string, std::shared_ptr<Expression>>>;
Location location;
@ -687,18 +695,18 @@ struct TextTemplateToken : public TemplateToken {
};
struct ExpressionTemplateToken : public TemplateToken {
std::unique_ptr<Expression> expr;
ExpressionTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::unique_ptr<Expression> && e) : TemplateToken(Type::Expression, location, pre, post), expr(std::move(e)) {}
std::shared_ptr<Expression> expr;
ExpressionTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr<Expression> && e) : TemplateToken(Type::Expression, location, pre, post), expr(std::move(e)) {}
};
struct IfTemplateToken : public TemplateToken {
std::unique_ptr<Expression> condition;
IfTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::unique_ptr<Expression> && c) : TemplateToken(Type::If, location, pre, post), condition(std::move(c)) {}
std::shared_ptr<Expression> condition;
IfTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr<Expression> && c) : TemplateToken(Type::If, location, pre, post), condition(std::move(c)) {}
};
struct ElifTemplateToken : public TemplateToken {
std::unique_ptr<Expression> condition;
ElifTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::unique_ptr<Expression> && c) : TemplateToken(Type::Elif, location, pre, post), condition(std::move(c)) {}
std::shared_ptr<Expression> condition;
ElifTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr<Expression> && c) : TemplateToken(Type::Elif, location, pre, post), condition(std::move(c)) {}
};
struct ElseTemplateToken : public TemplateToken {
@ -710,9 +718,9 @@ struct EndIfTemplateToken : public TemplateToken {
};
struct MacroTemplateToken : public TemplateToken {
std::unique_ptr<VariableExpr> name;
std::shared_ptr<VariableExpr> name;
Expression::Parameters params;
MacroTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::unique_ptr<VariableExpr> && n, Expression::Parameters && p)
MacroTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr<VariableExpr> && n, Expression::Parameters && p)
: TemplateToken(Type::Macro, location, pre, post), name(std::move(n)), params(std::move(p)) {}
};
@ -722,11 +730,11 @@ struct EndMacroTemplateToken : public TemplateToken {
struct ForTemplateToken : public TemplateToken {
std::vector<std::string> var_names;
std::unique_ptr<Expression> iterable;
std::unique_ptr<Expression> condition;
std::shared_ptr<Expression> iterable;
std::shared_ptr<Expression> condition;
bool recursive;
ForTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::vector<std::string> & vns, std::unique_ptr<Expression> && iter,
std::unique_ptr<Expression> && c, bool r)
ForTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::vector<std::string> & vns, std::shared_ptr<Expression> && iter,
std::shared_ptr<Expression> && c, bool r)
: TemplateToken(Type::For, location, pre, post), var_names(vns), iterable(std::move(iter)), condition(std::move(c)), recursive(r) {}
};
@ -737,8 +745,8 @@ struct EndForTemplateToken : public TemplateToken {
struct SetTemplateToken : public TemplateToken {
std::string ns;
std::vector<std::string> var_names;
std::unique_ptr<Expression> value;
SetTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::string & ns, const std::vector<std::string> & vns, std::unique_ptr<Expression> && v)
std::shared_ptr<Expression> value;
SetTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::string & ns, const std::vector<std::string> & vns, std::shared_ptr<Expression> && v)
: TemplateToken(Type::Set, location, pre, post), ns(ns), var_names(vns), value(std::move(v)) {}
};
@ -778,9 +786,9 @@ public:
};
class SequenceNode : public TemplateNode {
std::vector<std::unique_ptr<TemplateNode>> children;
std::vector<std::shared_ptr<TemplateNode>> children;
public:
SequenceNode(const Location & location, std::vector<std::unique_ptr<TemplateNode>> && c)
SequenceNode(const Location & location, std::vector<std::shared_ptr<TemplateNode>> && c)
: TemplateNode(location), children(std::move(c)) {}
void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override {
for (const auto& child : children) child->render(out, context);
@ -797,10 +805,11 @@ public:
};
class ExpressionNode : public TemplateNode {
std::unique_ptr<Expression> expr;
std::shared_ptr<Expression> expr;
public:
ExpressionNode(const Location & location, std::unique_ptr<Expression> && e) : TemplateNode(location), expr(std::move(e)) {}
ExpressionNode(const Location & location, std::shared_ptr<Expression> && e) : TemplateNode(location), expr(std::move(e)) {}
void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override {
if (!expr) throw std::runtime_error("ExpressionNode.expr is null");
auto result = expr->evaluate(context);
if (result.is_string()) {
out << result.get<std::string>();
@ -813,9 +822,9 @@ public:
};
class IfNode : public TemplateNode {
std::vector<std::pair<std::unique_ptr<Expression>, std::unique_ptr<TemplateNode>>> cascade;
std::vector<std::pair<std::shared_ptr<Expression>, std::shared_ptr<TemplateNode>>> cascade;
public:
IfNode(const Location & location, std::vector<std::pair<std::unique_ptr<Expression>, std::unique_ptr<TemplateNode>>> && c)
IfNode(const Location & location, std::vector<std::pair<std::shared_ptr<Expression>, std::shared_ptr<TemplateNode>>> && c)
: TemplateNode(location), cascade(std::move(c)) {}
void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override {
for (const auto& branch : cascade) {
@ -824,6 +833,7 @@ public:
enter_branch = branch.first->evaluate(context).to_bool();
}
if (enter_branch) {
if (!branch.second) throw std::runtime_error("IfNode.cascade.second is null");
branch.second->render(out, context);
return;
}
@ -833,18 +843,20 @@ public:
class ForNode : public TemplateNode {
std::vector<std::string> var_names;
std::unique_ptr<Expression> iterable;
std::unique_ptr<Expression> condition;
std::unique_ptr<TemplateNode> body;
std::shared_ptr<Expression> iterable;
std::shared_ptr<Expression> condition;
std::shared_ptr<TemplateNode> body;
bool recursive;
std::unique_ptr<TemplateNode> else_body;
std::shared_ptr<TemplateNode> else_body;
public:
ForNode(const Location & location, std::vector<std::string> && var_names, std::unique_ptr<Expression> && iterable,
std::unique_ptr<Expression> && condition, std::unique_ptr<TemplateNode> && body, bool recursive, std::unique_ptr<TemplateNode> && else_body)
ForNode(const Location & location, std::vector<std::string> && var_names, std::shared_ptr<Expression> && iterable,
std::shared_ptr<Expression> && condition, std::shared_ptr<TemplateNode> && body, bool recursive, std::shared_ptr<TemplateNode> && else_body)
: TemplateNode(location), var_names(var_names), iterable(std::move(iterable)), condition(std::move(condition)), body(std::move(body)), recursive(recursive), else_body(std::move(else_body)) {}
void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override {
// https://jinja.palletsprojects.com/en/3.0.x/templates/#for
if (!iterable) throw std::runtime_error("ForNode.iterable is null");
if (!body) throw std::runtime_error("ForNode.body is null");
auto iterable_value = iterable->evaluate(context);
Value::CallableType loop_function;
@ -914,12 +926,12 @@ public:
};
class MacroNode : public TemplateNode {
std::unique_ptr<VariableExpr> name;
std::shared_ptr<VariableExpr> name;
Expression::Parameters params;
std::unique_ptr<TemplateNode> body;
std::shared_ptr<TemplateNode> body;
std::unordered_map<std::string, size_t> named_param_positions;
public:
MacroNode(const Location & location, std::unique_ptr<VariableExpr> && n, Expression::Parameters && p, std::unique_ptr<TemplateNode> && b)
MacroNode(const Location & location, std::shared_ptr<VariableExpr> && n, Expression::Parameters && p, std::shared_ptr<TemplateNode> && b)
: TemplateNode(location), name(std::move(n)), params(std::move(p)), body(std::move(b)) {
for (size_t i = 0; i < params.size(); ++i) {
const auto & name = params[i].first;
@ -929,6 +941,8 @@ public:
}
}
void do_render(std::ostringstream &, const std::shared_ptr<Context> & macro_context) const override {
if (!name) throw std::runtime_error("MacroNode.name is null");
if (!body) throw std::runtime_error("MacroNode.body is null");
auto callable = Value::callable([&](const std::shared_ptr<Context> & context, Value::Arguments & args) {
auto call_context = macro_context;
std::vector<bool> param_set(params.size(), false);
@ -964,19 +978,12 @@ public:
class SetNode : public TemplateNode {
std::string ns;
std::vector<std::string> var_names;
std::unique_ptr<Expression> value;
std::unique_ptr<TemplateNode> template_value;
std::shared_ptr<Expression> value;
public:
SetNode(const Location & location, const std::string & ns, const std::vector<std::string> & vns, std::unique_ptr<Expression> && v, std::unique_ptr<TemplateNode> && tv)
: TemplateNode(location), ns(ns), var_names(vns), value(std::move(v)), template_value(std::move(tv)) {
if (value && template_value) {
throw std::runtime_error("Cannot have both value and template value in set node");
}
if (template_value && var_names.size() != 1) {
throw std::runtime_error("Destructuring assignment is only supported with a single variable name");
}
}
SetNode(const Location & location, const std::string & ns, const std::vector<std::string> & vns, std::shared_ptr<Expression> && v)
: TemplateNode(location), ns(ns), var_names(vns), value(std::move(v)) {}
void do_render(std::ostringstream &, const std::shared_ptr<Context> & context) const override {
if (!value) throw std::runtime_error("SetNode.value is null");
if (!ns.empty()) {
if (var_names.size() != 1) {
throw std::runtime_error("Namespaced set only supports a single variable name");
@ -985,9 +992,6 @@ public:
auto ns_value = context->get(ns);
if (!ns_value.is_object()) throw std::runtime_error("Namespace '" + ns + "' is not an object");
ns_value.set(name, this->value->evaluate(context));
} else if (template_value) {
Value value { template_value->render(context) };
context->set(var_names[0], value);
} else {
auto val = value->evaluate(context);
destructuring_assign(var_names, context, val);
@ -995,14 +999,29 @@ public:
}
};
class IfExpr : public Expression {
std::unique_ptr<Expression> condition;
std::unique_ptr<Expression> then_expr;
std::unique_ptr<Expression> else_expr;
class SetTemplateNode : public TemplateNode {
std::string name;
std::shared_ptr<TemplateNode> template_value;
public:
IfExpr(const Location & location, std::unique_ptr<Expression> && c, std::unique_ptr<Expression> && t, std::unique_ptr<Expression> && e)
SetTemplateNode(const Location & location, const std::string & name, std::shared_ptr<TemplateNode> && tv)
: TemplateNode(location), name(name), template_value(std::move(tv)) {}
void do_render(std::ostringstream &, const std::shared_ptr<Context> & context) const override {
if (!template_value) throw std::runtime_error("SetTemplateNode.template_value is null");
Value value { template_value->render(context) };
context->set(name, value);
}
};
class IfExpr : public Expression {
std::shared_ptr<Expression> condition;
std::shared_ptr<Expression> then_expr;
std::shared_ptr<Expression> else_expr;
public:
IfExpr(const Location & location, std::shared_ptr<Expression> && c, std::shared_ptr<Expression> && t, std::shared_ptr<Expression> && e)
: Expression(location), condition(std::move(c)), then_expr(std::move(t)), else_expr(std::move(e)) {}
Value do_evaluate(const std::shared_ptr<Context> & context) const override {
if (!condition) throw std::runtime_error("IfExpr.condition is null");
if (!then_expr) throw std::runtime_error("IfExpr.then_expr is null");
if (condition->evaluate(context).to_bool()) {
return then_expr->evaluate(context);
}
@ -1022,13 +1041,14 @@ public:
};
class ArrayExpr : public Expression {
std::vector<std::unique_ptr<Expression>> elements;
std::vector<std::shared_ptr<Expression>> elements;
public:
ArrayExpr(const Location & location, std::vector<std::unique_ptr<Expression>> && e)
ArrayExpr(const Location & location, std::vector<std::shared_ptr<Expression>> && e)
: Expression(location), elements(std::move(e)) {}
Value do_evaluate(const std::shared_ptr<Context> & context) const override {
auto result = Value::array();
for (const auto& e : elements) {
if (!e) throw std::runtime_error("Array element is null");
result.push_back(e->evaluate(context));
}
return result;
@ -1036,13 +1056,15 @@ public:
};
class DictExpr : public Expression {
std::vector<std::pair<std::unique_ptr<Expression>, std::unique_ptr<Expression>>> elements;
std::vector<std::pair<std::shared_ptr<Expression>, std::shared_ptr<Expression>>> elements;
public:
DictExpr(const Location & location, std::vector<std::pair<std::unique_ptr<Expression>, std::unique_ptr<Expression>>> && e)
DictExpr(const Location & location, std::vector<std::pair<std::shared_ptr<Expression>, std::shared_ptr<Expression>>> && e)
: Expression(location), elements(std::move(e)) {}
Value do_evaluate(const std::shared_ptr<Context> & context) const override {
auto result = Value::object();
for (const auto& e : elements) {
if (!e.first) throw std::runtime_error("Dict key is null");
if (!e.second) throw std::runtime_error("Dict value is null");
result.set(e.first->evaluate(context), e.second->evaluate(context));
}
return result;
@ -1051,8 +1073,8 @@ public:
class SliceExpr : public Expression {
public:
std::unique_ptr<Expression> start, end;
SliceExpr(const Location & location, std::unique_ptr<Expression> && s, std::unique_ptr<Expression> && e)
std::shared_ptr<Expression> start, end;
SliceExpr(const Location & location, std::shared_ptr<Expression> && s, std::shared_ptr<Expression> && e)
: Expression(location), start(std::move(s)), end(std::move(e)) {}
Value do_evaluate(const std::shared_ptr<Context> &) const override {
throw std::runtime_error("SliceExpr not implemented");
@ -1060,12 +1082,14 @@ public:
};
class SubscriptExpr : public Expression {
std::unique_ptr<Expression> base;
std::unique_ptr<Expression> index;
std::shared_ptr<Expression> base;
std::shared_ptr<Expression> index;
public:
SubscriptExpr(const Location & location, std::unique_ptr<Expression> && b, std::unique_ptr<Expression> && i)
SubscriptExpr(const Location & location, std::shared_ptr<Expression> && b, std::shared_ptr<Expression> && i)
: Expression(location), base(std::move(b)), index(std::move(i)) {}
Value do_evaluate(const std::shared_ptr<Context> & context) const override {
if (!base) throw std::runtime_error("SubscriptExpr.base is null");
if (!index) throw std::runtime_error("SubscriptExpr.index is null");
auto target_value = base->evaluate(context);
if (auto slice = dynamic_cast<SliceExpr*>(index.get())) {
if (!target_value.is_array()) throw std::runtime_error("Subscripting non-array");
@ -1094,12 +1118,13 @@ class UnaryOpExpr : public Expression {
public:
enum class Op { Plus, Minus, LogicalNot };
private:
std::unique_ptr<Expression> expr;
std::shared_ptr<Expression> expr;
Op op;
public:
UnaryOpExpr(const Location & location, std::unique_ptr<Expression> && e, Op o)
UnaryOpExpr(const Location & location, std::shared_ptr<Expression> && e, Op o)
: Expression(location), expr(std::move(e)), op(o) {}
Value do_evaluate(const std::shared_ptr<Context> & context) const override {
if (!expr) throw std::runtime_error("UnaryOpExpr.expr is null");
auto e = expr->evaluate(context);
switch (op) {
case Op::Plus: return e;
@ -1114,13 +1139,15 @@ class BinaryOpExpr : public Expression {
public:
enum class Op { StrConcat, Add, Sub, Mul, MulMul, Div, DivDiv, Mod, Eq, Ne, Lt, Gt, Le, Ge, And, Or, In, NotIn, Is, IsNot };
private:
std::unique_ptr<Expression> left;
std::unique_ptr<Expression> right;
std::shared_ptr<Expression> left;
std::shared_ptr<Expression> right;
Op op;
public:
BinaryOpExpr(const Location & location, std::unique_ptr<Expression> && l, std::unique_ptr<Expression> && r, Op o)
BinaryOpExpr(const Location & location, std::shared_ptr<Expression> && l, std::shared_ptr<Expression> && r, Op o)
: Expression(location), left(std::move(l)), right(std::move(r)), op(o) {}
Value do_evaluate(const std::shared_ptr<Context> & context) const override {
if (!left) throw std::runtime_error("BinaryOpExpr.left is null");
if (!right) throw std::runtime_error("BinaryOpExpr.right is null");
auto l = left->evaluate(context);
auto do_eval = [&](const Value & l) -> Value {
@ -1210,13 +1237,15 @@ static std::string html_escape(const std::string & s) {
}
class MethodCallExpr : public Expression {
std::unique_ptr<Expression> object;
std::unique_ptr<VariableExpr> method;
std::shared_ptr<Expression> object;
std::shared_ptr<VariableExpr> method;
Expression::Arguments args;
public:
MethodCallExpr(const Location & location, std::unique_ptr<Expression> && obj, std::unique_ptr<VariableExpr> && m, Expression::Arguments && a)
MethodCallExpr(const Location & location, std::shared_ptr<Expression> && obj, std::shared_ptr<VariableExpr> && m, Expression::Arguments && a)
: Expression(location), object(std::move(obj)), method(std::move(m)), args(std::move(a)) {}
Value do_evaluate(const std::shared_ptr<Context> & context) const override {
if (!object) throw std::runtime_error("MethodCallExpr.object is null");
if (!method) throw std::runtime_error("MethodCallExpr.method is null");
auto obj = object->evaluate(context);
if (obj.is_array()) {
if (method->get_name() == "append") {
@ -1279,11 +1308,12 @@ public:
class CallExpr : public Expression {
public:
std::unique_ptr<Expression> object;
std::shared_ptr<Expression> object;
Expression::Arguments args;
CallExpr(const Location & location, std::unique_ptr<Expression> && obj, Expression::Arguments && a)
CallExpr(const Location & location, std::shared_ptr<Expression> && obj, Expression::Arguments && a)
: Expression(location), object(std::move(obj)), args(std::move(a)) {}
Value do_evaluate(const std::shared_ptr<Context> & context) const override {
if (!object) throw std::runtime_error("CallExpr.object is null");
auto obj = object->evaluate(context);
if (!obj.is_callable()) {
throw std::runtime_error("Object is not callable: " + obj.dump(2));
@ -1294,14 +1324,15 @@ public:
};
class FilterExpr : public Expression {
std::vector<std::unique_ptr<Expression>> parts;
std::vector<std::shared_ptr<Expression>> parts;
public:
FilterExpr(const Location & location, std::vector<std::unique_ptr<Expression>> && p)
FilterExpr(const Location & location, std::vector<std::shared_ptr<Expression>> && p)
: Expression(location), parts(std::move(p)) {}
Value do_evaluate(const std::shared_ptr<Context> & context) const override {
Value result;
bool first = true;
for (const auto& part : parts) {
if (!part) throw std::runtime_error("FilterExpr.part is null");
if (first) {
first = false;
result = part->evaluate(context);
@ -1322,7 +1353,7 @@ public:
return result;
}
void prepend(std::unique_ptr<Expression> && e) {
void prepend(std::shared_ptr<Expression> && e) {
parts.insert(parts.begin(), std::move(e));
}
};
@ -1375,7 +1406,7 @@ private:
escape = true;
} else if (*it == quote) {
++it;
return nonstd_make_unique<std::string>(result);
return nonstd_make_unique<std::string>(std::move(result));
} else {
result += *it;
}
@ -1429,25 +1460,25 @@ private:
}
/** integer, float, bool, string */
std::unique_ptr<Value> parseConstant() {
std::shared_ptr<Value> parseConstant() {
auto start = it;
consumeSpaces();
if (it == end) return nullptr;
if (*it == '"' || *it == '\'') {
auto str = parseString();
if (str) return nonstd_make_unique<Value>(*str);
if (str) return std::make_shared<Value>(*str);
}
static std::regex prim_tok(R"(true\b|True\b|false\b|False\b|None\b)");
auto token = consumeToken(prim_tok);
if (!token.empty()) {
if (token == "true" || token == "True") return nonstd_make_unique<Value>(true);
if (token == "false" || token == "False") return nonstd_make_unique<Value>(false);
if (token == "None") return nonstd_make_unique<Value>(nullptr);
if (token == "true" || token == "True") return std::make_shared<Value>(true);
if (token == "false" || token == "False") return std::make_shared<Value>(false);
if (token == "None") return std::make_shared<Value>(nullptr);
throw std::runtime_error("Unknown constant token: " + token);
}
auto number = parseNumber(it, end);
if (!number.is_null()) return nonstd_make_unique<Value>(number);
if (!number.is_null()) return std::make_shared<Value>(number);
it = start;
return nullptr;
@ -1510,7 +1541,7 @@ private:
return "";
}
std::unique_ptr<Expression> parseExpression(bool allow_if_expr = true) {
std::shared_ptr<Expression> parseExpression(bool allow_if_expr = true) {
auto left = parseLogicalOr();
if (it == end) return left;
@ -1523,19 +1554,19 @@ private:
auto location = get_location();
auto if_expr = parseIfExpression();
return nonstd_make_unique<IfExpr>(location, std::move(if_expr.first), std::move(left), std::move(if_expr.second));
return std::make_shared<IfExpr>(location, std::move(if_expr.first), std::move(left), std::move(if_expr.second));
}
Location get_location() const {
return {template_str, (size_t) std::distance(start, it)};
}
std::pair<std::unique_ptr<Expression>, std::unique_ptr<Expression>> parseIfExpression() {
std::pair<std::shared_ptr<Expression>, std::shared_ptr<Expression>> parseIfExpression() {
auto condition = parseLogicalOr();
if (!condition) throw std::runtime_error("Expected condition expression");
static std::regex else_tok(R"(else\b)");
std::unique_ptr<Expression> else_expr;
std::shared_ptr<Expression> else_expr;
if (!consumeToken(else_tok).empty()) {
else_expr = parseExpression();
if (!else_expr) throw std::runtime_error("Expected 'else' expression");
@ -1543,7 +1574,7 @@ private:
return std::make_pair(std::move(condition), std::move(else_expr));
}
std::unique_ptr<Expression> parseLogicalOr() {
std::shared_ptr<Expression> parseLogicalOr() {
auto left = parseLogicalAnd();
if (!left) throw std::runtime_error("Expected left side of 'logical or' expression");
@ -1552,24 +1583,24 @@ private:
while (!consumeToken(or_tok).empty()) {
auto right = parseLogicalAnd();
if (!right) throw std::runtime_error("Expected right side of 'or' expression");
left = nonstd_make_unique<BinaryOpExpr>(location, std::move(left), std::move(right), BinaryOpExpr::Op::Or);
left = std::make_shared<BinaryOpExpr>(location, std::move(left), std::move(right), BinaryOpExpr::Op::Or);
}
return left;
}
std::unique_ptr<Expression> parseLogicalNot() {
std::shared_ptr<Expression> parseLogicalNot() {
static std::regex not_tok(R"(not\b)");
auto location = get_location();
if (!consumeToken(not_tok).empty()) {
auto sub = parseLogicalNot();
if (!sub) throw std::runtime_error("Expected expression after 'not' keyword");
return nonstd_make_unique<UnaryOpExpr>(location, std::move(sub), UnaryOpExpr::Op::LogicalNot);
return std::make_shared<UnaryOpExpr>(location, std::move(sub), UnaryOpExpr::Op::LogicalNot);
}
return parseLogicalCompare();
}
std::unique_ptr<Expression> parseLogicalAnd() {
std::shared_ptr<Expression> parseLogicalAnd() {
auto left = parseLogicalNot();
if (!left) throw std::runtime_error("Expected left side of 'logical and' expression");
@ -1578,12 +1609,12 @@ private:
while (!consumeToken(and_tok).empty()) {
auto right = parseLogicalNot();
if (!right) throw std::runtime_error("Expected right side of 'and' expression");
left = nonstd_make_unique<BinaryOpExpr>(location, std::move(left), std::move(right), BinaryOpExpr::Op::And);
left = std::make_shared<BinaryOpExpr>(location, std::move(left), std::move(right), BinaryOpExpr::Op::And);
}
return left;
}
std::unique_ptr<Expression> parseLogicalCompare() {
std::shared_ptr<Expression> parseLogicalCompare() {
auto left = parseStringConcat();
if (!left) throw std::runtime_error("Expected left side of 'logical compare' expression");
@ -1598,7 +1629,7 @@ private:
auto identifier = parseIdentifier();
if (!identifier) throw std::runtime_error("Expected identifier after 'is' keyword");
return nonstd_make_unique<BinaryOpExpr>(
return std::make_shared<BinaryOpExpr>(
left->location,
std::move(left), std::move(identifier),
negated ? BinaryOpExpr::Op::IsNot : BinaryOpExpr::Op::Is);
@ -1615,7 +1646,7 @@ private:
else if (op_str == "in") op = BinaryOpExpr::Op::In;
else if (op_str.substr(0, 3) == "not") op = BinaryOpExpr::Op::NotIn;
else throw std::runtime_error("Unknown comparison operator: " + op_str);
left = nonstd_make_unique<BinaryOpExpr>(get_location(), std::move(left), std::move(right), op);
left = std::make_shared<BinaryOpExpr>(get_location(), std::move(left), std::move(right), op);
}
return left;
}
@ -1688,16 +1719,16 @@ private:
throw std::runtime_error("Expected closing parenthesis in call args");
}
std::unique_ptr<VariableExpr> parseIdentifier() {
std::shared_ptr<VariableExpr> parseIdentifier() {
static std::regex ident_regex(R"((?!(?:not|is|and|or|del)\b)[a-zA-Z_]\w*)");
auto location = get_location();
auto ident = consumeToken(ident_regex);
if (ident.empty())
return nullptr;
return nonstd_make_unique<VariableExpr>(location, ident);
return std::make_shared<VariableExpr>(location, ident);
}
std::unique_ptr<Expression> parseStringConcat() {
std::shared_ptr<Expression> parseStringConcat() {
auto left = parseMathPow();
if (!left) throw std::runtime_error("Expected left side of 'string concat' expression");
@ -1705,24 +1736,24 @@ private:
if (!consumeToken(concat_tok).empty()) {
auto right = parseLogicalAnd();
if (!right) throw std::runtime_error("Expected right side of 'string concat' expression");
left = nonstd_make_unique<BinaryOpExpr>(get_location(), std::move(left), std::move(right), BinaryOpExpr::Op::StrConcat);
left = std::make_shared<BinaryOpExpr>(get_location(), std::move(left), std::move(right), BinaryOpExpr::Op::StrConcat);
}
return left;
}
std::unique_ptr<Expression> parseMathPow() {
std::shared_ptr<Expression> parseMathPow() {
auto left = parseMathPlusMinus();
if (!left) throw std::runtime_error("Expected left side of 'math pow' expression");
while (!consumeToken("**").empty()) {
auto right = parseMathPlusMinus();
if (!right) throw std::runtime_error("Expected right side of 'math pow' expression");
left = nonstd_make_unique<BinaryOpExpr>(get_location(), std::move(left), std::move(right), BinaryOpExpr::Op::MulMul);
left = std::make_shared<BinaryOpExpr>(get_location(), std::move(left), std::move(right), BinaryOpExpr::Op::MulMul);
}
return left;
}
std::unique_ptr<Expression> parseMathPlusMinus() {
std::shared_ptr<Expression> parseMathPlusMinus() {
static std::regex plus_minus_tok(R"(\+|-(?![}%#]\}))");
auto left = parseMathMulDiv();
@ -1732,12 +1763,12 @@ private:
auto right = parseMathMulDiv();
if (!right) throw std::runtime_error("Expected right side of 'math plus/minus' expression");
auto op = op_str == "+" ? BinaryOpExpr::Op::Add : BinaryOpExpr::Op::Sub;
left = nonstd_make_unique<BinaryOpExpr>(get_location(), std::move(left), std::move(right), op);
left = std::make_shared<BinaryOpExpr>(get_location(), std::move(left), std::move(right), op);
}
return left;
}
std::unique_ptr<Expression> parseMathMulDiv() {
std::shared_ptr<Expression> parseMathMulDiv() {
auto left = parseMathUnaryPlusMinus();
if (!left) throw std::runtime_error("Expected left side of 'math mul/div' expression");
@ -1751,7 +1782,7 @@ private:
: op_str == "/" ? BinaryOpExpr::Op::Div
: op_str == "//" ? BinaryOpExpr::Op::DivDiv
: BinaryOpExpr::Op::Mod;
left = nonstd_make_unique<BinaryOpExpr>(get_location(), std::move(left), std::move(right), op);
left = std::make_shared<BinaryOpExpr>(get_location(), std::move(left), std::move(right), op);
}
if (!consumeToken("|").empty()) {
@ -1760,20 +1791,20 @@ private:
filter->prepend(std::move(left));
return expr;
} else {
std::vector<std::unique_ptr<Expression>> parts;
std::vector<std::shared_ptr<Expression>> parts;
parts.emplace_back(std::move(left));
parts.emplace_back(std::move(expr));
return nonstd_make_unique<FilterExpr>(get_location(), std::move(parts));
return std::make_shared<FilterExpr>(get_location(), std::move(parts));
}
}
return left;
}
std::unique_ptr<Expression> call_func(const std::string & name, Expression::Arguments && args) const {
return nonstd_make_unique<CallExpr>(get_location(), nonstd_make_unique<VariableExpr>(get_location(), name), std::move(args));
std::shared_ptr<Expression> call_func(const std::string & name, Expression::Arguments && args) const {
return std::make_shared<CallExpr>(get_location(), std::make_shared<VariableExpr>(get_location(), name), std::move(args));
}
std::unique_ptr<Expression> parseMathUnaryPlusMinus() {
std::shared_ptr<Expression> parseMathUnaryPlusMinus() {
static std::regex unary_plus_minus_tok(R"(\+|-(?![}%#]\}))");
auto op_str = consumeToken(unary_plus_minus_tok);
auto expr = parseValueExpression();
@ -1781,19 +1812,19 @@ private:
if (!op_str.empty()) {
auto op = op_str == "+" ? UnaryOpExpr::Op::Plus : UnaryOpExpr::Op::Minus;
return nonstd_make_unique<UnaryOpExpr>(get_location(), std::move(expr), op);
return std::make_shared<UnaryOpExpr>(get_location(), std::move(expr), op);
}
return expr;
}
std::unique_ptr<Expression> parseValueExpression() {
auto parseValue = [&]() -> std::unique_ptr<Expression> {
std::shared_ptr<Expression> parseValueExpression() {
auto parseValue = [&]() -> std::shared_ptr<Expression> {
auto location = get_location();
auto constant = parseConstant();
if (constant) return nonstd_make_unique<LiteralExpr>(location, *constant);
if (constant) return std::make_shared<LiteralExpr>(location, *constant);
static std::regex null_regex(R"(null\b)");
if (!consumeToken(null_regex).empty()) return nonstd_make_unique<LiteralExpr>(location, Value());
if (!consumeToken(null_regex).empty()) return std::make_shared<LiteralExpr>(location, Value());
auto identifier = parseIdentifier();
if (identifier) return identifier;
@ -1814,19 +1845,19 @@ private:
while (it != end && consumeSpaces() && peekSymbols({ "[", "." })) {
if (!consumeToken("[").empty()) {
std::unique_ptr<Expression> index;
std::shared_ptr<Expression> index;
if (!consumeToken(":").empty()) {
auto slice_end = parseExpression();
index = nonstd_make_unique<SliceExpr>(slice_end->location, nullptr, std::move(slice_end));
index = std::make_shared<SliceExpr>(slice_end->location, nullptr, std::move(slice_end));
} else {
auto slice_start = parseExpression();
if (!consumeToken(":").empty()) {
consumeSpaces();
if (peekSymbols({ "]" })) {
index = nonstd_make_unique<SliceExpr>(slice_start->location, std::move(slice_start), nullptr);
index = std::make_shared<SliceExpr>(slice_start->location, std::move(slice_start), nullptr);
} else {
auto slice_end = parseExpression();
index = nonstd_make_unique<SliceExpr>(slice_start->location, std::move(slice_start), std::move(slice_end));
index = std::make_shared<SliceExpr>(slice_start->location, std::move(slice_start), std::move(slice_end));
}
} else {
index = std::move(slice_start);
@ -1835,7 +1866,7 @@ private:
if (!index) throw std::runtime_error("Empty index in subscript");
if (consumeToken("]").empty()) throw std::runtime_error("Expected closing bracket in subscript");
value = nonstd_make_unique<SubscriptExpr>(value->location, std::move(value), std::move(index));
value = std::make_shared<SubscriptExpr>(value->location, std::move(value), std::move(index));
} else if (!consumeToken(".").empty()) {
auto identifier = parseIdentifier();
if (!identifier) throw std::runtime_error("Expected identifier in subscript");
@ -1843,10 +1874,10 @@ private:
consumeSpaces();
if (peekSymbols({ "(" })) {
auto callParams = parseCallArgs();
value = nonstd_make_unique<MethodCallExpr>(identifier->location, std::move(value), std::move(identifier), std::move(callParams));
value = std::make_shared<MethodCallExpr>(identifier->location, std::move(value), std::move(identifier), std::move(callParams));
} else {
auto key = nonstd_make_unique<LiteralExpr>(identifier->location, Value(identifier->get_name()));
value = nonstd_make_unique<SubscriptExpr>(identifier->location, std::move(value), std::move(key));
auto key = std::make_shared<LiteralExpr>(identifier->location, Value(identifier->get_name()));
value = std::make_shared<SubscriptExpr>(identifier->location, std::move(value), std::move(key));
}
}
consumeSpaces();
@ -1855,12 +1886,12 @@ private:
if (peekSymbols({ "(" })) {
auto location = get_location();
auto callParams = parseCallArgs();
value = nonstd_make_unique<CallExpr>(location, std::move(value), std::move(callParams));
value = std::make_shared<CallExpr>(location, std::move(value), std::move(callParams));
}
return value;
}
std::unique_ptr<Expression> parseBracedExpressionOrArray() {
std::shared_ptr<Expression> parseBracedExpressionOrArray() {
if (consumeToken("(").empty()) return nullptr;
auto expr = parseExpression();
@ -1870,7 +1901,7 @@ private:
return expr; // Drop the parentheses
}
std::vector<std::unique_ptr<Expression>> tuple;
std::vector<std::shared_ptr<Expression>> tuple;
tuple.emplace_back(std::move(expr));
while (it != end) {
@ -1880,18 +1911,18 @@ private:
tuple.push_back(std::move(next));
if (!consumeToken(")").empty()) {
return nonstd_make_unique<ArrayExpr>(get_location(), std::move(tuple));
return std::make_shared<ArrayExpr>(get_location(), std::move(tuple));
}
}
throw std::runtime_error("Expected closing parenthesis");
}
std::unique_ptr<Expression> parseArray() {
std::shared_ptr<Expression> parseArray() {
if (consumeToken("[").empty()) return nullptr;
std::vector<std::unique_ptr<Expression>> elements;
std::vector<std::shared_ptr<Expression>> elements;
if (!consumeToken("]").empty()) {
return nonstd_make_unique<ArrayExpr>(get_location(), std::move(elements));
return std::make_shared<ArrayExpr>(get_location(), std::move(elements));
}
auto first_expr = parseExpression();
if (!first_expr) throw std::runtime_error("Expected first expression in array");
@ -1903,7 +1934,7 @@ private:
if (!expr) throw std::runtime_error("Expected expression in array");
elements.push_back(std::move(expr));
} else if (!consumeToken("]").empty()) {
return nonstd_make_unique<ArrayExpr>(get_location(), std::move(elements));
return std::make_shared<ArrayExpr>(get_location(), std::move(elements));
} else {
throw std::runtime_error("Expected comma or closing bracket in array");
}
@ -1911,12 +1942,12 @@ private:
throw std::runtime_error("Expected closing bracket");
}
std::unique_ptr<Expression> parseDictionary() {
std::shared_ptr<Expression> parseDictionary() {
if (consumeToken("{").empty()) return nullptr;
std::vector<std::pair<std::unique_ptr<Expression>, std::unique_ptr<Expression>>> elements;
std::vector<std::pair<std::shared_ptr<Expression>, std::shared_ptr<Expression>>> elements;
if (!consumeToken("}").empty()) {
return nonstd_make_unique<DictExpr>(get_location(), std::move(elements));
return std::make_shared<DictExpr>(get_location(), std::move(elements));
}
auto parseKeyValuePair = [&]() {
@ -1934,7 +1965,7 @@ private:
if (!consumeToken(",").empty()) {
parseKeyValuePair();
} else if (!consumeToken("}").empty()) {
return nonstd_make_unique<DictExpr>(get_location(), std::move(elements));
return std::make_shared<DictExpr>(get_location(), std::move(elements));
} else {
throw std::runtime_error("Expected comma or closing brace in dictionary");
}
@ -2051,7 +2082,7 @@ private:
auto iterable = parseExpression(/* allow_if_expr = */ false);
if (!iterable) throw std::runtime_error("Expected iterable in for block");
std::unique_ptr<Expression> condition;
std::shared_ptr<Expression> condition;
if (!consumeToken(if_tok).empty()) {
condition = parseExpression();
}
@ -2067,7 +2098,7 @@ private:
std::string ns;
std::vector<std::string> var_names;
std::unique_ptr<Expression> value;
std::shared_ptr<Expression> value;
if (!(group = consumeTokenGroups(namespaced_var_regex)).empty()) {
ns = group[1];
var_names.push_back(group[2]);
@ -2114,17 +2145,17 @@ private:
}
}
std::unique_ptr<TemplateNode> parseTemplate(
std::shared_ptr<TemplateNode> parseTemplate(
const TemplateTokenIterator & begin,
TemplateTokenIterator & it,
const TemplateTokenIterator & end,
bool fully = false) const {
std::vector<std::unique_ptr<TemplateNode>> children;
std::vector<std::shared_ptr<TemplateNode>> children;
while (it != end) {
const auto start = it;
const auto & token = *(it++);
if (auto if_token = dynamic_cast<IfTemplateToken*>(token.get())) {
std::vector<std::pair<std::unique_ptr<Expression>, std::unique_ptr<TemplateNode>>> cascade;
std::vector<std::pair<std::shared_ptr<Expression>, std::shared_ptr<TemplateNode>>> cascade;
cascade.emplace_back(std::move(if_token->condition), parseTemplate(begin, it, end));
while (it != end && (*it)->type == TemplateToken::Type::Elif) {
@ -2138,17 +2169,17 @@ private:
if (it == end || (*(it++))->type != TemplateToken::Type::EndIf) {
throw unterminated(**start);
}
children.emplace_back(nonstd_make_unique<IfNode>(token->location, std::move(cascade)));
children.emplace_back(std::make_shared<IfNode>(token->location, std::move(cascade)));
} else if (auto for_token = dynamic_cast<ForTemplateToken*>(token.get())) {
auto body = parseTemplate(begin, it, end);
auto else_body = std::unique_ptr<TemplateNode>();
auto else_body = std::shared_ptr<TemplateNode>();
if (it != end && (*it)->type == TemplateToken::Type::Else) {
else_body = parseTemplate(begin, ++it, end);
}
if (it == end || (*(it++))->type != TemplateToken::Type::EndFor) {
throw unterminated(**start);
}
children.emplace_back(nonstd_make_unique<ForNode>(token->location, std::move(for_token->var_names), std::move(for_token->iterable), std::move(for_token->condition), std::move(body), for_token->recursive, std::move(else_body)));
children.emplace_back(std::make_shared<ForNode>(token->location, std::move(for_token->var_names), std::move(for_token->iterable), std::move(for_token->condition), std::move(body), for_token->recursive, std::move(else_body)));
} else if (auto text_token = dynamic_cast<TextTemplateToken*>(token.get())) {
SpaceHandling pre_space = (it - 1) != begin ? (*(it - 2))->post_space : SpaceHandling::Keep;
SpaceHandling post_space = it != end ? (*it)->pre_space : SpaceHandling::Keep;
@ -2173,25 +2204,28 @@ private:
static std::regex r(R"(\r?\n$)");
text = std::regex_replace(text, r, ""); // Strip one trailing newline
}
children.emplace_back(nonstd_make_unique<TextNode>(token->location, text));
children.emplace_back(std::make_shared<TextNode>(token->location, text));
} else if (auto expr_token = dynamic_cast<ExpressionTemplateToken*>(token.get())) {
children.emplace_back(nonstd_make_unique<ExpressionNode>(token->location, std::move(expr_token->expr)));
children.emplace_back(std::make_shared<ExpressionNode>(token->location, std::move(expr_token->expr)));
} else if (auto set_token = dynamic_cast<SetTemplateToken*>(token.get())) {
if (set_token->value) {
children.emplace_back(nonstd_make_unique<SetNode>(token->location, set_token->ns, set_token->var_names, std::move(set_token->value), nullptr));
children.emplace_back(std::make_shared<SetNode>(token->location, set_token->ns, set_token->var_names, std::move(set_token->value)));
} else {
auto value_template = parseTemplate(begin, it, end);
if (it == end || (*(it++))->type != TemplateToken::Type::EndSet) {
throw unterminated(**start);
}
children.emplace_back(nonstd_make_unique<SetNode>(token->location, set_token->ns, set_token->var_names, nullptr, std::move(value_template)));
if (!set_token->ns.empty()) throw std::runtime_error("Namespaced set not supported in set with template value");
if (set_token->var_names.size() != 1) throw std::runtime_error("Structural assignment not supported in set with template value");
auto & name = set_token->var_names[0];
children.emplace_back(std::make_shared<SetTemplateNode>(token->location, name, std::move(value_template)));
}
} else if (auto macro_token = dynamic_cast<MacroTemplateToken*>(token.get())) {
auto body = parseTemplate(begin, it, end);
if (it == end || (*(it++))->type != TemplateToken::Type::EndMacro) {
throw unterminated(**start);
}
children.emplace_back(nonstd_make_unique<MacroNode>(token->location, std::move(macro_token->name), std::move(macro_token->params), std::move(body)));
children.emplace_back(std::make_shared<MacroNode>(token->location, std::move(macro_token->name), std::move(macro_token->params), std::move(body)));
} else if (dynamic_cast<CommentTemplateToken*>(token.get())) {
// Ignore comments
} else if (dynamic_cast<EndForTemplateToken*>(token.get())
@ -2210,17 +2244,17 @@ private:
throw unexpected(**it);
}
if (children.empty()) {
return nonstd_make_unique<TextNode>(Location { template_str, 0 }, std::string());
return std::make_shared<TextNode>(Location { template_str, 0 }, std::string());
} else if (children.size() == 1) {
return std::move(children[0]);
} else {
return nonstd_make_unique<SequenceNode>(children[0]->location(), std::move(children));
return std::make_shared<SequenceNode>(children[0]->location(), std::move(children));
}
}
public:
static std::unique_ptr<TemplateNode> parse(const std::string& template_str, const Options & options) {
static std::shared_ptr<TemplateNode> parse(const std::string& template_str, const Options & options) {
Parser parser(std::make_shared<std::string>(template_str), options);
auto tokens = parser.tokenize();
TemplateTokenIterator begin = tokens.begin();

View File

@ -12,6 +12,29 @@
using json = nlohmann::ordered_json;
llama_tool_call_style llama_tool_call_style_detect(const minja::chat_template & chat_template) {
const auto & src = chat_template.source();
if (src.find("<tool_call>") != std::string::npos) {
return Hermes2Pro;
} else if (src.find(">>>all") != std::string::npos) {
return FunctionaryV3Llama3;
} else if (src.find("<|start_header_id|>") != std::string::npos
&& src.find("<function=") != std::string::npos) {
return FunctionaryV3Llama31;
} else if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
if (src.find("<|python_tag|>") != std::string::npos) {
return Llama31;
} else {
return Llama32;
}
} else if (src.find("<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>") != std::string::npos) {
return CommandRPlus;
} else {
return UnknownToolCallStyle;
}
}
static bool parse_json(std::string::const_iterator & it, const std::string::const_iterator & end, json & out) {
// // https://json.nlohmann.me/features/parsing/sax_interface/
struct json_error_locator : public nlohmann::json_sax<json> {
@ -207,7 +230,8 @@ llama_tool_calls parse_tool_calls(llama_tool_call_style style, const json & tool
}
llama_tool_call_handler llama_tool_call_handler_init(
const llama_chat_template & tmpl,
llama_tool_call_style style,
const minja::chat_template & tmpl,
bool allow_content,
bool parallel_tool_calls,
const nlohmann::ordered_json & messages,
@ -215,18 +239,18 @@ llama_tool_call_handler llama_tool_call_handler_init(
{
llama_tool_call_handler handler;
switch (tmpl.tool_call_style()) {
switch (style) {
case llama_tool_call_style::Llama31:
case llama_tool_call_style::Llama32: {
static auto builtin_tools = json {"wolfram_alpha", "brave_search"};
auto uses_python_tag = tmpl.tool_call_style() == llama_tool_call_style::Llama31;
auto uses_python_tag = style == llama_tool_call_style::Llama31;
// Technically we should only trigger on `"\n{\"name\": \"" + name + "\""` for each tool name,
// but Llama-3.2-3B (and 1B) struggles to output valid tool calls so we're "guiding" it strongly as soon
// as it seems to be outputting some JSON.
// TODO: make this conditional on a very small model (e.g. 1B / 3B).
auto eagerly_match_any_json = tmpl.tool_call_style() == llama_tool_call_style::Llama32;
auto eagerly_match_any_json = style == llama_tool_call_style::Llama32;
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
std::vector<std::string> tool_rules;

View File

@ -2,10 +2,20 @@
#include "ggml.h"
#include "common.h"
#include "chat-template.hpp"
// Change JSON_ASSERT from assert() to GGML_ASSERT:
#define JSON_ASSERT GGML_ASSERT
#include "json.hpp"
#include "chat-template.h"
enum llama_tool_call_style {
UnknownToolCallStyle,
Llama31,
Llama32,
FunctionaryV3Llama3,
FunctionaryV3Llama31,
Hermes2Pro,
CommandRPlus,
};
struct llama_tool_call {
std::string name;
@ -24,10 +34,13 @@ struct llama_tool_call_handler {
std::vector<std::string> additional_stop_words;
};
llama_tool_call_style llama_tool_call_style_detect(const minja::chat_template & chat_template);
llama_tool_calls parse_tool_calls(llama_tool_call_style style, const nlohmann::ordered_json & tools, const std::string& input);
llama_tool_call_handler llama_tool_call_handler_init(
const llama_chat_template & tmpl,
llama_tool_call_style style,
const minja::chat_template & tmpl,
bool allow_content,
bool parallel_tool_calls,
const nlohmann::ordered_json & messages,

View File

@ -663,7 +663,7 @@ struct server_context {
llama_chat_message chat[] = {{"user", "test"}};
if (use_jinja) {
auto chat_template = llama_chat_template::from_model(model);
auto chat_template = llama_chat_template_from_model(model);
try {
chat_template.apply({{
{"role", "user"},
@ -2875,11 +2875,12 @@ int main(int argc, char ** argv) {
return;
}
auto chat_template = llama_chat_template::from_model(ctx_server.model, params.chat_template.empty() ? nullptr : params.chat_template.c_str());
static auto chat_template = llama_chat_template_from_model(ctx_server.model, params.chat_template.empty() ? nullptr : params.chat_template.c_str());
static auto tool_call_style = llama_tool_call_style_detect(chat_template);
json data;
try {
data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), chat_template, params.use_jinja);
data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), chat_template, tool_call_style, params.use_jinja);
} catch (const std::exception & e) {
res_error(res, format_error_response(e.what(), ERROR_TYPE_NOT_SUPPORTED));
return;
@ -2897,7 +2898,7 @@ int main(int argc, char ** argv) {
ctx_server.receive_cmpl_results(task_ids, [&](const std::vector<server_task_result> & results) {
// multitask is never support in chat completion, there is only one result
try {
json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, chat_template, /*.streaming =*/ false, verbose);
json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, tool_call_style, /*.streaming =*/ false, verbose);
res_ok(res, result_oai);
} catch (const std::runtime_error & e) {
res_error(res, format_error_response(e.what(), ERROR_TYPE_SERVER));

View File

@ -14,7 +14,6 @@
// Change JSON_ASSERT from assert() to GGML_ASSERT:
#define JSON_ASSERT GGML_ASSERT
#include "chat-template.h"
#include "json.hpp"
#include "minja.hpp"
#include "tool-call.h"
@ -309,7 +308,8 @@ static bool server_sent_event(httplib::DataSink & sink, const char * event, cons
static json oaicompat_completion_params_parse(
const struct llama_model * model,
const json & body, /* openai api json semantics */
const llama_chat_template & tmpl,
const minja::chat_template & tmpl,
llama_tool_call_style tool_call_style,
bool use_jinja)
{
json llama_params;
@ -320,7 +320,7 @@ static json oaicompat_completion_params_parse(
auto has_tools = tools.is_array() && !tools.empty();
// Apply chat template to the list of messages
llama_params["chat_template"] = tmpl.chat_template();
llama_params["chat_template"] = tmpl.source();
if (use_jinja) {
if (has_tools && !tmpl.supports_tools()) {
@ -372,7 +372,7 @@ static json oaicompat_completion_params_parse(
llama_params["parse_tool_calls"] = true;
llama_params["parallel_tool_calls"] = parallel_tool_calls;
auto handler = llama_tool_call_handler_init(tmpl, allow_content, parallel_tool_calls, body.at("messages"), tools);
auto handler = llama_tool_call_handler_init(tool_call_style, tmpl, allow_content, parallel_tool_calls, body.at("messages"), tools);
llama_params["prompt"] = handler.prompt;
for (const auto & stop : handler.additional_stop_words) {
@ -395,7 +395,7 @@ static json oaicompat_completion_params_parse(
llama_params["prompt"] = tmpl.apply(body.at("messages"), tools, /* add_generation_prompt= */ true);
}
} else {
llama_params["prompt"] = format_chat(model, tmpl.chat_template(), body.at("messages"));
llama_params["prompt"] = format_chat(model, tmpl.source(), body.at("messages"));
}
// Handle "n" field
@ -435,7 +435,7 @@ static json oaicompat_completion_params_parse(
return llama_params;
}
static json format_final_response_oaicompat(const json & request, const json & result, const std::string & completion_id, const llama_chat_template & tmpl, bool streaming = false, bool verbose = false) {
static json format_final_response_oaicompat(const json & request, const json & result, const std::string & completion_id, llama_tool_call_style tool_call_style, bool streaming = false, bool verbose = false) {
bool stopped_word = result.count("stopped_word") != 0;
bool stopped_eos = json_value(result, "stopped_eos", false);
int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
@ -452,7 +452,7 @@ static json format_final_response_oaicompat(const json & request, const json & r
json tool_calls;
json message_content;
if (json_value(request, "parse_tool_calls", false)
&& !(parsed_tool_calls = parse_tool_calls(tmpl.tool_call_style(), tools, content)).tool_calls.empty()) {
&& !(parsed_tool_calls = parse_tool_calls(tool_call_style, tools, content)).tool_calls.empty()) {
finish_reason = "tool_calls";
if (!parsed_tool_calls.content.empty()) {
message_content = parsed_tool_calls.content;

View File

@ -0,0 +1,148 @@
#!/usr/bin/env uv run
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "jinja2",
# "huggingface_hub",
# ]
# ///
'''
Fetches the Jinja2 templates of specified models and generates prompt goldens for predefined chat contexts.
Outputs lines of arguments for a C++ test binary.
All files are written to the specified output folder.
Usage:
python ./update_jinja_goldens.py output_folder context1.json context2.json ... model_id1 model_id2 ...
Example:
python ./update_jinja_goldens.py ./test_files "microsoft/Phi-3-medium-4k-instruct" "Qwen/Qwen2-7B-Instruct"
'''
import logging
import datetime
import glob
import os
from huggingface_hub import hf_hub_download
import json
import jinja2
import jinja2.ext
import re
import argparse
import shutil
logging.basicConfig(level=logging.INFO, format='%(message)s')
logger = logging.getLogger(__name__)
def raise_exception(message: str):
raise ValueError(message)
def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False):
return json.dumps(x, ensure_ascii=ensure_ascii, indent=indent, separators=separators, sort_keys=sort_keys)
TEST_DATE = os.environ.get('TEST_DATE', '2024-07-26')
def strftime_now(format):
now = datetime.datetime.strptime(TEST_DATE, "%Y-%m-%d")
return now.strftime(format)
def handle_chat_template(output_folder, model_id, variant, template_src):
model_name = model_id.replace("/", "-")
base_name = f'{model_name}-{variant}' if variant else model_name
template_file = os.path.join(output_folder, f'{base_name}.jinja')
with open(template_file, 'w') as f:
f.write(template_src)
env = jinja2.Environment(
trim_blocks=True,
lstrip_blocks=True,
extensions=[jinja2.ext.loopcontrols]
)
env.filters['safe'] = lambda x: x
env.filters['tojson'] = tojson
env.globals['raise_exception'] = raise_exception
env.globals['strftime_now'] = strftime_now
template_handles_tools = 'tools' in template_src
template_hates_the_system = 'System role not supported' in template_src
template = env.from_string(template_src)
context_files = glob.glob(os.path.join(output_folder, '*.json'))
for context_file in context_files:
context_name = os.path.basename(context_file).replace(".json", "")
with open(context_file, 'r') as f:
context = json.load(f)
if not template_handles_tools and 'tools' in context:
continue
if template_hates_the_system and any(m['role'] == 'system' for m in context['messages']):
continue
output_file = os.path.join(output_folder, f'{base_name}-{context_name}.txt')
render_context = json.loads(json.dumps(context))
if 'tool_call.arguments | items' in template_src or 'tool_call.arguments | tojson' in template_src:
for message in render_context['messages']:
if 'tool_calls' in message:
for tool_call in message['tool_calls']:
if tool_call.get('type') == 'function':
arguments = tool_call['function']['arguments']
tool_call['function']['arguments'] = json.loads(arguments)
try:
output = template.render(**render_context)
except Exception as e1:
for message in context["messages"]:
if message.get("content") is None:
message["content"] = ""
try:
output = template.render(**render_context)
except Exception as e2:
logger.info(f" ERROR: {e2} (after first error: {e1})")
output = f"ERROR: {e2}"
with open(output_file, 'w') as f:
f.write(output)
# Output the line of arguments for the C++ test binary
print(f"{template_file} {context_file} {output_file}")
def main():
parser = argparse.ArgumentParser(description="Generate chat templates and output test arguments.")
parser.add_argument("output_folder", help="Folder to store all output files")
parser.add_argument("model_ids", nargs="+", help="List of model IDs to process")
args = parser.parse_args()
output_folder = args.output_folder
if not os.path.isdir(output_folder):
os.makedirs(output_folder)
# Copy context files to the output folder
for context_file in glob.glob('tests/chat/contexts/*.json'):
shutil.copy(context_file, output_folder)
for model_id in args.model_ids:
try:
with open(hf_hub_download(repo_id=model_id, filename="tokenizer_config.json")) as f:
config_str = f.read()
try:
config = json.loads(config_str)
except json.JSONDecodeError:
config = json.loads(re.sub(r'\}([\n\s]*\}[\n\s]*\],[\n\s]*"clean_up_tokenization_spaces")', r'\1', config_str))
chat_template = config['chat_template']
if isinstance(chat_template, str):
handle_chat_template(output_folder, model_id, None, chat_template)
else:
for ct in chat_template:
handle_chat_template(output_folder, model_id, ct['name'], ct['template'])
except Exception as e:
logger.error(f"Error processing model {model_id}: {e}")
if __name__ == '__main__':
main()

View File

@ -7,7 +7,7 @@
#include "llama.h"
#include "common.h"
#include "chat-template.h"
#include "chat-template.hpp"
#include <iostream>
#include <fstream>
#include <iostream>
@ -73,7 +73,7 @@ static void test_jinja_templates() {
return "tests/chat/goldens/" + golden_name + ".txt";
};
auto fail_with_golden_instructions = [&]() {
throw std::runtime_error("To fetch templates and generate golden files, run `python tests/update_jinja_goldens.py`");
throw std::runtime_error("To fetch templates and generate golden files, run `python update_templates_and_goldens.py`");
};
if (jinja_template_files.empty()) {
std::cerr << "No Jinja templates found in tests/chat/templates" << std::endl;
@ -89,7 +89,7 @@ static void test_jinja_templates() {
for (const auto & ctx_file : context_files) {
auto ctx = json::parse(read_file(ctx_file));
llama_chat_template tmpl(
minja::chat_template tmpl(
tmpl_str,
ctx.at("bos_token"),
ctx.at("eos_token"));
@ -127,20 +127,6 @@ static void test_jinja_templates() {
}
}
void test_tool_call_style(const std::string & template_file, llama_tool_call_style expected) {
auto tmpl = llama_chat_template(read_file(template_file), "<s>", "</s>");
std::cout << "# Testing tool call style of: " << template_file << std::endl << std::flush;
assert_equals(expected, tmpl.tool_call_style());
}
void test_tool_call_styles() {
test_tool_call_style("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja", FunctionaryV3Llama31);
test_tool_call_style("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja", FunctionaryV3Llama3);
test_tool_call_style("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja", Llama31);
test_tool_call_style("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja", Llama32);
test_tool_call_style("tests/chat/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja", CommandRPlus);
}
static void test_legacy_templates() {
struct test_template {
std::string name;
@ -353,7 +339,6 @@ int main(void) {
if (getenv("LLAMA_SKIP_TESTS_SLOW_ON_EMULATOR")) {
fprintf(stderr, "\033[33mWARNING: Skipping slow tests on emulator.\n\033[0m");
} else {
test_tool_call_styles();
test_jinja_templates();
}

View File

@ -9,7 +9,8 @@
using json = nlohmann::ordered_json;
static void assert_equals(const std::string & expected, const std::string & actual) {
template <class T>
static void assert_equals(const T & expected, const T & actual) {
if (expected != actual) {
std::cerr << "Expected: " << expected << std::endl;
std::cerr << "Actual: " << actual << std::endl;
@ -242,7 +243,22 @@ static void test_parsing() {
"{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", json::array());
}
static std::string get_message_prompt_delta(const llama_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools) {
void test_tool_call_style(const std::string & template_file, llama_tool_call_style expected) {
const minja::chat_template tmpl(read_file(template_file), "<s>", "</s>");
auto tool_call_style = llama_tool_call_style_detect(tmpl);
std::cout << "# Testing tool call style of: " << template_file << std::endl << std::flush;
assert_equals(expected, tool_call_style);
}
void test_tool_call_style_detection() {
test_tool_call_style("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja", FunctionaryV3Llama31);
test_tool_call_style("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja", FunctionaryV3Llama3);
test_tool_call_style("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja", Llama31);
test_tool_call_style("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja", Llama32);
test_tool_call_style("tests/chat/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja", CommandRPlus);
}
static std::string get_message_prompt_delta(const minja::chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools) {
auto prefix = tmpl.apply(json::array({user_message}), tools, /* add_generation_prompt= */ true, json::object());
auto full = tmpl.apply(json::array({user_message, delta_message}), tools, /* add_generation_prompt= */ false, json::object());
@ -267,7 +283,8 @@ static std::string get_message_prompt_delta(const llama_chat_template & tmpl, co
static void test_template(const std::string & template_file, const char * bos_token, const char * eos_token, const std::vector<std::string> & end_tokens, const json & tool_calling_message, const json & tools) {
std::cout << "# Testing template: " << template_file << std::endl << std::flush;
const llama_chat_template & tmpl = llama_chat_template(read_file(template_file), bos_token, eos_token);
const minja::chat_template tmpl(read_file(template_file), bos_token, eos_token);
auto tool_call_style = llama_tool_call_style_detect(tmpl);
auto & tool_calls = tool_calling_message.at("tool_calls");
// Format the message: apply the template to 1 user message w/ add_generation_prompt=true, then w/ the extra message w/ add_generation_prompt=false,
@ -277,7 +294,7 @@ static void test_template(const std::string & template_file, const char * bos_to
{"content", "Hello, world!"}
};
auto handler = llama_tool_call_handler_init(tmpl, /* allow_content= */ true, /* parallel_tool_calls= */ true, {user_message, tool_calling_message}, tools);
auto handler = llama_tool_call_handler_init(tool_call_style, tmpl, /* allow_content= */ true, /* parallel_tool_calls= */ true, {user_message, tool_calling_message}, tools);
auto grammar = build_grammar(handler.grammar);
if (!grammar) {
throw std::runtime_error("Failed to build grammar");
@ -285,7 +302,7 @@ static void test_template(const std::string & template_file, const char * bos_to
auto full_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, tool_calling_message, tools);
std::cout << "Full delta:\n```\n" << full_delta << "\n```" << std::endl;
test_parse_tool_call(tmpl.tool_call_style(), tools, full_delta, "", tool_calls);
test_parse_tool_call(tool_call_style, tools, full_delta, "", tool_calls);
auto content_less_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, {
{"role", "assistant"},
@ -319,6 +336,7 @@ static void test_grammars() {
int main() {
test_grammars();
test_parsing();
test_tool_call_style_detection();
std::cout << "[tool-call] All tests passed!" << std::endl;
return 0;