mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-10 10:41:47 +00:00
llama : adds llama-grammar memorization stacks (#4218)
This commit is contained in:
parent
7eee341bee
commit
cb1632b593
@ -682,6 +682,114 @@ static bool llama_grammar_match_partial_char(
|
|||||||
return !is_positive_char;
|
return !is_positive_char;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// transforms a grammar pushdown stack into N possible stacks, all ending
|
||||||
|
// at a character range (terminal element)
|
||||||
|
// additionally memorizes the stack to its possible stacks by mapping
|
||||||
|
// < llama_grammar_stack, llama_grammar_stacks >
|
||||||
|
|
||||||
|
struct VectorPointerHash {
|
||||||
|
size_t operator()(const llama_grammar_stack & v) const {
|
||||||
|
size_t seed = v.size();
|
||||||
|
for (const auto* ptr : v) {
|
||||||
|
seed ^= std::hash<const llama_grammar_element*>()(ptr) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
|
||||||
|
}
|
||||||
|
return seed;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
static std::unordered_map<
|
||||||
|
llama_grammar_stack,
|
||||||
|
llama_grammar_stacks,
|
||||||
|
VectorPointerHash>
|
||||||
|
llama_grammar_stacks_cache = {};
|
||||||
|
|
||||||
|
static void llama_grammar_advance_stack_memo(
|
||||||
|
const llama_grammar_rules & rules,
|
||||||
|
const llama_grammar_stack & stack,
|
||||||
|
llama_grammar_stacks & new_stacks);
|
||||||
|
|
||||||
|
static void llama_grammar_advance_stack_memo_impl(
|
||||||
|
const llama_grammar_rules & rules,
|
||||||
|
const llama_grammar_stack & stack,
|
||||||
|
llama_grammar_stacks & new_stacks) {
|
||||||
|
if (stack.empty()) {
|
||||||
|
if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
|
||||||
|
new_stacks.emplace_back(stack);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const llama_grammar_element * pos = stack.back();
|
||||||
|
|
||||||
|
switch (pos->type) {
|
||||||
|
case LLAMA_GRETYPE_RULE_REF: {
|
||||||
|
const size_t rule_id = static_cast<size_t>(pos->value);
|
||||||
|
const llama_grammar_element * subpos = rules[rule_id].data();
|
||||||
|
do {
|
||||||
|
// init new stack without the top (pos)
|
||||||
|
llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
|
||||||
|
if (!llama_grammar_is_end_of_sequence(pos + 1)) {
|
||||||
|
// if this rule ref is followed by another element, add that to stack
|
||||||
|
new_stack.push_back(pos + 1);
|
||||||
|
}
|
||||||
|
if (!llama_grammar_is_end_of_sequence(subpos)) {
|
||||||
|
// if alternate is nonempty, add to stack
|
||||||
|
new_stack.push_back(subpos);
|
||||||
|
}
|
||||||
|
llama_grammar_advance_stack_memo(rules, new_stack, new_stacks);
|
||||||
|
while (!llama_grammar_is_end_of_sequence(subpos)) {
|
||||||
|
// scan to end of alternate def
|
||||||
|
subpos++;
|
||||||
|
}
|
||||||
|
if (subpos->type == LLAMA_GRETYPE_ALT) {
|
||||||
|
// there's another alternate def of this rule to process
|
||||||
|
subpos++;
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} while (true);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case LLAMA_GRETYPE_CHAR:
|
||||||
|
case LLAMA_GRETYPE_CHAR_NOT:
|
||||||
|
case LLAMA_GRETYPE_CHAR_ANY:
|
||||||
|
if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
|
||||||
|
// only add the stack if it's not a duplicate of one we already have
|
||||||
|
new_stacks.emplace_back(stack);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
// end of alternate (LLAMA_GRETYPE_END, LLAMA_GRETYPE_ALT) or middle of char range
|
||||||
|
// (LLAMA_GRETYPE_CHAR_ALT, LLAMA_GRETYPE_CHAR_RNG_UPPER); stack should never be left on
|
||||||
|
// those
|
||||||
|
GGML_ABORT("fatal error");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void llama_grammar_advance_stack_memo(
|
||||||
|
const llama_grammar_rules & rules,
|
||||||
|
const llama_grammar_stack & stack,
|
||||||
|
llama_grammar_stacks & new_stacks) {
|
||||||
|
|
||||||
|
llama_grammar_stacks advanced_stacks;
|
||||||
|
// Look if stack is already in memory
|
||||||
|
auto it = llama_grammar_stacks_cache.find(stack);
|
||||||
|
if (it != llama_grammar_stacks_cache.end()) {
|
||||||
|
advanced_stacks = it->second;
|
||||||
|
} else {
|
||||||
|
// Advance stacks with memorization
|
||||||
|
llama_grammar_advance_stack_memo_impl(rules, stack, advanced_stacks);
|
||||||
|
llama_grammar_stacks_cache.insert(make_pair(stack, advanced_stacks));
|
||||||
|
}
|
||||||
|
// Add the advanced stacks to new_stacks avoiding duplicates
|
||||||
|
for (const auto & new_stack : advanced_stacks) {
|
||||||
|
if (std::find(new_stacks.begin(), new_stacks.end(), new_stack) == new_stacks.end()) {
|
||||||
|
new_stacks.emplace_back(new_stack);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
// transforms a grammar pushdown stack into N possible stacks, all ending
|
// transforms a grammar pushdown stack into N possible stacks, all ending
|
||||||
// at a character range (terminal element)
|
// at a character range (terminal element)
|
||||||
static void llama_grammar_advance_stack(
|
static void llama_grammar_advance_stack(
|
||||||
@ -844,7 +952,7 @@ void llama_grammar_accept(
|
|||||||
if (!llama_grammar_is_end_of_sequence(pos)) {
|
if (!llama_grammar_is_end_of_sequence(pos)) {
|
||||||
new_stack.push_back(pos);
|
new_stack.push_back(pos);
|
||||||
}
|
}
|
||||||
llama_grammar_advance_stack(rules, new_stack, stacks_new);
|
llama_grammar_advance_stack_memo(rules, new_stack, stacks_new);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -911,6 +1019,8 @@ struct llama_grammar * llama_grammar_init_impl(
|
|||||||
const llama_grammar_element ** rules,
|
const llama_grammar_element ** rules,
|
||||||
size_t n_rules,
|
size_t n_rules,
|
||||||
size_t start_rule_index) {
|
size_t start_rule_index) {
|
||||||
|
// Clear stacks cache
|
||||||
|
llama_grammar_stacks_cache.clear();
|
||||||
const llama_grammar_element * pos;
|
const llama_grammar_element * pos;
|
||||||
|
|
||||||
// copy rule definitions into vectors
|
// copy rule definitions into vectors
|
||||||
@ -945,7 +1055,7 @@ struct llama_grammar * llama_grammar_init_impl(
|
|||||||
// if alternate is nonempty, add to stack
|
// if alternate is nonempty, add to stack
|
||||||
stack.push_back(pos);
|
stack.push_back(pos);
|
||||||
}
|
}
|
||||||
llama_grammar_advance_stack(vec_rules, stack, stacks);
|
llama_grammar_advance_stack_memo(vec_rules, stack, stacks);
|
||||||
while (!llama_grammar_is_end_of_sequence(pos)) {
|
while (!llama_grammar_is_end_of_sequence(pos)) {
|
||||||
// scan to end of alternate def
|
// scan to end of alternate def
|
||||||
pos++;
|
pos++;
|
||||||
@ -965,6 +1075,8 @@ struct llama_grammar * llama_grammar_init_impl(
|
|||||||
}
|
}
|
||||||
|
|
||||||
struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) {
|
struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) {
|
||||||
|
// Clear stacks cache
|
||||||
|
llama_grammar_stacks_cache.clear();
|
||||||
llama_grammar_parser parser;
|
llama_grammar_parser parser;
|
||||||
|
|
||||||
// if there is a grammar, parse it
|
// if there is a grammar, parse it
|
||||||
@ -1023,7 +1135,7 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab,
|
|||||||
// if alternate is nonempty, add to stack
|
// if alternate is nonempty, add to stack
|
||||||
stack.push_back(pos);
|
stack.push_back(pos);
|
||||||
}
|
}
|
||||||
llama_grammar_advance_stack(vec_rules, stack, stacks);
|
llama_grammar_advance_stack_memo(vec_rules, stack, stacks);
|
||||||
while (!llama_grammar_is_end_of_sequence(pos)) {
|
while (!llama_grammar_is_end_of_sequence(pos)) {
|
||||||
// scan to end of alternate def
|
// scan to end of alternate def
|
||||||
pos++;
|
pos++;
|
||||||
|
Loading…
Reference in New Issue
Block a user