From 69f2fafebcfffb123f67ec4233b0aa6aa85453e6 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 3 Sep 2023 15:25:53 +0300 Subject: [PATCH] speculative : add grammar support --- examples/speculative/speculative.cpp | 81 +++++++++++++++++++++++++++- llama.cpp | 19 +++++++ llama.h | 2 + 3 files changed, 100 insertions(+), 2 deletions(-) diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index f0400c13f..594d4f5d6 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -6,6 +6,7 @@ #include "common.h" #include "llama.h" +#include "grammar-parser.h" #include #include @@ -109,6 +110,41 @@ int main(int argc, char ** argv) { // used to determine end of generation bool has_eos = false; + // grammar stuff + struct llama_grammar * grammar_dft = NULL; + struct llama_grammar * grammar_tgt = NULL; + + grammar_parser::parse_state parsed_grammar_dft; + grammar_parser::parse_state parsed_grammar_tgt; + + std::vector grammar_mem(n_draft, NULL); + + if (!params.grammar.empty()) { + // dft + { + parsed_grammar_dft = grammar_parser::parse(params.grammar.c_str()); + // will be empty (default) if there are parse errors + if (parsed_grammar_dft.rules.empty()) { + return 1; + } + + std::vector grammar_rules(parsed_grammar_dft.c_rules()); + grammar_dft = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar_dft.symbol_ids.at("root")); + } + + // tgt + { + parsed_grammar_tgt = grammar_parser::parse(params.grammar.c_str()); + // will be empty (default) if there are parse errors + if (parsed_grammar_tgt.rules.empty()) { + return 1; + } + + std::vector grammar_rules(parsed_grammar_tgt.c_rules()); + grammar_tgt = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar_tgt.symbol_ids.at("root")); + } + } + const auto t_dec_start = ggml_time_us(); while (true) { @@ -117,7 +153,7 @@ int main(int argc, char ** argv) { // sample from the drafted tokens if any int i_dft = 0; while (true) { - const llama_token id = llama_sample_token(ctx_tgt, NULL, NULL, params, last_tokens, candidates, i_dft); + const llama_token id = llama_sample_token(ctx_tgt, NULL, grammar_tgt, params, last_tokens, candidates, i_dft); last_tokens.erase(last_tokens.begin()); last_tokens.push_back(id); @@ -144,6 +180,24 @@ int main(int argc, char ** argv) { continue; } + if (i_dft < (int) drafted.size()) { + LOG("drafted token %d rejected\n", id); + + if (grammar_mem[i_dft]) { + grammar_dft = llama_grammar_copy(grammar_mem[i_dft]); + LOG("restored grammar %d\n", i_dft); + } + } + + for (auto & g : grammar_mem) { + if (g) { + llama_grammar_free(g); + g = NULL; + } + } + + LOG("i_dft = %d, drafted.size() = %d\n", i_dft, (int) drafted.size()); + // the drafted token was rejected or we are out of drafted tokens llama_eval(ctx_dft, &id, 1, n_past_dft, params.n_threads); ++n_past_dft; @@ -151,6 +205,10 @@ int main(int argc, char ** argv) { drafted.clear(); drafted.push_back(id); + if (grammar_dft != NULL) { + llama_grammar_accept_token(ctx_dft, grammar_dft, id); + } + break; } @@ -161,6 +219,11 @@ int main(int argc, char ** argv) { // sample n_draft tokens from the draft model picking the best token int n_past_cur = n_past_dft; for (int i = 0; i < n_draft; ++i) { + // remember the grammar state + if (grammar_dft != NULL) { + grammar_mem[i] = llama_grammar_copy(grammar_dft); + } + float * logits = llama_get_logits(ctx_dft); candidates.clear(); @@ -170,6 +233,10 @@ int main(int argc, char ** argv) { llama_token_data_array cur_p = { candidates.data(), candidates.size(), false }; + if (grammar_dft != NULL) { + llama_sample_grammar(ctx_dft, &cur_p, grammar_dft); + } + // computes softmax and sorts the candidates llama_sample_softmax(ctx_dft, &cur_p); @@ -182,7 +249,13 @@ int main(int argc, char ** argv) { break; } - drafted.push_back(cur_p.data[0].id); + const llama_token id = cur_p.data[0].id; + + if (grammar_dft != NULL) { + llama_grammar_accept_token(ctx_dft, grammar_dft, id); + } + + drafted.push_back(id); ++n_drafted; if (i < n_draft - 1) { @@ -226,6 +299,10 @@ int main(int argc, char ** argv) { llama_free(ctx_dft); llama_free_model(model_dft); + if (grammar_dft != NULL) { + llama_grammar_free(grammar_dft); + llama_grammar_free(grammar_tgt); + } llama_backend_free(); fprintf(stderr, "\n\n"); diff --git a/llama.cpp b/llama.cpp index c97c1462f..cbf255115 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3850,6 +3850,25 @@ void llama_grammar_free(struct llama_grammar * grammar) { delete grammar; } +struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar) { + llama_grammar * result = new llama_grammar{ grammar->rules, grammar->stacks, grammar->partial_utf8 }; + + // redirect elements in stacks to point to new rules + for (size_t is = 0; is < result->stacks.size(); is++) { + for (size_t ie = 0; ie < result->stacks[is].size(); ie++) { + for (size_t ir0 = 0; ir0 < grammar->rules.size(); ir0++) { + for (size_t ir1 = 0; ir1 < grammar->rules[ir0].size(); ir1++) { + if (grammar->stacks[is][ie] == &grammar->rules[ir0][ir1]) { + result->stacks[is][ie] = &result->rules[ir0][ir1]; + } + } + } + } + } + + return result; +} + // // sampling // diff --git a/llama.h b/llama.h index 422f28527..5b95aaa87 100644 --- a/llama.h +++ b/llama.h @@ -410,6 +410,8 @@ extern "C" { LLAMA_API void llama_grammar_free(struct llama_grammar * grammar); + LLAMA_API struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar); + // // Sampling functions //