llama : minor llama_grammar refactoring

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-10-17 12:19:28 +03:00
parent 2aa6dd273a
commit 17b3a3e8cc
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
5 changed files with 33 additions and 51 deletions

View File

@ -11,20 +11,15 @@
static bool llama_grammar_validate(struct llama_grammar * grammar, const std::string & input_str, size_t & error_pos, std::string & error_msg) { static bool llama_grammar_validate(struct llama_grammar * grammar, const std::string & input_str, size_t & error_pos, std::string & error_msg) {
const auto cpts = unicode_cpts_from_utf8(input_str); const auto cpts = unicode_cpts_from_utf8(input_str);
const llama_grammar_rules & rules = llama_grammar_get_rules (grammar); auto & stacks_cur = llama_grammar_get_stacks(grammar);
llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
llama_grammar_stacks_cache & stacks_cache = llama_grammar_get_stacks_cache(grammar);
size_t pos = 0; size_t pos = 0;
for (const auto & cpt : cpts) { for (const auto & cpt : cpts) {
const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy llama_grammar_accept(grammar, cpt);
llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur, stacks_cache);
if (stacks_cur.empty()) { if (stacks_cur.empty()) {
error_pos = pos; error_pos = pos;
error_msg = "Unexpected character '" + unicode_cpt_to_utf8(cpt) + "'"; error_msg = "Unexpected character '" + unicode_cpt_to_utf8(cpt) + "'";
stacks_cur = stacks_prev;
return false; return false;
} }
++pos; ++pos;
@ -83,7 +78,8 @@ int main(int argc, char** argv) {
llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root"); llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root");
if (grammar == nullptr) { if (grammar == nullptr) {
throw std::runtime_error("Failed to initialize llama_grammar"); fprintf(stdout, "Failed to initialize llama_grammar\n");
return 1;
} }
// Read the input file // Read the input file
std::string input_str; std::string input_str;

View File

@ -917,20 +917,11 @@ llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar)
return grammar->stacks; return grammar->stacks;
} }
llama_grammar_stacks_cache & llama_grammar_get_stacks_cache(struct llama_grammar * grammar) { void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr) {
return grammar->stacks_cache; llama_grammar_stacks stacks_new;
} stacks_new.reserve(grammar->stacks.size());
void llama_grammar_accept( for (const auto & stack : grammar->stacks) {
const llama_grammar_rules & rules,
const llama_grammar_stacks & stacks,
const uint32_t chr,
llama_grammar_stacks & stacks_new,
llama_grammar_stacks_cache & stacks_cache) {
stacks_new.clear();
stacks_new.reserve(stacks.size());
for (const auto & stack : stacks) {
if (stack.empty()) { if (stack.empty()) {
continue; continue;
} }
@ -944,9 +935,11 @@ void llama_grammar_accept(
if (!llama_grammar_is_end_of_sequence(pos)) { if (!llama_grammar_is_end_of_sequence(pos)) {
new_stack.push_back(pos); new_stack.push_back(pos);
} }
llama_grammar_advance_stack_memo(rules, new_stack, stacks_new, stacks_cache); llama_grammar_advance_stack_memo(grammar->rules, new_stack, stacks_new, grammar->stacks_cache);
} }
} }
grammar->stacks = std::move(stacks_new);
} }
llama_grammar_candidates llama_grammar_reject_candidates_for_stack( llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
@ -1062,7 +1055,7 @@ struct llama_grammar * llama_grammar_init_impl(
// Important: vec_rules has to be moved here, not copied, because stacks contains // Important: vec_rules has to be moved here, not copied, because stacks contains
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
// then the pointers would be invalidated when the local vec_rules goes out of scope. // then the pointers would be invalidated when the local vec_rules goes out of scope.
return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, std::move(stacks_cache), }; return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), std::move(stacks_cache), {}, };
} }
struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) { struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) {
@ -1141,7 +1134,7 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab,
// Important: vec_rules has to be moved here, not copied, because stacks contains // Important: vec_rules has to be moved here, not copied, because stacks contains
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
// then the pointers would be invalidated when the local vec_rules goes out of scope. // then the pointers would be invalidated when the local vec_rules goes out of scope.
return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, std::move(stacks_cache), }; return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), std::move(stacks_cache), {}, };
} }
void llama_grammar_free_impl(struct llama_grammar * grammar) { void llama_grammar_free_impl(struct llama_grammar * grammar) {
@ -1153,7 +1146,13 @@ void llama_grammar_free_impl(struct llama_grammar * grammar) {
} }
struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) { struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) {
llama_grammar * result = new llama_grammar { grammar.vocab, grammar.rules, grammar.stacks, grammar.partial_utf8, }; llama_grammar * result = new llama_grammar {
grammar.vocab,
grammar.rules,
grammar.stacks,
grammar.stacks_cache,
grammar.partial_utf8,
};
// redirect elements in stacks to point to new rules // redirect elements in stacks to point to new rules
for (size_t is = 0; is < result->stacks.size(); is++) { for (size_t is = 0; is < result->stacks.size(); is++) {
@ -1228,11 +1227,8 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
const auto decoded = decode_utf8(piece, grammar.partial_utf8); const auto decoded = decode_utf8(piece, grammar.partial_utf8);
const auto & code_points = decoded.first; const auto & code_points = decoded.first;
llama_grammar_stacks stacks_new;
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
llama_grammar_accept(grammar.rules, grammar.stacks, *it, stacks_new, grammar.stacks_cache); llama_grammar_accept(&grammar, *it);
grammar.stacks = std::move(stacks_new);
} }
grammar.partial_utf8 = decoded.second; grammar.partial_utf8 = decoded.second;

View File

@ -71,20 +71,15 @@ struct VectorPointerHash {
using llama_grammar_stacks_cache = std::unordered_map<llama_grammar_stack, llama_grammar_stacks, VectorPointerHash>; using llama_grammar_stacks_cache = std::unordered_map<llama_grammar_stack, llama_grammar_stacks, VectorPointerHash>;
// TODO: remove, needed for tests atm
const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar); const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar);
llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar); llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar);
llama_grammar_stacks_cache & llama_grammar_get_stacks_cache( struct llama_grammar * grammar);
// takes a set of possible pushdown stacks on a grammar, which are required to // takes a set of possible pushdown stacks on a grammar, which are required to
// be positioned at a character range (see `llama_grammar_advance_stack`), and // be positioned at a character range (see `llama_grammar_advance_stack`), and
// produces the N possible stacks if the given char is accepted at those // produces the N possible stacks if the given char is accepted at those
// positions // positions
void llama_grammar_accept( void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr);
const llama_grammar_rules & rules,
const llama_grammar_stacks & stacks,
uint32_t chr,
llama_grammar_stacks & stacks_new,
llama_grammar_stacks_cache & stacks_cache);
std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack( std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
const llama_grammar_rules & rules, const llama_grammar_rules & rules,
@ -128,10 +123,11 @@ struct llama_grammar {
const llama_grammar_rules rules; // TODO: shared ptr const llama_grammar_rules rules; // TODO: shared ptr
llama_grammar_stacks stacks; llama_grammar_stacks stacks;
// buffer for partially generated UTF-8 sequence from accepted tokens
llama_partial_utf8 partial_utf8;
// cache N possible stacks from a stack // cache N possible stacks from a stack
llama_grammar_stacks_cache stacks_cache; llama_grammar_stacks_cache stacks_cache;
// buffer for partially generated UTF-8 sequence from accepted tokens
llama_partial_utf8 partial_utf8;
}; };
// //

