add stacks cache into llama_grammar

This commit is contained in:
Clarissa Miranda 2024-10-17 14:30:07 +11:00
parent 901a3479b1
commit 2aa6dd273a
4 changed files with 15 additions and 9 deletions

View File

@ -13,9 +13,9 @@ static bool llama_grammar_validate(struct llama_grammar * grammar, const std::st
const llama_grammar_rules & rules = llama_grammar_get_rules (grammar); const llama_grammar_rules & rules = llama_grammar_get_rules (grammar);
llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar); llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
llama_grammar_stacks_cache & stacks_cache = llama_grammar_get_stacks_cache(grammar);
size_t pos = 0; size_t pos = 0;
llama_grammar_stacks_cache stacks_cache;
for (const auto & cpt : cpts) { for (const auto & cpt : cpts) {
const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy

View File

@ -917,6 +917,10 @@ llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar)
return grammar->stacks; return grammar->stacks;
} }
llama_grammar_stacks_cache & llama_grammar_get_stacks_cache(struct llama_grammar * grammar) {
return grammar->stacks_cache;
}
void llama_grammar_accept( void llama_grammar_accept(
const llama_grammar_rules & rules, const llama_grammar_rules & rules,
const llama_grammar_stacks & stacks, const llama_grammar_stacks & stacks,
@ -1058,7 +1062,7 @@ struct llama_grammar * llama_grammar_init_impl(
// Important: vec_rules has to be moved here, not copied, because stacks contains // Important: vec_rules has to be moved here, not copied, because stacks contains
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
// then the pointers would be invalidated when the local vec_rules goes out of scope. // then the pointers would be invalidated when the local vec_rules goes out of scope.
return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, }; return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, std::move(stacks_cache), };
} }
struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) { struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) {
@ -1137,7 +1141,7 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab,
// Important: vec_rules has to be moved here, not copied, because stacks contains // Important: vec_rules has to be moved here, not copied, because stacks contains
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
// then the pointers would be invalidated when the local vec_rules goes out of scope. // then the pointers would be invalidated when the local vec_rules goes out of scope.
return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, }; return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, std::move(stacks_cache), };
} }
void llama_grammar_free_impl(struct llama_grammar * grammar) { void llama_grammar_free_impl(struct llama_grammar * grammar) {
@ -1225,10 +1229,9 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
const auto & code_points = decoded.first; const auto & code_points = decoded.first;
llama_grammar_stacks stacks_new; llama_grammar_stacks stacks_new;
llama_grammar_stacks_cache stacks_cache;
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
llama_grammar_accept(grammar.rules, grammar.stacks, *it, stacks_new, stacks_cache); llama_grammar_accept(grammar.rules, grammar.stacks, *it, stacks_new, grammar.stacks_cache);
grammar.stacks = std::move(stacks_new); grammar.stacks = std::move(stacks_new);
} }

View File

@ -59,9 +59,6 @@ using llama_grammar_rules = std::vector<llama_grammar_rule>;
using llama_grammar_stacks = std::vector<llama_grammar_stack>; using llama_grammar_stacks = std::vector<llama_grammar_stack>;
using llama_grammar_candidates = std::vector<llama_grammar_candidate>; using llama_grammar_candidates = std::vector<llama_grammar_candidate>;
const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar);
llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar);
struct VectorPointerHash { struct VectorPointerHash {
size_t operator()(const llama_grammar_stack & v) const { size_t operator()(const llama_grammar_stack & v) const {
size_t seed = v.size(); size_t seed = v.size();
@ -74,6 +71,10 @@ struct VectorPointerHash {
using llama_grammar_stacks_cache = std::unordered_map<llama_grammar_stack, llama_grammar_stacks, VectorPointerHash>; using llama_grammar_stacks_cache = std::unordered_map<llama_grammar_stack, llama_grammar_stacks, VectorPointerHash>;
const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar);
llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar);
llama_grammar_stacks_cache & llama_grammar_get_stacks_cache( struct llama_grammar * grammar);
// takes a set of possible pushdown stacks on a grammar, which are required to // takes a set of possible pushdown stacks on a grammar, which are required to
// be positioned at a character range (see `llama_grammar_advance_stack`), and // be positioned at a character range (see `llama_grammar_advance_stack`), and
// produces the N possible stacks if the given char is accepted at those // produces the N possible stacks if the given char is accepted at those
@ -129,6 +130,8 @@ struct llama_grammar {
// buffer for partially generated UTF-8 sequence from accepted tokens // buffer for partially generated UTF-8 sequence from accepted tokens
llama_partial_utf8 partial_utf8; llama_partial_utf8 partial_utf8;
// cache N possible stacks from a stack
llama_grammar_stacks_cache stacks_cache;
}; };
// //

View File

@ -34,8 +34,8 @@ static bool match_string(const std::string & input, llama_grammar * grammar) {
const llama_grammar_rules & rules = llama_grammar_get_rules (grammar); const llama_grammar_rules & rules = llama_grammar_get_rules (grammar);
llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar); llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
llama_grammar_stacks_cache & stacks_cache = llama_grammar_get_stacks_cache(grammar);
llama_grammar_stacks_cache stacks_cache;
for (const auto & cpt : cpts) { for (const auto & cpt : cpts) {
const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy