Extending grammar integration tests (#6644)

* Cleaning up integration tests to share code between tests and make it simpler to add new tests.

* Add tests around quantifiers to ensure both matching and non-matching compliance.

* Add slightly more complex grammar with quantifiers to test references with quantifiers.

* Fixing build when C++17 is not present.

* Separating test calls to give more helpful stack traces on failure. Adding verbose messages to give visibility for what is being tested.

* Adding quotes around strings to explicitly show whitespace

* Removing trailing whitespace.

* Implementing suggestions from @ochafik -- grammars and test strings now print and flush before tests to aid in debugging segfaults and whatnot.

* Cleaning up forgotten symbols. Modifying simple test to use test harness. Added comments for more verbose descriptions of what each test is accomplishing.

* Unicode symbol modifications to hopefully make log easier to parse visually.
This commit is contained in:
Clint Herron 2024-04-29 14:40:14 -04:00 committed by GitHub
parent 5539e6fdd1
commit b8c1476e44
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -10,15 +10,10 @@
#include "unicode.h" #include "unicode.h"
#include <cassert> #include <cassert>
#include <string> #include <string>
#include <vector>
static void test_simple_grammar() { static llama_grammar* build_grammar(const std::string & grammar_str) {
// Test case for a simple grammar auto parsed_grammar = grammar_parser::parse(grammar_str.c_str());
const std::string grammar_str = R"""(root ::= expr
expr ::= term ("+" term)*
term ::= number
number ::= [0-9]+)""";
grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
// Ensure we parsed correctly // Ensure we parsed correctly
assert(!parsed_grammar.rules.empty()); assert(!parsed_grammar.rules.empty());
@ -30,8 +25,10 @@ number ::= [0-9]+)""";
llama_grammar* grammar = llama_grammar_init( llama_grammar* grammar = llama_grammar_init(
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
std::string input = "123+456"; return grammar;
}
static bool match_string(const std::string & input, llama_grammar* grammar) {
auto decoded = decode_utf8(input, {}); auto decoded = decode_utf8(input, {});
const auto & code_points = decoded.first; const auto & code_points = decoded.first;
@ -39,52 +36,117 @@ number ::= [0-9]+)""";
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
auto prev_stacks = grammar->stacks; auto prev_stacks = grammar->stacks;
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks); llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
assert(!grammar->stacks.empty()); if (grammar->stacks.empty()) {
// no stacks means that the grammar failed to match at this point
return false;
}
} }
bool completed_grammar = false;
for (const auto & stack : grammar->stacks) { for (const auto & stack : grammar->stacks) {
if (stack.empty()) { if (stack.empty()) {
completed_grammar = true; // An empty stack means that the grammar has been completed
break; return true;
} }
} }
assert(completed_grammar); return false;
}
static void test_grammar(const std::string & test_desc, const std::string & grammar_str, const std::vector<std::string> & passing_strings, const std::vector<std::string> & failing_strings) {
fprintf(stderr, "⚫ Testing %s. Grammar: %s\n", test_desc.c_str(), grammar_str.c_str());
fflush(stderr);
auto grammar = build_grammar(grammar_str);
// Save the original grammar stacks so that we can reset after every new string we want to test
auto original_stacks = grammar->stacks;
fprintf(stderr, " 🔵 Valid strings:\n");
// Passing strings
for (const auto & test_string : passing_strings) {
fprintf(stderr, " \"%s\" ", test_string.c_str());
fflush(stderr);
bool matched = match_string(test_string, grammar);
if (!matched) {
fprintf(stderr, "❌ (failed to match)\n");
} else {
fprintf(stdout, "✅︎\n");
}
assert(matched);
// Reset the grammar stacks
grammar->stacks = original_stacks;
}
fprintf(stderr, " 🟠 Invalid strings:\n");
// Failing strings
for (const auto & test_string : failing_strings) {
fprintf(stderr, " \"%s\" ", test_string.c_str());
fflush(stderr);
bool matched = match_string(test_string, grammar);
if (matched) {
fprintf(stderr, "❌ (incorrectly matched)\n");
} else {
fprintf(stdout, "✅︎\n");
}
assert(!matched);
// Reset the grammar stacks
grammar->stacks = original_stacks;
}
// Clean up allocated memory // Clean up allocated memory
llama_grammar_free(grammar); llama_grammar_free(grammar);
} }
static void test_simple_grammar() {
// Test case for a simple grammar
test_grammar(
"simple grammar",
R"""(
root ::= expr
expr ::= term ("+" term)*
term ::= number
number ::= [0-9]+)""",
// Passing strings
{
"42",
"1+2+3+4+5",
"123+456",
},
// Failing strings
{
"+",
"/ 3",
"1+2+3+4+5+",
"12a45",
}
);
}
static void test_complex_grammar() { static void test_complex_grammar() {
// Test case for a more complex grammar, with both failure strings and success strings // Test case for a more complex grammar, with both failure strings and success strings
const std::string grammar_str = R"""(root ::= expression test_grammar(
"medium complexity grammar",
// Grammar
R"""(
root ::= expression
expression ::= term ws (("+"|"-") ws term)* expression ::= term ws (("+"|"-") ws term)*
term ::= factor ws (("*"|"/") ws factor)* term ::= factor ws (("*"|"/") ws factor)*
factor ::= number | variable | "(" expression ")" | function-call factor ::= number | variable | "(" expression ")" | function-call
number ::= [0-9]+ number ::= [0-9]+
variable ::= [a-zA-Z_][a-zA-Z0-9_]* variable ::= [a-zA-Z_][a-zA-Z0-9_]*
function-call ::= variable ws "(" (expression ("," ws expression)*)? ")" function-call ::= variable ws "(" (expression ("," ws expression)*)? ")"
ws ::= [ \t\n\r]?)"""; ws ::= [ \t\n\r]?)""",
// Passing strings
grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str()); {
// Ensure we parsed correctly
assert(!parsed_grammar.rules.empty());
// Ensure we have a root node
assert(!(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()));
std::vector<const llama_grammar_element*> grammar_rules(parsed_grammar.c_rules());
llama_grammar* grammar = llama_grammar_init(
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
// Save the original grammar stacks so that we can reset after every new string we want to test
auto original_stacks = grammar->stacks;
// Test a few strings
std::vector<std::string> test_strings_pass = {
"42", "42",
"1*2*3*4*5", "1*2*3*4*5",
"x", "x",
@ -105,9 +167,9 @@ ws ::= [ \t\n\r]?)""";
"123+456", "123+456",
"123*456*789-123/456+789*123", "123*456*789-123/456+789*123",
"123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456" "123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456"
}; },
// Failing strings
std::vector<std::string> test_strings_fail = { {
"+", "+",
"/ 3x", "/ 3x",
"x + + y", "x + + y",
@ -126,82 +188,101 @@ ws ::= [ \t\n\r]?)""";
"a * (b + c) - d /", "a * (b + c) - d /",
"f(g(x), h(y, z)", "f(g(x), h(y, z)",
"123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456/", "123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456/",
}; }
);
}
static void test_quantifiers() {
// A collection of tests to exercise * + and ? quantifiers
test_grammar(
"* quantifier",
// Grammar
R"""(root ::= "a"*)""",
// Passing strings // Passing strings
for (const auto & test_string : test_strings_pass) { {
auto decoded = decode_utf8(test_string, {}); "",
"a",
const auto & code_points = decoded.first; "aaaaa",
"aaaaaaaaaaaaaaaaaa",
int pos = 0; "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { },
++pos;
auto prev_stacks = grammar->stacks;
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
// Expect that each code point will not cause the grammar to fail
if (grammar->stacks.empty()) {
fprintf(stdout, "Error at position %d\n", pos);
fprintf(stderr, "Unexpected character '%s'\n", unicode_cpt_to_utf8(*it).c_str());
fprintf(stderr, "Input string is %s:\n", test_string.c_str());
}
assert(!grammar->stacks.empty());
}
bool completed_grammar = false;
for (const auto & stack : grammar->stacks) {
if (stack.empty()) {
completed_grammar = true;
break;
}
}
assert(completed_grammar);
// Reset the grammar stacks
grammar->stacks = original_stacks;
}
// Failing strings // Failing strings
for (const auto & test_string : test_strings_fail) { {
auto decoded = decode_utf8(test_string, {}); "b",
"ab",
const auto & code_points = decoded.first; "aab",
bool parse_failed = false; "ba",
"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab"
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
auto prev_stacks = grammar->stacks;
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
if (grammar->stacks.empty()) {
parse_failed = true;
break;
} }
assert(!grammar->stacks.empty()); );
test_grammar(
"+ quantifier",
// Grammar
R"""(root ::= "a"+)""",
// Passing strings
{
"a",
"aaaaa",
"aaaaaaaaaaaaaaaaaa",
"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
},
// Failing strings
{
"",
"b",
"ab",
"aab",
"ba",
"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab"
} }
);
bool completed_grammar = false; test_grammar(
"? quantifier",
for (const auto & stack : grammar->stacks) { // Grammar
if (stack.empty()) { R"""(root ::= "a"?)""",
completed_grammar = true; // Passing strings
break; {
"",
"a"
},
// Failing strings
{
"b",
"ab",
"aa",
"ba",
} }
);
test_grammar(
"mixed quantifiers",
// Grammar
R"""(
root ::= cons+ vowel* cons? (vowel cons)*
vowel ::= [aeiouy]
cons ::= [bcdfghjklmnpqrstvwxyz]
)""",
// Passing strings
{
"yes",
"no",
"noyes",
"crwth",
"four",
"bryyyy",
},
// Failing strings
{
"yess",
"yesno",
"forty",
"catyyy",
} }
);
// Ensure that the grammar is not completed, or that each string failed to match as-expected
assert((!completed_grammar) || parse_failed);
// Reset the grammar stacks
grammar->stacks = original_stacks;
}
// Clean up allocated memory
llama_grammar_free(grammar);
} }
static void test_failure_missing_root() { static void test_failure_missing_root() {
fprintf(stderr, "⚫ Testing missing root node:\n");
// Test case for a grammar that is missing a root rule // Test case for a grammar that is missing a root rule
const std::string grammar_str = R"""(rot ::= expr const std::string grammar_str = R"""(rot ::= expr
expr ::= term ("+" term)* expr ::= term ("+" term)*
@ -215,11 +296,15 @@ number ::= [0-9]+)""";
// Ensure we do NOT have a root node // Ensure we do NOT have a root node
assert(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()); assert(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end());
fprintf(stderr, " ✅︎ Passed\n");
} }
static void test_failure_missing_reference() { static void test_failure_missing_reference() {
fprintf(stderr, "⚫ Testing missing reference node:\n");
// Test case for a grammar that is missing a referenced rule // Test case for a grammar that is missing a referenced rule
const std::string grammar_str = R"""(root ::= expr const std::string grammar_str =
R"""(root ::= expr
expr ::= term ("+" term)* expr ::= term ("+" term)*
term ::= numero term ::= numero
number ::= [0-9]+)"""; number ::= [0-9]+)""";
@ -231,13 +316,17 @@ number ::= [0-9]+)""";
// Ensure we did NOT parsed correctly // Ensure we did NOT parsed correctly
assert(parsed_grammar.rules.empty()); assert(parsed_grammar.rules.empty());
fprintf(stderr, "End of expected error. Test successful.\n"); fprintf(stderr, " End of expected error.\n");
fprintf(stderr, " ✅︎ Passed\n");
} }
int main() { int main() {
fprintf(stdout, "Running grammar integration tests...\n");
test_simple_grammar(); test_simple_grammar();
test_complex_grammar(); test_complex_grammar();
test_quantifiers();
test_failure_missing_root(); test_failure_missing_root();
test_failure_missing_reference(); test_failure_missing_reference();
fprintf(stdout, "All tests passed.\n");
return 0; return 0;
} }