From f837c3a992b2b6146936cb120871a8cf9d0e3857 Mon Sep 17 00:00:00 2001 From: Marcus Dunn <51931484+MarcusDunn@users.noreply.github.com> Date: Sat, 25 Nov 2023 08:58:23 -0800 Subject: [PATCH] llama : grammar `reserve` space in `decode_utf8` (#4210) * reserve space for codepoints * improvement for the appended 0 --- llama.cpp | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/llama.cpp b/llama.cpp index c5f4053f2..f2b5967d7 100644 --- a/llama.cpp +++ b/llama.cpp @@ -6420,10 +6420,13 @@ struct llama_grammar_candidate { // pointer. If an invalid sequence is encountered, returns `llama_partial_utf8.n_remain == -1`. static std::pair, llama_partial_utf8> decode_utf8( const char * src, + size_t n_src, llama_partial_utf8 partial_start) { static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 }; const char * pos = src; std::vector code_points; + // common english strings have the same number of codepoints and bytes. `+ 1` for the terminating 0. + code_points.reserve(n_src + 1); uint32_t value = partial_start.value; int n_remain = partial_start.n_remain; @@ -6474,6 +6477,13 @@ static std::pair, llama_partial_utf8> decode_utf8( return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain }); } +static std::pair, llama_partial_utf8> decode_utf8( + std::string src, + llama_partial_utf8 partial_start +) { + return decode_utf8(src.c_str(), src.size(), partial_start); +} + // returns true iff pos points to the end of one of the definitions of a rule static bool llama_grammar_is_end_of_sequence(const llama_grammar_element * pos) { switch (pos->type) { @@ -7123,7 +7133,7 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c } else if (piece.empty() || piece[0] == 0) { candidates->data[i].logit = -INFINITY; } else { - candidates_decoded.push_back(decode_utf8(piece.c_str(), grammar->partial_utf8)); + candidates_decoded.push_back(decode_utf8(piece, grammar->partial_utf8)); candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second }); } } @@ -7330,7 +7340,7 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar const std::string piece = llama_token_to_piece(ctx, token); // Note terminating 0 in decoded string - const auto decoded = decode_utf8(piece.c_str(), grammar->partial_utf8); + const auto decoded = decode_utf8(piece, grammar->partial_utf8); const auto & code_points = decoded.first; for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);