From 901a3479b10c3e71a29b86050bcae25d98102908 Mon Sep 17 00:00:00 2001 From: Clarissa Miranda Date: Mon, 14 Oct 2024 17:13:40 +1100 Subject: [PATCH] move cache stack to advance stack --- examples/gbnf-validator/gbnf-validator.cpp | 3 +- src/llama-grammar.cpp | 53 ++++++++-------------- src/llama-grammar.h | 16 ++++++- tests/test-grammar-integration.cpp | 3 +- 4 files changed, 39 insertions(+), 36 deletions(-) diff --git a/examples/gbnf-validator/gbnf-validator.cpp b/examples/gbnf-validator/gbnf-validator.cpp index 7493af9d3..2cf3bb047 100644 --- a/examples/gbnf-validator/gbnf-validator.cpp +++ b/examples/gbnf-validator/gbnf-validator.cpp @@ -15,10 +15,11 @@ static bool llama_grammar_validate(struct llama_grammar * grammar, const std::st llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar); size_t pos = 0; + llama_grammar_stacks_cache stacks_cache; for (const auto & cpt : cpts) { const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy - llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur); + llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur, stacks_cache); if (stacks_cur.empty()) { error_pos = pos; diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index 22c63ebfe..af72de9e0 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -687,31 +687,17 @@ static bool llama_grammar_match_partial_char( // additionally memorizes the stack to its possible stacks by mapping // < llama_grammar_stack, llama_grammar_stacks > -struct VectorPointerHash { - size_t operator()(const llama_grammar_stack & v) const { - size_t seed = v.size(); - for (const auto* ptr : v) { - seed ^= std::hash()(ptr) + 0x9e3779b9 + (seed << 6) + (seed >> 2); - } - return seed; - } -}; - -static std::unordered_map< - llama_grammar_stack, - llama_grammar_stacks, - VectorPointerHash> - llama_grammar_stacks_cache = {}; - static void llama_grammar_advance_stack_memo( const llama_grammar_rules & rules, const llama_grammar_stack & stack, - llama_grammar_stacks & new_stacks); + llama_grammar_stacks & new_stacks, + llama_grammar_stacks_cache & stacks_cache); static void llama_grammar_advance_stack_memo_impl( const llama_grammar_rules & rules, const llama_grammar_stack & stack, - llama_grammar_stacks & new_stacks) { + llama_grammar_stacks & new_stacks, + llama_grammar_stacks_cache & stacks_cache) { if (stack.empty()) { if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) { new_stacks.emplace_back(stack); @@ -736,7 +722,7 @@ static void llama_grammar_advance_stack_memo_impl( // if alternate is nonempty, add to stack new_stack.push_back(subpos); } - llama_grammar_advance_stack_memo(rules, new_stack, new_stacks); + llama_grammar_advance_stack_memo(rules, new_stack, new_stacks, stacks_cache); while (!llama_grammar_is_end_of_sequence(subpos)) { // scan to end of alternate def subpos++; @@ -769,17 +755,18 @@ static void llama_grammar_advance_stack_memo_impl( static void llama_grammar_advance_stack_memo( const llama_grammar_rules & rules, const llama_grammar_stack & stack, - llama_grammar_stacks & new_stacks) { + llama_grammar_stacks & new_stacks, + llama_grammar_stacks_cache & stacks_cache) { llama_grammar_stacks advanced_stacks; // Look if stack is already in memory - auto it = llama_grammar_stacks_cache.find(stack); - if (it != llama_grammar_stacks_cache.end()) { + auto it = stacks_cache.find(stack); + if (it != stacks_cache.end()) { advanced_stacks = it->second; } else { // Advance stacks with memorization - llama_grammar_advance_stack_memo_impl(rules, stack, advanced_stacks); - llama_grammar_stacks_cache.insert(make_pair(stack, advanced_stacks)); + llama_grammar_advance_stack_memo_impl(rules, stack, advanced_stacks, stacks_cache); + stacks_cache.insert(make_pair(stack, advanced_stacks)); } // Add the advanced stacks to new_stacks avoiding duplicates for (const auto & new_stack : advanced_stacks) { @@ -934,7 +921,8 @@ void llama_grammar_accept( const llama_grammar_rules & rules, const llama_grammar_stacks & stacks, const uint32_t chr, - llama_grammar_stacks & stacks_new) { + llama_grammar_stacks & stacks_new, + llama_grammar_stacks_cache & stacks_cache) { stacks_new.clear(); stacks_new.reserve(stacks.size()); @@ -952,7 +940,7 @@ void llama_grammar_accept( if (!llama_grammar_is_end_of_sequence(pos)) { new_stack.push_back(pos); } - llama_grammar_advance_stack_memo(rules, new_stack, stacks_new); + llama_grammar_advance_stack_memo(rules, new_stack, stacks_new, stacks_cache); } } } @@ -1019,8 +1007,6 @@ struct llama_grammar * llama_grammar_init_impl( const llama_grammar_element ** rules, size_t n_rules, size_t start_rule_index) { - // Clear stacks cache - llama_grammar_stacks_cache.clear(); const llama_grammar_element * pos; // copy rule definitions into vectors @@ -1048,6 +1034,7 @@ struct llama_grammar * llama_grammar_init_impl( // loop over alternates of start rule to build initial stacks llama_grammar_stacks stacks; + llama_grammar_stacks_cache stacks_cache; pos = vec_rules[start_rule_index].data(); do { llama_grammar_stack stack; @@ -1055,7 +1042,7 @@ struct llama_grammar * llama_grammar_init_impl( // if alternate is nonempty, add to stack stack.push_back(pos); } - llama_grammar_advance_stack_memo(vec_rules, stack, stacks); + llama_grammar_advance_stack_memo(vec_rules, stack, stacks, stacks_cache); while (!llama_grammar_is_end_of_sequence(pos)) { // scan to end of alternate def pos++; @@ -1075,8 +1062,6 @@ struct llama_grammar * llama_grammar_init_impl( } struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) { - // Clear stacks cache - llama_grammar_stacks_cache.clear(); llama_grammar_parser parser; // if there is a grammar, parse it @@ -1128,6 +1113,7 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, // loop over alternates of start rule to build initial stacks llama_grammar_stacks stacks; + llama_grammar_stacks_cache stacks_cache; pos = vec_rules[start_rule_index].data(); do { llama_grammar_stack stack; @@ -1135,7 +1121,7 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, // if alternate is nonempty, add to stack stack.push_back(pos); } - llama_grammar_advance_stack_memo(vec_rules, stack, stacks); + llama_grammar_advance_stack_memo(vec_rules, stack, stacks, stacks_cache); while (!llama_grammar_is_end_of_sequence(pos)) { // scan to end of alternate def pos++; @@ -1239,9 +1225,10 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token const auto & code_points = decoded.first; llama_grammar_stacks stacks_new; + llama_grammar_stacks_cache stacks_cache; for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { - llama_grammar_accept(grammar.rules, grammar.stacks, *it, stacks_new); + llama_grammar_accept(grammar.rules, grammar.stacks, *it, stacks_new, stacks_cache); grammar.stacks = std::move(stacks_new); } diff --git a/src/llama-grammar.h b/src/llama-grammar.h index f529ce351..de5e16874 100644 --- a/src/llama-grammar.h +++ b/src/llama-grammar.h @@ -3,6 +3,7 @@ #include "llama-impl.h" #include +#include struct llama_vocab; @@ -61,6 +62,18 @@ using llama_grammar_candidates = std::vector; const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar); llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar); +struct VectorPointerHash { + size_t operator()(const llama_grammar_stack & v) const { + size_t seed = v.size(); + for (const auto* ptr : v) { + seed ^= std::hash()(ptr) + 0x9e3779b9 + (seed << 6) + (seed >> 2); + } + return seed; + } +}; + +using llama_grammar_stacks_cache = std::unordered_map; + // takes a set of possible pushdown stacks on a grammar, which are required to // be positioned at a character range (see `llama_grammar_advance_stack`), and // produces the N possible stacks if the given char is accepted at those @@ -69,7 +82,8 @@ void llama_grammar_accept( const llama_grammar_rules & rules, const llama_grammar_stacks & stacks, uint32_t chr, - llama_grammar_stacks & stacks_new); + llama_grammar_stacks & stacks_new, + llama_grammar_stacks_cache & stacks_cache); std::vector llama_grammar_reject_candidates_for_stack( const llama_grammar_rules & rules, diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp index 5cc0cdb04..dc260b55a 100644 --- a/tests/test-grammar-integration.cpp +++ b/tests/test-grammar-integration.cpp @@ -35,10 +35,11 @@ static bool match_string(const std::string & input, llama_grammar * grammar) { const llama_grammar_rules & rules = llama_grammar_get_rules (grammar); llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar); + llama_grammar_stacks_cache stacks_cache; for (const auto & cpt : cpts) { const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy - llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur); + llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur, stacks_cache); if (stacks_cur.empty()) { // no stacks means that the grammar failed to match at this point