diff --git a/examples/gbnf-validator/gbnf-validator.cpp b/examples/gbnf-validator/gbnf-validator.cpp index 2cf3bb047..646b8e176 100644 --- a/examples/gbnf-validator/gbnf-validator.cpp +++ b/examples/gbnf-validator/gbnf-validator.cpp @@ -13,9 +13,9 @@ static bool llama_grammar_validate(struct llama_grammar * grammar, const std::st const llama_grammar_rules & rules = llama_grammar_get_rules (grammar); llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar); + llama_grammar_stacks_cache & stacks_cache = llama_grammar_get_stacks_cache(grammar); size_t pos = 0; - llama_grammar_stacks_cache stacks_cache; for (const auto & cpt : cpts) { const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index af72de9e0..214820790 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -917,6 +917,10 @@ llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar) return grammar->stacks; } +llama_grammar_stacks_cache & llama_grammar_get_stacks_cache(struct llama_grammar * grammar) { + return grammar->stacks_cache; +} + void llama_grammar_accept( const llama_grammar_rules & rules, const llama_grammar_stacks & stacks, @@ -1058,7 +1062,7 @@ struct llama_grammar * llama_grammar_init_impl( // Important: vec_rules has to be moved here, not copied, because stacks contains // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar // then the pointers would be invalidated when the local vec_rules goes out of scope. - return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, }; + return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, std::move(stacks_cache), }; } struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) { @@ -1137,7 +1141,7 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, // Important: vec_rules has to be moved here, not copied, because stacks contains // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar // then the pointers would be invalidated when the local vec_rules goes out of scope. - return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, }; + return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, std::move(stacks_cache), }; } void llama_grammar_free_impl(struct llama_grammar * grammar) { @@ -1225,10 +1229,9 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token const auto & code_points = decoded.first; llama_grammar_stacks stacks_new; - llama_grammar_stacks_cache stacks_cache; for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { - llama_grammar_accept(grammar.rules, grammar.stacks, *it, stacks_new, stacks_cache); + llama_grammar_accept(grammar.rules, grammar.stacks, *it, stacks_new, grammar.stacks_cache); grammar.stacks = std::move(stacks_new); } diff --git a/src/llama-grammar.h b/src/llama-grammar.h index de5e16874..42ab06cd8 100644 --- a/src/llama-grammar.h +++ b/src/llama-grammar.h @@ -59,9 +59,6 @@ using llama_grammar_rules = std::vector; using llama_grammar_stacks = std::vector; using llama_grammar_candidates = std::vector; -const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar); - llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar); - struct VectorPointerHash { size_t operator()(const llama_grammar_stack & v) const { size_t seed = v.size(); @@ -74,6 +71,10 @@ struct VectorPointerHash { using llama_grammar_stacks_cache = std::unordered_map; +const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar); + llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar); + llama_grammar_stacks_cache & llama_grammar_get_stacks_cache( struct llama_grammar * grammar); + // takes a set of possible pushdown stacks on a grammar, which are required to // be positioned at a character range (see `llama_grammar_advance_stack`), and // produces the N possible stacks if the given char is accepted at those @@ -129,6 +130,8 @@ struct llama_grammar { // buffer for partially generated UTF-8 sequence from accepted tokens llama_partial_utf8 partial_utf8; + // cache N possible stacks from a stack + llama_grammar_stacks_cache stacks_cache; }; // diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp index dc260b55a..0883120d1 100644 --- a/tests/test-grammar-integration.cpp +++ b/tests/test-grammar-integration.cpp @@ -34,8 +34,8 @@ static bool match_string(const std::string & input, llama_grammar * grammar) { const llama_grammar_rules & rules = llama_grammar_get_rules (grammar); llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar); + llama_grammar_stacks_cache & stacks_cache = llama_grammar_get_stacks_cache(grammar); - llama_grammar_stacks_cache stacks_cache; for (const auto & cpt : cpts) { const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy