mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-13 04:00:16 +00:00
tool-call
: prepare possible externalization of minja + factor tool call style out of template
This commit is contained in:
parent
d9451fd647
commit
c36a196f53
@ -54,8 +54,7 @@ add_library(${TARGET} STATIC
|
||||
arg.cpp
|
||||
arg.h
|
||||
base64.hpp
|
||||
chat-template.cpp
|
||||
chat-template.h
|
||||
chat-template.hpp
|
||||
common.cpp
|
||||
common.h
|
||||
console.cpp
|
||||
|
@ -1,156 +0,0 @@
|
||||
#include "chat-template.h"
|
||||
#include "llama.h"
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
static std::string _llama_token_to_piece(const struct llama_model * model, llama_token token, bool special) {
|
||||
std::string piece;
|
||||
piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n'
|
||||
const int n_chars = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special);
|
||||
if (n_chars < 0) {
|
||||
piece.resize(-n_chars);
|
||||
int check = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special);
|
||||
GGML_ASSERT(check == -n_chars);
|
||||
}
|
||||
else {
|
||||
piece.resize(n_chars);
|
||||
}
|
||||
|
||||
return piece;
|
||||
}
|
||||
|
||||
static std::string llama_model_meta_val_str(const struct llama_model * model, const char * key) {
|
||||
int32_t tlen = llama_model_meta_val_str(model, key, nullptr, 0);
|
||||
if (tlen > 0) {
|
||||
std::vector<char> curr_tmpl_buf(tlen + 1, 0);
|
||||
if (llama_model_meta_val_str(model, key, curr_tmpl_buf.data(), curr_tmpl_buf.size()) == tlen) {
|
||||
return std::string(curr_tmpl_buf.data(), tlen);
|
||||
}
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
llama_chat_template::llama_chat_template(const std::string & chat_template, const std::string & bos_token, const std::string & eos_token)
|
||||
: _chat_template(chat_template), _bos_token(bos_token), _eos_token(eos_token) {
|
||||
|
||||
_supports_tools = chat_template.find("tools") != std::string::npos;
|
||||
_requires_object_arguments =
|
||||
chat_template.find("tool_call.arguments | items") != std::string::npos
|
||||
|| chat_template.find("tool_call.arguments | tojson") != std::string::npos;
|
||||
_supports_system_role = chat_template.find("System role not supported") == std::string::npos;
|
||||
|
||||
if (chat_template.find("<tool_call>") != std::string::npos) {
|
||||
_tool_call_style = Hermes2Pro;
|
||||
} else if (chat_template.find(">>>all") != std::string::npos) {
|
||||
_tool_call_style = FunctionaryV3Llama3;
|
||||
} else if (chat_template.find("<|start_header_id|>") != std::string::npos
|
||||
&& chat_template.find("<function=") != std::string::npos) {
|
||||
_tool_call_style = FunctionaryV3Llama31;
|
||||
} else if (chat_template.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
|
||||
if (chat_template.find("<|python_tag|>") != std::string::npos) {
|
||||
_tool_call_style = Llama31;
|
||||
} else {
|
||||
_tool_call_style = Llama32;
|
||||
}
|
||||
} else if (chat_template.find("<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>") != std::string::npos) {
|
||||
_tool_call_style = CommandRPlus;
|
||||
} else {
|
||||
_tool_call_style = UnknownToolCallStyle;
|
||||
}
|
||||
_template_root = minja::Parser::parse(_chat_template, {
|
||||
/* .trim_blocks = */ true,
|
||||
/* .lstrip_blocks = */ true,
|
||||
/* .keep_trailing_newline = */ false,
|
||||
});
|
||||
}
|
||||
|
||||
llama_chat_template llama_chat_template::from_model(
|
||||
const struct llama_model * model,
|
||||
const char * chat_template_override)
|
||||
{
|
||||
// TODO: handle "chatml"?
|
||||
std::string chat_template = chat_template_override
|
||||
? chat_template_override
|
||||
: llama_model_meta_val_str(model, "tokenizer.chat_template");
|
||||
auto bos_token = _llama_token_to_piece(model, llama_token_bos(model), true);
|
||||
auto eos_token = _llama_token_to_piece(model, llama_token_eos(model), true);
|
||||
return llama_chat_template(chat_template, bos_token, eos_token);
|
||||
}
|
||||
|
||||
std::string llama_chat_template::apply(
|
||||
const json & messages,
|
||||
const json & tools,
|
||||
bool add_generation_prompt,
|
||||
const json & extra_context) const
|
||||
{
|
||||
auto actual_messages = messages;
|
||||
|
||||
// First, "fix" messages so they have a chance to be rendered correctly by the template
|
||||
|
||||
if (_requires_object_arguments || !_supports_system_role) {
|
||||
std::string pending_system;
|
||||
auto flush_sys = [&]() {
|
||||
if (!pending_system.empty()) {
|
||||
actual_messages.push_back({
|
||||
{"role", "user"},
|
||||
{"content", pending_system},
|
||||
});
|
||||
pending_system.clear();
|
||||
}
|
||||
};
|
||||
for (auto & message : actual_messages) {
|
||||
if (!message.contains("role") || !message.contains("content")) {
|
||||
throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump());
|
||||
}
|
||||
std::string role = message.at("role");
|
||||
|
||||
if (!message["content"].is_null() && !_supports_system_role) {
|
||||
std::string content = message.at("content");
|
||||
if (role == "system") {
|
||||
if (!pending_system.empty()) pending_system += "\n";
|
||||
pending_system += content;
|
||||
continue;
|
||||
} else {
|
||||
if (role == "user") {
|
||||
if (!pending_system.empty()) {
|
||||
message["content"] = pending_system + (content.empty() ? "" : "\n" + content);
|
||||
pending_system.clear();
|
||||
}
|
||||
} else {
|
||||
flush_sys();
|
||||
}
|
||||
}
|
||||
}
|
||||
if (_requires_object_arguments && message.contains("tool_calls")) {
|
||||
for (auto & tool_call : message.at("tool_calls")) {
|
||||
if (tool_call["type"] == "function") {
|
||||
auto & function = tool_call.at("function");
|
||||
std::string arguments = function.at("arguments");
|
||||
function["arguments"] = json::parse(arguments);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
flush_sys();
|
||||
}
|
||||
|
||||
auto context = minja::Context::make(json({
|
||||
{"messages", actual_messages},
|
||||
{"add_generation_prompt", add_generation_prompt},
|
||||
{"bos_token", _bos_token},
|
||||
{"eos_token", _eos_token},
|
||||
}));
|
||||
|
||||
if (!tools.is_null()) {
|
||||
auto tools_val = minja::Value(tools);
|
||||
context->set("tools", tools_val);
|
||||
}
|
||||
if (!extra_context.is_null()) {
|
||||
for (auto & kv : extra_context.items()) {
|
||||
minja::Value val(kv.value());
|
||||
context->set(kv.key(), val);
|
||||
}
|
||||
}
|
||||
|
||||
return _template_root->render(context);
|
||||
}
|
@ -1,53 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "minja.hpp"
|
||||
#include <json.hpp>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
|
||||
enum llama_tool_call_style {
|
||||
UnknownToolCallStyle,
|
||||
Llama31,
|
||||
Llama32,
|
||||
FunctionaryV3Llama3,
|
||||
FunctionaryV3Llama31,
|
||||
Hermes2Pro,
|
||||
CommandRPlus,
|
||||
};
|
||||
|
||||
class llama_chat_template {
|
||||
public:
|
||||
|
||||
private:
|
||||
llama_tool_call_style _tool_call_style = UnknownToolCallStyle;
|
||||
bool _supports_tools = true;
|
||||
// Meta-Llama-3.1-8B-Instruct's template expects arguments to be an object.
|
||||
// Most other templates (and OpenAI's API) expect the arguments object to be stringified.
|
||||
bool _requires_object_arguments = false;
|
||||
bool _supports_system_role = true;
|
||||
std::string _chat_template;
|
||||
std::string _bos_token;
|
||||
std::string _eos_token;
|
||||
std::unique_ptr<minja::TemplateNode> _template_root;
|
||||
|
||||
public:
|
||||
llama_chat_template(const std::string & chat_template, const std::string & bos_token, const std::string & eos_token);
|
||||
|
||||
static llama_chat_template from_model(
|
||||
const struct llama_model * model,
|
||||
const char * chat_template_override = nullptr);
|
||||
|
||||
llama_tool_call_style tool_call_style() const { return _tool_call_style; }
|
||||
|
||||
const std::string & chat_template() const { return _chat_template; }
|
||||
bool supports_tools() const { return _supports_tools; }
|
||||
|
||||
std::string apply(
|
||||
const nlohmann::ordered_json & messages,
|
||||
const nlohmann::ordered_json & tools,
|
||||
bool add_generation_prompt,
|
||||
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const;
|
||||
};
|
133
common/chat-template.hpp
Normal file
133
common/chat-template.hpp
Normal file
@ -0,0 +1,133 @@
|
||||
/*
|
||||
Copyright 2024 Google LLC
|
||||
|
||||
Use of this source code is governed by an MIT-style
|
||||
license that can be found in the LICENSE file or at
|
||||
https://opensource.org/licenses/MIT.
|
||||
*/
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
|
||||
#include "minja.hpp"
|
||||
#include <json.hpp>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
namespace minja {
|
||||
|
||||
class chat_template {
|
||||
public:
|
||||
|
||||
private:
|
||||
bool _supports_tools = true;
|
||||
// Meta-Llama-3.1-8B-Instruct's template expects arguments to be an object.
|
||||
// Most other templates (and OpenAI's API) expect the arguments object to be stringified.
|
||||
bool _requires_object_arguments = false;
|
||||
bool _supports_system_role = true;
|
||||
std::string _source;
|
||||
std::string _bos_token;
|
||||
std::string _eos_token;
|
||||
std::shared_ptr<minja::TemplateNode> _template_root;
|
||||
|
||||
public:
|
||||
chat_template(const std::string & source, const std::string & bos_token, const std::string & eos_token)
|
||||
: _source(source), _bos_token(bos_token), _eos_token(eos_token)
|
||||
{
|
||||
_supports_tools = source.find("tools") != std::string::npos;
|
||||
_requires_object_arguments =
|
||||
source.find("tool_call.arguments | items") != std::string::npos
|
||||
|| source.find("tool_call.arguments | tojson") != std::string::npos;
|
||||
_supports_system_role = source.find("System role not supported") == std::string::npos;
|
||||
|
||||
_template_root = minja::Parser::parse(_source, {
|
||||
/* .trim_blocks = */ true,
|
||||
/* .lstrip_blocks = */ true,
|
||||
/* .keep_trailing_newline = */ false,
|
||||
});
|
||||
}
|
||||
|
||||
const std::string & source() const { return _source; }
|
||||
bool supports_tools() const { return _supports_tools; }
|
||||
|
||||
std::string apply(
|
||||
const nlohmann::ordered_json & messages,
|
||||
const nlohmann::ordered_json & tools,
|
||||
bool add_generation_prompt,
|
||||
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const
|
||||
{
|
||||
auto actual_messages = messages;
|
||||
|
||||
// First, "fix" messages so they have a chance to be rendered correctly by the template
|
||||
|
||||
if (_requires_object_arguments || !_supports_system_role) {
|
||||
std::string pending_system;
|
||||
auto flush_sys = [&]() {
|
||||
if (!pending_system.empty()) {
|
||||
actual_messages.push_back({
|
||||
{"role", "user"},
|
||||
{"content", pending_system},
|
||||
});
|
||||
pending_system.clear();
|
||||
}
|
||||
};
|
||||
for (auto & message : actual_messages) {
|
||||
if (!message.contains("role") || !message.contains("content")) {
|
||||
throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump());
|
||||
}
|
||||
std::string role = message.at("role");
|
||||
|
||||
if (!message["content"].is_null() && !_supports_system_role) {
|
||||
std::string content = message.at("content");
|
||||
if (role == "system") {
|
||||
if (!pending_system.empty()) pending_system += "\n";
|
||||
pending_system += content;
|
||||
continue;
|
||||
} else {
|
||||
if (role == "user") {
|
||||
if (!pending_system.empty()) {
|
||||
message["content"] = pending_system + (content.empty() ? "" : "\n" + content);
|
||||
pending_system.clear();
|
||||
}
|
||||
} else {
|
||||
flush_sys();
|
||||
}
|
||||
}
|
||||
}
|
||||
if (_requires_object_arguments && message.contains("tool_calls")) {
|
||||
for (auto & tool_call : message.at("tool_calls")) {
|
||||
if (tool_call["type"] == "function") {
|
||||
auto & function = tool_call.at("function");
|
||||
std::string arguments = function.at("arguments");
|
||||
function["arguments"] = json::parse(arguments);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
flush_sys();
|
||||
}
|
||||
|
||||
auto context = minja::Context::make(json({
|
||||
{"messages", actual_messages},
|
||||
{"add_generation_prompt", add_generation_prompt},
|
||||
{"bos_token", _bos_token},
|
||||
{"eos_token", _eos_token},
|
||||
}));
|
||||
|
||||
if (!tools.is_null()) {
|
||||
auto tools_val = minja::Value(tools);
|
||||
context->set("tools", tools_val);
|
||||
}
|
||||
if (!extra_context.is_null()) {
|
||||
for (auto & kv : extra_context.items()) {
|
||||
minja::Value val(kv.value());
|
||||
context->set(kv.key(), val);
|
||||
}
|
||||
}
|
||||
|
||||
return _template_root->render(context);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace minja
|
@ -9,7 +9,7 @@
|
||||
#include "json.hpp"
|
||||
#include "json-schema-to-grammar.h"
|
||||
#include "llama.h"
|
||||
#include "chat-template.h"
|
||||
#include "chat-template.hpp"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cinttypes>
|
||||
@ -1513,13 +1513,13 @@ std::vector<llama_token> llama_tokenize(
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) {
|
||||
static std::string _llama_token_to_piece(const struct llama_model * model, llama_token token, bool special) {
|
||||
std::string piece;
|
||||
piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n'
|
||||
const int n_chars = llama_token_to_piece(llama_get_model(ctx), token, &piece[0], piece.size(), 0, special);
|
||||
const int n_chars = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special);
|
||||
if (n_chars < 0) {
|
||||
piece.resize(-n_chars);
|
||||
int check = llama_token_to_piece(llama_get_model(ctx), token, &piece[0], piece.size(), 0, special);
|
||||
int check = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special);
|
||||
GGML_ASSERT(check == -n_chars);
|
||||
}
|
||||
else {
|
||||
@ -1529,6 +1529,10 @@ std::string llama_token_to_piece(const struct llama_context * ctx, llama_token t
|
||||
return piece;
|
||||
}
|
||||
|
||||
std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) {
|
||||
return _llama_token_to_piece(llama_get_model(ctx), token, special);
|
||||
}
|
||||
|
||||
std::string llama_detokenize(llama_context * ctx, const std::vector<llama_token> & tokens, bool special) {
|
||||
std::string text;
|
||||
text.resize(std::max(text.capacity(), tokens.size()));
|
||||
@ -1552,7 +1556,7 @@ std::string llama_detokenize(llama_context * ctx, const std::vector<llama_token>
|
||||
bool llama_chat_verify_template(const std::string & tmpl, bool use_jinja) {
|
||||
if (use_jinja) {
|
||||
try {
|
||||
auto chat_template = llama_chat_template(tmpl, "<s>", "</s>");
|
||||
auto chat_template = minja::chat_template(tmpl, "<s>", "</s>");
|
||||
chat_template.apply({{
|
||||
{"role", "user"},
|
||||
{"content", "test"},
|
||||
@ -1651,6 +1655,30 @@ std::string llama_chat_format_example(const struct llama_model * model,
|
||||
return llama_chat_apply_template(model, tmpl, msgs, true);
|
||||
}
|
||||
|
||||
static std::string _llama_model_meta_val_str(const struct llama_model * model, const char * key) {
|
||||
int32_t tlen = llama_model_meta_val_str(model, key, nullptr, 0);
|
||||
if (tlen > 0) {
|
||||
std::vector<char> curr_tmpl_buf(tlen + 1, 0);
|
||||
if (llama_model_meta_val_str(model, key, curr_tmpl_buf.data(), curr_tmpl_buf.size()) == tlen) {
|
||||
return std::string(curr_tmpl_buf.data(), tlen);
|
||||
}
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
minja::chat_template llama_chat_template_from_model(
|
||||
const struct llama_model * model,
|
||||
const char * chat_template_override)
|
||||
{
|
||||
// TODO: handle "chatml"?
|
||||
std::string chat_template = chat_template_override
|
||||
? chat_template_override
|
||||
: _llama_model_meta_val_str(model, "tokenizer.chat_template");
|
||||
auto bos_token = _llama_token_to_piece(model, llama_token_bos(model), true);
|
||||
auto eos_token = _llama_token_to_piece(model, llama_token_eos(model), true);
|
||||
return {std::move(chat_template), bos_token, eos_token};
|
||||
}
|
||||
|
||||
//
|
||||
// KV cache utils
|
||||
//
|
||||
|
@ -27,6 +27,9 @@
|
||||
|
||||
#define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf"
|
||||
|
||||
// Forward declaration
|
||||
namespace minja { class chat_template; }
|
||||
|
||||
struct llama_lora_adapter_info {
|
||||
std::string path;
|
||||
float scale;
|
||||
@ -500,6 +503,10 @@ std::string llama_chat_format_single(const struct llama_model * model,
|
||||
std::string llama_chat_format_example(const struct llama_model * model,
|
||||
const std::string & tmpl);
|
||||
|
||||
minja::chat_template llama_chat_template_from_model(
|
||||
const struct llama_model * model,
|
||||
const char * chat_template_override = nullptr);
|
||||
|
||||
//
|
||||
// KV cache utils
|
||||
//
|
||||
|
332
common/minja.hpp
332
common/minja.hpp
@ -1,3 +1,11 @@
|
||||
/*
|
||||
Copyright 2024 Google LLC
|
||||
|
||||
Use of this source code is governed by an MIT-style
|
||||
license that can be found in the LICENSE file or at
|
||||
https://opensource.org/licenses/MIT.
|
||||
*/
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
@ -577,8 +585,8 @@ protected:
|
||||
virtual Value do_evaluate(const std::shared_ptr<Context> & context) const = 0;
|
||||
public:
|
||||
struct Arguments {
|
||||
std::vector<std::unique_ptr<Expression>> args;
|
||||
std::vector<std::pair<std::string, std::unique_ptr<Expression>>> kwargs;
|
||||
std::vector<std::shared_ptr<Expression>> args;
|
||||
std::vector<std::pair<std::string, std::shared_ptr<Expression>>> kwargs;
|
||||
|
||||
void expectArgs(const std::string & method_name, const std::pair<size_t, size_t> & pos_count, const std::pair<size_t, size_t> & kw_count) const {
|
||||
if (args.size() < pos_count.first || args.size() > pos_count.second || kwargs.size() < kw_count.first || kwargs.size() > kw_count.second) {
|
||||
@ -600,7 +608,7 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
using Parameters = std::vector<std::pair<std::string, std::unique_ptr<Expression>>>;
|
||||
using Parameters = std::vector<std::pair<std::string, std::shared_ptr<Expression>>>;
|
||||
|
||||
Location location;
|
||||
|
||||
@ -687,18 +695,18 @@ struct TextTemplateToken : public TemplateToken {
|
||||
};
|
||||
|
||||
struct ExpressionTemplateToken : public TemplateToken {
|
||||
std::unique_ptr<Expression> expr;
|
||||
ExpressionTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::unique_ptr<Expression> && e) : TemplateToken(Type::Expression, location, pre, post), expr(std::move(e)) {}
|
||||
std::shared_ptr<Expression> expr;
|
||||
ExpressionTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr<Expression> && e) : TemplateToken(Type::Expression, location, pre, post), expr(std::move(e)) {}
|
||||
};
|
||||
|
||||
struct IfTemplateToken : public TemplateToken {
|
||||
std::unique_ptr<Expression> condition;
|
||||
IfTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::unique_ptr<Expression> && c) : TemplateToken(Type::If, location, pre, post), condition(std::move(c)) {}
|
||||
std::shared_ptr<Expression> condition;
|
||||
IfTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr<Expression> && c) : TemplateToken(Type::If, location, pre, post), condition(std::move(c)) {}
|
||||
};
|
||||
|
||||
struct ElifTemplateToken : public TemplateToken {
|
||||
std::unique_ptr<Expression> condition;
|
||||
ElifTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::unique_ptr<Expression> && c) : TemplateToken(Type::Elif, location, pre, post), condition(std::move(c)) {}
|
||||
std::shared_ptr<Expression> condition;
|
||||
ElifTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr<Expression> && c) : TemplateToken(Type::Elif, location, pre, post), condition(std::move(c)) {}
|
||||
};
|
||||
|
||||
struct ElseTemplateToken : public TemplateToken {
|
||||
@ -710,9 +718,9 @@ struct EndIfTemplateToken : public TemplateToken {
|
||||
};
|
||||
|
||||
struct MacroTemplateToken : public TemplateToken {
|
||||
std::unique_ptr<VariableExpr> name;
|
||||
std::shared_ptr<VariableExpr> name;
|
||||
Expression::Parameters params;
|
||||
MacroTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::unique_ptr<VariableExpr> && n, Expression::Parameters && p)
|
||||
MacroTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr<VariableExpr> && n, Expression::Parameters && p)
|
||||
: TemplateToken(Type::Macro, location, pre, post), name(std::move(n)), params(std::move(p)) {}
|
||||
};
|
||||
|
||||
@ -722,11 +730,11 @@ struct EndMacroTemplateToken : public TemplateToken {
|
||||
|
||||
struct ForTemplateToken : public TemplateToken {
|
||||
std::vector<std::string> var_names;
|
||||
std::unique_ptr<Expression> iterable;
|
||||
std::unique_ptr<Expression> condition;
|
||||
std::shared_ptr<Expression> iterable;
|
||||
std::shared_ptr<Expression> condition;
|
||||
bool recursive;
|
||||
ForTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::vector<std::string> & vns, std::unique_ptr<Expression> && iter,
|
||||
std::unique_ptr<Expression> && c, bool r)
|
||||
ForTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::vector<std::string> & vns, std::shared_ptr<Expression> && iter,
|
||||
std::shared_ptr<Expression> && c, bool r)
|
||||
: TemplateToken(Type::For, location, pre, post), var_names(vns), iterable(std::move(iter)), condition(std::move(c)), recursive(r) {}
|
||||
};
|
||||
|
||||
@ -737,8 +745,8 @@ struct EndForTemplateToken : public TemplateToken {
|
||||
struct SetTemplateToken : public TemplateToken {
|
||||
std::string ns;
|
||||
std::vector<std::string> var_names;
|
||||
std::unique_ptr<Expression> value;
|
||||
SetTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::string & ns, const std::vector<std::string> & vns, std::unique_ptr<Expression> && v)
|
||||
std::shared_ptr<Expression> value;
|
||||
SetTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::string & ns, const std::vector<std::string> & vns, std::shared_ptr<Expression> && v)
|
||||
: TemplateToken(Type::Set, location, pre, post), ns(ns), var_names(vns), value(std::move(v)) {}
|
||||
};
|
||||
|
||||
@ -778,9 +786,9 @@ public:
|
||||
};
|
||||
|
||||
class SequenceNode : public TemplateNode {
|
||||
std::vector<std::unique_ptr<TemplateNode>> children;
|
||||
std::vector<std::shared_ptr<TemplateNode>> children;
|
||||
public:
|
||||
SequenceNode(const Location & location, std::vector<std::unique_ptr<TemplateNode>> && c)
|
||||
SequenceNode(const Location & location, std::vector<std::shared_ptr<TemplateNode>> && c)
|
||||
: TemplateNode(location), children(std::move(c)) {}
|
||||
void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override {
|
||||
for (const auto& child : children) child->render(out, context);
|
||||
@ -797,10 +805,11 @@ public:
|
||||
};
|
||||
|
||||
class ExpressionNode : public TemplateNode {
|
||||
std::unique_ptr<Expression> expr;
|
||||
std::shared_ptr<Expression> expr;
|
||||
public:
|
||||
ExpressionNode(const Location & location, std::unique_ptr<Expression> && e) : TemplateNode(location), expr(std::move(e)) {}
|
||||
ExpressionNode(const Location & location, std::shared_ptr<Expression> && e) : TemplateNode(location), expr(std::move(e)) {}
|
||||
void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override {
|
||||
if (!expr) throw std::runtime_error("ExpressionNode.expr is null");
|
||||
auto result = expr->evaluate(context);
|
||||
if (result.is_string()) {
|
||||
out << result.get<std::string>();
|
||||
@ -813,9 +822,9 @@ public:
|
||||
};
|
||||
|
||||
class IfNode : public TemplateNode {
|
||||
std::vector<std::pair<std::unique_ptr<Expression>, std::unique_ptr<TemplateNode>>> cascade;
|
||||
std::vector<std::pair<std::shared_ptr<Expression>, std::shared_ptr<TemplateNode>>> cascade;
|
||||
public:
|
||||
IfNode(const Location & location, std::vector<std::pair<std::unique_ptr<Expression>, std::unique_ptr<TemplateNode>>> && c)
|
||||
IfNode(const Location & location, std::vector<std::pair<std::shared_ptr<Expression>, std::shared_ptr<TemplateNode>>> && c)
|
||||
: TemplateNode(location), cascade(std::move(c)) {}
|
||||
void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override {
|
||||
for (const auto& branch : cascade) {
|
||||
@ -824,6 +833,7 @@ public:
|
||||
enter_branch = branch.first->evaluate(context).to_bool();
|
||||
}
|
||||
if (enter_branch) {
|
||||
if (!branch.second) throw std::runtime_error("IfNode.cascade.second is null");
|
||||
branch.second->render(out, context);
|
||||
return;
|
||||
}
|
||||
@ -833,18 +843,20 @@ public:
|
||||
|
||||
class ForNode : public TemplateNode {
|
||||
std::vector<std::string> var_names;
|
||||
std::unique_ptr<Expression> iterable;
|
||||
std::unique_ptr<Expression> condition;
|
||||
std::unique_ptr<TemplateNode> body;
|
||||
std::shared_ptr<Expression> iterable;
|
||||
std::shared_ptr<Expression> condition;
|
||||
std::shared_ptr<TemplateNode> body;
|
||||
bool recursive;
|
||||
std::unique_ptr<TemplateNode> else_body;
|
||||
std::shared_ptr<TemplateNode> else_body;
|
||||
public:
|
||||
ForNode(const Location & location, std::vector<std::string> && var_names, std::unique_ptr<Expression> && iterable,
|
||||
std::unique_ptr<Expression> && condition, std::unique_ptr<TemplateNode> && body, bool recursive, std::unique_ptr<TemplateNode> && else_body)
|
||||
ForNode(const Location & location, std::vector<std::string> && var_names, std::shared_ptr<Expression> && iterable,
|
||||
std::shared_ptr<Expression> && condition, std::shared_ptr<TemplateNode> && body, bool recursive, std::shared_ptr<TemplateNode> && else_body)
|
||||
: TemplateNode(location), var_names(var_names), iterable(std::move(iterable)), condition(std::move(condition)), body(std::move(body)), recursive(recursive), else_body(std::move(else_body)) {}
|
||||
|
||||
void do_render(std::ostringstream & out, const std::shared_ptr<Context> & context) const override {
|
||||
// https://jinja.palletsprojects.com/en/3.0.x/templates/#for
|
||||
if (!iterable) throw std::runtime_error("ForNode.iterable is null");
|
||||
if (!body) throw std::runtime_error("ForNode.body is null");
|
||||
|
||||
auto iterable_value = iterable->evaluate(context);
|
||||
Value::CallableType loop_function;
|
||||
@ -914,12 +926,12 @@ public:
|
||||
};
|
||||
|
||||
class MacroNode : public TemplateNode {
|
||||
std::unique_ptr<VariableExpr> name;
|
||||
std::shared_ptr<VariableExpr> name;
|
||||
Expression::Parameters params;
|
||||
std::unique_ptr<TemplateNode> body;
|
||||
std::shared_ptr<TemplateNode> body;
|
||||
std::unordered_map<std::string, size_t> named_param_positions;
|
||||
public:
|
||||
MacroNode(const Location & location, std::unique_ptr<VariableExpr> && n, Expression::Parameters && p, std::unique_ptr<TemplateNode> && b)
|
||||
MacroNode(const Location & location, std::shared_ptr<VariableExpr> && n, Expression::Parameters && p, std::shared_ptr<TemplateNode> && b)
|
||||
: TemplateNode(location), name(std::move(n)), params(std::move(p)), body(std::move(b)) {
|
||||
for (size_t i = 0; i < params.size(); ++i) {
|
||||
const auto & name = params[i].first;
|
||||
@ -929,6 +941,8 @@ public:
|
||||
}
|
||||
}
|
||||
void do_render(std::ostringstream &, const std::shared_ptr<Context> & macro_context) const override {
|
||||
if (!name) throw std::runtime_error("MacroNode.name is null");
|
||||
if (!body) throw std::runtime_error("MacroNode.body is null");
|
||||
auto callable = Value::callable([&](const std::shared_ptr<Context> & context, Value::Arguments & args) {
|
||||
auto call_context = macro_context;
|
||||
std::vector<bool> param_set(params.size(), false);
|
||||
@ -964,19 +978,12 @@ public:
|
||||
class SetNode : public TemplateNode {
|
||||
std::string ns;
|
||||
std::vector<std::string> var_names;
|
||||
std::unique_ptr<Expression> value;
|
||||
std::unique_ptr<TemplateNode> template_value;
|
||||
std::shared_ptr<Expression> value;
|
||||
public:
|
||||
SetNode(const Location & location, const std::string & ns, const std::vector<std::string> & vns, std::unique_ptr<Expression> && v, std::unique_ptr<TemplateNode> && tv)
|
||||
: TemplateNode(location), ns(ns), var_names(vns), value(std::move(v)), template_value(std::move(tv)) {
|
||||
if (value && template_value) {
|
||||
throw std::runtime_error("Cannot have both value and template value in set node");
|
||||
}
|
||||
if (template_value && var_names.size() != 1) {
|
||||
throw std::runtime_error("Destructuring assignment is only supported with a single variable name");
|
||||
}
|
||||
}
|
||||
SetNode(const Location & location, const std::string & ns, const std::vector<std::string> & vns, std::shared_ptr<Expression> && v)
|
||||
: TemplateNode(location), ns(ns), var_names(vns), value(std::move(v)) {}
|
||||
void do_render(std::ostringstream &, const std::shared_ptr<Context> & context) const override {
|
||||
if (!value) throw std::runtime_error("SetNode.value is null");
|
||||
if (!ns.empty()) {
|
||||
if (var_names.size() != 1) {
|
||||
throw std::runtime_error("Namespaced set only supports a single variable name");
|
||||
@ -985,9 +992,6 @@ public:
|
||||
auto ns_value = context->get(ns);
|
||||
if (!ns_value.is_object()) throw std::runtime_error("Namespace '" + ns + "' is not an object");
|
||||
ns_value.set(name, this->value->evaluate(context));
|
||||
} else if (template_value) {
|
||||
Value value { template_value->render(context) };
|
||||
context->set(var_names[0], value);
|
||||
} else {
|
||||
auto val = value->evaluate(context);
|
||||
destructuring_assign(var_names, context, val);
|
||||
@ -995,14 +999,29 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class IfExpr : public Expression {
|
||||
std::unique_ptr<Expression> condition;
|
||||
std::unique_ptr<Expression> then_expr;
|
||||
std::unique_ptr<Expression> else_expr;
|
||||
class SetTemplateNode : public TemplateNode {
|
||||
std::string name;
|
||||
std::shared_ptr<TemplateNode> template_value;
|
||||
public:
|
||||
IfExpr(const Location & location, std::unique_ptr<Expression> && c, std::unique_ptr<Expression> && t, std::unique_ptr<Expression> && e)
|
||||
SetTemplateNode(const Location & location, const std::string & name, std::shared_ptr<TemplateNode> && tv)
|
||||
: TemplateNode(location), name(name), template_value(std::move(tv)) {}
|
||||
void do_render(std::ostringstream &, const std::shared_ptr<Context> & context) const override {
|
||||
if (!template_value) throw std::runtime_error("SetTemplateNode.template_value is null");
|
||||
Value value { template_value->render(context) };
|
||||
context->set(name, value);
|
||||
}
|
||||
};
|
||||
|
||||
class IfExpr : public Expression {
|
||||
std::shared_ptr<Expression> condition;
|
||||
std::shared_ptr<Expression> then_expr;
|
||||
std::shared_ptr<Expression> else_expr;
|
||||
public:
|
||||
IfExpr(const Location & location, std::shared_ptr<Expression> && c, std::shared_ptr<Expression> && t, std::shared_ptr<Expression> && e)
|
||||
: Expression(location), condition(std::move(c)), then_expr(std::move(t)), else_expr(std::move(e)) {}
|
||||
Value do_evaluate(const std::shared_ptr<Context> & context) const override {
|
||||
if (!condition) throw std::runtime_error("IfExpr.condition is null");
|
||||
if (!then_expr) throw std::runtime_error("IfExpr.then_expr is null");
|
||||
if (condition->evaluate(context).to_bool()) {
|
||||
return then_expr->evaluate(context);
|
||||
}
|
||||
@ -1022,13 +1041,14 @@ public:
|
||||
};
|
||||
|
||||
class ArrayExpr : public Expression {
|
||||
std::vector<std::unique_ptr<Expression>> elements;
|
||||
std::vector<std::shared_ptr<Expression>> elements;
|
||||
public:
|
||||
ArrayExpr(const Location & location, std::vector<std::unique_ptr<Expression>> && e)
|
||||
ArrayExpr(const Location & location, std::vector<std::shared_ptr<Expression>> && e)
|
||||
: Expression(location), elements(std::move(e)) {}
|
||||
Value do_evaluate(const std::shared_ptr<Context> & context) const override {
|
||||
auto result = Value::array();
|
||||
for (const auto& e : elements) {
|
||||
if (!e) throw std::runtime_error("Array element is null");
|
||||
result.push_back(e->evaluate(context));
|
||||
}
|
||||
return result;
|
||||
@ -1036,13 +1056,15 @@ public:
|
||||
};
|
||||
|
||||
class DictExpr : public Expression {
|
||||
std::vector<std::pair<std::unique_ptr<Expression>, std::unique_ptr<Expression>>> elements;
|
||||
std::vector<std::pair<std::shared_ptr<Expression>, std::shared_ptr<Expression>>> elements;
|
||||
public:
|
||||
DictExpr(const Location & location, std::vector<std::pair<std::unique_ptr<Expression>, std::unique_ptr<Expression>>> && e)
|
||||
DictExpr(const Location & location, std::vector<std::pair<std::shared_ptr<Expression>, std::shared_ptr<Expression>>> && e)
|
||||
: Expression(location), elements(std::move(e)) {}
|
||||
Value do_evaluate(const std::shared_ptr<Context> & context) const override {
|
||||
auto result = Value::object();
|
||||
for (const auto& e : elements) {
|
||||
if (!e.first) throw std::runtime_error("Dict key is null");
|
||||
if (!e.second) throw std::runtime_error("Dict value is null");
|
||||
result.set(e.first->evaluate(context), e.second->evaluate(context));
|
||||
}
|
||||
return result;
|
||||
@ -1051,8 +1073,8 @@ public:
|
||||
|
||||
class SliceExpr : public Expression {
|
||||
public:
|
||||
std::unique_ptr<Expression> start, end;
|
||||
SliceExpr(const Location & location, std::unique_ptr<Expression> && s, std::unique_ptr<Expression> && e)
|
||||
std::shared_ptr<Expression> start, end;
|
||||
SliceExpr(const Location & location, std::shared_ptr<Expression> && s, std::shared_ptr<Expression> && e)
|
||||
: Expression(location), start(std::move(s)), end(std::move(e)) {}
|
||||
Value do_evaluate(const std::shared_ptr<Context> &) const override {
|
||||
throw std::runtime_error("SliceExpr not implemented");
|
||||
@ -1060,12 +1082,14 @@ public:
|
||||
};
|
||||
|
||||
class SubscriptExpr : public Expression {
|
||||
std::unique_ptr<Expression> base;
|
||||
std::unique_ptr<Expression> index;
|
||||
std::shared_ptr<Expression> base;
|
||||
std::shared_ptr<Expression> index;
|
||||
public:
|
||||
SubscriptExpr(const Location & location, std::unique_ptr<Expression> && b, std::unique_ptr<Expression> && i)
|
||||
SubscriptExpr(const Location & location, std::shared_ptr<Expression> && b, std::shared_ptr<Expression> && i)
|
||||
: Expression(location), base(std::move(b)), index(std::move(i)) {}
|
||||
Value do_evaluate(const std::shared_ptr<Context> & context) const override {
|
||||
if (!base) throw std::runtime_error("SubscriptExpr.base is null");
|
||||
if (!index) throw std::runtime_error("SubscriptExpr.index is null");
|
||||
auto target_value = base->evaluate(context);
|
||||
if (auto slice = dynamic_cast<SliceExpr*>(index.get())) {
|
||||
if (!target_value.is_array()) throw std::runtime_error("Subscripting non-array");
|
||||
@ -1094,12 +1118,13 @@ class UnaryOpExpr : public Expression {
|
||||
public:
|
||||
enum class Op { Plus, Minus, LogicalNot };
|
||||
private:
|
||||
std::unique_ptr<Expression> expr;
|
||||
std::shared_ptr<Expression> expr;
|
||||
Op op;
|
||||
public:
|
||||
UnaryOpExpr(const Location & location, std::unique_ptr<Expression> && e, Op o)
|
||||
UnaryOpExpr(const Location & location, std::shared_ptr<Expression> && e, Op o)
|
||||
: Expression(location), expr(std::move(e)), op(o) {}
|
||||
Value do_evaluate(const std::shared_ptr<Context> & context) const override {
|
||||
if (!expr) throw std::runtime_error("UnaryOpExpr.expr is null");
|
||||
auto e = expr->evaluate(context);
|
||||
switch (op) {
|
||||
case Op::Plus: return e;
|
||||
@ -1114,13 +1139,15 @@ class BinaryOpExpr : public Expression {
|
||||
public:
|
||||
enum class Op { StrConcat, Add, Sub, Mul, MulMul, Div, DivDiv, Mod, Eq, Ne, Lt, Gt, Le, Ge, And, Or, In, NotIn, Is, IsNot };
|
||||
private:
|
||||
std::unique_ptr<Expression> left;
|
||||
std::unique_ptr<Expression> right;
|
||||
std::shared_ptr<Expression> left;
|
||||
std::shared_ptr<Expression> right;
|
||||
Op op;
|
||||
public:
|
||||
BinaryOpExpr(const Location & location, std::unique_ptr<Expression> && l, std::unique_ptr<Expression> && r, Op o)
|
||||
BinaryOpExpr(const Location & location, std::shared_ptr<Expression> && l, std::shared_ptr<Expression> && r, Op o)
|
||||
: Expression(location), left(std::move(l)), right(std::move(r)), op(o) {}
|
||||
Value do_evaluate(const std::shared_ptr<Context> & context) const override {
|
||||
if (!left) throw std::runtime_error("BinaryOpExpr.left is null");
|
||||
if (!right) throw std::runtime_error("BinaryOpExpr.right is null");
|
||||
auto l = left->evaluate(context);
|
||||
|
||||
auto do_eval = [&](const Value & l) -> Value {
|
||||
@ -1210,13 +1237,15 @@ static std::string html_escape(const std::string & s) {
|
||||
}
|
||||
|
||||
class MethodCallExpr : public Expression {
|
||||
std::unique_ptr<Expression> object;
|
||||
std::unique_ptr<VariableExpr> method;
|
||||
std::shared_ptr<Expression> object;
|
||||
std::shared_ptr<VariableExpr> method;
|
||||
Expression::Arguments args;
|
||||
public:
|
||||
MethodCallExpr(const Location & location, std::unique_ptr<Expression> && obj, std::unique_ptr<VariableExpr> && m, Expression::Arguments && a)
|
||||
MethodCallExpr(const Location & location, std::shared_ptr<Expression> && obj, std::shared_ptr<VariableExpr> && m, Expression::Arguments && a)
|
||||
: Expression(location), object(std::move(obj)), method(std::move(m)), args(std::move(a)) {}
|
||||
Value do_evaluate(const std::shared_ptr<Context> & context) const override {
|
||||
if (!object) throw std::runtime_error("MethodCallExpr.object is null");
|
||||
if (!method) throw std::runtime_error("MethodCallExpr.method is null");
|
||||
auto obj = object->evaluate(context);
|
||||
if (obj.is_array()) {
|
||||
if (method->get_name() == "append") {
|
||||
@ -1279,11 +1308,12 @@ public:
|
||||
|
||||
class CallExpr : public Expression {
|
||||
public:
|
||||
std::unique_ptr<Expression> object;
|
||||
std::shared_ptr<Expression> object;
|
||||
Expression::Arguments args;
|
||||
CallExpr(const Location & location, std::unique_ptr<Expression> && obj, Expression::Arguments && a)
|
||||
CallExpr(const Location & location, std::shared_ptr<Expression> && obj, Expression::Arguments && a)
|
||||
: Expression(location), object(std::move(obj)), args(std::move(a)) {}
|
||||
Value do_evaluate(const std::shared_ptr<Context> & context) const override {
|
||||
if (!object) throw std::runtime_error("CallExpr.object is null");
|
||||
auto obj = object->evaluate(context);
|
||||
if (!obj.is_callable()) {
|
||||
throw std::runtime_error("Object is not callable: " + obj.dump(2));
|
||||
@ -1294,14 +1324,15 @@ public:
|
||||
};
|
||||
|
||||
class FilterExpr : public Expression {
|
||||
std::vector<std::unique_ptr<Expression>> parts;
|
||||
std::vector<std::shared_ptr<Expression>> parts;
|
||||
public:
|
||||
FilterExpr(const Location & location, std::vector<std::unique_ptr<Expression>> && p)
|
||||
FilterExpr(const Location & location, std::vector<std::shared_ptr<Expression>> && p)
|
||||
: Expression(location), parts(std::move(p)) {}
|
||||
Value do_evaluate(const std::shared_ptr<Context> & context) const override {
|
||||
Value result;
|
||||
bool first = true;
|
||||
for (const auto& part : parts) {
|
||||
if (!part) throw std::runtime_error("FilterExpr.part is null");
|
||||
if (first) {
|
||||
first = false;
|
||||
result = part->evaluate(context);
|
||||
@ -1322,7 +1353,7 @@ public:
|
||||
return result;
|
||||
}
|
||||
|
||||
void prepend(std::unique_ptr<Expression> && e) {
|
||||
void prepend(std::shared_ptr<Expression> && e) {
|
||||
parts.insert(parts.begin(), std::move(e));
|
||||
}
|
||||
};
|
||||
@ -1375,7 +1406,7 @@ private:
|
||||
escape = true;
|
||||
} else if (*it == quote) {
|
||||
++it;
|
||||
return nonstd_make_unique<std::string>(result);
|
||||
return nonstd_make_unique<std::string>(std::move(result));
|
||||
} else {
|
||||
result += *it;
|
||||
}
|
||||
@ -1429,25 +1460,25 @@ private:
|
||||
}
|
||||
|
||||
/** integer, float, bool, string */
|
||||
std::unique_ptr<Value> parseConstant() {
|
||||
std::shared_ptr<Value> parseConstant() {
|
||||
auto start = it;
|
||||
consumeSpaces();
|
||||
if (it == end) return nullptr;
|
||||
if (*it == '"' || *it == '\'') {
|
||||
auto str = parseString();
|
||||
if (str) return nonstd_make_unique<Value>(*str);
|
||||
if (str) return std::make_shared<Value>(*str);
|
||||
}
|
||||
static std::regex prim_tok(R"(true\b|True\b|false\b|False\b|None\b)");
|
||||
auto token = consumeToken(prim_tok);
|
||||
if (!token.empty()) {
|
||||
if (token == "true" || token == "True") return nonstd_make_unique<Value>(true);
|
||||
if (token == "false" || token == "False") return nonstd_make_unique<Value>(false);
|
||||
if (token == "None") return nonstd_make_unique<Value>(nullptr);
|
||||
if (token == "true" || token == "True") return std::make_shared<Value>(true);
|
||||
if (token == "false" || token == "False") return std::make_shared<Value>(false);
|
||||
if (token == "None") return std::make_shared<Value>(nullptr);
|
||||
throw std::runtime_error("Unknown constant token: " + token);
|
||||
}
|
||||
|
||||
auto number = parseNumber(it, end);
|
||||
if (!number.is_null()) return nonstd_make_unique<Value>(number);
|
||||
if (!number.is_null()) return std::make_shared<Value>(number);
|
||||
|
||||
it = start;
|
||||
return nullptr;
|
||||
@ -1510,7 +1541,7 @@ private:
|
||||
return "";
|
||||
}
|
||||
|
||||
std::unique_ptr<Expression> parseExpression(bool allow_if_expr = true) {
|
||||
std::shared_ptr<Expression> parseExpression(bool allow_if_expr = true) {
|
||||
auto left = parseLogicalOr();
|
||||
if (it == end) return left;
|
||||
|
||||
@ -1523,19 +1554,19 @@ private:
|
||||
|
||||
auto location = get_location();
|
||||
auto if_expr = parseIfExpression();
|
||||
return nonstd_make_unique<IfExpr>(location, std::move(if_expr.first), std::move(left), std::move(if_expr.second));
|
||||
return std::make_shared<IfExpr>(location, std::move(if_expr.first), std::move(left), std::move(if_expr.second));
|
||||
}
|
||||
|
||||
Location get_location() const {
|
||||
return {template_str, (size_t) std::distance(start, it)};
|
||||
}
|
||||
|
||||
std::pair<std::unique_ptr<Expression>, std::unique_ptr<Expression>> parseIfExpression() {
|
||||
std::pair<std::shared_ptr<Expression>, std::shared_ptr<Expression>> parseIfExpression() {
|
||||
auto condition = parseLogicalOr();
|
||||
if (!condition) throw std::runtime_error("Expected condition expression");
|
||||
|
||||
static std::regex else_tok(R"(else\b)");
|
||||
std::unique_ptr<Expression> else_expr;
|
||||
std::shared_ptr<Expression> else_expr;
|
||||
if (!consumeToken(else_tok).empty()) {
|
||||
else_expr = parseExpression();
|
||||
if (!else_expr) throw std::runtime_error("Expected 'else' expression");
|
||||
@ -1543,7 +1574,7 @@ private:
|
||||
return std::make_pair(std::move(condition), std::move(else_expr));
|
||||
}
|
||||
|
||||
std::unique_ptr<Expression> parseLogicalOr() {
|
||||
std::shared_ptr<Expression> parseLogicalOr() {
|
||||
auto left = parseLogicalAnd();
|
||||
if (!left) throw std::runtime_error("Expected left side of 'logical or' expression");
|
||||
|
||||
@ -1552,24 +1583,24 @@ private:
|
||||
while (!consumeToken(or_tok).empty()) {
|
||||
auto right = parseLogicalAnd();
|
||||
if (!right) throw std::runtime_error("Expected right side of 'or' expression");
|
||||
left = nonstd_make_unique<BinaryOpExpr>(location, std::move(left), std::move(right), BinaryOpExpr::Op::Or);
|
||||
left = std::make_shared<BinaryOpExpr>(location, std::move(left), std::move(right), BinaryOpExpr::Op::Or);
|
||||
}
|
||||
return left;
|
||||
}
|
||||
|
||||
std::unique_ptr<Expression> parseLogicalNot() {
|
||||
std::shared_ptr<Expression> parseLogicalNot() {
|
||||
static std::regex not_tok(R"(not\b)");
|
||||
auto location = get_location();
|
||||
|
||||
if (!consumeToken(not_tok).empty()) {
|
||||
auto sub = parseLogicalNot();
|
||||
if (!sub) throw std::runtime_error("Expected expression after 'not' keyword");
|
||||
return nonstd_make_unique<UnaryOpExpr>(location, std::move(sub), UnaryOpExpr::Op::LogicalNot);
|
||||
return std::make_shared<UnaryOpExpr>(location, std::move(sub), UnaryOpExpr::Op::LogicalNot);
|
||||
}
|
||||
return parseLogicalCompare();
|
||||
}
|
||||
|
||||
std::unique_ptr<Expression> parseLogicalAnd() {
|
||||
std::shared_ptr<Expression> parseLogicalAnd() {
|
||||
auto left = parseLogicalNot();
|
||||
if (!left) throw std::runtime_error("Expected left side of 'logical and' expression");
|
||||
|
||||
@ -1578,12 +1609,12 @@ private:
|
||||
while (!consumeToken(and_tok).empty()) {
|
||||
auto right = parseLogicalNot();
|
||||
if (!right) throw std::runtime_error("Expected right side of 'and' expression");
|
||||
left = nonstd_make_unique<BinaryOpExpr>(location, std::move(left), std::move(right), BinaryOpExpr::Op::And);
|
||||
left = std::make_shared<BinaryOpExpr>(location, std::move(left), std::move(right), BinaryOpExpr::Op::And);
|
||||
}
|
||||
return left;
|
||||
}
|
||||
|
||||
std::unique_ptr<Expression> parseLogicalCompare() {
|
||||
std::shared_ptr<Expression> parseLogicalCompare() {
|
||||
auto left = parseStringConcat();
|
||||
if (!left) throw std::runtime_error("Expected left side of 'logical compare' expression");
|
||||
|
||||
@ -1598,7 +1629,7 @@ private:
|
||||
auto identifier = parseIdentifier();
|
||||
if (!identifier) throw std::runtime_error("Expected identifier after 'is' keyword");
|
||||
|
||||
return nonstd_make_unique<BinaryOpExpr>(
|
||||
return std::make_shared<BinaryOpExpr>(
|
||||
left->location,
|
||||
std::move(left), std::move(identifier),
|
||||
negated ? BinaryOpExpr::Op::IsNot : BinaryOpExpr::Op::Is);
|
||||
@ -1615,7 +1646,7 @@ private:
|
||||
else if (op_str == "in") op = BinaryOpExpr::Op::In;
|
||||
else if (op_str.substr(0, 3) == "not") op = BinaryOpExpr::Op::NotIn;
|
||||
else throw std::runtime_error("Unknown comparison operator: " + op_str);
|
||||
left = nonstd_make_unique<BinaryOpExpr>(get_location(), std::move(left), std::move(right), op);
|
||||
left = std::make_shared<BinaryOpExpr>(get_location(), std::move(left), std::move(right), op);
|
||||
}
|
||||
return left;
|
||||
}
|
||||
@ -1688,16 +1719,16 @@ private:
|
||||
throw std::runtime_error("Expected closing parenthesis in call args");
|
||||
}
|
||||
|
||||
std::unique_ptr<VariableExpr> parseIdentifier() {
|
||||
std::shared_ptr<VariableExpr> parseIdentifier() {
|
||||
static std::regex ident_regex(R"((?!(?:not|is|and|or|del)\b)[a-zA-Z_]\w*)");
|
||||
auto location = get_location();
|
||||
auto ident = consumeToken(ident_regex);
|
||||
if (ident.empty())
|
||||
return nullptr;
|
||||
return nonstd_make_unique<VariableExpr>(location, ident);
|
||||
return std::make_shared<VariableExpr>(location, ident);
|
||||
}
|
||||
|
||||
std::unique_ptr<Expression> parseStringConcat() {
|
||||
std::shared_ptr<Expression> parseStringConcat() {
|
||||
auto left = parseMathPow();
|
||||
if (!left) throw std::runtime_error("Expected left side of 'string concat' expression");
|
||||
|
||||
@ -1705,24 +1736,24 @@ private:
|
||||
if (!consumeToken(concat_tok).empty()) {
|
||||
auto right = parseLogicalAnd();
|
||||
if (!right) throw std::runtime_error("Expected right side of 'string concat' expression");
|
||||
left = nonstd_make_unique<BinaryOpExpr>(get_location(), std::move(left), std::move(right), BinaryOpExpr::Op::StrConcat);
|
||||
left = std::make_shared<BinaryOpExpr>(get_location(), std::move(left), std::move(right), BinaryOpExpr::Op::StrConcat);
|
||||
}
|
||||
return left;
|
||||
}
|
||||
|
||||
std::unique_ptr<Expression> parseMathPow() {
|
||||
std::shared_ptr<Expression> parseMathPow() {
|
||||
auto left = parseMathPlusMinus();
|
||||
if (!left) throw std::runtime_error("Expected left side of 'math pow' expression");
|
||||
|
||||
while (!consumeToken("**").empty()) {
|
||||
auto right = parseMathPlusMinus();
|
||||
if (!right) throw std::runtime_error("Expected right side of 'math pow' expression");
|
||||
left = nonstd_make_unique<BinaryOpExpr>(get_location(), std::move(left), std::move(right), BinaryOpExpr::Op::MulMul);
|
||||
left = std::make_shared<BinaryOpExpr>(get_location(), std::move(left), std::move(right), BinaryOpExpr::Op::MulMul);
|
||||
}
|
||||
return left;
|
||||
}
|
||||
|
||||
std::unique_ptr<Expression> parseMathPlusMinus() {
|
||||
std::shared_ptr<Expression> parseMathPlusMinus() {
|
||||
static std::regex plus_minus_tok(R"(\+|-(?![}%#]\}))");
|
||||
|
||||
auto left = parseMathMulDiv();
|
||||
@ -1732,12 +1763,12 @@ private:
|
||||
auto right = parseMathMulDiv();
|
||||
if (!right) throw std::runtime_error("Expected right side of 'math plus/minus' expression");
|
||||
auto op = op_str == "+" ? BinaryOpExpr::Op::Add : BinaryOpExpr::Op::Sub;
|
||||
left = nonstd_make_unique<BinaryOpExpr>(get_location(), std::move(left), std::move(right), op);
|
||||
left = std::make_shared<BinaryOpExpr>(get_location(), std::move(left), std::move(right), op);
|
||||
}
|
||||
return left;
|
||||
}
|
||||
|
||||
std::unique_ptr<Expression> parseMathMulDiv() {
|
||||
std::shared_ptr<Expression> parseMathMulDiv() {
|
||||
auto left = parseMathUnaryPlusMinus();
|
||||
if (!left) throw std::runtime_error("Expected left side of 'math mul/div' expression");
|
||||
|
||||
@ -1751,7 +1782,7 @@ private:
|
||||
: op_str == "/" ? BinaryOpExpr::Op::Div
|
||||
: op_str == "//" ? BinaryOpExpr::Op::DivDiv
|
||||
: BinaryOpExpr::Op::Mod;
|
||||
left = nonstd_make_unique<BinaryOpExpr>(get_location(), std::move(left), std::move(right), op);
|
||||
left = std::make_shared<BinaryOpExpr>(get_location(), std::move(left), std::move(right), op);
|
||||
}
|
||||
|
||||
if (!consumeToken("|").empty()) {
|
||||
@ -1760,20 +1791,20 @@ private:
|
||||
filter->prepend(std::move(left));
|
||||
return expr;
|
||||
} else {
|
||||
std::vector<std::unique_ptr<Expression>> parts;
|
||||
std::vector<std::shared_ptr<Expression>> parts;
|
||||
parts.emplace_back(std::move(left));
|
||||
parts.emplace_back(std::move(expr));
|
||||
return nonstd_make_unique<FilterExpr>(get_location(), std::move(parts));
|
||||
return std::make_shared<FilterExpr>(get_location(), std::move(parts));
|
||||
}
|
||||
}
|
||||
return left;
|
||||
}
|
||||
|
||||
std::unique_ptr<Expression> call_func(const std::string & name, Expression::Arguments && args) const {
|
||||
return nonstd_make_unique<CallExpr>(get_location(), nonstd_make_unique<VariableExpr>(get_location(), name), std::move(args));
|
||||
std::shared_ptr<Expression> call_func(const std::string & name, Expression::Arguments && args) const {
|
||||
return std::make_shared<CallExpr>(get_location(), std::make_shared<VariableExpr>(get_location(), name), std::move(args));
|
||||
}
|
||||
|
||||
std::unique_ptr<Expression> parseMathUnaryPlusMinus() {
|
||||
std::shared_ptr<Expression> parseMathUnaryPlusMinus() {
|
||||
static std::regex unary_plus_minus_tok(R"(\+|-(?![}%#]\}))");
|
||||
auto op_str = consumeToken(unary_plus_minus_tok);
|
||||
auto expr = parseValueExpression();
|
||||
@ -1781,19 +1812,19 @@ private:
|
||||
|
||||
if (!op_str.empty()) {
|
||||
auto op = op_str == "+" ? UnaryOpExpr::Op::Plus : UnaryOpExpr::Op::Minus;
|
||||
return nonstd_make_unique<UnaryOpExpr>(get_location(), std::move(expr), op);
|
||||
return std::make_shared<UnaryOpExpr>(get_location(), std::move(expr), op);
|
||||
}
|
||||
return expr;
|
||||
}
|
||||
|
||||
std::unique_ptr<Expression> parseValueExpression() {
|
||||
auto parseValue = [&]() -> std::unique_ptr<Expression> {
|
||||
std::shared_ptr<Expression> parseValueExpression() {
|
||||
auto parseValue = [&]() -> std::shared_ptr<Expression> {
|
||||
auto location = get_location();
|
||||
auto constant = parseConstant();
|
||||
if (constant) return nonstd_make_unique<LiteralExpr>(location, *constant);
|
||||
if (constant) return std::make_shared<LiteralExpr>(location, *constant);
|
||||
|
||||
static std::regex null_regex(R"(null\b)");
|
||||
if (!consumeToken(null_regex).empty()) return nonstd_make_unique<LiteralExpr>(location, Value());
|
||||
if (!consumeToken(null_regex).empty()) return std::make_shared<LiteralExpr>(location, Value());
|
||||
|
||||
auto identifier = parseIdentifier();
|
||||
if (identifier) return identifier;
|
||||
@ -1814,19 +1845,19 @@ private:
|
||||
|
||||
while (it != end && consumeSpaces() && peekSymbols({ "[", "." })) {
|
||||
if (!consumeToken("[").empty()) {
|
||||
std::unique_ptr<Expression> index;
|
||||
std::shared_ptr<Expression> index;
|
||||
if (!consumeToken(":").empty()) {
|
||||
auto slice_end = parseExpression();
|
||||
index = nonstd_make_unique<SliceExpr>(slice_end->location, nullptr, std::move(slice_end));
|
||||
index = std::make_shared<SliceExpr>(slice_end->location, nullptr, std::move(slice_end));
|
||||
} else {
|
||||
auto slice_start = parseExpression();
|
||||
if (!consumeToken(":").empty()) {
|
||||
consumeSpaces();
|
||||
if (peekSymbols({ "]" })) {
|
||||
index = nonstd_make_unique<SliceExpr>(slice_start->location, std::move(slice_start), nullptr);
|
||||
index = std::make_shared<SliceExpr>(slice_start->location, std::move(slice_start), nullptr);
|
||||
} else {
|
||||
auto slice_end = parseExpression();
|
||||
index = nonstd_make_unique<SliceExpr>(slice_start->location, std::move(slice_start), std::move(slice_end));
|
||||
index = std::make_shared<SliceExpr>(slice_start->location, std::move(slice_start), std::move(slice_end));
|
||||
}
|
||||
} else {
|
||||
index = std::move(slice_start);
|
||||
@ -1835,7 +1866,7 @@ private:
|
||||
if (!index) throw std::runtime_error("Empty index in subscript");
|
||||
if (consumeToken("]").empty()) throw std::runtime_error("Expected closing bracket in subscript");
|
||||
|
||||
value = nonstd_make_unique<SubscriptExpr>(value->location, std::move(value), std::move(index));
|
||||
value = std::make_shared<SubscriptExpr>(value->location, std::move(value), std::move(index));
|
||||
} else if (!consumeToken(".").empty()) {
|
||||
auto identifier = parseIdentifier();
|
||||
if (!identifier) throw std::runtime_error("Expected identifier in subscript");
|
||||
@ -1843,10 +1874,10 @@ private:
|
||||
consumeSpaces();
|
||||
if (peekSymbols({ "(" })) {
|
||||
auto callParams = parseCallArgs();
|
||||
value = nonstd_make_unique<MethodCallExpr>(identifier->location, std::move(value), std::move(identifier), std::move(callParams));
|
||||
value = std::make_shared<MethodCallExpr>(identifier->location, std::move(value), std::move(identifier), std::move(callParams));
|
||||
} else {
|
||||
auto key = nonstd_make_unique<LiteralExpr>(identifier->location, Value(identifier->get_name()));
|
||||
value = nonstd_make_unique<SubscriptExpr>(identifier->location, std::move(value), std::move(key));
|
||||
auto key = std::make_shared<LiteralExpr>(identifier->location, Value(identifier->get_name()));
|
||||
value = std::make_shared<SubscriptExpr>(identifier->location, std::move(value), std::move(key));
|
||||
}
|
||||
}
|
||||
consumeSpaces();
|
||||
@ -1855,12 +1886,12 @@ private:
|
||||
if (peekSymbols({ "(" })) {
|
||||
auto location = get_location();
|
||||
auto callParams = parseCallArgs();
|
||||
value = nonstd_make_unique<CallExpr>(location, std::move(value), std::move(callParams));
|
||||
value = std::make_shared<CallExpr>(location, std::move(value), std::move(callParams));
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
std::unique_ptr<Expression> parseBracedExpressionOrArray() {
|
||||
std::shared_ptr<Expression> parseBracedExpressionOrArray() {
|
||||
if (consumeToken("(").empty()) return nullptr;
|
||||
|
||||
auto expr = parseExpression();
|
||||
@ -1870,7 +1901,7 @@ private:
|
||||
return expr; // Drop the parentheses
|
||||
}
|
||||
|
||||
std::vector<std::unique_ptr<Expression>> tuple;
|
||||
std::vector<std::shared_ptr<Expression>> tuple;
|
||||
tuple.emplace_back(std::move(expr));
|
||||
|
||||
while (it != end) {
|
||||
@ -1880,18 +1911,18 @@ private:
|
||||
tuple.push_back(std::move(next));
|
||||
|
||||
if (!consumeToken(")").empty()) {
|
||||
return nonstd_make_unique<ArrayExpr>(get_location(), std::move(tuple));
|
||||
return std::make_shared<ArrayExpr>(get_location(), std::move(tuple));
|
||||
}
|
||||
}
|
||||
throw std::runtime_error("Expected closing parenthesis");
|
||||
}
|
||||
|
||||
std::unique_ptr<Expression> parseArray() {
|
||||
std::shared_ptr<Expression> parseArray() {
|
||||
if (consumeToken("[").empty()) return nullptr;
|
||||
|
||||
std::vector<std::unique_ptr<Expression>> elements;
|
||||
std::vector<std::shared_ptr<Expression>> elements;
|
||||
if (!consumeToken("]").empty()) {
|
||||
return nonstd_make_unique<ArrayExpr>(get_location(), std::move(elements));
|
||||
return std::make_shared<ArrayExpr>(get_location(), std::move(elements));
|
||||
}
|
||||
auto first_expr = parseExpression();
|
||||
if (!first_expr) throw std::runtime_error("Expected first expression in array");
|
||||
@ -1903,7 +1934,7 @@ private:
|
||||
if (!expr) throw std::runtime_error("Expected expression in array");
|
||||
elements.push_back(std::move(expr));
|
||||
} else if (!consumeToken("]").empty()) {
|
||||
return nonstd_make_unique<ArrayExpr>(get_location(), std::move(elements));
|
||||
return std::make_shared<ArrayExpr>(get_location(), std::move(elements));
|
||||
} else {
|
||||
throw std::runtime_error("Expected comma or closing bracket in array");
|
||||
}
|
||||
@ -1911,12 +1942,12 @@ private:
|
||||
throw std::runtime_error("Expected closing bracket");
|
||||
}
|
||||
|
||||
std::unique_ptr<Expression> parseDictionary() {
|
||||
std::shared_ptr<Expression> parseDictionary() {
|
||||
if (consumeToken("{").empty()) return nullptr;
|
||||
|
||||
std::vector<std::pair<std::unique_ptr<Expression>, std::unique_ptr<Expression>>> elements;
|
||||
std::vector<std::pair<std::shared_ptr<Expression>, std::shared_ptr<Expression>>> elements;
|
||||
if (!consumeToken("}").empty()) {
|
||||
return nonstd_make_unique<DictExpr>(get_location(), std::move(elements));
|
||||
return std::make_shared<DictExpr>(get_location(), std::move(elements));
|
||||
}
|
||||
|
||||
auto parseKeyValuePair = [&]() {
|
||||
@ -1934,7 +1965,7 @@ private:
|
||||
if (!consumeToken(",").empty()) {
|
||||
parseKeyValuePair();
|
||||
} else if (!consumeToken("}").empty()) {
|
||||
return nonstd_make_unique<DictExpr>(get_location(), std::move(elements));
|
||||
return std::make_shared<DictExpr>(get_location(), std::move(elements));
|
||||
} else {
|
||||
throw std::runtime_error("Expected comma or closing brace in dictionary");
|
||||
}
|
||||
@ -2051,7 +2082,7 @@ private:
|
||||
auto iterable = parseExpression(/* allow_if_expr = */ false);
|
||||
if (!iterable) throw std::runtime_error("Expected iterable in for block");
|
||||
|
||||
std::unique_ptr<Expression> condition;
|
||||
std::shared_ptr<Expression> condition;
|
||||
if (!consumeToken(if_tok).empty()) {
|
||||
condition = parseExpression();
|
||||
}
|
||||
@ -2067,7 +2098,7 @@ private:
|
||||
|
||||
std::string ns;
|
||||
std::vector<std::string> var_names;
|
||||
std::unique_ptr<Expression> value;
|
||||
std::shared_ptr<Expression> value;
|
||||
if (!(group = consumeTokenGroups(namespaced_var_regex)).empty()) {
|
||||
ns = group[1];
|
||||
var_names.push_back(group[2]);
|
||||
@ -2114,17 +2145,17 @@ private:
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<TemplateNode> parseTemplate(
|
||||
std::shared_ptr<TemplateNode> parseTemplate(
|
||||
const TemplateTokenIterator & begin,
|
||||
TemplateTokenIterator & it,
|
||||
const TemplateTokenIterator & end,
|
||||
bool fully = false) const {
|
||||
std::vector<std::unique_ptr<TemplateNode>> children;
|
||||
std::vector<std::shared_ptr<TemplateNode>> children;
|
||||
while (it != end) {
|
||||
const auto start = it;
|
||||
const auto & token = *(it++);
|
||||
if (auto if_token = dynamic_cast<IfTemplateToken*>(token.get())) {
|
||||
std::vector<std::pair<std::unique_ptr<Expression>, std::unique_ptr<TemplateNode>>> cascade;
|
||||
std::vector<std::pair<std::shared_ptr<Expression>, std::shared_ptr<TemplateNode>>> cascade;
|
||||
cascade.emplace_back(std::move(if_token->condition), parseTemplate(begin, it, end));
|
||||
|
||||
while (it != end && (*it)->type == TemplateToken::Type::Elif) {
|
||||
@ -2138,17 +2169,17 @@ private:
|
||||
if (it == end || (*(it++))->type != TemplateToken::Type::EndIf) {
|
||||
throw unterminated(**start);
|
||||
}
|
||||
children.emplace_back(nonstd_make_unique<IfNode>(token->location, std::move(cascade)));
|
||||
children.emplace_back(std::make_shared<IfNode>(token->location, std::move(cascade)));
|
||||
} else if (auto for_token = dynamic_cast<ForTemplateToken*>(token.get())) {
|
||||
auto body = parseTemplate(begin, it, end);
|
||||
auto else_body = std::unique_ptr<TemplateNode>();
|
||||
auto else_body = std::shared_ptr<TemplateNode>();
|
||||
if (it != end && (*it)->type == TemplateToken::Type::Else) {
|
||||
else_body = parseTemplate(begin, ++it, end);
|
||||
}
|
||||
if (it == end || (*(it++))->type != TemplateToken::Type::EndFor) {
|
||||
throw unterminated(**start);
|
||||
}
|
||||
children.emplace_back(nonstd_make_unique<ForNode>(token->location, std::move(for_token->var_names), std::move(for_token->iterable), std::move(for_token->condition), std::move(body), for_token->recursive, std::move(else_body)));
|
||||
children.emplace_back(std::make_shared<ForNode>(token->location, std::move(for_token->var_names), std::move(for_token->iterable), std::move(for_token->condition), std::move(body), for_token->recursive, std::move(else_body)));
|
||||
} else if (auto text_token = dynamic_cast<TextTemplateToken*>(token.get())) {
|
||||
SpaceHandling pre_space = (it - 1) != begin ? (*(it - 2))->post_space : SpaceHandling::Keep;
|
||||
SpaceHandling post_space = it != end ? (*it)->pre_space : SpaceHandling::Keep;
|
||||
@ -2173,25 +2204,28 @@ private:
|
||||
static std::regex r(R"(\r?\n$)");
|
||||
text = std::regex_replace(text, r, ""); // Strip one trailing newline
|
||||
}
|
||||
children.emplace_back(nonstd_make_unique<TextNode>(token->location, text));
|
||||
children.emplace_back(std::make_shared<TextNode>(token->location, text));
|
||||
} else if (auto expr_token = dynamic_cast<ExpressionTemplateToken*>(token.get())) {
|
||||
children.emplace_back(nonstd_make_unique<ExpressionNode>(token->location, std::move(expr_token->expr)));
|
||||
children.emplace_back(std::make_shared<ExpressionNode>(token->location, std::move(expr_token->expr)));
|
||||
} else if (auto set_token = dynamic_cast<SetTemplateToken*>(token.get())) {
|
||||
if (set_token->value) {
|
||||
children.emplace_back(nonstd_make_unique<SetNode>(token->location, set_token->ns, set_token->var_names, std::move(set_token->value), nullptr));
|
||||
children.emplace_back(std::make_shared<SetNode>(token->location, set_token->ns, set_token->var_names, std::move(set_token->value)));
|
||||
} else {
|
||||
auto value_template = parseTemplate(begin, it, end);
|
||||
if (it == end || (*(it++))->type != TemplateToken::Type::EndSet) {
|
||||
throw unterminated(**start);
|
||||
}
|
||||
children.emplace_back(nonstd_make_unique<SetNode>(token->location, set_token->ns, set_token->var_names, nullptr, std::move(value_template)));
|
||||
if (!set_token->ns.empty()) throw std::runtime_error("Namespaced set not supported in set with template value");
|
||||
if (set_token->var_names.size() != 1) throw std::runtime_error("Structural assignment not supported in set with template value");
|
||||
auto & name = set_token->var_names[0];
|
||||
children.emplace_back(std::make_shared<SetTemplateNode>(token->location, name, std::move(value_template)));
|
||||
}
|
||||
} else if (auto macro_token = dynamic_cast<MacroTemplateToken*>(token.get())) {
|
||||
auto body = parseTemplate(begin, it, end);
|
||||
if (it == end || (*(it++))->type != TemplateToken::Type::EndMacro) {
|
||||
throw unterminated(**start);
|
||||
}
|
||||
children.emplace_back(nonstd_make_unique<MacroNode>(token->location, std::move(macro_token->name), std::move(macro_token->params), std::move(body)));
|
||||
children.emplace_back(std::make_shared<MacroNode>(token->location, std::move(macro_token->name), std::move(macro_token->params), std::move(body)));
|
||||
} else if (dynamic_cast<CommentTemplateToken*>(token.get())) {
|
||||
// Ignore comments
|
||||
} else if (dynamic_cast<EndForTemplateToken*>(token.get())
|
||||
@ -2210,17 +2244,17 @@ private:
|
||||
throw unexpected(**it);
|
||||
}
|
||||
if (children.empty()) {
|
||||
return nonstd_make_unique<TextNode>(Location { template_str, 0 }, std::string());
|
||||
return std::make_shared<TextNode>(Location { template_str, 0 }, std::string());
|
||||
} else if (children.size() == 1) {
|
||||
return std::move(children[0]);
|
||||
} else {
|
||||
return nonstd_make_unique<SequenceNode>(children[0]->location(), std::move(children));
|
||||
return std::make_shared<SequenceNode>(children[0]->location(), std::move(children));
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
static std::unique_ptr<TemplateNode> parse(const std::string& template_str, const Options & options) {
|
||||
static std::shared_ptr<TemplateNode> parse(const std::string& template_str, const Options & options) {
|
||||
Parser parser(std::make_shared<std::string>(template_str), options);
|
||||
auto tokens = parser.tokenize();
|
||||
TemplateTokenIterator begin = tokens.begin();
|
||||
|
@ -12,6 +12,29 @@
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
llama_tool_call_style llama_tool_call_style_detect(const minja::chat_template & chat_template) {
|
||||
const auto & src = chat_template.source();
|
||||
|
||||
if (src.find("<tool_call>") != std::string::npos) {
|
||||
return Hermes2Pro;
|
||||
} else if (src.find(">>>all") != std::string::npos) {
|
||||
return FunctionaryV3Llama3;
|
||||
} else if (src.find("<|start_header_id|>") != std::string::npos
|
||||
&& src.find("<function=") != std::string::npos) {
|
||||
return FunctionaryV3Llama31;
|
||||
} else if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
|
||||
if (src.find("<|python_tag|>") != std::string::npos) {
|
||||
return Llama31;
|
||||
} else {
|
||||
return Llama32;
|
||||
}
|
||||
} else if (src.find("<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>") != std::string::npos) {
|
||||
return CommandRPlus;
|
||||
} else {
|
||||
return UnknownToolCallStyle;
|
||||
}
|
||||
}
|
||||
|
||||
static bool parse_json(std::string::const_iterator & it, const std::string::const_iterator & end, json & out) {
|
||||
// // https://json.nlohmann.me/features/parsing/sax_interface/
|
||||
struct json_error_locator : public nlohmann::json_sax<json> {
|
||||
@ -207,7 +230,8 @@ llama_tool_calls parse_tool_calls(llama_tool_call_style style, const json & tool
|
||||
}
|
||||
|
||||
llama_tool_call_handler llama_tool_call_handler_init(
|
||||
const llama_chat_template & tmpl,
|
||||
llama_tool_call_style style,
|
||||
const minja::chat_template & tmpl,
|
||||
bool allow_content,
|
||||
bool parallel_tool_calls,
|
||||
const nlohmann::ordered_json & messages,
|
||||
@ -215,18 +239,18 @@ llama_tool_call_handler llama_tool_call_handler_init(
|
||||
{
|
||||
llama_tool_call_handler handler;
|
||||
|
||||
switch (tmpl.tool_call_style()) {
|
||||
switch (style) {
|
||||
case llama_tool_call_style::Llama31:
|
||||
case llama_tool_call_style::Llama32: {
|
||||
static auto builtin_tools = json {"wolfram_alpha", "brave_search"};
|
||||
|
||||
auto uses_python_tag = tmpl.tool_call_style() == llama_tool_call_style::Llama31;
|
||||
auto uses_python_tag = style == llama_tool_call_style::Llama31;
|
||||
|
||||
// Technically we should only trigger on `"\n{\"name\": \"" + name + "\""` for each tool name,
|
||||
// but Llama-3.2-3B (and 1B) struggles to output valid tool calls so we're "guiding" it strongly as soon
|
||||
// as it seems to be outputting some JSON.
|
||||
// TODO: make this conditional on a very small model (e.g. 1B / 3B).
|
||||
auto eagerly_match_any_json = tmpl.tool_call_style() == llama_tool_call_style::Llama32;
|
||||
auto eagerly_match_any_json = style == llama_tool_call_style::Llama32;
|
||||
|
||||
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
|
@ -2,10 +2,20 @@
|
||||
|
||||
#include "ggml.h"
|
||||
#include "common.h"
|
||||
#include "chat-template.hpp"
|
||||
// Change JSON_ASSERT from assert() to GGML_ASSERT:
|
||||
#define JSON_ASSERT GGML_ASSERT
|
||||
#include "json.hpp"
|
||||
#include "chat-template.h"
|
||||
|
||||
enum llama_tool_call_style {
|
||||
UnknownToolCallStyle,
|
||||
Llama31,
|
||||
Llama32,
|
||||
FunctionaryV3Llama3,
|
||||
FunctionaryV3Llama31,
|
||||
Hermes2Pro,
|
||||
CommandRPlus,
|
||||
};
|
||||
|
||||
struct llama_tool_call {
|
||||
std::string name;
|
||||
@ -24,10 +34,13 @@ struct llama_tool_call_handler {
|
||||
std::vector<std::string> additional_stop_words;
|
||||
};
|
||||
|
||||
llama_tool_call_style llama_tool_call_style_detect(const minja::chat_template & chat_template);
|
||||
|
||||
llama_tool_calls parse_tool_calls(llama_tool_call_style style, const nlohmann::ordered_json & tools, const std::string& input);
|
||||
|
||||
llama_tool_call_handler llama_tool_call_handler_init(
|
||||
const llama_chat_template & tmpl,
|
||||
llama_tool_call_style style,
|
||||
const minja::chat_template & tmpl,
|
||||
bool allow_content,
|
||||
bool parallel_tool_calls,
|
||||
const nlohmann::ordered_json & messages,
|
||||
|
@ -663,7 +663,7 @@ struct server_context {
|
||||
llama_chat_message chat[] = {{"user", "test"}};
|
||||
|
||||
if (use_jinja) {
|
||||
auto chat_template = llama_chat_template::from_model(model);
|
||||
auto chat_template = llama_chat_template_from_model(model);
|
||||
try {
|
||||
chat_template.apply({{
|
||||
{"role", "user"},
|
||||
@ -2875,11 +2875,12 @@ int main(int argc, char ** argv) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto chat_template = llama_chat_template::from_model(ctx_server.model, params.chat_template.empty() ? nullptr : params.chat_template.c_str());
|
||||
static auto chat_template = llama_chat_template_from_model(ctx_server.model, params.chat_template.empty() ? nullptr : params.chat_template.c_str());
|
||||
static auto tool_call_style = llama_tool_call_style_detect(chat_template);
|
||||
|
||||
json data;
|
||||
try {
|
||||
data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), chat_template, params.use_jinja);
|
||||
data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), chat_template, tool_call_style, params.use_jinja);
|
||||
} catch (const std::exception & e) {
|
||||
res_error(res, format_error_response(e.what(), ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
@ -2897,7 +2898,7 @@ int main(int argc, char ** argv) {
|
||||
ctx_server.receive_cmpl_results(task_ids, [&](const std::vector<server_task_result> & results) {
|
||||
// multitask is never support in chat completion, there is only one result
|
||||
try {
|
||||
json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, chat_template, /*.streaming =*/ false, verbose);
|
||||
json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, tool_call_style, /*.streaming =*/ false, verbose);
|
||||
res_ok(res, result_oai);
|
||||
} catch (const std::runtime_error & e) {
|
||||
res_error(res, format_error_response(e.what(), ERROR_TYPE_SERVER));
|
||||
|
@ -14,7 +14,6 @@
|
||||
|
||||
// Change JSON_ASSERT from assert() to GGML_ASSERT:
|
||||
#define JSON_ASSERT GGML_ASSERT
|
||||
#include "chat-template.h"
|
||||
#include "json.hpp"
|
||||
#include "minja.hpp"
|
||||
#include "tool-call.h"
|
||||
@ -309,7 +308,8 @@ static bool server_sent_event(httplib::DataSink & sink, const char * event, cons
|
||||
static json oaicompat_completion_params_parse(
|
||||
const struct llama_model * model,
|
||||
const json & body, /* openai api json semantics */
|
||||
const llama_chat_template & tmpl,
|
||||
const minja::chat_template & tmpl,
|
||||
llama_tool_call_style tool_call_style,
|
||||
bool use_jinja)
|
||||
{
|
||||
json llama_params;
|
||||
@ -320,7 +320,7 @@ static json oaicompat_completion_params_parse(
|
||||
auto has_tools = tools.is_array() && !tools.empty();
|
||||
|
||||
// Apply chat template to the list of messages
|
||||
llama_params["chat_template"] = tmpl.chat_template();
|
||||
llama_params["chat_template"] = tmpl.source();
|
||||
|
||||
if (use_jinja) {
|
||||
if (has_tools && !tmpl.supports_tools()) {
|
||||
@ -372,7 +372,7 @@ static json oaicompat_completion_params_parse(
|
||||
llama_params["parse_tool_calls"] = true;
|
||||
llama_params["parallel_tool_calls"] = parallel_tool_calls;
|
||||
|
||||
auto handler = llama_tool_call_handler_init(tmpl, allow_content, parallel_tool_calls, body.at("messages"), tools);
|
||||
auto handler = llama_tool_call_handler_init(tool_call_style, tmpl, allow_content, parallel_tool_calls, body.at("messages"), tools);
|
||||
llama_params["prompt"] = handler.prompt;
|
||||
|
||||
for (const auto & stop : handler.additional_stop_words) {
|
||||
@ -395,7 +395,7 @@ static json oaicompat_completion_params_parse(
|
||||
llama_params["prompt"] = tmpl.apply(body.at("messages"), tools, /* add_generation_prompt= */ true);
|
||||
}
|
||||
} else {
|
||||
llama_params["prompt"] = format_chat(model, tmpl.chat_template(), body.at("messages"));
|
||||
llama_params["prompt"] = format_chat(model, tmpl.source(), body.at("messages"));
|
||||
}
|
||||
|
||||
// Handle "n" field
|
||||
@ -435,7 +435,7 @@ static json oaicompat_completion_params_parse(
|
||||
return llama_params;
|
||||
}
|
||||
|
||||
static json format_final_response_oaicompat(const json & request, const json & result, const std::string & completion_id, const llama_chat_template & tmpl, bool streaming = false, bool verbose = false) {
|
||||
static json format_final_response_oaicompat(const json & request, const json & result, const std::string & completion_id, llama_tool_call_style tool_call_style, bool streaming = false, bool verbose = false) {
|
||||
bool stopped_word = result.count("stopped_word") != 0;
|
||||
bool stopped_eos = json_value(result, "stopped_eos", false);
|
||||
int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
|
||||
@ -452,7 +452,7 @@ static json format_final_response_oaicompat(const json & request, const json & r
|
||||
json tool_calls;
|
||||
json message_content;
|
||||
if (json_value(request, "parse_tool_calls", false)
|
||||
&& !(parsed_tool_calls = parse_tool_calls(tmpl.tool_call_style(), tools, content)).tool_calls.empty()) {
|
||||
&& !(parsed_tool_calls = parse_tool_calls(tool_call_style, tools, content)).tool_calls.empty()) {
|
||||
finish_reason = "tool_calls";
|
||||
if (!parsed_tool_calls.content.empty()) {
|
||||
message_content = parsed_tool_calls.content;
|
||||
|
148
fetch_templates_and_goldens.py
Normal file
148
fetch_templates_and_goldens.py
Normal file
@ -0,0 +1,148 @@
|
||||
#!/usr/bin/env uv run
|
||||
# /// script
|
||||
# requires-python = ">=3.10"
|
||||
# dependencies = [
|
||||
# "jinja2",
|
||||
# "huggingface_hub",
|
||||
# ]
|
||||
# ///
|
||||
'''
|
||||
Fetches the Jinja2 templates of specified models and generates prompt goldens for predefined chat contexts.
|
||||
Outputs lines of arguments for a C++ test binary.
|
||||
All files are written to the specified output folder.
|
||||
|
||||
Usage:
|
||||
python ./update_jinja_goldens.py output_folder context1.json context2.json ... model_id1 model_id2 ...
|
||||
|
||||
Example:
|
||||
python ./update_jinja_goldens.py ./test_files "microsoft/Phi-3-medium-4k-instruct" "Qwen/Qwen2-7B-Instruct"
|
||||
'''
|
||||
|
||||
import logging
|
||||
import datetime
|
||||
import glob
|
||||
import os
|
||||
from huggingface_hub import hf_hub_download
|
||||
import json
|
||||
import jinja2
|
||||
import jinja2.ext
|
||||
import re
|
||||
import argparse
|
||||
import shutil
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def raise_exception(message: str):
|
||||
raise ValueError(message)
|
||||
|
||||
def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False):
|
||||
return json.dumps(x, ensure_ascii=ensure_ascii, indent=indent, separators=separators, sort_keys=sort_keys)
|
||||
|
||||
TEST_DATE = os.environ.get('TEST_DATE', '2024-07-26')
|
||||
|
||||
def strftime_now(format):
|
||||
now = datetime.datetime.strptime(TEST_DATE, "%Y-%m-%d")
|
||||
return now.strftime(format)
|
||||
|
||||
def handle_chat_template(output_folder, model_id, variant, template_src):
|
||||
model_name = model_id.replace("/", "-")
|
||||
base_name = f'{model_name}-{variant}' if variant else model_name
|
||||
template_file = os.path.join(output_folder, f'{base_name}.jinja')
|
||||
|
||||
with open(template_file, 'w') as f:
|
||||
f.write(template_src)
|
||||
|
||||
env = jinja2.Environment(
|
||||
trim_blocks=True,
|
||||
lstrip_blocks=True,
|
||||
extensions=[jinja2.ext.loopcontrols]
|
||||
)
|
||||
env.filters['safe'] = lambda x: x
|
||||
env.filters['tojson'] = tojson
|
||||
env.globals['raise_exception'] = raise_exception
|
||||
env.globals['strftime_now'] = strftime_now
|
||||
|
||||
template_handles_tools = 'tools' in template_src
|
||||
template_hates_the_system = 'System role not supported' in template_src
|
||||
|
||||
template = env.from_string(template_src)
|
||||
|
||||
context_files = glob.glob(os.path.join(output_folder, '*.json'))
|
||||
for context_file in context_files:
|
||||
context_name = os.path.basename(context_file).replace(".json", "")
|
||||
with open(context_file, 'r') as f:
|
||||
context = json.load(f)
|
||||
|
||||
if not template_handles_tools and 'tools' in context:
|
||||
continue
|
||||
|
||||
if template_hates_the_system and any(m['role'] == 'system' for m in context['messages']):
|
||||
continue
|
||||
|
||||
output_file = os.path.join(output_folder, f'{base_name}-{context_name}.txt')
|
||||
|
||||
render_context = json.loads(json.dumps(context))
|
||||
|
||||
if 'tool_call.arguments | items' in template_src or 'tool_call.arguments | tojson' in template_src:
|
||||
for message in render_context['messages']:
|
||||
if 'tool_calls' in message:
|
||||
for tool_call in message['tool_calls']:
|
||||
if tool_call.get('type') == 'function':
|
||||
arguments = tool_call['function']['arguments']
|
||||
tool_call['function']['arguments'] = json.loads(arguments)
|
||||
|
||||
try:
|
||||
output = template.render(**render_context)
|
||||
except Exception as e1:
|
||||
for message in context["messages"]:
|
||||
if message.get("content") is None:
|
||||
message["content"] = ""
|
||||
|
||||
try:
|
||||
output = template.render(**render_context)
|
||||
except Exception as e2:
|
||||
logger.info(f" ERROR: {e2} (after first error: {e1})")
|
||||
output = f"ERROR: {e2}"
|
||||
|
||||
with open(output_file, 'w') as f:
|
||||
f.write(output)
|
||||
|
||||
# Output the line of arguments for the C++ test binary
|
||||
print(f"{template_file} {context_file} {output_file}")
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Generate chat templates and output test arguments.")
|
||||
parser.add_argument("output_folder", help="Folder to store all output files")
|
||||
parser.add_argument("model_ids", nargs="+", help="List of model IDs to process")
|
||||
args = parser.parse_args()
|
||||
|
||||
output_folder = args.output_folder
|
||||
if not os.path.isdir(output_folder):
|
||||
os.makedirs(output_folder)
|
||||
|
||||
# Copy context files to the output folder
|
||||
for context_file in glob.glob('tests/chat/contexts/*.json'):
|
||||
shutil.copy(context_file, output_folder)
|
||||
|
||||
for model_id in args.model_ids:
|
||||
try:
|
||||
with open(hf_hub_download(repo_id=model_id, filename="tokenizer_config.json")) as f:
|
||||
config_str = f.read()
|
||||
|
||||
try:
|
||||
config = json.loads(config_str)
|
||||
except json.JSONDecodeError:
|
||||
config = json.loads(re.sub(r'\}([\n\s]*\}[\n\s]*\],[\n\s]*"clean_up_tokenization_spaces")', r'\1', config_str))
|
||||
|
||||
chat_template = config['chat_template']
|
||||
if isinstance(chat_template, str):
|
||||
handle_chat_template(output_folder, model_id, None, chat_template)
|
||||
else:
|
||||
for ct in chat_template:
|
||||
handle_chat_template(output_folder, model_id, ct['name'], ct['template'])
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing model {model_id}: {e}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -7,7 +7,7 @@
|
||||
|
||||
#include "llama.h"
|
||||
#include "common.h"
|
||||
#include "chat-template.h"
|
||||
#include "chat-template.hpp"
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
@ -73,7 +73,7 @@ static void test_jinja_templates() {
|
||||
return "tests/chat/goldens/" + golden_name + ".txt";
|
||||
};
|
||||
auto fail_with_golden_instructions = [&]() {
|
||||
throw std::runtime_error("To fetch templates and generate golden files, run `python tests/update_jinja_goldens.py`");
|
||||
throw std::runtime_error("To fetch templates and generate golden files, run `python update_templates_and_goldens.py`");
|
||||
};
|
||||
if (jinja_template_files.empty()) {
|
||||
std::cerr << "No Jinja templates found in tests/chat/templates" << std::endl;
|
||||
@ -89,7 +89,7 @@ static void test_jinja_templates() {
|
||||
for (const auto & ctx_file : context_files) {
|
||||
auto ctx = json::parse(read_file(ctx_file));
|
||||
|
||||
llama_chat_template tmpl(
|
||||
minja::chat_template tmpl(
|
||||
tmpl_str,
|
||||
ctx.at("bos_token"),
|
||||
ctx.at("eos_token"));
|
||||
@ -127,20 +127,6 @@ static void test_jinja_templates() {
|
||||
}
|
||||
}
|
||||
|
||||
void test_tool_call_style(const std::string & template_file, llama_tool_call_style expected) {
|
||||
auto tmpl = llama_chat_template(read_file(template_file), "<s>", "</s>");
|
||||
std::cout << "# Testing tool call style of: " << template_file << std::endl << std::flush;
|
||||
assert_equals(expected, tmpl.tool_call_style());
|
||||
}
|
||||
|
||||
void test_tool_call_styles() {
|
||||
test_tool_call_style("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja", FunctionaryV3Llama31);
|
||||
test_tool_call_style("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja", FunctionaryV3Llama3);
|
||||
test_tool_call_style("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja", Llama31);
|
||||
test_tool_call_style("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja", Llama32);
|
||||
test_tool_call_style("tests/chat/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja", CommandRPlus);
|
||||
}
|
||||
|
||||
static void test_legacy_templates() {
|
||||
struct test_template {
|
||||
std::string name;
|
||||
@ -353,7 +339,6 @@ int main(void) {
|
||||
if (getenv("LLAMA_SKIP_TESTS_SLOW_ON_EMULATOR")) {
|
||||
fprintf(stderr, "\033[33mWARNING: Skipping slow tests on emulator.\n\033[0m");
|
||||
} else {
|
||||
test_tool_call_styles();
|
||||
test_jinja_templates();
|
||||
}
|
||||
|
||||
|
@ -9,7 +9,8 @@
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
static void assert_equals(const std::string & expected, const std::string & actual) {
|
||||
template <class T>
|
||||
static void assert_equals(const T & expected, const T & actual) {
|
||||
if (expected != actual) {
|
||||
std::cerr << "Expected: " << expected << std::endl;
|
||||
std::cerr << "Actual: " << actual << std::endl;
|
||||
@ -242,7 +243,22 @@ static void test_parsing() {
|
||||
"{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", json::array());
|
||||
}
|
||||
|
||||
static std::string get_message_prompt_delta(const llama_chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools) {
|
||||
void test_tool_call_style(const std::string & template_file, llama_tool_call_style expected) {
|
||||
const minja::chat_template tmpl(read_file(template_file), "<s>", "</s>");
|
||||
auto tool_call_style = llama_tool_call_style_detect(tmpl);
|
||||
std::cout << "# Testing tool call style of: " << template_file << std::endl << std::flush;
|
||||
assert_equals(expected, tool_call_style);
|
||||
}
|
||||
|
||||
void test_tool_call_style_detection() {
|
||||
test_tool_call_style("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja", FunctionaryV3Llama31);
|
||||
test_tool_call_style("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja", FunctionaryV3Llama3);
|
||||
test_tool_call_style("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja", Llama31);
|
||||
test_tool_call_style("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja", Llama32);
|
||||
test_tool_call_style("tests/chat/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja", CommandRPlus);
|
||||
}
|
||||
|
||||
static std::string get_message_prompt_delta(const minja::chat_template & tmpl, const std::vector<std::string> & end_tokens, const json & user_message, const json & delta_message, const json & tools) {
|
||||
auto prefix = tmpl.apply(json::array({user_message}), tools, /* add_generation_prompt= */ true, json::object());
|
||||
auto full = tmpl.apply(json::array({user_message, delta_message}), tools, /* add_generation_prompt= */ false, json::object());
|
||||
|
||||
@ -267,7 +283,8 @@ static std::string get_message_prompt_delta(const llama_chat_template & tmpl, co
|
||||
|
||||
static void test_template(const std::string & template_file, const char * bos_token, const char * eos_token, const std::vector<std::string> & end_tokens, const json & tool_calling_message, const json & tools) {
|
||||
std::cout << "# Testing template: " << template_file << std::endl << std::flush;
|
||||
const llama_chat_template & tmpl = llama_chat_template(read_file(template_file), bos_token, eos_token);
|
||||
const minja::chat_template tmpl(read_file(template_file), bos_token, eos_token);
|
||||
auto tool_call_style = llama_tool_call_style_detect(tmpl);
|
||||
auto & tool_calls = tool_calling_message.at("tool_calls");
|
||||
|
||||
// Format the message: apply the template to 1 user message w/ add_generation_prompt=true, then w/ the extra message w/ add_generation_prompt=false,
|
||||
@ -277,7 +294,7 @@ static void test_template(const std::string & template_file, const char * bos_to
|
||||
{"content", "Hello, world!"}
|
||||
};
|
||||
|
||||
auto handler = llama_tool_call_handler_init(tmpl, /* allow_content= */ true, /* parallel_tool_calls= */ true, {user_message, tool_calling_message}, tools);
|
||||
auto handler = llama_tool_call_handler_init(tool_call_style, tmpl, /* allow_content= */ true, /* parallel_tool_calls= */ true, {user_message, tool_calling_message}, tools);
|
||||
auto grammar = build_grammar(handler.grammar);
|
||||
if (!grammar) {
|
||||
throw std::runtime_error("Failed to build grammar");
|
||||
@ -285,7 +302,7 @@ static void test_template(const std::string & template_file, const char * bos_to
|
||||
|
||||
auto full_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, tool_calling_message, tools);
|
||||
std::cout << "Full delta:\n```\n" << full_delta << "\n```" << std::endl;
|
||||
test_parse_tool_call(tmpl.tool_call_style(), tools, full_delta, "", tool_calls);
|
||||
test_parse_tool_call(tool_call_style, tools, full_delta, "", tool_calls);
|
||||
|
||||
auto content_less_delta = get_message_prompt_delta(tmpl, end_tokens, user_message, {
|
||||
{"role", "assistant"},
|
||||
@ -319,6 +336,7 @@ static void test_grammars() {
|
||||
int main() {
|
||||
test_grammars();
|
||||
test_parsing();
|
||||
test_tool_call_style_detection();
|
||||
|
||||
std::cout << "[tool-call] All tests passed!" << std::endl;
|
||||
return 0;
|
||||
|
Loading…
Reference in New Issue
Block a user