From e0f556186b6e1f2b7032a1479edf5e89e2b1bd86 Mon Sep 17 00:00:00 2001 From: Haggai Nuchi Date: Mon, 13 May 2024 22:25:56 -0700 Subject: [PATCH] Add left recursion check: quit early instead of going into an infinite loop (#7083) * Add left recursion check: quit early instead of going into an infinite loop * Remove custom enum, rename left recursion check and move to "grammar internal" section, add handling for edge case where a leftmost nonterminal may be empty * Remove unnecessary declaration --- llama.cpp | 68 ++++++++++++++++++++++++++++++ tests/test-grammar-integration.cpp | 46 ++++++++++++++++++++ 2 files changed, 114 insertions(+) diff --git a/llama.cpp b/llama.cpp index 202bf94c8..01a35dfb6 100644 --- a/llama.cpp +++ b/llama.cpp @@ -13182,6 +13182,58 @@ static std::vector llama_grammar_reject_candidates( return rejects; } +static bool llama_grammar_detect_left_recursion( + const std::vector> & rules, + size_t rule_index, + std::vector * rules_visited, + std::vector * rules_in_progress, + std::vector * rules_may_be_empty) { + if ((*rules_in_progress)[rule_index]) { + return true; + } + + (*rules_in_progress)[rule_index] = true; + + const std::vector & rule = rules[rule_index]; + + // First check if the rule might produce the empty string. This could be done combined with the second + // step but it's more readable as two steps. + bool at_rule_start = true; + for (size_t i = 0; i < rule.size(); i++) { + if (llama_grammar_is_end_of_sequence(&rule[i])) { + if (at_rule_start) { + (*rules_may_be_empty)[rule_index] = true; + break; + } + at_rule_start = true; + } else { + at_rule_start = false; + } + } + + // Second, recurse into leftmost nonterminals (or next-leftmost as long as the previous nonterminal may + // be empty) + bool recurse_into_nonterminal = true; + for (size_t i = 0; i < rule.size(); i++) { + if (rule[i].type == LLAMA_GRETYPE_RULE_REF && recurse_into_nonterminal) { + if (llama_grammar_detect_left_recursion(rules, (size_t)rule[i].value, rules_visited, rules_in_progress, rules_may_be_empty)) { + return true; + } + if (!((*rules_may_be_empty)[(size_t)rule[i].value])) { + recurse_into_nonterminal = false; + } + } else if (llama_grammar_is_end_of_sequence(&rule[i])) { + recurse_into_nonterminal = true; + } else { + recurse_into_nonterminal = false; + } + } + + (*rules_in_progress)[rule_index] = false; + (*rules_visited)[rule_index] = true; + return false; +} + // // grammar - external // @@ -13201,6 +13253,19 @@ struct llama_grammar * llama_grammar_init( vec_rules[i].push_back({LLAMA_GRETYPE_END, 0}); } + // Check for left recursion + std::vector rules_visited(n_rules); + std::vector rules_in_progress(n_rules); + std::vector rules_may_be_empty(n_rules); + for (size_t i = 0; i < n_rules; i++) { + if (rules_visited[i]) { + continue; + } + if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) { + throw std::runtime_error(format("unsupported grammar, left recursion detected for nonterminal at index %zu", i)); + } + } + // loop over alternates of start rule to build initial stacks std::vector> stacks; pos = vec_rules[start_rule_index].data(); @@ -13223,6 +13288,9 @@ struct llama_grammar * llama_grammar_init( } } while (true); + // Important: vec_rules has to be moved here, not copied, because stacks contains + // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar + // then the pointers would be invalidated when the local vec_rules goes out of scope. return new llama_grammar{ std::move(vec_rules), std::move(stacks), {} }; } diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp index 1a4004e2a..01c5bb27a 100644 --- a/tests/test-grammar-integration.cpp +++ b/tests/test-grammar-integration.cpp @@ -28,6 +28,19 @@ static llama_grammar* build_grammar(const std::string & grammar_str) { return grammar; } +static bool test_build_grammar_fails(const std::string & grammar_str) { + fprintf(stderr, "⚫ Testing failure for grammar: %s\n", grammar_str.c_str()); + bool grammar_fails = false; + try { + build_grammar(grammar_str); + fprintf(stderr, " ❌ Expected build failure, but succeeded\n"); + } catch (const std::exception & err) { + grammar_fails = true; + fprintf(stdout, " ✅︎\n"); + } + return grammar_fails; +} + static bool match_string(const std::string & input, llama_grammar* grammar) { auto decoded = decode_utf8(input, {}); @@ -320,6 +333,38 @@ number ::= [0-9]+)"""; fprintf(stderr, " ✅︎ Passed\n"); } +static void test_failure_left_recursion() { + fprintf(stderr, "⚫ Testing left recursion detection:\n"); + + // Test simple left recursion detection + const std::string simple_str = R"""(root ::= "a" | root "a")"""; + assert(test_build_grammar_fails(simple_str)); + + // Test more complicated left recursion detection + const std::string medium_str = R"""( +root ::= asdf +asdf ::= "a" | asdf "a" +)"""; + assert(test_build_grammar_fails(medium_str)); + + // Test even more complicated left recursion detection + const std::string hard_str = R"""( +root ::= asdf +asdf ::= "a" | foo "b" +foo ::= "c" | asdf "d" | "e")"""; + assert(test_build_grammar_fails(hard_str)); + + // Test yet even more complicated left recursion detection + const std::string hardest_str = R"""( +root ::= asdf +asdf ::= "a" | foo "b" +foo ::= "c" | empty asdf "d" | "e" +empty ::= "blah" | )"""; + assert(test_build_grammar_fails(hardest_str)); + + fprintf(stderr, " ✅︎ Passed\n"); +} + int main() { fprintf(stdout, "Running grammar integration tests...\n"); test_simple_grammar(); @@ -327,6 +372,7 @@ int main() { test_quantifiers(); test_failure_missing_root(); test_failure_missing_reference(); + test_failure_left_recursion(); fprintf(stdout, "All tests passed.\n"); return 0; }