mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-11 19:21:46 +00:00
Extending grammar integration tests (#6644)
* Cleaning up integration tests to share code between tests and make it simpler to add new tests. * Add tests around quantifiers to ensure both matching and non-matching compliance. * Add slightly more complex grammar with quantifiers to test references with quantifiers. * Fixing build when C++17 is not present. * Separating test calls to give more helpful stack traces on failure. Adding verbose messages to give visibility for what is being tested. * Adding quotes around strings to explicitly show whitespace * Removing trailing whitespace. * Implementing suggestions from @ochafik -- grammars and test strings now print and flush before tests to aid in debugging segfaults and whatnot. * Cleaning up forgotten symbols. Modifying simple test to use test harness. Added comments for more verbose descriptions of what each test is accomplishing. * Unicode symbol modifications to hopefully make log easier to parse visually.
This commit is contained in:
parent
5539e6fdd1
commit
b8c1476e44
@ -10,15 +10,10 @@
|
||||
#include "unicode.h"
|
||||
#include <cassert>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
static void test_simple_grammar() {
|
||||
// Test case for a simple grammar
|
||||
const std::string grammar_str = R"""(root ::= expr
|
||||
expr ::= term ("+" term)*
|
||||
term ::= number
|
||||
number ::= [0-9]+)""";
|
||||
|
||||
grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
|
||||
static llama_grammar* build_grammar(const std::string & grammar_str) {
|
||||
auto parsed_grammar = grammar_parser::parse(grammar_str.c_str());
|
||||
|
||||
// Ensure we parsed correctly
|
||||
assert(!parsed_grammar.rules.empty());
|
||||
@ -30,8 +25,10 @@ number ::= [0-9]+)""";
|
||||
llama_grammar* grammar = llama_grammar_init(
|
||||
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
|
||||
|
||||
std::string input = "123+456";
|
||||
return grammar;
|
||||
}
|
||||
|
||||
static bool match_string(const std::string & input, llama_grammar* grammar) {
|
||||
auto decoded = decode_utf8(input, {});
|
||||
|
||||
const auto & code_points = decoded.first;
|
||||
@ -39,52 +36,117 @@ number ::= [0-9]+)""";
|
||||
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
||||
auto prev_stacks = grammar->stacks;
|
||||
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
|
||||
assert(!grammar->stacks.empty());
|
||||
if (grammar->stacks.empty()) {
|
||||
// no stacks means that the grammar failed to match at this point
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool completed_grammar = false;
|
||||
|
||||
for (const auto & stack : grammar->stacks) {
|
||||
if (stack.empty()) {
|
||||
completed_grammar = true;
|
||||
break;
|
||||
// An empty stack means that the grammar has been completed
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
assert(completed_grammar);
|
||||
return false;
|
||||
}
|
||||
|
||||
static void test_grammar(const std::string & test_desc, const std::string & grammar_str, const std::vector<std::string> & passing_strings, const std::vector<std::string> & failing_strings) {
|
||||
fprintf(stderr, "⚫ Testing %s. Grammar: %s\n", test_desc.c_str(), grammar_str.c_str());
|
||||
fflush(stderr);
|
||||
|
||||
auto grammar = build_grammar(grammar_str);
|
||||
|
||||
// Save the original grammar stacks so that we can reset after every new string we want to test
|
||||
auto original_stacks = grammar->stacks;
|
||||
|
||||
fprintf(stderr, " 🔵 Valid strings:\n");
|
||||
|
||||
// Passing strings
|
||||
for (const auto & test_string : passing_strings) {
|
||||
fprintf(stderr, " \"%s\" ", test_string.c_str());
|
||||
fflush(stderr);
|
||||
|
||||
bool matched = match_string(test_string, grammar);
|
||||
|
||||
if (!matched) {
|
||||
fprintf(stderr, "❌ (failed to match)\n");
|
||||
} else {
|
||||
fprintf(stdout, "✅︎\n");
|
||||
}
|
||||
|
||||
assert(matched);
|
||||
|
||||
// Reset the grammar stacks
|
||||
grammar->stacks = original_stacks;
|
||||
}
|
||||
|
||||
fprintf(stderr, " 🟠 Invalid strings:\n");
|
||||
|
||||
// Failing strings
|
||||
for (const auto & test_string : failing_strings) {
|
||||
fprintf(stderr, " \"%s\" ", test_string.c_str());
|
||||
fflush(stderr);
|
||||
|
||||
bool matched = match_string(test_string, grammar);
|
||||
|
||||
if (matched) {
|
||||
fprintf(stderr, "❌ (incorrectly matched)\n");
|
||||
} else {
|
||||
fprintf(stdout, "✅︎\n");
|
||||
}
|
||||
assert(!matched);
|
||||
|
||||
// Reset the grammar stacks
|
||||
grammar->stacks = original_stacks;
|
||||
}
|
||||
|
||||
// Clean up allocated memory
|
||||
llama_grammar_free(grammar);
|
||||
}
|
||||
|
||||
static void test_simple_grammar() {
|
||||
// Test case for a simple grammar
|
||||
test_grammar(
|
||||
"simple grammar",
|
||||
R"""(
|
||||
root ::= expr
|
||||
expr ::= term ("+" term)*
|
||||
term ::= number
|
||||
number ::= [0-9]+)""",
|
||||
// Passing strings
|
||||
{
|
||||
"42",
|
||||
"1+2+3+4+5",
|
||||
"123+456",
|
||||
},
|
||||
// Failing strings
|
||||
{
|
||||
"+",
|
||||
"/ 3",
|
||||
"1+2+3+4+5+",
|
||||
"12a45",
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
static void test_complex_grammar() {
|
||||
// Test case for a more complex grammar, with both failure strings and success strings
|
||||
const std::string grammar_str = R"""(root ::= expression
|
||||
test_grammar(
|
||||
"medium complexity grammar",
|
||||
// Grammar
|
||||
R"""(
|
||||
root ::= expression
|
||||
expression ::= term ws (("+"|"-") ws term)*
|
||||
term ::= factor ws (("*"|"/") ws factor)*
|
||||
factor ::= number | variable | "(" expression ")" | function-call
|
||||
number ::= [0-9]+
|
||||
variable ::= [a-zA-Z_][a-zA-Z0-9_]*
|
||||
function-call ::= variable ws "(" (expression ("," ws expression)*)? ")"
|
||||
ws ::= [ \t\n\r]?)""";
|
||||
|
||||
grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
|
||||
|
||||
// Ensure we parsed correctly
|
||||
assert(!parsed_grammar.rules.empty());
|
||||
|
||||
// Ensure we have a root node
|
||||
assert(!(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()));
|
||||
|
||||
std::vector<const llama_grammar_element*> grammar_rules(parsed_grammar.c_rules());
|
||||
llama_grammar* grammar = llama_grammar_init(
|
||||
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
|
||||
|
||||
// Save the original grammar stacks so that we can reset after every new string we want to test
|
||||
auto original_stacks = grammar->stacks;
|
||||
|
||||
// Test a few strings
|
||||
std::vector<std::string> test_strings_pass = {
|
||||
ws ::= [ \t\n\r]?)""",
|
||||
// Passing strings
|
||||
{
|
||||
"42",
|
||||
"1*2*3*4*5",
|
||||
"x",
|
||||
@ -105,9 +167,9 @@ ws ::= [ \t\n\r]?)""";
|
||||
"123+456",
|
||||
"123*456*789-123/456+789*123",
|
||||
"123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456"
|
||||
};
|
||||
|
||||
std::vector<std::string> test_strings_fail = {
|
||||
},
|
||||
// Failing strings
|
||||
{
|
||||
"+",
|
||||
"/ 3x",
|
||||
"x + + y",
|
||||
@ -126,82 +188,101 @@ ws ::= [ \t\n\r]?)""";
|
||||
"a * (b + c) - d /",
|
||||
"f(g(x), h(y, z)",
|
||||
"123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456/",
|
||||
};
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
static void test_quantifiers() {
|
||||
// A collection of tests to exercise * + and ? quantifiers
|
||||
|
||||
test_grammar(
|
||||
"* quantifier",
|
||||
// Grammar
|
||||
R"""(root ::= "a"*)""",
|
||||
// Passing strings
|
||||
for (const auto & test_string : test_strings_pass) {
|
||||
auto decoded = decode_utf8(test_string, {});
|
||||
|
||||
const auto & code_points = decoded.first;
|
||||
|
||||
int pos = 0;
|
||||
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
||||
++pos;
|
||||
auto prev_stacks = grammar->stacks;
|
||||
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
|
||||
|
||||
// Expect that each code point will not cause the grammar to fail
|
||||
if (grammar->stacks.empty()) {
|
||||
fprintf(stdout, "Error at position %d\n", pos);
|
||||
fprintf(stderr, "Unexpected character '%s'\n", unicode_cpt_to_utf8(*it).c_str());
|
||||
fprintf(stderr, "Input string is %s:\n", test_string.c_str());
|
||||
}
|
||||
assert(!grammar->stacks.empty());
|
||||
}
|
||||
|
||||
bool completed_grammar = false;
|
||||
|
||||
for (const auto & stack : grammar->stacks) {
|
||||
if (stack.empty()) {
|
||||
completed_grammar = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
assert(completed_grammar);
|
||||
|
||||
// Reset the grammar stacks
|
||||
grammar->stacks = original_stacks;
|
||||
}
|
||||
|
||||
{
|
||||
"",
|
||||
"a",
|
||||
"aaaaa",
|
||||
"aaaaaaaaaaaaaaaaaa",
|
||||
"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
|
||||
},
|
||||
// Failing strings
|
||||
for (const auto & test_string : test_strings_fail) {
|
||||
auto decoded = decode_utf8(test_string, {});
|
||||
|
||||
const auto & code_points = decoded.first;
|
||||
bool parse_failed = false;
|
||||
|
||||
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
||||
auto prev_stacks = grammar->stacks;
|
||||
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
|
||||
if (grammar->stacks.empty()) {
|
||||
parse_failed = true;
|
||||
break;
|
||||
{
|
||||
"b",
|
||||
"ab",
|
||||
"aab",
|
||||
"ba",
|
||||
"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab"
|
||||
}
|
||||
assert(!grammar->stacks.empty());
|
||||
);
|
||||
test_grammar(
|
||||
"+ quantifier",
|
||||
// Grammar
|
||||
R"""(root ::= "a"+)""",
|
||||
// Passing strings
|
||||
{
|
||||
"a",
|
||||
"aaaaa",
|
||||
"aaaaaaaaaaaaaaaaaa",
|
||||
"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
|
||||
},
|
||||
// Failing strings
|
||||
{
|
||||
"",
|
||||
"b",
|
||||
"ab",
|
||||
"aab",
|
||||
"ba",
|
||||
"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab"
|
||||
}
|
||||
|
||||
bool completed_grammar = false;
|
||||
|
||||
for (const auto & stack : grammar->stacks) {
|
||||
if (stack.empty()) {
|
||||
completed_grammar = true;
|
||||
break;
|
||||
);
|
||||
test_grammar(
|
||||
"? quantifier",
|
||||
// Grammar
|
||||
R"""(root ::= "a"?)""",
|
||||
// Passing strings
|
||||
{
|
||||
"",
|
||||
"a"
|
||||
},
|
||||
// Failing strings
|
||||
{
|
||||
"b",
|
||||
"ab",
|
||||
"aa",
|
||||
"ba",
|
||||
}
|
||||
);
|
||||
test_grammar(
|
||||
"mixed quantifiers",
|
||||
// Grammar
|
||||
R"""(
|
||||
root ::= cons+ vowel* cons? (vowel cons)*
|
||||
vowel ::= [aeiouy]
|
||||
cons ::= [bcdfghjklmnpqrstvwxyz]
|
||||
)""",
|
||||
// Passing strings
|
||||
{
|
||||
"yes",
|
||||
"no",
|
||||
"noyes",
|
||||
"crwth",
|
||||
"four",
|
||||
"bryyyy",
|
||||
},
|
||||
// Failing strings
|
||||
{
|
||||
"yess",
|
||||
"yesno",
|
||||
"forty",
|
||||
"catyyy",
|
||||
}
|
||||
|
||||
// Ensure that the grammar is not completed, or that each string failed to match as-expected
|
||||
assert((!completed_grammar) || parse_failed);
|
||||
|
||||
// Reset the grammar stacks
|
||||
grammar->stacks = original_stacks;
|
||||
}
|
||||
|
||||
// Clean up allocated memory
|
||||
llama_grammar_free(grammar);
|
||||
);
|
||||
}
|
||||
|
||||
static void test_failure_missing_root() {
|
||||
fprintf(stderr, "⚫ Testing missing root node:\n");
|
||||
// Test case for a grammar that is missing a root rule
|
||||
const std::string grammar_str = R"""(rot ::= expr
|
||||
expr ::= term ("+" term)*
|
||||
@ -215,11 +296,15 @@ number ::= [0-9]+)""";
|
||||
|
||||
// Ensure we do NOT have a root node
|
||||
assert(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end());
|
||||
fprintf(stderr, " ✅︎ Passed\n");
|
||||
}
|
||||
|
||||
static void test_failure_missing_reference() {
|
||||
fprintf(stderr, "⚫ Testing missing reference node:\n");
|
||||
|
||||
// Test case for a grammar that is missing a referenced rule
|
||||
const std::string grammar_str = R"""(root ::= expr
|
||||
const std::string grammar_str =
|
||||
R"""(root ::= expr
|
||||
expr ::= term ("+" term)*
|
||||
term ::= numero
|
||||
number ::= [0-9]+)""";
|
||||
@ -231,13 +316,17 @@ number ::= [0-9]+)""";
|
||||
// Ensure we did NOT parsed correctly
|
||||
assert(parsed_grammar.rules.empty());
|
||||
|
||||
fprintf(stderr, "End of expected error. Test successful.\n");
|
||||
fprintf(stderr, " End of expected error.\n");
|
||||
fprintf(stderr, " ✅︎ Passed\n");
|
||||
}
|
||||
|
||||
int main() {
|
||||
fprintf(stdout, "Running grammar integration tests...\n");
|
||||
test_simple_grammar();
|
||||
test_complex_grammar();
|
||||
test_quantifiers();
|
||||
test_failure_missing_root();
|
||||
test_failure_missing_reference();
|
||||
fprintf(stdout, "All tests passed.\n");
|
||||
return 0;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user