From 55b2d0849d3ec9e45e4a4d9e480f5fa7977872a6 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Thu, 6 Jun 2024 10:07:06 +0100 Subject: [PATCH] grammars: x{min,max} repetition operator (#6640) * grammars: x{min,max} repetition operator + tweak +/*/? to avoid duplication of original over alternates * grammars: handle `x{n}` and fix `x{n,n}` * grammars: document new repetition operators * grammars: uniform use of int for min & max * grammars: refactor parser test * grammar: parsing tests w/ natural pretty print of updated expectations * grammars: much prettier print of expectations (+ TEST_GRAMMAR_PARSER_PRINT_ALL=1 to force all) * grammars: improve test pretty print again * grammars: pretty print rules and chars * grammars: fix copy rule skipping * grammars: disallow `a{,}` (not allowed in regexps) * Update common/grammar-parser.cpp Co-authored-by: Clint Herron * grammars: fix copy rule skipping (again) & display of expectations * grammars: more test cases * grammars: update reps parsing to bring ? / * / + closer to before * json: use new GBNF repetitions{m,n} syntax * grammars: update performance gotchas w/ repetition advice * Update examples/json_schema_to_grammar.py Co-authored-by: Clint Herron * Update examples/server/public/json-schema-to-grammar.mjs Co-authored-by: Clint Herron * grammars: comment on rule repetitions * grammars: ensure unambiguous number alternatives * grammar: nit typo switched error msgs * grammar: nit numbering in comment * json: update numeric rule to be unambiguous * Apply suggestions from code review Co-authored-by: Clint Herron * Update examples/server/public/json-schema-to-grammar.mjs Co-authored-by: Clint Herron * json: fix integral-part * grammar: add repetition tests --------- Co-authored-by: Clint Herron --- common/grammar-parser.cpp | 144 ++++- common/json-schema-to-grammar.cpp | 80 +-- examples/json_schema_to_grammar.py | 70 +-- examples/pydantic_models_to_grammar.py | 2 +- .../server/public/json-schema-to-grammar.mjs | 67 +- grammars/README.md | 12 +- tests/test-grammar-integration.cpp | 76 +++ tests/test-grammar-parser.cpp | 591 +++++++++++++----- tests/test-json-schema-to-grammar.cpp | 112 ++-- 9 files changed, 736 insertions(+), 418 deletions(-) diff --git a/common/grammar-parser.cpp b/common/grammar-parser.cpp index b5bc7d49b..79d2b0354 100644 --- a/common/grammar-parser.cpp +++ b/common/grammar-parser.cpp @@ -46,8 +46,12 @@ namespace grammar_parser { state.rules[rule_id] = rule; } + static bool is_digit_char(char c) { + return '0' <= c && c <= '9'; + } + static bool is_word_char(char c) { - return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9'); + return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || is_digit_char(c); } static std::pair parse_hex(const char * src, int size) { @@ -99,6 +103,17 @@ namespace grammar_parser { return pos; } + static const char * parse_int(const char * src) { + const char * pos = src; + while (is_digit_char(*pos)) { + pos++; + } + if (pos == src) { + throw std::runtime_error(std::string("expecting integer at ") + src); + } + return pos; + } + static std::pair parse_char(const char * src) { if (*src == '\\') { switch (src[1]) { @@ -137,6 +152,60 @@ namespace grammar_parser { bool is_nested) { size_t last_sym_start = out_elements.size(); const char * pos = src; + + auto handle_repetitions = [&](int min_times, int max_times) { + + if (last_sym_start == out_elements.size()) { + throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos); + } + + // apply transformation to previous symbol (last_sym_start to end) according to + // the following rewrite rules: + // S{m,n} --> S S S (m times) S'(n-m) + // S'(x) ::= S S'(x-1) | + // (... n-m definitions of these S' rules ...) + // S'(1) ::= S | + // S{m,} --> S S S (m times) S' + // S' ::= S S' | + // S* --> S{0,} + // --> S' ::= S S' | + // S+ --> S{1,} + // --> S S' + // S' ::= S S' | + // S? --> S{0,1} + // --> S' + // S' ::= S | + + std::vector previous_elements(out_elements.begin() + last_sym_start, out_elements.end()); + if (min_times == 0) { + out_elements.resize(last_sym_start); + } else { + // Repeat the previous elements (min_times - 1) times + for (int i = 1; i < min_times; i++) { + out_elements.insert(out_elements.end(), previous_elements.begin(), previous_elements.end()); + } + } + + uint32_t last_rec_rule_id = 0; + auto n_opt = max_times < 0 ? 1 : max_times - min_times; + + std::vector rec_rule(previous_elements); + for (int i = 0; i < n_opt; i++) { + rec_rule.resize(previous_elements.size()); + uint32_t rec_rule_id = generate_symbol_id(state, rule_name); + if (i > 0 || max_times < 0) { + rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id}); + } + rec_rule.push_back({LLAMA_GRETYPE_ALT, 0}); + rec_rule.push_back({LLAMA_GRETYPE_END, 0}); + add_rule(state, rec_rule_id, rec_rule); + last_rec_rule_id = rec_rule_id; + } + if (n_opt > 0) { + out_elements.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id}); + } + }; + while (*pos) { if (*pos == '"') { // literal string pos++; @@ -197,40 +266,47 @@ namespace grammar_parser { throw std::runtime_error(std::string("expecting ')' at ") + pos); } pos = parse_space(pos + 1, is_nested); - } else if (*pos == '*' || *pos == '+' || *pos == '?') { // repetition operator - if (last_sym_start == out_elements.size()) { - throw std::runtime_error(std::string("expecting preceding item to */+/? at ") + pos); - } - - // apply transformation to previous symbol (last_sym_start to end) according to - // rewrite rules: - // S* --> S' ::= S S' | - // S+ --> S' ::= S S' | S - // S? --> S' ::= S | - uint32_t sub_rule_id = generate_symbol_id(state, rule_name); - std::vector sub_rule; - // add preceding symbol to generated rule - sub_rule.insert( - sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end()); - if (*pos == '*' || *pos == '+') { - // cause generated rule to recurse - sub_rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); - } - // mark start of alternate def - sub_rule.push_back({LLAMA_GRETYPE_ALT, 0}); - if (*pos == '+') { - // add preceding symbol as alternate only for '+' (otherwise empty) - sub_rule.insert( - sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end()); - } - sub_rule.push_back({LLAMA_GRETYPE_END, 0}); - add_rule(state, sub_rule_id, sub_rule); - - // in original rule, replace previous symbol with reference to generated rule - out_elements.resize(last_sym_start); - out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); - + } else if (*pos == '*') { pos = parse_space(pos + 1, is_nested); + handle_repetitions(0, -1); + } else if (*pos == '+') { + pos = parse_space(pos + 1, is_nested); + handle_repetitions(1, -1); + } else if (*pos == '?') { + pos = parse_space(pos + 1, is_nested); + handle_repetitions(0, 1); + } else if (*pos == '{') { + pos = parse_space(pos + 1, is_nested); + + if (!is_digit_char(*pos)) { + throw std::runtime_error(std::string("expecting an int at ") + pos); + } + const char * int_end = parse_int(pos); + int min_times = std::stoul(std::string(pos, int_end - pos)); + pos = parse_space(int_end, is_nested); + + int max_times = -1; + + if (*pos == '}') { + max_times = min_times; + pos = parse_space(pos + 1, is_nested); + } else if (*pos == ',') { + pos = parse_space(pos + 1, is_nested); + + if (is_digit_char(*pos)) { + const char * int_end = parse_int(pos); + max_times = std::stoul(std::string(pos, int_end - pos)); + pos = parse_space(int_end, is_nested); + } + + if (*pos != '}') { + throw std::runtime_error(std::string("expecting '}' at ") + pos); + } + pos = parse_space(pos + 1, is_nested); + } else { + throw std::runtime_error(std::string("expecting ',' at ") + pos); + } + handle_repetitions(min_times, max_times); } else { break; } diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp index 9a71f5d8d..737bae27c 100644 --- a/common/json-schema-to-grammar.cpp +++ b/common/json-schema-to-grammar.cpp @@ -16,58 +16,27 @@ static std::string join(Iterator begin, Iterator end, const std::string & separa static std::string repeat(const std::string & str, size_t n); -static std::string build_repetition(const std::string & item_rule, int min_items, int max_items, const std::string & separator_rule = "", bool item_rule_is_literal = false) { +static std::string build_repetition(const std::string & item_rule, int min_items, int max_items, const std::string & separator_rule = "") { + auto has_max = max_items != std::numeric_limits::max(); + + if (min_items == 0 && max_items == 1) { + return item_rule + "?"; + } + if (separator_rule.empty()) { - if (min_items == 0 && max_items == 1) { - return item_rule + "?"; - } else if (min_items == 1 && max_items == std::numeric_limits::max()) { + if (min_items == 1 && !has_max) { return item_rule + "+"; - } - } - - std::string result; - if (min_items > 0) { - if (item_rule_is_literal && separator_rule.empty()) { - result = "\"" + repeat(std::string(item_rule.begin() + 1, item_rule.end() - 1), min_items) + "\""; + } else if (min_items == 0 && !has_max) { + return item_rule + "*"; } else { - std::vector items(min_items, item_rule); - result = join(items.begin(), items.end(), separator_rule.empty() ? " " : " " + separator_rule + " "); + return item_rule + "{" + std::to_string(min_items) + "," + (has_max ? std::to_string(max_items) : "") + "}"; } } - std::function opt_repetitions = [&](int up_to_n, bool prefix_with_sep) -> std::string { - auto content = prefix_with_sep && !separator_rule.empty() ? separator_rule + " " + item_rule : item_rule; - - if (up_to_n == 0) { - return ""; - } else if (up_to_n == 1) { - return "(" + content + ")?"; - } else if (!separator_rule.empty() && !prefix_with_sep) { - return "(" + content + " " + opt_repetitions(up_to_n - 1, true) + ")?"; - } else { - std::string res = repeat("(" + content + " ", up_to_n); - // strip trailing space - res = res.substr(0, res.length() - 1); - res += repeat(")?", up_to_n); - return res; - } - }; - - if (min_items > 0 && max_items != min_items) { - result += " "; + auto result = item_rule + " " + build_repetition("(" + separator_rule + " " + item_rule + ")", min_items == 0 ? 0 : min_items - 1, has_max ? max_items - 1 : max_items); + if (min_items == 0) { + result = "(" + result + ")?"; } - - if (max_items != std::numeric_limits::max()) { - result += opt_repetitions(max_items - min_items, min_items > 0); - } else { - std::string item_operator = "(" + (separator_rule.empty() ? "" : separator_rule + " ") + item_rule + ")"; - if (min_items == 0 && !separator_rule.empty()) { - result = "(" + item_rule + " " + item_operator + "*)?"; - } else { - result += item_operator + "*"; - } - } - return result; } @@ -78,30 +47,24 @@ struct BuiltinRule { std::vector deps; }; -const std::string _up_to_15_digits = build_repetition("[0-9]", 0, 15); - std::unordered_map PRIMITIVE_RULES = { {"boolean", {"(\"true\" | \"false\") space", {}}}, - {"decimal-part", {"[0-9] " + _up_to_15_digits, {}}}, - {"integral-part", {"[0-9] | [1-9] " + _up_to_15_digits, {}}}, + {"decimal-part", {"[0-9]{1,16}", {}}}, + {"integral-part", {"[0] | [1-9] [0-9]{0,15}", {}}}, {"number", {"(\"-\"? integral-part) (\".\" decimal-part)? ([eE] [-+]? integral-part)? space", {"integral-part", "decimal-part"}}}, {"integer", {"(\"-\"? integral-part) space", {"integral-part"}}}, {"value", {"object | array | string | number | boolean | null", {"object", "array", "string", "number", "boolean", "null"}}}, {"object", {"\"{\" space ( string \":\" space value (\",\" space string \":\" space value)* )? \"}\" space", {"string", "value"}}}, {"array", {"\"[\" space ( value (\",\" space value)* )? \"]\" space", {"value"}}}, - {"uuid", {"\"\\\"\" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] " - "\"-\" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] " - "\"-\" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] " - "\"-\" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] " - "\"-\" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] \"\\\"\" space", {}}}, - {"char", {"[^\"\\\\] | \"\\\\\" ([\"\\\\/bfnrt] | \"u\" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])", {}}}, + {"uuid", {"\"\\\"\" [0-9a-fA-F]{8} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{12} \"\\\"\" space", {}}}, + {"char", {"[^\"\\\\] | \"\\\\\" ([\"\\\\/bfnrt] | \"u\" [0-9a-fA-F]{4})", {}}}, {"string", {"\"\\\"\" char* \"\\\"\" space", {"char"}}}, {"null", {"\"null\" space", {}}}, }; std::unordered_map STRING_FORMAT_RULES = { - {"date", {"[0-9] [0-9] [0-9] [0-9] \"-\" ( \"0\" [1-9] | \"1\" [0-2] ) \"-\" ( \"0\" [1-9] | [1-2] [0-9] | \"3\" [0-1] )", {}}}, - {"time", {"([01] [0-9] | \"2\" [0-3]) \":\" [0-5] [0-9] \":\" [0-5] [0-9] ( \".\" [0-9] [0-9] [0-9] )? ( \"Z\" | ( \"+\" | \"-\" ) ( [01] [0-9] | \"2\" [0-3] ) \":\" [0-5] [0-9] )", {}}}, + {"date", {"[0-9]{4} \"-\" ( \"0\" [1-9] | \"1\" [0-2] ) \"-\" ( \"0\" [1-9] | [1-2] [0-9] | \"3\" [0-1] )", {}}}, + {"time", {"([01] [0-9] | \"2\" [0-3]) \":\" [0-5] [0-9] \":\" [0-5] [0-9] ( \".\" [0-9]{3} )? ( \"Z\" | ( \"+\" | \"-\" ) ( [01] [0-9] | \"2\" [0-3] ) \":\" [0-5] [0-9] )", {}}}, {"date-time", {"date \"T\" time", {"date", "time"}}}, {"date-string", {"\"\\\"\" date \"\\\"\" space", {"date"}}}, {"time-string", {"\"\\\"\" time \"\\\"\" space", {"time"}}}, @@ -385,8 +348,7 @@ private: sub_is_literal ? "\"" + sub + "\"" : sub, min_times, max_times, - "", - sub_is_literal + "" ); seq.back().second = false; } else { diff --git a/examples/json_schema_to_grammar.py b/examples/json_schema_to_grammar.py index 826cd3f72..7d889c3fe 100755 --- a/examples/json_schema_to_grammar.py +++ b/examples/json_schema_to_grammar.py @@ -6,52 +6,22 @@ import re import sys from typing import Any, Dict, List, Set, Tuple, Union -def _build_repetition(item_rule, min_items, max_items, separator_rule=None, item_rule_is_literal=False): + +def _build_repetition(item_rule, min_items, max_items, separator_rule=None): + + if min_items == 0 and max_items == 1: + return f'{item_rule}?' + if not separator_rule: - if min_items == 0 and max_items == 1: - return f'{item_rule}?' - elif min_items == 1 and max_items is None: + if min_items == 1 and max_items is None: return f'{item_rule}+' - - result = '' - - if min_items > 0: - if item_rule_is_literal and separator_rule is None: - result = '"' + (item_rule[1:-1] * min_items) + '"' + elif min_items == 0 and max_items is None: + return f'{item_rule}*' else: - result = (f' {separator_rule} ' if separator_rule else ' ').join([item_rule] * min_items) + return f'{item_rule}{{{min_items},{max_items if max_items is not None else ""}}}' - def opt_repetitions(up_to_n, prefix_with_sep=False): - ''' - - n=4, no sep: '(a (a (a (a)?)?)?)?' - - n=4, sep=',', prefix: '("," a ("," a ("," a ("," a)?)?)?)?' - - n=4, sep=',', no prefix: '(a ("," a ("," a ("," a)?)?)?)?' - ''' - - content = f'{separator_rule} {item_rule}' if prefix_with_sep and separator_rule else item_rule - if up_to_n == 0: - return '' - elif up_to_n == 1: - return f'({content})?' - elif separator_rule and not prefix_with_sep: - return f'({content} {opt_repetitions(up_to_n - 1, prefix_with_sep=True)})?' - else: - return (f'({content} ' * up_to_n).rstrip() + (')?' * up_to_n) - - if min_items > 0 and max_items != min_items: - result += ' ' - - if max_items is not None: - result += opt_repetitions(max_items - min_items, prefix_with_sep=min_items > 0) - else: - item_operator = f'({separator_rule + " " if separator_rule else ""}{item_rule})' - - if min_items == 0 and separator_rule: - result = f'({item_rule} {item_operator}*)?' - else: - result += f'{item_operator}*' - - return result + result = item_rule + ' ' + _build_repetition(f'({separator_rule} {item_rule})', min_items - 1 if min_items > 0 else 0, max_items - 1 if max_items is not None else None) + return f'({result})?' if min_items == 0 else result class BuiltinRule: @@ -59,31 +29,29 @@ class BuiltinRule: self.content = content self.deps = deps or [] -_up_to_15_digits = _build_repetition('[0-9]', 0, 15) - # whitespace is constrained to a single space char to prevent model "running away" in # whitespace. Also maybe improves generation quality? SPACE_RULE = '" "?' PRIMITIVE_RULES = { 'boolean' : BuiltinRule('("true" | "false") space', []), - 'decimal-part' : BuiltinRule('[0-9] ' + _up_to_15_digits, []), - 'integral-part': BuiltinRule('[0-9] | [1-9] ' + _up_to_15_digits, []), + 'decimal-part' : BuiltinRule('[0-9]{1,16}', []), + 'integral-part': BuiltinRule('[0] | [1-9] [0-9]{0,15}', []), 'number' : BuiltinRule('("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space', ['integral-part', 'decimal-part']), 'integer' : BuiltinRule('("-"? integral-part) space', ['integral-part']), 'value' : BuiltinRule('object | array | string | number | boolean | null', ['object', 'array', 'string', 'number', 'boolean', 'null']), 'object' : BuiltinRule('"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space', ['string', 'value']), 'array' : BuiltinRule('"[" space ( value ("," space value)* )? "]" space', ['value']), - 'uuid' : BuiltinRule(r'"\"" ' + ' "-" '.join('[0-9a-fA-F]' * n for n in [8, 4, 4, 4, 12]) + r' "\"" space', []), - 'char' : BuiltinRule(r'[^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])', []), + 'uuid' : BuiltinRule(r'"\"" [0-9a-fA-F]{8} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{12} "\"" space', []), + 'char' : BuiltinRule(r'[^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F]{4})', []), 'string' : BuiltinRule(r'"\"" char* "\"" space', ['char']), 'null' : BuiltinRule('"null" space', []), } # TODO: support "uri", "email" string formats STRING_FORMAT_RULES = { - 'date' : BuiltinRule('[0-9] [0-9] [0-9] [0-9] "-" ( "0" [1-9] | "1" [0-2] ) "-" ( \"0\" [1-9] | [1-2] [0-9] | "3" [0-1] )', []), - 'time' : BuiltinRule('([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9] [0-9] [0-9] )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )', []), + 'date' : BuiltinRule('[0-9]{4} "-" ( "0" [1-9] | "1" [0-2] ) "-" ( \"0\" [1-9] | [1-2] [0-9] | "3" [0-1] )', []), + 'time' : BuiltinRule('([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9]{3} )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )', []), 'date-time' : BuiltinRule('date "T" time', ['date', 'time']), 'date-string' : BuiltinRule('"\\"" date "\\"" space', ['date']), 'time-string' : BuiltinRule('"\\"" time "\\"" space', ['time']), @@ -333,7 +301,7 @@ class SchemaConverter: sub_rule_ids[sub] = id sub = id - seq[-1] = (_build_repetition(f'"{sub}"' if sub_is_literal else sub, min_times, max_times, item_rule_is_literal=sub_is_literal), False) + seq[-1] = (_build_repetition(f'"{sub}"' if sub_is_literal else sub, min_times, max_times), False) else: literal = '' while i < length: diff --git a/examples/pydantic_models_to_grammar.py b/examples/pydantic_models_to_grammar.py index 9acc7cc6d..f029c73a2 100644 --- a/examples/pydantic_models_to_grammar.py +++ b/examples/pydantic_models_to_grammar.py @@ -624,7 +624,7 @@ string ::= "\"" ( "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) )* "\"" ws ws ::= ([ \t\n] ws)? -float ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws +float ::= ("-"? ([0] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws integer ::= [0-9]+""" diff --git a/examples/server/public/json-schema-to-grammar.mjs b/examples/server/public/json-schema-to-grammar.mjs index 8e0be1b40..cef11eab8 100644 --- a/examples/server/public/json-schema-to-grammar.mjs +++ b/examples/server/public/json-schema-to-grammar.mjs @@ -2,57 +2,26 @@ const SPACE_RULE = '" "?'; function _buildRepetition(itemRule, minItems, maxItems, opts={}) { + if (minItems === 0 && maxItems === 1) { + return `${itemRule}?`; + } + + const separatorRule = opts.separatorRule ?? ''; const itemRuleIsLiteral = opts.itemRuleIsLiteral ?? false if (separatorRule === '') { - if (minItems === 0 && maxItems === 1) { - return `${itemRule}?`; - } else if (minItems === 1 && maxItems === undefined) { + if (minItems === 1 && maxItems === undefined) { return `${itemRule}+`; - } - } - - let result = ''; - if (minItems > 0) { - if (itemRuleIsLiteral && separatorRule === '') { - result = `"${itemRule.slice(1, -1).repeat(minItems)}"`; + } else if (minItems === 0 && maxItems === undefined) { + return `${itemRule}*`; } else { - result = Array.from({ length: minItems }, () => itemRule) - .join(separatorRule !== '' ? ` ${separatorRule} ` : ' '); + return `${itemRule}{${minItems},${maxItems !== undefined ? maxItems : ''}}`; } } - const optRepetitions = (upToN, prefixWithSep=false) => { - const content = separatorRule !== '' && prefixWithSep ? `${separatorRule} ${itemRule}` : itemRule; - if (upToN === 0) { - return ''; - } else if (upToN === 1) { - return `(${content})?`; - } else if (separatorRule !== '' && !prefixWithSep) { - return `(${content} ${optRepetitions(upToN - 1, true)})?`; - } else { - return Array.from({ length: upToN }, () => `(${content}`).join(' ').trim() + Array.from({ length: upToN }, () => ')?').join(''); - } - }; - - if (minItems > 0 && maxItems !== minItems) { - result += ' '; - } - - if (maxItems !== undefined) { - result += optRepetitions(maxItems - minItems, minItems > 0); - } else { - const itemOperator = `(${separatorRule !== '' ? separatorRule + ' ' : ''}${itemRule})`; - - if (minItems === 0 && separatorRule !== '') { - result = `(${itemRule} ${itemOperator}*)?`; - } else { - result += `${itemOperator}*`; - } - } - - return result; + const result = itemRule + ' ' + _buildRepetition(`(${separatorRule} ${itemRule})`, minItems > 0 ? minItems - 1 : 0, maxItems !== undefined ? maxItems - 1 : undefined); + return minItems === 0 ? `(${result})?` : result; } class BuiltinRule { @@ -62,27 +31,25 @@ class BuiltinRule { } } -const UP_TO_15_DIGITS = _buildRepetition('[0-9]', 0, 15); - const PRIMITIVE_RULES = { boolean : new BuiltinRule('("true" | "false") space', []), - 'decimal-part' : new BuiltinRule('[0-9] ' + UP_TO_15_DIGITS, []), - 'integral-part': new BuiltinRule('[0-9] | [1-9] ' + UP_TO_15_DIGITS, []), + 'decimal-part' : new BuiltinRule('[0-9]{1,16}', []), + 'integral-part': new BuiltinRule('[0] | [1-9] [0-9]{0,15}', []), number : new BuiltinRule('("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space', ['integral-part', 'decimal-part']), integer : new BuiltinRule('("-"? integral-part) space', ['integral-part']), value : new BuiltinRule('object | array | string | number | boolean | null', ['object', 'array', 'string', 'number', 'boolean', 'null']), object : new BuiltinRule('"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space', ['string', 'value']), array : new BuiltinRule('"[" space ( value ("," space value)* )? "]" space', ['value']), - uuid : new BuiltinRule('"\\"" ' + [8, 4, 4, 4, 12].map(n => [...new Array(n)].map(_ => '[0-9a-fA-F]').join('')).join(' "-" ') + ' "\\"" space', []), - char : new BuiltinRule(`[^"\\\\] | "\\\\" (["\\\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])`, []), + uuid : new BuiltinRule('"\\"" [0-9a-fA-F]{8} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{12} "\\"" space', []), + char : new BuiltinRule(`[^"\\\\] | "\\\\" (["\\\\/bfnrt] | "u" [0-9a-fA-F]{4})`, []), string : new BuiltinRule(`"\\"" char* "\\"" space`, ['char']), null : new BuiltinRule('"null" space', []), }; // TODO: support "uri", "email" string formats const STRING_FORMAT_RULES = { - 'date' : new BuiltinRule('[0-9] [0-9] [0-9] [0-9] "-" ( "0" [1-9] | "1" [0-2] ) "-" ( \"0\" [1-9] | [1-2] [0-9] | "3" [0-1] )', []), - 'time' : new BuiltinRule('([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9] [0-9] [0-9] )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )', []), + 'date' : new BuiltinRule('[0-9]{4} "-" ( "0" [1-9] | "1" [0-2] ) "-" ( \"0\" [1-9] | [1-2] [0-9] | "3" [0-1] )', []), + 'time' : new BuiltinRule('([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9]{3} )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )', []), 'date-time' : new BuiltinRule('date "T" time', ['date', 'time']), 'date-string' : new BuiltinRule('"\\"" date "\\"" space', ['date']), 'time-string' : new BuiltinRule('"\\"" time "\\"" space', ['time']), diff --git a/grammars/README.md b/grammars/README.md index 2b8384d9d..3ffc7cec0 100644 --- a/grammars/README.md +++ b/grammars/README.md @@ -59,9 +59,13 @@ Parentheses `()` can be used to group sequences, which allows for embedding alte ## Repetition and Optional Symbols -- `*` after a symbol or sequence means that it can be repeated zero or more times. -- `+` denotes that the symbol or sequence should appear one or more times. -- `?` makes the preceding symbol or sequence optional. +- `*` after a symbol or sequence means that it can be repeated zero or more times (equivalent to `{0,}`). +- `+` denotes that the symbol or sequence should appear one or more times (equivalent to `{1,}`). +- `?` makes the preceding symbol or sequence optional (equivalent to `{0,1}`). +- `{m}` repeats the precedent symbol or sequence exactly `m` times +- `{m,}` repeats the precedent symbol or sequence at least `m` times +- `{m,n}` repeats the precedent symbol or sequence at between `m` and `n` times (included) +- `{0,n}` repeats the precedent symbol or sequence at most `n` times (included) ## Comments and newlines @@ -98,4 +102,4 @@ Grammars currently have performance gotchas (see https://github.com/ggerganov/ll A common pattern is to allow repetitions of a pattern `x` up to N times. -While semantically correct, the syntax `x? x? x?.... x?` (with N repetitions) will result in extremely slow inference. Instead, you can write `(x (x (x ... (x)?...)?)?)?` (w/ N-deep nesting) +While semantically correct, the syntax `x? x? x?.... x?` (with N repetitions) may result in extremely slow sampling. Instead, you can write `x{0,N}` (or `(x (x (x ... (x)?...)?)?)?` w/ N-deep nesting in earlier llama.cpp versions). diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp index 01c5bb27a..9bdab05af 100644 --- a/tests/test-grammar-integration.cpp +++ b/tests/test-grammar-integration.cpp @@ -292,6 +292,82 @@ static void test_quantifiers() { "catyyy", } ); + test_grammar( + "simple exact repetition", + // Grammar + R"""( + root ::= [ab]{4} + )""", + // Passing strings + { + "aaaa", + "bbbb", + "abab", + }, + // Failing strings + { + "a", + "b", + "aaaaa", + } + ); + test_grammar( + "simple min repetition", + // Grammar + R"""( + root ::= [ab]{4,} + )""", + // Passing strings + { + "aaaa", + "aaaaab", + "bbbb", + "ababab", + }, + // Failing strings + { + "", + "aba", + } + ); + test_grammar( + "simple max repetition", + // Grammar + R"""( + root ::= [ab]{0,4} + )""", + // Passing strings + { + "", + "a", + "aa", + "aaa", + "aaab", + }, + // Failing strings + { + "aaaaa", + } + ); + test_grammar( + "min / max repetition", + // Grammar + R"""( + root ::= ("0x" [A-F0-9]{2} " "?){3,5} + )""", + // Passing strings + { + "0xFF 0x12 0xAB", + "0xFF 0x12 0xAB 0x00 0x00", + }, + // Failing strings + { + "", + "0xFF", + "0xFF 0x12", + "0xFF 0x12 0xAB 0x00 0x00 0x00", + } + ); } static void test_failure_missing_root() { diff --git a/tests/test-grammar-parser.cpp b/tests/test-grammar-parser.cpp index 91939e276..5df5abb25 100644 --- a/tests/test-grammar-parser.cpp +++ b/tests/test-grammar-parser.cpp @@ -7,28 +7,79 @@ #include -int main() -{ - grammar_parser::parse_state parsed_grammar; +static const char * type_str(llama_gretype type) { + switch (type) { + case LLAMA_GRETYPE_CHAR: return "LLAMA_GRETYPE_CHAR"; + case LLAMA_GRETYPE_CHAR_NOT: return "LLAMA_GRETYPE_CHAR_NOT"; + case LLAMA_GRETYPE_CHAR_ALT: return "LLAMA_GRETYPE_CHAR_ALT"; + case LLAMA_GRETYPE_CHAR_RNG_UPPER: return "LLAMA_GRETYPE_CHAR_RNG_UPPER"; + case LLAMA_GRETYPE_RULE_REF: return "LLAMA_GRETYPE_RULE_REF"; + case LLAMA_GRETYPE_ALT: return "LLAMA_GRETYPE_ALT"; + case LLAMA_GRETYPE_END: return "LLAMA_GRETYPE_END"; + default: return "?"; + } +} - const char *grammar_bytes = R"""(root ::= (expr "=" term "\n")+ -expr ::= term ([-+*/] term)* -term ::= [0-9]+)"""; +static void verify_parsing(const char *grammar_bytes, const std::vector> expected, const std::vector &expected_rules) { + uint32_t index = 0; + grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_bytes); - parsed_grammar = grammar_parser::parse(grammar_bytes); + std::map symbol_names; + for (auto it = parsed_grammar.symbol_ids.begin(); it != parsed_grammar.symbol_ids.end(); ++it) { + symbol_names[it->second] = it->first; + } - std::vector> expected = { - {"expr", 2}, - {"expr_5", 5}, - {"expr_6", 6}, - {"root", 0}, - {"root_1", 1}, - {"root_4", 4}, - {"term", 3}, - {"term_7", 7}, + auto print_all = [&]() { + fprintf(stderr, " verify_parsing(R\"\"\"(%s)\"\"\", {\n", grammar_bytes); + for (auto it = parsed_grammar.symbol_ids.begin(); it != parsed_grammar.symbol_ids.end(); ++it) { + fprintf(stderr, " {\"%s\", %u},\n", it->first.c_str(), it->second); + } + fprintf(stderr, " }, {\n"); + for (size_t i_rule = 0; i_rule < parsed_grammar.rules.size(); i_rule++) { + fprintf(stderr, " // %s (index %zu)\n", symbol_names[i_rule].c_str(), i_rule); + auto & rule = parsed_grammar.rules[i_rule]; + for (uint32_t i = 0; i < rule.size(); i++) { + std::string rule_str; + fprintf(stderr, " {%s, ", type_str(rule[i].type)); + if (rule[i].type == LLAMA_GRETYPE_CHAR || rule[i].type == LLAMA_GRETYPE_CHAR_ALT || + rule[i].type == LLAMA_GRETYPE_CHAR_NOT || rule[i].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) { + char c = rule[i].value; + if (c == '\n') { + fprintf(stderr, "'\\n'"); + } else if (c == '\t') { + fprintf(stderr, "'\\t'"); + } else if (c == '\r') { + fprintf(stderr, "'\\r'"); + } else if (c == '\0') { + fprintf(stderr, "'\\0'"); + } else { + fprintf(stderr, "'%c'", c); + } + } else if (rule[i].type == LLAMA_GRETYPE_RULE_REF) { + fprintf(stderr, "/* %s */ %u", symbol_names[rule[i].value].c_str(), rule[i].value); + } else { + fprintf(stderr, "%u", rule[i].value); + } + fprintf(stderr, "},\n"); + } + } + fprintf(stderr, " });\n"); }; - uint32_t index = 0; + if (getenv("TEST_GRAMMAR_PARSER_PRINT_ALL")) { + print_all(); + fprintf(stderr, "\n"); + return; + } + + fprintf(stderr, "Testing grammar:%s\n", grammar_bytes); + + if (parsed_grammar.symbol_ids.size() != expected.size()) { + fprintf(stderr, "Code to update expectation (set TEST_GRAMMAR_PARSER_PRINT_ALL=1 to print all):\n"); + print_all(); + assert(parsed_grammar.symbol_ids.size() == expected.size()); + } + for (auto it = parsed_grammar.symbol_ids.begin(); it != parsed_grammar.symbol_ids.end(); ++it) { std::string key = it->first; @@ -38,51 +89,18 @@ term ::= [0-9]+)"""; // pretty print error message before asserting if (expected_pair.first != key || expected_pair.second != value) { + fprintf(stderr, "index: %u\n", index); fprintf(stderr, "expected_pair: %s, %u\n", expected_pair.first.c_str(), expected_pair.second); fprintf(stderr, "actual_pair: %s, %u\n", key.c_str(), value); fprintf(stderr, "expected_pair != actual_pair\n"); + fprintf(stderr, "Code to update expectation (set TEST_GRAMMAR_PARSER_PRINT_ALL=1 to print all):\n"); + print_all(); } assert(expected_pair.first == key && expected_pair.second == value); index++; } - std::vector expected_rules = { - {LLAMA_GRETYPE_RULE_REF, 4}, - {LLAMA_GRETYPE_END, 0}, - {LLAMA_GRETYPE_RULE_REF, 2}, - {LLAMA_GRETYPE_CHAR, 61}, - {LLAMA_GRETYPE_RULE_REF, 3}, - {LLAMA_GRETYPE_CHAR, 10}, - {LLAMA_GRETYPE_END, 0}, - {LLAMA_GRETYPE_RULE_REF, 3}, - {LLAMA_GRETYPE_RULE_REF, 6}, - {LLAMA_GRETYPE_END, 0}, - {LLAMA_GRETYPE_RULE_REF, 7}, - {LLAMA_GRETYPE_END, 0}, - {LLAMA_GRETYPE_RULE_REF, 1}, - {LLAMA_GRETYPE_RULE_REF, 4}, - {LLAMA_GRETYPE_ALT, 0}, - {LLAMA_GRETYPE_RULE_REF, 1}, - {LLAMA_GRETYPE_END, 0}, - {LLAMA_GRETYPE_CHAR, 45}, - {LLAMA_GRETYPE_CHAR_ALT, 43}, - {LLAMA_GRETYPE_CHAR_ALT, 42}, - {LLAMA_GRETYPE_CHAR_ALT, 47}, - {LLAMA_GRETYPE_RULE_REF, 3}, - {LLAMA_GRETYPE_END, 0}, - {LLAMA_GRETYPE_RULE_REF, 5}, - {LLAMA_GRETYPE_RULE_REF, 6}, - {LLAMA_GRETYPE_ALT, 0}, - {LLAMA_GRETYPE_END, 0}, - {LLAMA_GRETYPE_CHAR, 48}, - {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57}, - {LLAMA_GRETYPE_RULE_REF, 7}, - {LLAMA_GRETYPE_ALT, 0}, - {LLAMA_GRETYPE_CHAR, 48}, - {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57}, - {LLAMA_GRETYPE_END, 0}, - }; index = 0; for (auto rule : parsed_grammar.rules) @@ -97,28 +115,306 @@ term ::= [0-9]+)"""; if (expected_element.type != element.type || expected_element.value != element.value) { fprintf(stderr, "index: %u\n", index); - fprintf(stderr, "expected_element: %d, %u\n", expected_element.type, expected_element.value); - fprintf(stderr, "actual_element: %d, %u\n", element.type, element.value); + fprintf(stderr, "expected_element: %s, %u\n", type_str(expected_element.type), expected_element.value); + fprintf(stderr, "actual_element: %s, %u\n", type_str(element.type), element.value); fprintf(stderr, "expected_element != actual_element\n"); + fprintf(stderr, "all elements:\n"); + fprintf(stderr, "Code to update expectation (set TEST_GRAMMAR_PARSER_PRINT_ALL=1 to print all):\n"); + print_all(); } assert(expected_element.type == element.type && expected_element.value == element.value); index++; } } +} - const char *longer_grammar_bytes = R"""( - root ::= (expr "=" ws term "\n")+ - expr ::= term ([-+*/] term)* - term ::= ident | num | "(" ws expr ")" ws - ident ::= [a-z] [a-z0-9_]* ws - num ::= [0-9]+ ws - ws ::= [ \t\n]* - )"""; +static void verify_failure(const char *grammar_bytes) { + fprintf(stderr, "Testing expected failure:%s\n", grammar_bytes); + auto result = grammar_parser::parse(grammar_bytes); + assert(result.rules.empty() && "should have failed"); +} - parsed_grammar = grammar_parser::parse(longer_grammar_bytes); +int main() +{ + verify_failure(R"""( + root ::= "a"{,}" + )"""); - expected = { + verify_failure(R"""( + root ::= "a"{,10}" + )"""); + + verify_parsing(R"""( + root ::= "a" + )""", { + {"root", 0}, + }, { + // root (index 0) + {LLAMA_GRETYPE_CHAR, 'a'}, + {LLAMA_GRETYPE_END, 0}, + }); + + verify_parsing(R"""( + root ::= "a" | [bdx-z] | [^1-3] + )""", { + {"root", 0}, + }, { + // root (index 0) + {LLAMA_GRETYPE_CHAR, 'a'}, + {LLAMA_GRETYPE_ALT, 0}, + {LLAMA_GRETYPE_CHAR, 'b'}, + {LLAMA_GRETYPE_CHAR_ALT, 'd'}, + {LLAMA_GRETYPE_CHAR_ALT, 'x'}, + {LLAMA_GRETYPE_CHAR_RNG_UPPER, 'z'}, + {LLAMA_GRETYPE_ALT, 0}, + {LLAMA_GRETYPE_CHAR_NOT, '1'}, + {LLAMA_GRETYPE_CHAR_RNG_UPPER, '3'}, + {LLAMA_GRETYPE_END, 0}, + }); + + verify_parsing(R"""( + root ::= a+ + a ::= "a" + )""", { + {"a", 1}, + {"root", 0}, + {"root_2", 2}, + }, { + // root (index 0) + {LLAMA_GRETYPE_RULE_REF, /* a */ 1}, + {LLAMA_GRETYPE_RULE_REF, /* root_2 */ 2}, + {LLAMA_GRETYPE_END, 0}, + // a (index 1) + {LLAMA_GRETYPE_CHAR, 'a'}, + {LLAMA_GRETYPE_END, 0}, + // root_2 (index 2) + {LLAMA_GRETYPE_RULE_REF, /* a */ 1}, + {LLAMA_GRETYPE_RULE_REF, /* root_2 */ 2}, + {LLAMA_GRETYPE_ALT, 0}, + {LLAMA_GRETYPE_END, 0}, + }); + + verify_parsing(R"""( + root ::= "a"+ + )""", { + {"root", 0}, + {"root_1", 1}, + }, { + // root (index 0) + {LLAMA_GRETYPE_CHAR, 'a'}, + {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1}, + {LLAMA_GRETYPE_END, 0}, + // root_1 (index 1) + {LLAMA_GRETYPE_CHAR, 'a'}, + {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1}, + {LLAMA_GRETYPE_ALT, 0}, + {LLAMA_GRETYPE_END, 0}, + }); + + verify_parsing(R"""( + root ::= a? + a ::= "a" + )""", { + {"a", 1}, + {"root", 0}, + {"root_2", 2}, + }, { + // root (index 0) + {LLAMA_GRETYPE_RULE_REF, /* root_2 */ 2}, + {LLAMA_GRETYPE_END, 0}, + // a (index 1) + {LLAMA_GRETYPE_CHAR, 'a'}, + {LLAMA_GRETYPE_END, 0}, + // root_2 (index 2) + {LLAMA_GRETYPE_RULE_REF, /* a */ 1}, + {LLAMA_GRETYPE_ALT, 0}, + {LLAMA_GRETYPE_END, 0}, + }); + + verify_parsing(R"""( + root ::= "a"? + )""", { + {"root", 0}, + {"root_1", 1}, + }, { + // root (index 0) + {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1}, + {LLAMA_GRETYPE_END, 0}, + // root_1 (index 1) + {LLAMA_GRETYPE_CHAR, 'a'}, + {LLAMA_GRETYPE_ALT, 0}, + {LLAMA_GRETYPE_END, 0}, + }); + + verify_parsing(R"""( + root ::= a* + a ::= "a" + )""", { + {"a", 1}, + {"root", 0}, + {"root_2", 2}, + }, { + // root (index 0) + {LLAMA_GRETYPE_RULE_REF, /* root_2 */ 2}, + {LLAMA_GRETYPE_END, 0}, + // a (index 1) + {LLAMA_GRETYPE_CHAR, 'a'}, + {LLAMA_GRETYPE_END, 0}, + // root_2 (index 2) + {LLAMA_GRETYPE_RULE_REF, /* a */ 1}, + {LLAMA_GRETYPE_RULE_REF, /* root_2 */ 2}, + {LLAMA_GRETYPE_ALT, 0}, + {LLAMA_GRETYPE_END, 0}, + }); + + verify_parsing(R"""( + root ::= "a"* + )""", { + {"root", 0}, + {"root_1", 1}, + }, { + // root (index 0) + {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1}, + {LLAMA_GRETYPE_END, 0}, + // root_1 (index 1) + {LLAMA_GRETYPE_CHAR, 'a'}, + {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1}, + {LLAMA_GRETYPE_ALT, 0}, + {LLAMA_GRETYPE_END, 0}, + }); + + verify_parsing(R"""( + root ::= "a"{2} + )""", { + {"root", 0}, + }, { + // root (index 0) + {LLAMA_GRETYPE_CHAR, 'a'}, + {LLAMA_GRETYPE_CHAR, 'a'}, + {LLAMA_GRETYPE_END, 0}, + }); + + verify_parsing(R"""( + root ::= "a"{2,} + )""", { + {"root", 0}, + {"root_1", 1}, + }, { + // root (index 0) + {LLAMA_GRETYPE_CHAR, 'a'}, + {LLAMA_GRETYPE_CHAR, 'a'}, + {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1}, + {LLAMA_GRETYPE_END, 0}, + // root_1 (index 1) + {LLAMA_GRETYPE_CHAR, 'a'}, + {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1}, + {LLAMA_GRETYPE_ALT, 0}, + {LLAMA_GRETYPE_END, 0}, + }); + + verify_parsing(R"""( + root ::= "a"{ 4} + )""", { + {"root", 0}, + }, { + // root (index 0) + {LLAMA_GRETYPE_CHAR, 'a'}, + {LLAMA_GRETYPE_CHAR, 'a'}, + {LLAMA_GRETYPE_CHAR, 'a'}, + {LLAMA_GRETYPE_CHAR, 'a'}, + {LLAMA_GRETYPE_END, 0}, + }); + + verify_parsing(R"""( + root ::= "a"{2,4} + )""", { + {"root", 0}, + {"root_1", 1}, + {"root_2", 2}, + }, { + // root (index 0) + {LLAMA_GRETYPE_CHAR, 'a'}, + {LLAMA_GRETYPE_CHAR, 'a'}, + {LLAMA_GRETYPE_RULE_REF, /* root_2 */ 2}, + {LLAMA_GRETYPE_END, 0}, + // root_1 (index 1) + {LLAMA_GRETYPE_CHAR, 'a'}, + {LLAMA_GRETYPE_ALT, 0}, + {LLAMA_GRETYPE_END, 0}, + // root_2 (index 2) + {LLAMA_GRETYPE_CHAR, 'a'}, + {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1}, + {LLAMA_GRETYPE_ALT, 0}, + {LLAMA_GRETYPE_END, 0}, + }); + + verify_parsing(R"""( + root ::= (expr "=" term "\n")+ + expr ::= term ([-+*/] term)* + term ::= [0-9]+ + )""", { + {"expr", 2}, + {"expr_5", 5}, + {"expr_6", 6}, + {"root", 0}, + {"root_1", 1}, + {"root_4", 4}, + {"term", 3}, + {"term_7", 7}, + }, { + // root (index 0) + {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1}, + {LLAMA_GRETYPE_RULE_REF, /* root_4 */ 4}, + {LLAMA_GRETYPE_END, 0}, + // root_1 (index 1) + {LLAMA_GRETYPE_RULE_REF, /* expr */ 2}, + {LLAMA_GRETYPE_CHAR, '='}, + {LLAMA_GRETYPE_RULE_REF, /* term */ 3}, + {LLAMA_GRETYPE_CHAR, '\n'}, + {LLAMA_GRETYPE_END, 0}, + // expr (index 2) + {LLAMA_GRETYPE_RULE_REF, /* term */ 3}, + {LLAMA_GRETYPE_RULE_REF, /* expr_6 */ 6}, + {LLAMA_GRETYPE_END, 0}, + // term (index 3) + {LLAMA_GRETYPE_CHAR, '0'}, + {LLAMA_GRETYPE_CHAR_RNG_UPPER, '9'}, + {LLAMA_GRETYPE_RULE_REF, /* term_7 */ 7}, + {LLAMA_GRETYPE_END, 0}, + // root_4 (index 4) + {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1}, + {LLAMA_GRETYPE_RULE_REF, /* root_4 */ 4}, + {LLAMA_GRETYPE_ALT, 0}, + {LLAMA_GRETYPE_END, 0}, + // expr_5 (index 5) + {LLAMA_GRETYPE_CHAR, '-'}, + {LLAMA_GRETYPE_CHAR_ALT, '+'}, + {LLAMA_GRETYPE_CHAR_ALT, '*'}, + {LLAMA_GRETYPE_CHAR_ALT, '/'}, + {LLAMA_GRETYPE_RULE_REF, /* term */ 3}, + {LLAMA_GRETYPE_END, 0}, + // expr_6 (index 6) + {LLAMA_GRETYPE_RULE_REF, /* expr_5 */ 5}, + {LLAMA_GRETYPE_RULE_REF, /* expr_6 */ 6}, + {LLAMA_GRETYPE_ALT, 0}, + {LLAMA_GRETYPE_END, 0}, + // term_7 (index 7) + {LLAMA_GRETYPE_CHAR, '0'}, + {LLAMA_GRETYPE_CHAR_RNG_UPPER, '9'}, + {LLAMA_GRETYPE_RULE_REF, /* term_7 */ 7}, + {LLAMA_GRETYPE_ALT, 0}, + {LLAMA_GRETYPE_END, 0}, + }); + + verify_parsing(R"""( + root ::= (expr "=" ws term "\n")+ + expr ::= term ([-+*/] term)* + term ::= ident | num | "(" ws expr ")" ws + ident ::= [a-z] [a-z0-9_]* ws + num ::= [0-9]+ ws + ws ::= [ \t\n]* + )""", { {"expr", 2}, {"expr_6", 6}, {"expr_7", 7}, @@ -132,119 +428,88 @@ term ::= [0-9]+)"""; {"term", 4}, {"ws", 3}, {"ws_12", 12}, - }; - - index = 0; - for (auto it = parsed_grammar.symbol_ids.begin(); it != parsed_grammar.symbol_ids.end(); ++it) - { - std::string key = it->first; - uint32_t value = it->second; - std::pair expected_pair = expected[index]; - - // pretty print error message before asserting - if (expected_pair.first != key || expected_pair.second != value) - { - fprintf(stderr, "expected_pair: %s, %u\n", expected_pair.first.c_str(), expected_pair.second); - fprintf(stderr, "actual_pair: %s, %u\n", key.c_str(), value); - fprintf(stderr, "expected_pair != actual_pair\n"); - } - - assert(expected_pair.first == key && expected_pair.second == value); - - index++; - } - expected_rules = { - {LLAMA_GRETYPE_RULE_REF, 5}, + }, { + // root (index 0) + {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1}, + {LLAMA_GRETYPE_RULE_REF, /* root_5 */ 5}, {LLAMA_GRETYPE_END, 0}, - {LLAMA_GRETYPE_RULE_REF, 2}, - {LLAMA_GRETYPE_CHAR, 61}, - {LLAMA_GRETYPE_RULE_REF, 3}, - {LLAMA_GRETYPE_RULE_REF, 4}, - {LLAMA_GRETYPE_CHAR, 10}, + // root_1 (index 1) + {LLAMA_GRETYPE_RULE_REF, /* expr */ 2}, + {LLAMA_GRETYPE_CHAR, '='}, + {LLAMA_GRETYPE_RULE_REF, /* ws */ 3}, + {LLAMA_GRETYPE_RULE_REF, /* term */ 4}, + {LLAMA_GRETYPE_CHAR, '\n'}, {LLAMA_GRETYPE_END, 0}, - {LLAMA_GRETYPE_RULE_REF, 4}, - {LLAMA_GRETYPE_RULE_REF, 7}, + // expr (index 2) + {LLAMA_GRETYPE_RULE_REF, /* term */ 4}, + {LLAMA_GRETYPE_RULE_REF, /* expr_7 */ 7}, {LLAMA_GRETYPE_END, 0}, - {LLAMA_GRETYPE_RULE_REF, 12}, + // ws (index 3) + {LLAMA_GRETYPE_RULE_REF, /* ws_12 */ 12}, {LLAMA_GRETYPE_END, 0}, - {LLAMA_GRETYPE_RULE_REF, 8}, + // term (index 4) + {LLAMA_GRETYPE_RULE_REF, /* ident */ 8}, {LLAMA_GRETYPE_ALT, 0}, - {LLAMA_GRETYPE_RULE_REF, 9}, + {LLAMA_GRETYPE_RULE_REF, /* num */ 9}, {LLAMA_GRETYPE_ALT, 0}, - {LLAMA_GRETYPE_CHAR, 40}, - {LLAMA_GRETYPE_RULE_REF, 3}, - {LLAMA_GRETYPE_RULE_REF, 2}, - {LLAMA_GRETYPE_CHAR, 41}, - {LLAMA_GRETYPE_RULE_REF, 3}, + {LLAMA_GRETYPE_CHAR, '('}, + {LLAMA_GRETYPE_RULE_REF, /* ws */ 3}, + {LLAMA_GRETYPE_RULE_REF, /* expr */ 2}, + {LLAMA_GRETYPE_CHAR, ')'}, + {LLAMA_GRETYPE_RULE_REF, /* ws */ 3}, {LLAMA_GRETYPE_END, 0}, - {LLAMA_GRETYPE_RULE_REF, 1}, - {LLAMA_GRETYPE_RULE_REF, 5}, - {LLAMA_GRETYPE_ALT, 0}, - {LLAMA_GRETYPE_RULE_REF, 1}, - {LLAMA_GRETYPE_END, 0}, - {LLAMA_GRETYPE_CHAR, 45}, - {LLAMA_GRETYPE_CHAR_ALT, 43}, - {LLAMA_GRETYPE_CHAR_ALT, 42}, - {LLAMA_GRETYPE_CHAR_ALT, 47}, - {LLAMA_GRETYPE_RULE_REF, 4}, - {LLAMA_GRETYPE_END, 0}, - {LLAMA_GRETYPE_RULE_REF, 6}, - {LLAMA_GRETYPE_RULE_REF, 7}, + // root_5 (index 5) + {LLAMA_GRETYPE_RULE_REF, /* root_1 */ 1}, + {LLAMA_GRETYPE_RULE_REF, /* root_5 */ 5}, {LLAMA_GRETYPE_ALT, 0}, {LLAMA_GRETYPE_END, 0}, - {LLAMA_GRETYPE_CHAR, 97}, - {LLAMA_GRETYPE_CHAR_RNG_UPPER, 122}, - {LLAMA_GRETYPE_RULE_REF, 10}, - {LLAMA_GRETYPE_RULE_REF, 3}, + // expr_6 (index 6) + {LLAMA_GRETYPE_CHAR, '-'}, + {LLAMA_GRETYPE_CHAR_ALT, '+'}, + {LLAMA_GRETYPE_CHAR_ALT, '*'}, + {LLAMA_GRETYPE_CHAR_ALT, '/'}, + {LLAMA_GRETYPE_RULE_REF, /* term */ 4}, {LLAMA_GRETYPE_END, 0}, - {LLAMA_GRETYPE_RULE_REF, 11}, - {LLAMA_GRETYPE_RULE_REF, 3}, - {LLAMA_GRETYPE_END, 0}, - {LLAMA_GRETYPE_CHAR, 97}, - {LLAMA_GRETYPE_CHAR_RNG_UPPER, 122}, - {LLAMA_GRETYPE_CHAR_ALT, 48}, - {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57}, - {LLAMA_GRETYPE_CHAR_ALT, 95}, - {LLAMA_GRETYPE_RULE_REF, 10}, + // expr_7 (index 7) + {LLAMA_GRETYPE_RULE_REF, /* expr_6 */ 6}, + {LLAMA_GRETYPE_RULE_REF, /* expr_7 */ 7}, {LLAMA_GRETYPE_ALT, 0}, {LLAMA_GRETYPE_END, 0}, - {LLAMA_GRETYPE_CHAR, 48}, - {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57}, - {LLAMA_GRETYPE_RULE_REF, 11}, - {LLAMA_GRETYPE_ALT, 0}, - {LLAMA_GRETYPE_CHAR, 48}, - {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57}, + // ident (index 8) + {LLAMA_GRETYPE_CHAR, 'a'}, + {LLAMA_GRETYPE_CHAR_RNG_UPPER, 'z'}, + {LLAMA_GRETYPE_RULE_REF, /* ident_10 */ 10}, + {LLAMA_GRETYPE_RULE_REF, /* ws */ 3}, {LLAMA_GRETYPE_END, 0}, - {LLAMA_GRETYPE_CHAR, 32}, - {LLAMA_GRETYPE_CHAR_ALT, 9}, - {LLAMA_GRETYPE_CHAR_ALT, 10}, - {LLAMA_GRETYPE_RULE_REF, 12}, + // num (index 9) + {LLAMA_GRETYPE_CHAR, '0'}, + {LLAMA_GRETYPE_CHAR_RNG_UPPER, '9'}, + {LLAMA_GRETYPE_RULE_REF, /* num_11 */ 11}, + {LLAMA_GRETYPE_RULE_REF, /* ws */ 3}, + {LLAMA_GRETYPE_END, 0}, + // ident_10 (index 10) + {LLAMA_GRETYPE_CHAR, 'a'}, + {LLAMA_GRETYPE_CHAR_RNG_UPPER, 'z'}, + {LLAMA_GRETYPE_CHAR_ALT, '0'}, + {LLAMA_GRETYPE_CHAR_RNG_UPPER, '9'}, + {LLAMA_GRETYPE_CHAR_ALT, '_'}, + {LLAMA_GRETYPE_RULE_REF, /* ident_10 */ 10}, {LLAMA_GRETYPE_ALT, 0}, {LLAMA_GRETYPE_END, 0}, - }; - - index = 0; - for (auto rule : parsed_grammar.rules) - { - // compare rule to expected rule - for (uint32_t i = 0; i < rule.size(); i++) - { - llama_grammar_element element = rule[i]; - llama_grammar_element expected_element = expected_rules[index]; - - // pretty print error message before asserting - if (expected_element.type != element.type || expected_element.value != element.value) - { - fprintf(stderr, "index: %u\n", index); - fprintf(stderr, "expected_element: %d, %u\n", expected_element.type, expected_element.value); - fprintf(stderr, "actual_element: %d, %u\n", element.type, element.value); - fprintf(stderr, "expected_element != actual_element\n"); - } - - assert(expected_element.type == element.type && expected_element.value == element.value); - index++; - } - } + // num_11 (index 11) + {LLAMA_GRETYPE_CHAR, '0'}, + {LLAMA_GRETYPE_CHAR_RNG_UPPER, '9'}, + {LLAMA_GRETYPE_RULE_REF, /* num_11 */ 11}, + {LLAMA_GRETYPE_ALT, 0}, + {LLAMA_GRETYPE_END, 0}, + // ws_12 (index 12) + {LLAMA_GRETYPE_CHAR, ' '}, + {LLAMA_GRETYPE_CHAR_ALT, '\t'}, + {LLAMA_GRETYPE_CHAR_ALT, '\n'}, + {LLAMA_GRETYPE_RULE_REF, /* ws_12 */ 12}, + {LLAMA_GRETYPE_ALT, 0}, + {LLAMA_GRETYPE_END, 0}, + }); return 0; } diff --git a/tests/test-json-schema-to-grammar.cpp b/tests/test-json-schema-to-grammar.cpp index c5361b5b8..052c08073 100755 --- a/tests/test-json-schema-to-grammar.cpp +++ b/tests/test-json-schema-to-grammar.cpp @@ -105,9 +105,9 @@ static void test_all(const std::string & lang, std::function