From ad675e1c67a05b16e4e12abe30dbecfc808e7b7e Mon Sep 17 00:00:00 2001 From: Clint Herron Date: Thu, 6 Jun 2024 06:08:52 -0700 Subject: [PATCH] Added support for . (any character) token in grammar engine. (#6467) * Added support for . (any characer) token in grammar engine. * Add integration tests for any-character symbol. --- common/grammar-parser.cpp | 11 +++++++++++ llama.cpp | 12 ++++++++++-- llama.h | 3 +++ tests/test-grammar-integration.cpp | 28 ++++++++++++++++++++++++++++ 4 files changed, 52 insertions(+), 2 deletions(-) diff --git a/common/grammar-parser.cpp b/common/grammar-parser.cpp index 79d2b0354..a518b766d 100644 --- a/common/grammar-parser.cpp +++ b/common/grammar-parser.cpp @@ -266,6 +266,10 @@ namespace grammar_parser { throw std::runtime_error(std::string("expecting ')' at ") + pos); } pos = parse_space(pos + 1, is_nested); + } else if (*pos == '.') { // any char + last_sym_start = out_elements.size(); + out_elements.push_back({LLAMA_GRETYPE_CHAR_ANY, 0}); + pos = parse_space(pos + 1, is_nested); } else if (*pos == '*') { pos = parse_space(pos + 1, is_nested); handle_repetitions(0, -1); @@ -401,6 +405,7 @@ namespace grammar_parser { case LLAMA_GRETYPE_CHAR_NOT: return true; case LLAMA_GRETYPE_CHAR_ALT: return true; case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true; + case LLAMA_GRETYPE_CHAR_ANY: return true; default: return false; } } @@ -415,6 +420,7 @@ namespace grammar_parser { case LLAMA_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break; case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break; case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break; + case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "CHAR_ANY"); break; } switch (elem.type) { case LLAMA_GRETYPE_END: @@ -426,6 +432,7 @@ namespace grammar_parser { case LLAMA_GRETYPE_CHAR_NOT: case LLAMA_GRETYPE_CHAR_RNG_UPPER: case LLAMA_GRETYPE_CHAR_ALT: + case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "(\""); print_grammar_char(file, elem.value); fprintf(file, "\") "); @@ -483,11 +490,15 @@ namespace grammar_parser { } print_grammar_char(file, elem.value); break; + case LLAMA_GRETYPE_CHAR_ANY: + fprintf(file, "."); + break; } if (is_char_element(elem)) { switch (rule[i + 1].type) { case LLAMA_GRETYPE_CHAR_ALT: case LLAMA_GRETYPE_CHAR_RNG_UPPER: + case LLAMA_GRETYPE_CHAR_ANY: break; default: fprintf(file, "] "); diff --git a/llama.cpp b/llama.cpp index cefb4d1d5..32264a008 100644 --- a/llama.cpp +++ b/llama.cpp @@ -13640,7 +13640,7 @@ static std::pair llama_grammar_match_char( const uint32_t chr) { bool found = false; - bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR; + bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY; GGML_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT); // NOLINT @@ -13649,6 +13649,10 @@ static std::pair llama_grammar_match_char( // inclusive range, e.g. [a-z] found = found || (pos->value <= chr && chr <= pos[1].value); pos += 2; + } else if (pos->type == LLAMA_GRETYPE_CHAR_ANY) { + // Any character matches "." + found = true; + pos += 1; } else { // exact char match, e.g. [a] or "a" found = found || pos->value == chr; @@ -13666,7 +13670,7 @@ static bool llama_grammar_match_partial_char( const llama_grammar_element * pos, const llama_partial_utf8 partial_utf8) { - bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR; + bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY; GGML_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT); uint32_t partial_value = partial_utf8.value; @@ -13696,6 +13700,9 @@ static bool llama_grammar_match_partial_char( return is_positive_char; } pos += 2; + } else if (pos->type == LLAMA_GRETYPE_CHAR_ANY) { + // Any character matches "." + return true; } else { // exact char match, e.g. [a] or "a" if (low <= pos->value && pos->value <= high) { @@ -13756,6 +13763,7 @@ static void llama_grammar_advance_stack( } case LLAMA_GRETYPE_CHAR: case LLAMA_GRETYPE_CHAR_NOT: + case LLAMA_GRETYPE_CHAR_ANY: if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) { // only add the stack if it's not a duplicate of one we already have new_stacks.emplace_back(stack); diff --git a/llama.h b/llama.h index 9dcd67bef..62908261f 100644 --- a/llama.h +++ b/llama.h @@ -365,6 +365,9 @@ extern "C" { // modifies a preceding LLAMA_GRETYPE_CHAR or // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) LLAMA_GRETYPE_CHAR_ALT = 6, + + // any character (.) + LLAMA_GRETYPE_CHAR_ANY = 7, }; typedef struct llama_grammar_element { diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp index 9bdab05af..8787fb1ec 100644 --- a/tests/test-grammar-integration.cpp +++ b/tests/test-grammar-integration.cpp @@ -205,6 +205,33 @@ static void test_complex_grammar() { ); } +static void test_special_chars() { + // A collection of tests to exercise special characters such as "." + test_grammar( + "special characters", + // Grammar + R"""( + root ::= ... "abc" ... + )""", + // Passing strings + { + "abcabcabc", + "aaaabcccc", + // NOTE: Also ensures that multi-byte characters still count as a single character + "🔵🟠✅abc❌🟠🔵" + }, + // Failing strings + { + "aaabcccc", + "aaaaabcccc", + "aaaabccc", + "aaaabccccc", + "🔵🟠✅❌abc❌✅🟠🔵" + "🔵🟠abc🟠🔵" + } + ); +} + static void test_quantifiers() { // A collection of tests to exercise * + and ? quantifiers @@ -445,6 +472,7 @@ int main() { fprintf(stdout, "Running grammar integration tests...\n"); test_simple_grammar(); test_complex_grammar(); + test_special_chars(); test_quantifiers(); test_failure_missing_root(); test_failure_missing_reference();