Added support for . (any character) token in grammar engine. (#6467)

* Added support for . (any characer) token in grammar engine.

* Add integration tests for any-character symbol.
This commit is contained in:
Clint Herron 2024-06-06 06:08:52 -07:00 committed by GitHub
parent a143c04375
commit ad675e1c67
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 52 additions and 2 deletions

View File

@ -266,6 +266,10 @@ namespace grammar_parser {
throw std::runtime_error(std::string("expecting ')' at ") + pos); throw std::runtime_error(std::string("expecting ')' at ") + pos);
} }
pos = parse_space(pos + 1, is_nested); pos = parse_space(pos + 1, is_nested);
} else if (*pos == '.') { // any char
last_sym_start = out_elements.size();
out_elements.push_back({LLAMA_GRETYPE_CHAR_ANY, 0});
pos = parse_space(pos + 1, is_nested);
} else if (*pos == '*') { } else if (*pos == '*') {
pos = parse_space(pos + 1, is_nested); pos = parse_space(pos + 1, is_nested);
handle_repetitions(0, -1); handle_repetitions(0, -1);
@ -401,6 +405,7 @@ namespace grammar_parser {
case LLAMA_GRETYPE_CHAR_NOT: return true; case LLAMA_GRETYPE_CHAR_NOT: return true;
case LLAMA_GRETYPE_CHAR_ALT: return true; case LLAMA_GRETYPE_CHAR_ALT: return true;
case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true; case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true;
case LLAMA_GRETYPE_CHAR_ANY: return true;
default: return false; default: return false;
} }
} }
@ -415,6 +420,7 @@ namespace grammar_parser {
case LLAMA_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break; case LLAMA_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break;
case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break; case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break;
case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break; case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break;
case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "CHAR_ANY"); break;
} }
switch (elem.type) { switch (elem.type) {
case LLAMA_GRETYPE_END: case LLAMA_GRETYPE_END:
@ -426,6 +432,7 @@ namespace grammar_parser {
case LLAMA_GRETYPE_CHAR_NOT: case LLAMA_GRETYPE_CHAR_NOT:
case LLAMA_GRETYPE_CHAR_RNG_UPPER: case LLAMA_GRETYPE_CHAR_RNG_UPPER:
case LLAMA_GRETYPE_CHAR_ALT: case LLAMA_GRETYPE_CHAR_ALT:
case LLAMA_GRETYPE_CHAR_ANY:
fprintf(file, "(\""); fprintf(file, "(\"");
print_grammar_char(file, elem.value); print_grammar_char(file, elem.value);
fprintf(file, "\") "); fprintf(file, "\") ");
@ -483,11 +490,15 @@ namespace grammar_parser {
} }
print_grammar_char(file, elem.value); print_grammar_char(file, elem.value);
break; break;
case LLAMA_GRETYPE_CHAR_ANY:
fprintf(file, ".");
break;
} }
if (is_char_element(elem)) { if (is_char_element(elem)) {
switch (rule[i + 1].type) { switch (rule[i + 1].type) {
case LLAMA_GRETYPE_CHAR_ALT: case LLAMA_GRETYPE_CHAR_ALT:
case LLAMA_GRETYPE_CHAR_RNG_UPPER: case LLAMA_GRETYPE_CHAR_RNG_UPPER:
case LLAMA_GRETYPE_CHAR_ANY:
break; break;
default: default:
fprintf(file, "] "); fprintf(file, "] ");

View File

@ -13640,7 +13640,7 @@ static std::pair<bool, const llama_grammar_element *> llama_grammar_match_char(
const uint32_t chr) { const uint32_t chr) {
bool found = false; bool found = false;
bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR; bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY;
GGML_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT); // NOLINT GGML_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT); // NOLINT
@ -13649,6 +13649,10 @@ static std::pair<bool, const llama_grammar_element *> llama_grammar_match_char(
// inclusive range, e.g. [a-z] // inclusive range, e.g. [a-z]
found = found || (pos->value <= chr && chr <= pos[1].value); found = found || (pos->value <= chr && chr <= pos[1].value);
pos += 2; pos += 2;
} else if (pos->type == LLAMA_GRETYPE_CHAR_ANY) {
// Any character matches "."
found = true;
pos += 1;
} else { } else {
// exact char match, e.g. [a] or "a" // exact char match, e.g. [a] or "a"
found = found || pos->value == chr; found = found || pos->value == chr;
@ -13666,7 +13670,7 @@ static bool llama_grammar_match_partial_char(
const llama_grammar_element * pos, const llama_grammar_element * pos,
const llama_partial_utf8 partial_utf8) { const llama_partial_utf8 partial_utf8) {
bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR; bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY;
GGML_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT); GGML_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT);
uint32_t partial_value = partial_utf8.value; uint32_t partial_value = partial_utf8.value;
@ -13696,6 +13700,9 @@ static bool llama_grammar_match_partial_char(
return is_positive_char; return is_positive_char;
} }
pos += 2; pos += 2;
} else if (pos->type == LLAMA_GRETYPE_CHAR_ANY) {
// Any character matches "."
return true;
} else { } else {
// exact char match, e.g. [a] or "a" // exact char match, e.g. [a] or "a"
if (low <= pos->value && pos->value <= high) { if (low <= pos->value && pos->value <= high) {
@ -13756,6 +13763,7 @@ static void llama_grammar_advance_stack(
} }
case LLAMA_GRETYPE_CHAR: case LLAMA_GRETYPE_CHAR:
case LLAMA_GRETYPE_CHAR_NOT: case LLAMA_GRETYPE_CHAR_NOT:
case LLAMA_GRETYPE_CHAR_ANY:
if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) { if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
// only add the stack if it's not a duplicate of one we already have // only add the stack if it's not a duplicate of one we already have
new_stacks.emplace_back(stack); new_stacks.emplace_back(stack);

View File

@ -365,6 +365,9 @@ extern "C" {
// modifies a preceding LLAMA_GRETYPE_CHAR or // modifies a preceding LLAMA_GRETYPE_CHAR or
// LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
LLAMA_GRETYPE_CHAR_ALT = 6, LLAMA_GRETYPE_CHAR_ALT = 6,
// any character (.)
LLAMA_GRETYPE_CHAR_ANY = 7,
}; };
typedef struct llama_grammar_element { typedef struct llama_grammar_element {

View File

@ -205,6 +205,33 @@ static void test_complex_grammar() {
); );
} }
static void test_special_chars() {
// A collection of tests to exercise special characters such as "."
test_grammar(
"special characters",
// Grammar
R"""(
root ::= ... "abc" ...
)""",
// Passing strings
{
"abcabcabc",
"aaaabcccc",
// NOTE: Also ensures that multi-byte characters still count as a single character
"🔵🟠✅abc❌🟠🔵"
},
// Failing strings
{
"aaabcccc",
"aaaaabcccc",
"aaaabccc",
"aaaabccccc",
"🔵🟠✅❌abc❌✅🟠🔵"
"🔵🟠abc🟠🔵"
}
);
}
static void test_quantifiers() { static void test_quantifiers() {
// A collection of tests to exercise * + and ? quantifiers // A collection of tests to exercise * + and ? quantifiers
@ -445,6 +472,7 @@ int main() {
fprintf(stdout, "Running grammar integration tests...\n"); fprintf(stdout, "Running grammar integration tests...\n");
test_simple_grammar(); test_simple_grammar();
test_complex_grammar(); test_complex_grammar();
test_special_chars();
test_quantifiers(); test_quantifiers();
test_failure_missing_root(); test_failure_missing_root();
test_failure_missing_reference(); test_failure_missing_reference();