llama.cpp/tests/test-grammar-integration.cpp
Olivier Chafik cbaadc9294
grammars: 1.5x faster inference w/ complex grammars (vector reserves / reuses) (#6609)
* grammars: reserve rejects & next candidates

* grammars: reuse new_stacks

* grammars: fix missing sig change in llama.h

* grammars: fix test (api changed)

* grammars: update gbnf-validator.cpp

* grammars: simpler syntax (no swap)
2024-04-11 19:47:34 +01:00

244 lines
7.2 KiB
C++

#ifdef NDEBUG
#undef NDEBUG
#endif
#define LLAMA_API_INTERNAL
#include "ggml.h"
#include "llama.h"
#include "grammar-parser.h"
#include "unicode.h"
#include <cassert>
#include <string>
static void test_simple_grammar() {
// Test case for a simple grammar
const std::string grammar_str = R"""(root ::= expr
expr ::= term ("+" term)*
term ::= number
number ::= [0-9]+)""";
grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
// Ensure we parsed correctly
assert(!parsed_grammar.rules.empty());
// Ensure we have a root node
assert(!(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()));
std::vector<const llama_grammar_element*> grammar_rules(parsed_grammar.c_rules());
llama_grammar* grammar = llama_grammar_init(
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
std::string input = "123+456";
auto decoded = decode_utf8(input, {});
const auto & code_points = decoded.first;
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
auto prev_stacks = grammar->stacks;
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
assert(!grammar->stacks.empty());
}
bool completed_grammar = false;
for (const auto & stack : grammar->stacks) {
if (stack.empty()) {
completed_grammar = true;
break;
}
}
assert(completed_grammar);
// Clean up allocated memory
llama_grammar_free(grammar);
}
static void test_complex_grammar() {
// Test case for a more complex grammar, with both failure strings and success strings
const std::string grammar_str = R"""(root ::= expression
expression ::= term ws (("+"|"-") ws term)*
term ::= factor ws (("*"|"/") ws factor)*
factor ::= number | variable | "(" expression ")" | function-call
number ::= [0-9]+
variable ::= [a-zA-Z_][a-zA-Z0-9_]*
function-call ::= variable ws "(" (expression ("," ws expression)*)? ")"
ws ::= [ \t\n\r]?)""";
grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
// Ensure we parsed correctly
assert(!parsed_grammar.rules.empty());
// Ensure we have a root node
assert(!(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()));
std::vector<const llama_grammar_element*> grammar_rules(parsed_grammar.c_rules());
llama_grammar* grammar = llama_grammar_init(
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
// Save the original grammar stacks so that we can reset after every new string we want to test
auto original_stacks = grammar->stacks;
// Test a few strings
std::vector<std::string> test_strings_pass = {
"42",
"1*2*3*4*5",
"x",
"x+10",
"x1+y2",
"(a+b)*(c-d)",
"func()",
"func(x,y+2)",
"a*(b+c)-d/e",
"f(g(x),h(y,z))",
"x + 10",
"x1 + y2",
"(a + b) * (c - d)",
"func()",
"func(x, y + 2)",
"a * (b + c) - d / e",
"f(g(x), h(y, z))",
"123+456",
"123*456*789-123/456+789*123",
"123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456"
};
std::vector<std::string> test_strings_fail = {
"+",
"/ 3x",
"x + + y",
"a * / b",
"func(,)",
"func(x y)",
"(a + b",
"x + y)",
"a + b * (c - d",
"42 +",
"x +",
"x + 10 +",
"(a + b) * (c - d",
"func(",
"func(x, y + 2",
"a * (b + c) - d /",
"f(g(x), h(y, z)",
"123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456/",
};
// Passing strings
for (const auto & test_string : test_strings_pass) {
auto decoded = decode_utf8(test_string, {});
const auto & code_points = decoded.first;
int pos = 0;
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
++pos;
auto prev_stacks = grammar->stacks;
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
// Expect that each code point will not cause the grammar to fail
if (grammar->stacks.empty()) {
fprintf(stdout, "Error at position %d\n", pos);
fprintf(stderr, "Unexpected character '%s'\n", unicode_cpt_to_utf8(*it).c_str());
fprintf(stderr, "Input string is %s:\n", test_string.c_str());
}
assert(!grammar->stacks.empty());
}
bool completed_grammar = false;
for (const auto & stack : grammar->stacks) {
if (stack.empty()) {
completed_grammar = true;
break;
}
}
assert(completed_grammar);
// Reset the grammar stacks
grammar->stacks = original_stacks;
}
// Failing strings
for (const auto & test_string : test_strings_fail) {
auto decoded = decode_utf8(test_string, {});
const auto & code_points = decoded.first;
bool parse_failed = false;
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
auto prev_stacks = grammar->stacks;
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
if (grammar->stacks.empty()) {
parse_failed = true;
break;
}
assert(!grammar->stacks.empty());
}
bool completed_grammar = false;
for (const auto & stack : grammar->stacks) {
if (stack.empty()) {
completed_grammar = true;
break;
}
}
// Ensure that the grammar is not completed, or that each string failed to match as-expected
assert((!completed_grammar) || parse_failed);
// Reset the grammar stacks
grammar->stacks = original_stacks;
}
// Clean up allocated memory
llama_grammar_free(grammar);
}
static void test_failure_missing_root() {
// Test case for a grammar that is missing a root rule
const std::string grammar_str = R"""(rot ::= expr
expr ::= term ("+" term)*
term ::= number
number ::= [0-9]+)""";
grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
// Ensure we parsed correctly
assert(!parsed_grammar.rules.empty());
// Ensure we do NOT have a root node
assert(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end());
}
static void test_failure_missing_reference() {
// Test case for a grammar that is missing a referenced rule
const std::string grammar_str = R"""(root ::= expr
expr ::= term ("+" term)*
term ::= numero
number ::= [0-9]+)""";
fprintf(stderr, "Expected error: ");
grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
// Ensure we did NOT parsed correctly
assert(parsed_grammar.rules.empty());
fprintf(stderr, "End of expected error. Test successful.\n");
}
int main() {
test_simple_grammar();
test_complex_grammar();
test_failure_missing_root();
test_failure_missing_reference();
return 0;
}