diff --git a/examples/gbnf-validator/gbnf-validator.cpp b/examples/gbnf-validator/gbnf-validator.cpp index 646b8e176..17a0e27c4 100644 --- a/examples/gbnf-validator/gbnf-validator.cpp +++ b/examples/gbnf-validator/gbnf-validator.cpp @@ -11,20 +11,15 @@ static bool llama_grammar_validate(struct llama_grammar * grammar, const std::string & input_str, size_t & error_pos, std::string & error_msg) { const auto cpts = unicode_cpts_from_utf8(input_str); - const llama_grammar_rules & rules = llama_grammar_get_rules (grammar); - llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar); - llama_grammar_stacks_cache & stacks_cache = llama_grammar_get_stacks_cache(grammar); + auto & stacks_cur = llama_grammar_get_stacks(grammar); size_t pos = 0; for (const auto & cpt : cpts) { - const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy - - llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur, stacks_cache); + llama_grammar_accept(grammar, cpt); if (stacks_cur.empty()) { error_pos = pos; error_msg = "Unexpected character '" + unicode_cpt_to_utf8(cpt) + "'"; - stacks_cur = stacks_prev; return false; } ++pos; @@ -83,7 +78,8 @@ int main(int argc, char** argv) { llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root"); if (grammar == nullptr) { - throw std::runtime_error("Failed to initialize llama_grammar"); + fprintf(stdout, "Failed to initialize llama_grammar\n"); + return 1; } // Read the input file std::string input_str; diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index 214820790..1380dfc2a 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -764,7 +764,7 @@ static void llama_grammar_advance_stack_memo( if (it != stacks_cache.end()) { advanced_stacks = it->second; } else { - // Advance stacks with memorization + // Advance stacks with memorization llama_grammar_advance_stack_memo_impl(rules, stack, advanced_stacks, stacks_cache); stacks_cache.insert(make_pair(stack, advanced_stacks)); } @@ -917,20 +917,11 @@ llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar) return grammar->stacks; } -llama_grammar_stacks_cache & llama_grammar_get_stacks_cache(struct llama_grammar * grammar) { - return grammar->stacks_cache; -} +void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr) { + llama_grammar_stacks stacks_new; + stacks_new.reserve(grammar->stacks.size()); -void llama_grammar_accept( - const llama_grammar_rules & rules, - const llama_grammar_stacks & stacks, - const uint32_t chr, - llama_grammar_stacks & stacks_new, - llama_grammar_stacks_cache & stacks_cache) { - stacks_new.clear(); - stacks_new.reserve(stacks.size()); - - for (const auto & stack : stacks) { + for (const auto & stack : grammar->stacks) { if (stack.empty()) { continue; } @@ -944,9 +935,11 @@ void llama_grammar_accept( if (!llama_grammar_is_end_of_sequence(pos)) { new_stack.push_back(pos); } - llama_grammar_advance_stack_memo(rules, new_stack, stacks_new, stacks_cache); + llama_grammar_advance_stack_memo(grammar->rules, new_stack, stacks_new, grammar->stacks_cache); } } + + grammar->stacks = std::move(stacks_new); } llama_grammar_candidates llama_grammar_reject_candidates_for_stack( @@ -1062,7 +1055,7 @@ struct llama_grammar * llama_grammar_init_impl( // Important: vec_rules has to be moved here, not copied, because stacks contains // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar // then the pointers would be invalidated when the local vec_rules goes out of scope. - return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, std::move(stacks_cache), }; + return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), std::move(stacks_cache), {}, }; } struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) { @@ -1141,7 +1134,7 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, // Important: vec_rules has to be moved here, not copied, because stacks contains // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar // then the pointers would be invalidated when the local vec_rules goes out of scope. - return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, std::move(stacks_cache), }; + return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), std::move(stacks_cache), {}, }; } void llama_grammar_free_impl(struct llama_grammar * grammar) { @@ -1153,7 +1146,13 @@ void llama_grammar_free_impl(struct llama_grammar * grammar) { } struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) { - llama_grammar * result = new llama_grammar { grammar.vocab, grammar.rules, grammar.stacks, grammar.partial_utf8, }; + llama_grammar * result = new llama_grammar { + grammar.vocab, + grammar.rules, + grammar.stacks, + grammar.stacks_cache, + grammar.partial_utf8, + }; // redirect elements in stacks to point to new rules for (size_t is = 0; is < result->stacks.size(); is++) { @@ -1161,7 +1160,7 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra for (size_t ir0 = 0; ir0 < grammar.rules.size(); ir0++) { for (size_t ir1 = 0; ir1 < grammar.rules[ir0].size(); ir1++) { if (grammar.stacks[is][ie] == &grammar.rules[ir0][ir1]) { - result->stacks[is][ie] = &result->rules[ir0][ir1]; + result->stacks[is][ie] = &result->rules[ir0][ir1]; } } } @@ -1228,11 +1227,8 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token const auto decoded = decode_utf8(piece, grammar.partial_utf8); const auto & code_points = decoded.first; - llama_grammar_stacks stacks_new; - for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { - llama_grammar_accept(grammar.rules, grammar.stacks, *it, stacks_new, grammar.stacks_cache); - grammar.stacks = std::move(stacks_new); + llama_grammar_accept(&grammar, *it); } grammar.partial_utf8 = decoded.second; diff --git a/src/llama-grammar.h b/src/llama-grammar.h index 42ab06cd8..3f13fee4f 100644 --- a/src/llama-grammar.h +++ b/src/llama-grammar.h @@ -71,20 +71,15 @@ struct VectorPointerHash { using llama_grammar_stacks_cache = std::unordered_map; +// TODO: remove, needed for tests atm const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar); llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar); - llama_grammar_stacks_cache & llama_grammar_get_stacks_cache( struct llama_grammar * grammar); // takes a set of possible pushdown stacks on a grammar, which are required to // be positioned at a character range (see `llama_grammar_advance_stack`), and // produces the N possible stacks if the given char is accepted at those // positions -void llama_grammar_accept( - const llama_grammar_rules & rules, - const llama_grammar_stacks & stacks, - uint32_t chr, - llama_grammar_stacks & stacks_new, - llama_grammar_stacks_cache & stacks_cache); +void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr); std::vector llama_grammar_reject_candidates_for_stack( const llama_grammar_rules & rules, @@ -128,10 +123,11 @@ struct llama_grammar { const llama_grammar_rules rules; // TODO: shared ptr llama_grammar_stacks stacks; - // buffer for partially generated UTF-8 sequence from accepted tokens - llama_partial_utf8 partial_utf8; // cache N possible stacks from a stack llama_grammar_stacks_cache stacks_cache; + + // buffer for partially generated UTF-8 sequence from accepted tokens + llama_partial_utf8 partial_utf8; }; // diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp index 0883120d1..e1bdbb925 100644 --- a/tests/test-grammar-integration.cpp +++ b/tests/test-grammar-integration.cpp @@ -32,14 +32,10 @@ static bool test_build_grammar_fails(const std::string & grammar_str) { static bool match_string(const std::string & input, llama_grammar * grammar) { const auto cpts = unicode_cpts_from_utf8(input); - const llama_grammar_rules & rules = llama_grammar_get_rules (grammar); - llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar); - llama_grammar_stacks_cache & stacks_cache = llama_grammar_get_stacks_cache(grammar); + auto & stacks_cur = llama_grammar_get_stacks(grammar); for (const auto & cpt : cpts) { - const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy - - llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur, stacks_cache); + llama_grammar_accept(grammar, cpt); if (stacks_cur.empty()) { // no stacks means that the grammar failed to match at this point @@ -64,7 +60,7 @@ static void test(const std::string & test_desc, const std::string & grammar_str, auto * grammar = build_grammar(grammar_str); // Save the original grammar stacks so that we can reset after every new string we want to test - const llama_grammar_stacks stacks_org = llama_grammar_get_stacks(grammar); + const llama_grammar_stacks stacks_org = llama_grammar_get_stacks(grammar); // copy llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar); diff --git a/tests/test-llama-grammar.cpp b/tests/test-llama-grammar.cpp index 6f1374ca8..e2129206b 100644 --- a/tests/test-llama-grammar.cpp +++ b/tests/test-llama-grammar.cpp @@ -113,12 +113,10 @@ int main() } } - llama_grammar * grammar = NULL; std::vector grammar_rules(parsed_grammar.c_rules()); - grammar = llama_grammar_init_impl(nullptr, grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); - if (grammar == nullptr) - { + llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); + if (grammar == nullptr) { throw std::runtime_error("Failed to initialize llama_grammar"); }