mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-10 10:41:47 +00:00
add stacks cache into llama_grammar
This commit is contained in:
parent
901a3479b1
commit
2aa6dd273a
@ -13,9 +13,9 @@ static bool llama_grammar_validate(struct llama_grammar * grammar, const std::st
|
|||||||
|
|
||||||
const llama_grammar_rules & rules = llama_grammar_get_rules (grammar);
|
const llama_grammar_rules & rules = llama_grammar_get_rules (grammar);
|
||||||
llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
|
llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
|
||||||
|
llama_grammar_stacks_cache & stacks_cache = llama_grammar_get_stacks_cache(grammar);
|
||||||
|
|
||||||
size_t pos = 0;
|
size_t pos = 0;
|
||||||
llama_grammar_stacks_cache stacks_cache;
|
|
||||||
for (const auto & cpt : cpts) {
|
for (const auto & cpt : cpts) {
|
||||||
const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy
|
const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy
|
||||||
|
|
||||||
|
@ -917,6 +917,10 @@ llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar)
|
|||||||
return grammar->stacks;
|
return grammar->stacks;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
llama_grammar_stacks_cache & llama_grammar_get_stacks_cache(struct llama_grammar * grammar) {
|
||||||
|
return grammar->stacks_cache;
|
||||||
|
}
|
||||||
|
|
||||||
void llama_grammar_accept(
|
void llama_grammar_accept(
|
||||||
const llama_grammar_rules & rules,
|
const llama_grammar_rules & rules,
|
||||||
const llama_grammar_stacks & stacks,
|
const llama_grammar_stacks & stacks,
|
||||||
@ -1058,7 +1062,7 @@ struct llama_grammar * llama_grammar_init_impl(
|
|||||||
// Important: vec_rules has to be moved here, not copied, because stacks contains
|
// Important: vec_rules has to be moved here, not copied, because stacks contains
|
||||||
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
|
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
|
||||||
// then the pointers would be invalidated when the local vec_rules goes out of scope.
|
// then the pointers would be invalidated when the local vec_rules goes out of scope.
|
||||||
return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, };
|
return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, std::move(stacks_cache), };
|
||||||
}
|
}
|
||||||
|
|
||||||
struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) {
|
struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) {
|
||||||
@ -1137,7 +1141,7 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab,
|
|||||||
// Important: vec_rules has to be moved here, not copied, because stacks contains
|
// Important: vec_rules has to be moved here, not copied, because stacks contains
|
||||||
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
|
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
|
||||||
// then the pointers would be invalidated when the local vec_rules goes out of scope.
|
// then the pointers would be invalidated when the local vec_rules goes out of scope.
|
||||||
return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, };
|
return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, std::move(stacks_cache), };
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_grammar_free_impl(struct llama_grammar * grammar) {
|
void llama_grammar_free_impl(struct llama_grammar * grammar) {
|
||||||
@ -1225,10 +1229,9 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
|
|||||||
const auto & code_points = decoded.first;
|
const auto & code_points = decoded.first;
|
||||||
|
|
||||||
llama_grammar_stacks stacks_new;
|
llama_grammar_stacks stacks_new;
|
||||||
llama_grammar_stacks_cache stacks_cache;
|
|
||||||
|
|
||||||
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
||||||
llama_grammar_accept(grammar.rules, grammar.stacks, *it, stacks_new, stacks_cache);
|
llama_grammar_accept(grammar.rules, grammar.stacks, *it, stacks_new, grammar.stacks_cache);
|
||||||
grammar.stacks = std::move(stacks_new);
|
grammar.stacks = std::move(stacks_new);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -59,9 +59,6 @@ using llama_grammar_rules = std::vector<llama_grammar_rule>;
|
|||||||
using llama_grammar_stacks = std::vector<llama_grammar_stack>;
|
using llama_grammar_stacks = std::vector<llama_grammar_stack>;
|
||||||
using llama_grammar_candidates = std::vector<llama_grammar_candidate>;
|
using llama_grammar_candidates = std::vector<llama_grammar_candidate>;
|
||||||
|
|
||||||
const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar);
|
|
||||||
llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar);
|
|
||||||
|
|
||||||
struct VectorPointerHash {
|
struct VectorPointerHash {
|
||||||
size_t operator()(const llama_grammar_stack & v) const {
|
size_t operator()(const llama_grammar_stack & v) const {
|
||||||
size_t seed = v.size();
|
size_t seed = v.size();
|
||||||
@ -74,6 +71,10 @@ struct VectorPointerHash {
|
|||||||
|
|
||||||
using llama_grammar_stacks_cache = std::unordered_map<llama_grammar_stack, llama_grammar_stacks, VectorPointerHash>;
|
using llama_grammar_stacks_cache = std::unordered_map<llama_grammar_stack, llama_grammar_stacks, VectorPointerHash>;
|
||||||
|
|
||||||
|
const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar);
|
||||||
|
llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar);
|
||||||
|
llama_grammar_stacks_cache & llama_grammar_get_stacks_cache( struct llama_grammar * grammar);
|
||||||
|
|
||||||
// takes a set of possible pushdown stacks on a grammar, which are required to
|
// takes a set of possible pushdown stacks on a grammar, which are required to
|
||||||
// be positioned at a character range (see `llama_grammar_advance_stack`), and
|
// be positioned at a character range (see `llama_grammar_advance_stack`), and
|
||||||
// produces the N possible stacks if the given char is accepted at those
|
// produces the N possible stacks if the given char is accepted at those
|
||||||
@ -129,6 +130,8 @@ struct llama_grammar {
|
|||||||
|
|
||||||
// buffer for partially generated UTF-8 sequence from accepted tokens
|
// buffer for partially generated UTF-8 sequence from accepted tokens
|
||||||
llama_partial_utf8 partial_utf8;
|
llama_partial_utf8 partial_utf8;
|
||||||
|
// cache N possible stacks from a stack
|
||||||
|
llama_grammar_stacks_cache stacks_cache;
|
||||||
};
|
};
|
||||||
|
|
||||||
//
|
//
|
||||||
|
@ -34,8 +34,8 @@ static bool match_string(const std::string & input, llama_grammar * grammar) {
|
|||||||
|
|
||||||
const llama_grammar_rules & rules = llama_grammar_get_rules (grammar);
|
const llama_grammar_rules & rules = llama_grammar_get_rules (grammar);
|
||||||
llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
|
llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
|
||||||
|
llama_grammar_stacks_cache & stacks_cache = llama_grammar_get_stacks_cache(grammar);
|
||||||
|
|
||||||
llama_grammar_stacks_cache stacks_cache;
|
|
||||||
for (const auto & cpt : cpts) {
|
for (const auto & cpt : cpts) {
|
||||||
const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy
|
const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user