View File

@ -32,14 +32,10 @@ static bool test_build_grammar_fails(const std::string & grammar_str) {
static bool match_string(const std::string & input, llama_grammar * grammar) { static bool match_string(const std::string & input, llama_grammar * grammar) {
const auto cpts = unicode_cpts_from_utf8(input); const auto cpts = unicode_cpts_from_utf8(input);
const llama_grammar_rules & rules = llama_grammar_get_rules (grammar); auto & stacks_cur = llama_grammar_get_stacks(grammar);
llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
llama_grammar_stacks_cache & stacks_cache = llama_grammar_get_stacks_cache(grammar);
for (const auto & cpt : cpts) { for (const auto & cpt : cpts) {
const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy llama_grammar_accept(grammar, cpt);
llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur, stacks_cache);
if (stacks_cur.empty()) { if (stacks_cur.empty()) {
// no stacks means that the grammar failed to match at this point // no stacks means that the grammar failed to match at this point
@ -64,7 +60,7 @@ static void test(const std::string & test_desc, const std::string & grammar_str,
auto * grammar = build_grammar(grammar_str); auto * grammar = build_grammar(grammar_str);
// Save the original grammar stacks so that we can reset after every new string we want to test // Save the original grammar stacks so that we can reset after every new string we want to test
const llama_grammar_stacks stacks_org = llama_grammar_get_stacks(grammar); const llama_grammar_stacks stacks_org = llama_grammar_get_stacks(grammar); // copy
llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar); llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);

View File

@ -113,12 +113,10 @@ int main()
} }
} }
llama_grammar * grammar = NULL;
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules()); std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
grammar = llama_grammar_init_impl(nullptr, grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
if (grammar == nullptr) if (grammar == nullptr) {
{
throw std::runtime_error("Failed to initialize llama_grammar"); throw std::runtime_error("Failed to initialize llama_grammar");
} }