mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-10 10:41:47 +00:00
move cache stack to advance stack
This commit is contained in:
parent
cb1632b593
commit
901a3479b1
@ -15,10 +15,11 @@ static bool llama_grammar_validate(struct llama_grammar * grammar, const std::st
|
|||||||
llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
|
llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
|
||||||
|
|
||||||
size_t pos = 0;
|
size_t pos = 0;
|
||||||
|
llama_grammar_stacks_cache stacks_cache;
|
||||||
for (const auto & cpt : cpts) {
|
for (const auto & cpt : cpts) {
|
||||||
const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy
|
const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy
|
||||||
|
|
||||||
llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur);
|
llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur, stacks_cache);
|
||||||
|
|
||||||
if (stacks_cur.empty()) {
|
if (stacks_cur.empty()) {
|
||||||
error_pos = pos;
|
error_pos = pos;
|
||||||
|
@ -687,31 +687,17 @@ static bool llama_grammar_match_partial_char(
|
|||||||
// additionally memorizes the stack to its possible stacks by mapping
|
// additionally memorizes the stack to its possible stacks by mapping
|
||||||
// < llama_grammar_stack, llama_grammar_stacks >
|
// < llama_grammar_stack, llama_grammar_stacks >
|
||||||
|
|
||||||
struct VectorPointerHash {
|
|
||||||
size_t operator()(const llama_grammar_stack & v) const {
|
|
||||||
size_t seed = v.size();
|
|
||||||
for (const auto* ptr : v) {
|
|
||||||
seed ^= std::hash<const llama_grammar_element*>()(ptr) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
|
|
||||||
}
|
|
||||||
return seed;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
static std::unordered_map<
|
|
||||||
llama_grammar_stack,
|
|
||||||
llama_grammar_stacks,
|
|
||||||
VectorPointerHash>
|
|
||||||
llama_grammar_stacks_cache = {};
|
|
||||||
|
|
||||||
static void llama_grammar_advance_stack_memo(
|
static void llama_grammar_advance_stack_memo(
|
||||||
const llama_grammar_rules & rules,
|
const llama_grammar_rules & rules,
|
||||||
const llama_grammar_stack & stack,
|
const llama_grammar_stack & stack,
|
||||||
llama_grammar_stacks & new_stacks);
|
llama_grammar_stacks & new_stacks,
|
||||||
|
llama_grammar_stacks_cache & stacks_cache);
|
||||||
|
|
||||||
static void llama_grammar_advance_stack_memo_impl(
|
static void llama_grammar_advance_stack_memo_impl(
|
||||||
const llama_grammar_rules & rules,
|
const llama_grammar_rules & rules,
|
||||||
const llama_grammar_stack & stack,
|
const llama_grammar_stack & stack,
|
||||||
llama_grammar_stacks & new_stacks) {
|
llama_grammar_stacks & new_stacks,
|
||||||
|
llama_grammar_stacks_cache & stacks_cache) {
|
||||||
if (stack.empty()) {
|
if (stack.empty()) {
|
||||||
if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
|
if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
|
||||||
new_stacks.emplace_back(stack);
|
new_stacks.emplace_back(stack);
|
||||||
@ -736,7 +722,7 @@ static void llama_grammar_advance_stack_memo_impl(
|
|||||||
// if alternate is nonempty, add to stack
|
// if alternate is nonempty, add to stack
|
||||||
new_stack.push_back(subpos);
|
new_stack.push_back(subpos);
|
||||||
}
|
}
|
||||||
llama_grammar_advance_stack_memo(rules, new_stack, new_stacks);
|
llama_grammar_advance_stack_memo(rules, new_stack, new_stacks, stacks_cache);
|
||||||
while (!llama_grammar_is_end_of_sequence(subpos)) {
|
while (!llama_grammar_is_end_of_sequence(subpos)) {
|
||||||
// scan to end of alternate def
|
// scan to end of alternate def
|
||||||
subpos++;
|
subpos++;
|
||||||
@ -769,17 +755,18 @@ static void llama_grammar_advance_stack_memo_impl(
|
|||||||
static void llama_grammar_advance_stack_memo(
|
static void llama_grammar_advance_stack_memo(
|
||||||
const llama_grammar_rules & rules,
|
const llama_grammar_rules & rules,
|
||||||
const llama_grammar_stack & stack,
|
const llama_grammar_stack & stack,
|
||||||
llama_grammar_stacks & new_stacks) {
|
llama_grammar_stacks & new_stacks,
|
||||||
|
llama_grammar_stacks_cache & stacks_cache) {
|
||||||
|
|
||||||
llama_grammar_stacks advanced_stacks;
|
llama_grammar_stacks advanced_stacks;
|
||||||
// Look if stack is already in memory
|
// Look if stack is already in memory
|
||||||
auto it = llama_grammar_stacks_cache.find(stack);
|
auto it = stacks_cache.find(stack);
|
||||||
if (it != llama_grammar_stacks_cache.end()) {
|
if (it != stacks_cache.end()) {
|
||||||
advanced_stacks = it->second;
|
advanced_stacks = it->second;
|
||||||
} else {
|
} else {
|
||||||
// Advance stacks with memorization
|
// Advance stacks with memorization
|
||||||
llama_grammar_advance_stack_memo_impl(rules, stack, advanced_stacks);
|
llama_grammar_advance_stack_memo_impl(rules, stack, advanced_stacks, stacks_cache);
|
||||||
llama_grammar_stacks_cache.insert(make_pair(stack, advanced_stacks));
|
stacks_cache.insert(make_pair(stack, advanced_stacks));
|
||||||
}
|
}
|
||||||
// Add the advanced stacks to new_stacks avoiding duplicates
|
// Add the advanced stacks to new_stacks avoiding duplicates
|
||||||
for (const auto & new_stack : advanced_stacks) {
|
for (const auto & new_stack : advanced_stacks) {
|
||||||
@ -934,7 +921,8 @@ void llama_grammar_accept(
|
|||||||
const llama_grammar_rules & rules,
|
const llama_grammar_rules & rules,
|
||||||
const llama_grammar_stacks & stacks,
|
const llama_grammar_stacks & stacks,
|
||||||
const uint32_t chr,
|
const uint32_t chr,
|
||||||
llama_grammar_stacks & stacks_new) {
|
llama_grammar_stacks & stacks_new,
|
||||||
|
llama_grammar_stacks_cache & stacks_cache) {
|
||||||
stacks_new.clear();
|
stacks_new.clear();
|
||||||
stacks_new.reserve(stacks.size());
|
stacks_new.reserve(stacks.size());
|
||||||
|
|
||||||
@ -952,7 +940,7 @@ void llama_grammar_accept(
|
|||||||
if (!llama_grammar_is_end_of_sequence(pos)) {
|
if (!llama_grammar_is_end_of_sequence(pos)) {
|
||||||
new_stack.push_back(pos);
|
new_stack.push_back(pos);
|
||||||
}
|
}
|
||||||
llama_grammar_advance_stack_memo(rules, new_stack, stacks_new);
|
llama_grammar_advance_stack_memo(rules, new_stack, stacks_new, stacks_cache);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1019,8 +1007,6 @@ struct llama_grammar * llama_grammar_init_impl(
|
|||||||
const llama_grammar_element ** rules,
|
const llama_grammar_element ** rules,
|
||||||
size_t n_rules,
|
size_t n_rules,
|
||||||
size_t start_rule_index) {
|
size_t start_rule_index) {
|
||||||
// Clear stacks cache
|
|
||||||
llama_grammar_stacks_cache.clear();
|
|
||||||
const llama_grammar_element * pos;
|
const llama_grammar_element * pos;
|
||||||
|
|
||||||
// copy rule definitions into vectors
|
// copy rule definitions into vectors
|
||||||
@ -1048,6 +1034,7 @@ struct llama_grammar * llama_grammar_init_impl(
|
|||||||
|
|
||||||
// loop over alternates of start rule to build initial stacks
|
// loop over alternates of start rule to build initial stacks
|
||||||
llama_grammar_stacks stacks;
|
llama_grammar_stacks stacks;
|
||||||
|
llama_grammar_stacks_cache stacks_cache;
|
||||||
pos = vec_rules[start_rule_index].data();
|
pos = vec_rules[start_rule_index].data();
|
||||||
do {
|
do {
|
||||||
llama_grammar_stack stack;
|
llama_grammar_stack stack;
|
||||||
@ -1055,7 +1042,7 @@ struct llama_grammar * llama_grammar_init_impl(
|
|||||||
// if alternate is nonempty, add to stack
|
// if alternate is nonempty, add to stack
|
||||||
stack.push_back(pos);
|
stack.push_back(pos);
|
||||||
}
|
}
|
||||||
llama_grammar_advance_stack_memo(vec_rules, stack, stacks);
|
llama_grammar_advance_stack_memo(vec_rules, stack, stacks, stacks_cache);
|
||||||
while (!llama_grammar_is_end_of_sequence(pos)) {
|
while (!llama_grammar_is_end_of_sequence(pos)) {
|
||||||
// scan to end of alternate def
|
// scan to end of alternate def
|
||||||
pos++;
|
pos++;
|
||||||
@ -1075,8 +1062,6 @@ struct llama_grammar * llama_grammar_init_impl(
|
|||||||
}
|
}
|
||||||
|
|
||||||
struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) {
|
struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) {
|
||||||
// Clear stacks cache
|
|
||||||
llama_grammar_stacks_cache.clear();
|
|
||||||
llama_grammar_parser parser;
|
llama_grammar_parser parser;
|
||||||
|
|
||||||
// if there is a grammar, parse it
|
// if there is a grammar, parse it
|
||||||
@ -1128,6 +1113,7 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab,
|
|||||||
|
|
||||||
// loop over alternates of start rule to build initial stacks
|
// loop over alternates of start rule to build initial stacks
|
||||||
llama_grammar_stacks stacks;
|
llama_grammar_stacks stacks;
|
||||||
|
llama_grammar_stacks_cache stacks_cache;
|
||||||
pos = vec_rules[start_rule_index].data();
|
pos = vec_rules[start_rule_index].data();
|
||||||
do {
|
do {
|
||||||
llama_grammar_stack stack;
|
llama_grammar_stack stack;
|
||||||
@ -1135,7 +1121,7 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab,
|
|||||||
// if alternate is nonempty, add to stack
|
// if alternate is nonempty, add to stack
|
||||||
stack.push_back(pos);
|
stack.push_back(pos);
|
||||||
}
|
}
|
||||||
llama_grammar_advance_stack_memo(vec_rules, stack, stacks);
|
llama_grammar_advance_stack_memo(vec_rules, stack, stacks, stacks_cache);
|
||||||
while (!llama_grammar_is_end_of_sequence(pos)) {
|
while (!llama_grammar_is_end_of_sequence(pos)) {
|
||||||
// scan to end of alternate def
|
// scan to end of alternate def
|
||||||
pos++;
|
pos++;
|
||||||
@ -1239,9 +1225,10 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
|
|||||||
const auto & code_points = decoded.first;
|
const auto & code_points = decoded.first;
|
||||||
|
|
||||||
llama_grammar_stacks stacks_new;
|
llama_grammar_stacks stacks_new;
|
||||||
|
llama_grammar_stacks_cache stacks_cache;
|
||||||
|
|
||||||
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
||||||
llama_grammar_accept(grammar.rules, grammar.stacks, *it, stacks_new);
|
llama_grammar_accept(grammar.rules, grammar.stacks, *it, stacks_new, stacks_cache);
|
||||||
grammar.stacks = std::move(stacks_new);
|
grammar.stacks = std::move(stacks_new);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
#include "llama-impl.h"
|
#include "llama-impl.h"
|
||||||
|
|
||||||
#include <map>
|
#include <map>
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
struct llama_vocab;
|
struct llama_vocab;
|
||||||
|
|
||||||
@ -61,6 +62,18 @@ using llama_grammar_candidates = std::vector<llama_grammar_candidate>;
|
|||||||
const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar);
|
const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar);
|
||||||
llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar);
|
llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar);
|
||||||
|
|
||||||
|
struct VectorPointerHash {
|
||||||
|
size_t operator()(const llama_grammar_stack & v) const {
|
||||||
|
size_t seed = v.size();
|
||||||
|
for (const auto* ptr : v) {
|
||||||
|
seed ^= std::hash<const llama_grammar_element*>()(ptr) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
|
||||||
|
}
|
||||||
|
return seed;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
using llama_grammar_stacks_cache = std::unordered_map<llama_grammar_stack, llama_grammar_stacks, VectorPointerHash>;
|
||||||
|
|
||||||
// takes a set of possible pushdown stacks on a grammar, which are required to
|
// takes a set of possible pushdown stacks on a grammar, which are required to
|
||||||
// be positioned at a character range (see `llama_grammar_advance_stack`), and
|
// be positioned at a character range (see `llama_grammar_advance_stack`), and
|
||||||
// produces the N possible stacks if the given char is accepted at those
|
// produces the N possible stacks if the given char is accepted at those
|
||||||
@ -69,7 +82,8 @@ void llama_grammar_accept(
|
|||||||
const llama_grammar_rules & rules,
|
const llama_grammar_rules & rules,
|
||||||
const llama_grammar_stacks & stacks,
|
const llama_grammar_stacks & stacks,
|
||||||
uint32_t chr,
|
uint32_t chr,
|
||||||
llama_grammar_stacks & stacks_new);
|
llama_grammar_stacks & stacks_new,
|
||||||
|
llama_grammar_stacks_cache & stacks_cache);
|
||||||
|
|
||||||
std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
|
std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
|
||||||
const llama_grammar_rules & rules,
|
const llama_grammar_rules & rules,
|
||||||
|
@ -35,10 +35,11 @@ static bool match_string(const std::string & input, llama_grammar * grammar) {
|
|||||||
const llama_grammar_rules & rules = llama_grammar_get_rules (grammar);
|
const llama_grammar_rules & rules = llama_grammar_get_rules (grammar);
|
||||||
llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
|
llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
|
||||||
|
|
||||||
|
llama_grammar_stacks_cache stacks_cache;
|
||||||
for (const auto & cpt : cpts) {
|
for (const auto & cpt : cpts) {
|
||||||
const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy
|
const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy
|
||||||
|
|
||||||
llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur);
|
llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur, stacks_cache);
|
||||||
|
|
||||||
if (stacks_cur.empty()) {
|
if (stacks_cur.empty()) {
|
||||||
// no stacks means that the grammar failed to match at this point
|
// no stacks means that the grammar failed to match at this point
|
||||||
|
Loading…
Reference in New Issue
Block a user