From cb1632b5938ec36ed5f46ce929a6681e168ac216 Mon Sep 17 00:00:00 2001 From: Clarissa Miranda Date: Fri, 11 Oct 2024 12:20:48 +1100 Subject: [PATCH] llama : adds llama-grammar memorization stacks (#4218) --- src/llama-grammar.cpp | 118 ++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 115 insertions(+), 3 deletions(-) diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index 74e9f64b3..22c63ebfe 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -682,6 +682,114 @@ static bool llama_grammar_match_partial_char( return !is_positive_char; } +// transforms a grammar pushdown stack into N possible stacks, all ending +// at a character range (terminal element) +// additionally memorizes the stack to its possible stacks by mapping +// < llama_grammar_stack, llama_grammar_stacks > + +struct VectorPointerHash { + size_t operator()(const llama_grammar_stack & v) const { + size_t seed = v.size(); + for (const auto* ptr : v) { + seed ^= std::hash()(ptr) + 0x9e3779b9 + (seed << 6) + (seed >> 2); + } + return seed; + } +}; + +static std::unordered_map< + llama_grammar_stack, + llama_grammar_stacks, + VectorPointerHash> + llama_grammar_stacks_cache = {}; + +static void llama_grammar_advance_stack_memo( + const llama_grammar_rules & rules, + const llama_grammar_stack & stack, + llama_grammar_stacks & new_stacks); + +static void llama_grammar_advance_stack_memo_impl( + const llama_grammar_rules & rules, + const llama_grammar_stack & stack, + llama_grammar_stacks & new_stacks) { + if (stack.empty()) { + if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) { + new_stacks.emplace_back(stack); + } + return; + } + + const llama_grammar_element * pos = stack.back(); + + switch (pos->type) { + case LLAMA_GRETYPE_RULE_REF: { + const size_t rule_id = static_cast(pos->value); + const llama_grammar_element * subpos = rules[rule_id].data(); + do { + // init new stack without the top (pos) + llama_grammar_stack new_stack(stack.begin(), stack.end() - 1); + if (!llama_grammar_is_end_of_sequence(pos + 1)) { + // if this rule ref is followed by another element, add that to stack + new_stack.push_back(pos + 1); + } + if (!llama_grammar_is_end_of_sequence(subpos)) { + // if alternate is nonempty, add to stack + new_stack.push_back(subpos); + } + llama_grammar_advance_stack_memo(rules, new_stack, new_stacks); + while (!llama_grammar_is_end_of_sequence(subpos)) { + // scan to end of alternate def + subpos++; + } + if (subpos->type == LLAMA_GRETYPE_ALT) { + // there's another alternate def of this rule to process + subpos++; + } else { + break; + } + } while (true); + break; + } + case LLAMA_GRETYPE_CHAR: + case LLAMA_GRETYPE_CHAR_NOT: + case LLAMA_GRETYPE_CHAR_ANY: + if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) { + // only add the stack if it's not a duplicate of one we already have + new_stacks.emplace_back(stack); + } + break; + default: + // end of alternate (LLAMA_GRETYPE_END, LLAMA_GRETYPE_ALT) or middle of char range + // (LLAMA_GRETYPE_CHAR_ALT, LLAMA_GRETYPE_CHAR_RNG_UPPER); stack should never be left on + // those + GGML_ABORT("fatal error"); + } +} + +static void llama_grammar_advance_stack_memo( + const llama_grammar_rules & rules, + const llama_grammar_stack & stack, + llama_grammar_stacks & new_stacks) { + + llama_grammar_stacks advanced_stacks; + // Look if stack is already in memory + auto it = llama_grammar_stacks_cache.find(stack); + if (it != llama_grammar_stacks_cache.end()) { + advanced_stacks = it->second; + } else { + // Advance stacks with memorization + llama_grammar_advance_stack_memo_impl(rules, stack, advanced_stacks); + llama_grammar_stacks_cache.insert(make_pair(stack, advanced_stacks)); + } + // Add the advanced stacks to new_stacks avoiding duplicates + for (const auto & new_stack : advanced_stacks) { + if (std::find(new_stacks.begin(), new_stacks.end(), new_stack) == new_stacks.end()) { + new_stacks.emplace_back(new_stack); + } + } + +} + // transforms a grammar pushdown stack into N possible stacks, all ending // at a character range (terminal element) static void llama_grammar_advance_stack( @@ -844,7 +952,7 @@ void llama_grammar_accept( if (!llama_grammar_is_end_of_sequence(pos)) { new_stack.push_back(pos); } - llama_grammar_advance_stack(rules, new_stack, stacks_new); + llama_grammar_advance_stack_memo(rules, new_stack, stacks_new); } } } @@ -911,6 +1019,8 @@ struct llama_grammar * llama_grammar_init_impl( const llama_grammar_element ** rules, size_t n_rules, size_t start_rule_index) { + // Clear stacks cache + llama_grammar_stacks_cache.clear(); const llama_grammar_element * pos; // copy rule definitions into vectors @@ -945,7 +1055,7 @@ struct llama_grammar * llama_grammar_init_impl( // if alternate is nonempty, add to stack stack.push_back(pos); } - llama_grammar_advance_stack(vec_rules, stack, stacks); + llama_grammar_advance_stack_memo(vec_rules, stack, stacks); while (!llama_grammar_is_end_of_sequence(pos)) { // scan to end of alternate def pos++; @@ -965,6 +1075,8 @@ struct llama_grammar * llama_grammar_init_impl( } struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) { + // Clear stacks cache + llama_grammar_stacks_cache.clear(); llama_grammar_parser parser; // if there is a grammar, parse it @@ -1023,7 +1135,7 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, // if alternate is nonempty, add to stack stack.push_back(pos); } - llama_grammar_advance_stack(vec_rules, stack, stacks); + llama_grammar_advance_stack_memo(vec_rules, stack, stacks); while (!llama_grammar_is_end_of_sequence(pos)) { // scan to end of alternate def pos++;