mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-12 03:31:46 +00:00
408 lines
11 KiB
C++
408 lines
11 KiB
C++
#ifdef NDEBUG
|
|
#undef NDEBUG
|
|
#endif
|
|
|
|
#include "llama.h"
|
|
#include "llama-grammar.h"
|
|
|
|
#include <cassert>
|
|
#include <stdexcept>
|
|
|
|
int main()
|
|
{
|
|
llama_grammar_parser parsed_grammar;
|
|
|
|
std::vector<std::pair<std::string, uint32_t>> expected = {
|
|
{"expr", 2},
|
|
{"expr_6", 6},
|
|
{"expr_7", 7},
|
|
{"ident", 8},
|
|
{"ident_10", 10},
|
|
{"num", 9},
|
|
{"num_11", 11},
|
|
{"root", 0},
|
|
{"root_1", 1},
|
|
{"root_5", 5},
|
|
{"term", 4},
|
|
{"ws", 3},
|
|
{"ws_12", 12},
|
|
};
|
|
|
|
std::vector<std::vector<llama_grammar_element>> expected_rules = {
|
|
{{LLAMA_GRETYPE_RULE_REF, 5}, {LLAMA_GRETYPE_END, 0}},
|
|
{
|
|
{LLAMA_GRETYPE_RULE_REF, 2},
|
|
{LLAMA_GRETYPE_CHAR, 61},
|
|
{LLAMA_GRETYPE_RULE_REF, 3},
|
|
{LLAMA_GRETYPE_RULE_REF, 4},
|
|
{LLAMA_GRETYPE_CHAR, 10},
|
|
{LLAMA_GRETYPE_END, 0},
|
|
},
|
|
{{LLAMA_GRETYPE_RULE_REF, 4}, {LLAMA_GRETYPE_RULE_REF, 7}, {LLAMA_GRETYPE_END, 0}},
|
|
{{LLAMA_GRETYPE_RULE_REF, 12}, {LLAMA_GRETYPE_END, 0}},
|
|
{
|
|
{LLAMA_GRETYPE_RULE_REF, 8},
|
|
{LLAMA_GRETYPE_ALT, 0},
|
|
{LLAMA_GRETYPE_RULE_REF, 9},
|
|
{LLAMA_GRETYPE_ALT, 0},
|
|
{LLAMA_GRETYPE_CHAR, 40},
|
|
{LLAMA_GRETYPE_RULE_REF, 3},
|
|
{LLAMA_GRETYPE_RULE_REF, 2},
|
|
{LLAMA_GRETYPE_CHAR, 41},
|
|
{LLAMA_GRETYPE_RULE_REF, 3},
|
|
{LLAMA_GRETYPE_END, 0},
|
|
},
|
|
{{LLAMA_GRETYPE_RULE_REF, 1}, {LLAMA_GRETYPE_RULE_REF, 5}, {LLAMA_GRETYPE_ALT, 0}, {LLAMA_GRETYPE_RULE_REF, 1}, {LLAMA_GRETYPE_END, 0}},
|
|
{
|
|
{LLAMA_GRETYPE_CHAR, 45},
|
|
{LLAMA_GRETYPE_CHAR_ALT, 43},
|
|
{LLAMA_GRETYPE_CHAR_ALT, 42},
|
|
{LLAMA_GRETYPE_CHAR_ALT, 47},
|
|
{LLAMA_GRETYPE_RULE_REF, 4},
|
|
{LLAMA_GRETYPE_END, 0},
|
|
},
|
|
{{LLAMA_GRETYPE_RULE_REF, 6}, {LLAMA_GRETYPE_RULE_REF, 7}, {LLAMA_GRETYPE_ALT, 0}, {LLAMA_GRETYPE_END, 0}},
|
|
{
|
|
{LLAMA_GRETYPE_CHAR, 97},
|
|
{LLAMA_GRETYPE_CHAR_RNG_UPPER, 122},
|
|
{LLAMA_GRETYPE_RULE_REF, 10},
|
|
{LLAMA_GRETYPE_RULE_REF, 3},
|
|
{LLAMA_GRETYPE_END, 0},
|
|
},
|
|
{{LLAMA_GRETYPE_RULE_REF, 11}, {LLAMA_GRETYPE_RULE_REF, 3}, {LLAMA_GRETYPE_END, 0}},
|
|
{
|
|
{LLAMA_GRETYPE_CHAR, 97},
|
|
{LLAMA_GRETYPE_CHAR_RNG_UPPER, 122},
|
|
{LLAMA_GRETYPE_CHAR_ALT, 48},
|
|
{LLAMA_GRETYPE_CHAR_RNG_UPPER, 57},
|
|
{LLAMA_GRETYPE_CHAR_ALT, 95},
|
|
{LLAMA_GRETYPE_RULE_REF, 10},
|
|
{LLAMA_GRETYPE_ALT, 0},
|
|
{LLAMA_GRETYPE_END, 0},
|
|
},
|
|
{
|
|
{LLAMA_GRETYPE_CHAR, 48},
|
|
{LLAMA_GRETYPE_CHAR_RNG_UPPER, 57},
|
|
{LLAMA_GRETYPE_RULE_REF, 11},
|
|
{LLAMA_GRETYPE_ALT, 0},
|
|
{LLAMA_GRETYPE_CHAR, 48},
|
|
{LLAMA_GRETYPE_CHAR_RNG_UPPER, 57},
|
|
{LLAMA_GRETYPE_END, 0},
|
|
},
|
|
{
|
|
{LLAMA_GRETYPE_CHAR, 32},
|
|
{LLAMA_GRETYPE_CHAR_ALT, 9},
|
|
{LLAMA_GRETYPE_CHAR_ALT, 10},
|
|
{LLAMA_GRETYPE_RULE_REF, 12},
|
|
{LLAMA_GRETYPE_ALT, 0},
|
|
{LLAMA_GRETYPE_END, 0},
|
|
},
|
|
};
|
|
|
|
for (auto pair : expected)
|
|
{
|
|
parsed_grammar.symbol_ids[pair.first] = pair.second;
|
|
}
|
|
|
|
for (auto rule : expected_rules)
|
|
{
|
|
parsed_grammar.rules.emplace_back();
|
|
for (auto element : rule)
|
|
{
|
|
parsed_grammar.rules.back().push_back(element);
|
|
}
|
|
}
|
|
|
|
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
|
|
|
|
llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
|
|
if (grammar == nullptr) {
|
|
throw std::runtime_error("Failed to initialize llama_grammar");
|
|
}
|
|
|
|
std::vector<std::vector<llama_grammar_element>> expected_stacks = {
|
|
{
|
|
{LLAMA_GRETYPE_RULE_REF, 5},
|
|
{LLAMA_GRETYPE_CHAR, 61},
|
|
{LLAMA_GRETYPE_RULE_REF, 7},
|
|
{LLAMA_GRETYPE_CHAR, 97},
|
|
},
|
|
{
|
|
{LLAMA_GRETYPE_RULE_REF, 5},
|
|
{LLAMA_GRETYPE_CHAR, 61},
|
|
{LLAMA_GRETYPE_RULE_REF, 7},
|
|
{LLAMA_GRETYPE_RULE_REF, 3},
|
|
{LLAMA_GRETYPE_CHAR, 48},
|
|
},
|
|
{
|
|
{LLAMA_GRETYPE_RULE_REF, 5},
|
|
{LLAMA_GRETYPE_CHAR, 61},
|
|
{LLAMA_GRETYPE_RULE_REF, 7},
|
|
{LLAMA_GRETYPE_RULE_REF, 3},
|
|
{LLAMA_GRETYPE_CHAR, 48},
|
|
},
|
|
{
|
|
{LLAMA_GRETYPE_RULE_REF, 5},
|
|
{LLAMA_GRETYPE_CHAR, 61},
|
|
{LLAMA_GRETYPE_RULE_REF, 7},
|
|
{LLAMA_GRETYPE_CHAR, 40},
|
|
},
|
|
{
|
|
{LLAMA_GRETYPE_CHAR, 61},
|
|
{LLAMA_GRETYPE_RULE_REF, 7},
|
|
{LLAMA_GRETYPE_CHAR, 97},
|
|
},
|
|
{
|
|
{LLAMA_GRETYPE_CHAR, 61},
|
|
{LLAMA_GRETYPE_RULE_REF, 7},
|
|
{LLAMA_GRETYPE_RULE_REF, 3},
|
|
{LLAMA_GRETYPE_CHAR, 48},
|
|
},
|
|
{
|
|
{LLAMA_GRETYPE_CHAR, 61},
|
|
{LLAMA_GRETYPE_RULE_REF, 7},
|
|
{LLAMA_GRETYPE_RULE_REF, 3},
|
|
{LLAMA_GRETYPE_CHAR, 48},
|
|
},
|
|
{
|
|
{LLAMA_GRETYPE_CHAR, 61},
|
|
{LLAMA_GRETYPE_RULE_REF, 7},
|
|
{LLAMA_GRETYPE_CHAR, 40},
|
|
}};
|
|
|
|
auto index = 0;
|
|
for (const llama_grammar_stack & stack : llama_grammar_get_stacks(grammar))
|
|
{
|
|
// compare stack to expected_stack
|
|
for (uint32_t i = 0; i < stack.size(); i++)
|
|
{
|
|
const llama_grammar_element * element = stack[i];
|
|
const llama_grammar_element & expected_element = expected_stacks[index][i];
|
|
|
|
// pretty print error message before asserting
|
|
if (expected_element.type != element->type || expected_element.value != element->value)
|
|
{
|
|
fprintf(stderr, "index: %d\n", index);
|
|
fprintf(stderr, "expected_element: %d, %u\n", expected_element.type, expected_element.value);
|
|
fprintf(stderr, "actual_element: %d, %u\n", element->type, element->value);
|
|
fprintf(stderr, "expected_element != actual_element\n");
|
|
}
|
|
|
|
assert(expected_element.type == element->type && expected_element.value == element->value);
|
|
}
|
|
index++;
|
|
}
|
|
|
|
std::vector<llama_grammar_candidate> next_candidates;
|
|
next_candidates.resize(24);
|
|
|
|
for (size_t i = 0; i < 24; ++i)
|
|
{
|
|
uint32_t *cp = new uint32_t[2]; // dynamically allocate memory for code_point
|
|
cp[0] = 37 + i;
|
|
cp[1] = 0;
|
|
next_candidates[i] = {i, cp, {}};
|
|
}
|
|
|
|
std::vector<std::vector<std::pair<uint32_t, uint16_t>>> expected_reject = {
|
|
{
|
|
{0, 37},
|
|
{1, 38},
|
|
{2, 39},
|
|
{3, 40},
|
|
{4, 41},
|
|
{5, 42},
|
|
{6, 43},
|
|
{7, 44},
|
|
{8, 45},
|
|
{9, 46},
|
|
{10, 47},
|
|
{11, 48},
|
|
{12, 49},
|
|
{13, 50},
|
|
{14, 51},
|
|
{15, 52},
|
|
{16, 53},
|
|
{17, 54},
|
|
{18, 55},
|
|
{19, 56},
|
|
{20, 57},
|
|
{21, 58},
|
|
{22, 59},
|
|
{23, 60},
|
|
},
|
|
{
|
|
{0, 37},
|
|
{1, 38},
|
|
{2, 39},
|
|
{3, 40},
|
|
{4, 41},
|
|
{5, 42},
|
|
{6, 43},
|
|
{7, 44},
|
|
{8, 45},
|
|
{9, 46},
|
|
{10, 47},
|
|
{21, 58},
|
|
{22, 59},
|
|
{23, 60},
|
|
},
|
|
{
|
|
{0, 37},
|
|
{1, 38},
|
|
{2, 39},
|
|
{3, 40},
|
|
{4, 41},
|
|
{5, 42},
|
|
{6, 43},
|
|
{7, 44},
|
|
{8, 45},
|
|
{9, 46},
|
|
{10, 47},
|
|
{21, 58},
|
|
{22, 59},
|
|
{23, 60},
|
|
},
|
|
{
|
|
{0, 37},
|
|
{1, 38},
|
|
{2, 39},
|
|
{4, 41},
|
|
{5, 42},
|
|
{6, 43},
|
|
{7, 44},
|
|
{8, 45},
|
|
{9, 46},
|
|
{10, 47},
|
|
{11, 48},
|
|
{12, 49},
|
|
{13, 50},
|
|
{14, 51},
|
|
{15, 52},
|
|
{16, 53},
|
|
{17, 54},
|
|
{18, 55},
|
|
{19, 56},
|
|
{20, 57},
|
|
{21, 58},
|
|
{22, 59},
|
|
{23, 60},
|
|
},
|
|
{
|
|
{0, 37},
|
|
{1, 38},
|
|
{2, 39},
|
|
{3, 40},
|
|
{4, 41},
|
|
{5, 42},
|
|
{6, 43},
|
|
{7, 44},
|
|
{8, 45},
|
|
{9, 46},
|
|
{10, 47},
|
|
{11, 48},
|
|
{12, 49},
|
|
{13, 50},
|
|
{14, 51},
|
|
{15, 52},
|
|
{16, 53},
|
|
{17, 54},
|
|
{18, 55},
|
|
{19, 56},
|
|
{20, 57},
|
|
{21, 58},
|
|
{22, 59},
|
|
{23, 60},
|
|
},
|
|
{
|
|
{0, 37},
|
|
{1, 38},
|
|
{2, 39},
|
|
{3, 40},
|
|
{4, 41},
|
|
{5, 42},
|
|
{6, 43},
|
|
{7, 44},
|
|
{8, 45},
|
|
{9, 46},
|
|
{10, 47},
|
|
{21, 58},
|
|
{22, 59},
|
|
{23, 60},
|
|
},
|
|
{
|
|
{0, 37},
|
|
{1, 38},
|
|
{2, 39},
|
|
{3, 40},
|
|
{4, 41},
|
|
{5, 42},
|
|
{6, 43},
|
|
{7, 44},
|
|
{8, 45},
|
|
{9, 46},
|
|
{10, 47},
|
|
{21, 58},
|
|
{22, 59},
|
|
{23, 60},
|
|
},
|
|
{
|
|
{0, 37},
|
|
{1, 38},
|
|
{2, 39},
|
|
{4, 41},
|
|
{5, 42},
|
|
{6, 43},
|
|
{7, 44},
|
|
{8, 45},
|
|
{9, 46},
|
|
{10, 47},
|
|
{11, 48},
|
|
{12, 49},
|
|
{13, 50},
|
|
{14, 51},
|
|
{15, 52},
|
|
{16, 53},
|
|
{17, 54},
|
|
{18, 55},
|
|
{19, 56},
|
|
{20, 57},
|
|
{21, 58},
|
|
{22, 59},
|
|
{23, 60},
|
|
},
|
|
};
|
|
|
|
std::vector<llama_grammar_candidate> rejects = llama_grammar_reject_candidates_for_stack(llama_grammar_get_rules(grammar), llama_grammar_get_stacks(grammar)[0], next_candidates);
|
|
|
|
std::vector<std::vector<llama_grammar_candidate>> all_rejects;
|
|
|
|
for (std::size_t count = 0; count < llama_grammar_get_stacks(grammar).size(); ++count)
|
|
{
|
|
rejects = llama_grammar_reject_candidates_for_stack(llama_grammar_get_rules(grammar), llama_grammar_get_stacks(grammar)[count], next_candidates);
|
|
all_rejects.push_back(rejects);
|
|
}
|
|
|
|
index = 0;
|
|
for (auto rej : all_rejects)
|
|
{
|
|
for (uint32_t i = 0; i < rej.size(); i++)
|
|
{
|
|
auto element = rej[i];
|
|
auto expected_element = expected_reject[index][i];
|
|
assert(element.index == expected_element.first && *element.code_points == expected_element.second);
|
|
}
|
|
index++;
|
|
}
|
|
|
|
for (auto &candidate : next_candidates)
|
|
{
|
|
delete[] candidate.code_points;
|
|
candidate.code_points = nullptr;
|
|
}
|
|
|
|
llama_grammar_free_impl(grammar);
|
|
|
|
return 0;
|
|
}
|