mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 03:14:35 +00:00
* Fix unicode in grammars (fixes #2501) * add more comments * fix test-llama-grammar
This commit is contained in:
parent
10151bee2e
commit
604b8bdfa6
161
llama.cpp
161
llama.cpp
@ -2077,37 +2077,81 @@ static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, co
|
|||||||
// grammar - internal
|
// grammar - internal
|
||||||
//
|
//
|
||||||
|
|
||||||
|
struct llama_partial_utf8 {
|
||||||
|
uint32_t value; // bit value so far (unshifted)
|
||||||
|
int n_remain; // num bytes remaining; -1 indicates invalid sequence
|
||||||
|
};
|
||||||
|
|
||||||
struct llama_grammar {
|
struct llama_grammar {
|
||||||
const std::vector<std::vector<llama_grammar_element>> rules;
|
const std::vector<std::vector<llama_grammar_element>> rules;
|
||||||
std::vector<std::vector<const llama_grammar_element *>> stacks;
|
std::vector<std::vector<const llama_grammar_element *>> stacks;
|
||||||
|
|
||||||
|
// buffer for partially generated UTF-8 sequence from accepted tokens
|
||||||
|
llama_partial_utf8 partial_utf8;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_grammar_candidate {
|
struct llama_grammar_candidate {
|
||||||
size_t index;
|
size_t index;
|
||||||
const uint32_t * code_points;
|
const uint32_t * code_points;
|
||||||
|
llama_partial_utf8 partial_utf8;
|
||||||
};
|
};
|
||||||
|
|
||||||
// NOTE: assumes valid utf8 (but checks for overrun)
|
// Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
|
||||||
// adds a terminating 0 for use as pointer
|
// pointer. If an invalid sequence is encountered, returns `llama_partial_utf8.n_remain == -1`.
|
||||||
std::vector<uint32_t> decode_utf8(const char * src) {
|
std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
|
||||||
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
|
const char * src,
|
||||||
|
llama_partial_utf8 partial_start) {
|
||||||
|
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
|
||||||
const char * pos = src;
|
const char * pos = src;
|
||||||
std::vector<uint32_t> code_points;
|
std::vector<uint32_t> code_points;
|
||||||
|
uint32_t value = partial_start.value;
|
||||||
|
int n_remain = partial_start.n_remain;
|
||||||
|
|
||||||
|
// continue previous decode, if applicable
|
||||||
|
while (*pos != 0 && n_remain > 0) {
|
||||||
|
uint8_t next_byte = static_cast<uint8_t>(*pos);
|
||||||
|
if ((next_byte >> 6) != 2) {
|
||||||
|
// invalid sequence, abort
|
||||||
|
code_points.push_back(0);
|
||||||
|
return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, -1 });
|
||||||
|
}
|
||||||
|
value = (value << 6) + (next_byte & 0x3F);
|
||||||
|
++pos;
|
||||||
|
--n_remain;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (partial_start.n_remain > 0 && n_remain == 0) {
|
||||||
|
code_points.push_back(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
// decode any subsequent utf-8 sequences, which may end in an incomplete one
|
||||||
while (*pos != 0) {
|
while (*pos != 0) {
|
||||||
uint8_t first_byte = static_cast<uint8_t>(*pos);
|
uint8_t first_byte = static_cast<uint8_t>(*pos);
|
||||||
uint8_t highbits = first_byte >> 4;
|
uint8_t highbits = first_byte >> 4;
|
||||||
int len = lookup[highbits];
|
n_remain = lookup[highbits] - 1;
|
||||||
uint8_t mask = (1 << (8 - len)) - 1;
|
|
||||||
uint32_t value = first_byte & mask;
|
if (n_remain < 0) {
|
||||||
const char * end = pos + len; // may overrun!
|
// invalid sequence, abort
|
||||||
++pos;
|
code_points.clear();
|
||||||
for ( ; pos < end && *pos != 0; ++pos) {
|
code_points.push_back(0);
|
||||||
value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
|
return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, n_remain });
|
||||||
|
}
|
||||||
|
|
||||||
|
uint8_t mask = (1 << (7 - n_remain)) - 1;
|
||||||
|
value = first_byte & mask;
|
||||||
|
++pos;
|
||||||
|
while (*pos != 0 && n_remain > 0) {
|
||||||
|
value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
|
||||||
|
++pos;
|
||||||
|
--n_remain;
|
||||||
|
}
|
||||||
|
if (n_remain == 0) {
|
||||||
|
code_points.push_back(value);
|
||||||
}
|
}
|
||||||
code_points.push_back(value);
|
|
||||||
}
|
}
|
||||||
code_points.push_back(0);
|
code_points.push_back(0);
|
||||||
return code_points;
|
|
||||||
|
return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain });
|
||||||
}
|
}
|
||||||
|
|
||||||
// returns true iff pos points to the end of one of the definitions of a rule
|
// returns true iff pos points to the end of one of the definitions of a rule
|
||||||
@ -2144,6 +2188,56 @@ static std::pair<bool, const llama_grammar_element *> llama_grammar_match_char(
|
|||||||
return std::make_pair(found == is_positive_char, pos);
|
return std::make_pair(found == is_positive_char, pos);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// returns true iff some continuation of the given partial UTF-8 sequence could satisfy the char
|
||||||
|
// range at pos (regular or inverse range)
|
||||||
|
// asserts that pos is pointing to a char range element
|
||||||
|
static bool llama_grammar_match_partial_char(
|
||||||
|
const llama_grammar_element * pos,
|
||||||
|
const llama_partial_utf8 partial_utf8) {
|
||||||
|
|
||||||
|
bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR;
|
||||||
|
LLAMA_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT);
|
||||||
|
|
||||||
|
uint32_t partial_value = partial_utf8.value;
|
||||||
|
int n_remain = partial_utf8.n_remain;
|
||||||
|
|
||||||
|
// invalid sequence or 7-bit char split across 2 bytes (overlong)
|
||||||
|
if (n_remain < 0 || (n_remain == 1 && partial_value < 2)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// range of possible code points this partial UTF-8 sequence could complete to
|
||||||
|
uint32_t low = partial_value << (n_remain * 6);
|
||||||
|
uint32_t high = low | ((1 << (n_remain * 6)) - 1);
|
||||||
|
|
||||||
|
if (low == 0) {
|
||||||
|
if (n_remain == 2) {
|
||||||
|
low = 1 << 11;
|
||||||
|
} else if (n_remain == 3) {
|
||||||
|
low = 1 << 16;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
do {
|
||||||
|
if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) {
|
||||||
|
// inclusive range, e.g. [a-z]
|
||||||
|
if (pos->value <= high && low <= pos[1].value) {
|
||||||
|
return is_positive_char;
|
||||||
|
}
|
||||||
|
pos += 2;
|
||||||
|
} else {
|
||||||
|
// exact char match, e.g. [a] or "a"
|
||||||
|
if (low <= pos->value && pos->value <= high) {
|
||||||
|
return is_positive_char;
|
||||||
|
}
|
||||||
|
pos += 1;
|
||||||
|
}
|
||||||
|
} while (pos->type == LLAMA_GRETYPE_CHAR_ALT);
|
||||||
|
|
||||||
|
return !is_positive_char;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
// transforms a grammar pushdown stack into N possible stacks, all ending
|
// transforms a grammar pushdown stack into N possible stacks, all ending
|
||||||
// at a character range (terminal element)
|
// at a character range (terminal element)
|
||||||
static void llama_grammar_advance_stack(
|
static void llama_grammar_advance_stack(
|
||||||
@ -2244,8 +2338,11 @@ static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_
|
|||||||
std::vector<llama_grammar_candidate> rejects;
|
std::vector<llama_grammar_candidate> rejects;
|
||||||
|
|
||||||
if (stack.empty()) {
|
if (stack.empty()) {
|
||||||
// accept nothing; EOS is handled elsewhere
|
for (auto tok : candidates) {
|
||||||
rejects.insert(rejects.end(), candidates.begin(), candidates.end());
|
if (*tok.code_points != 0 || tok.partial_utf8.n_remain != 0) {
|
||||||
|
rejects.push_back(tok);
|
||||||
|
}
|
||||||
|
}
|
||||||
return rejects;
|
return rejects;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2253,10 +2350,15 @@ static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_
|
|||||||
|
|
||||||
std::vector<llama_grammar_candidate> next_candidates;
|
std::vector<llama_grammar_candidate> next_candidates;
|
||||||
for (auto tok : candidates) {
|
for (auto tok : candidates) {
|
||||||
if (llama_grammar_match_char(stack_pos, tok.code_points[0]).first) {
|
if (*tok.code_points == 0) {
|
||||||
if (tok.code_points[1] != 0) {
|
// reached end of full codepoints in token, reject iff it ended in a partial sequence
|
||||||
next_candidates.push_back({ tok.index, tok.code_points + 1 });
|
// that cannot satisfy this position in grammar
|
||||||
|
if (tok.partial_utf8.n_remain != 0 &&
|
||||||
|
!llama_grammar_match_partial_char(stack_pos, tok.partial_utf8)) {
|
||||||
|
rejects.push_back(tok);
|
||||||
}
|
}
|
||||||
|
} else if (llama_grammar_match_char(stack_pos, *tok.code_points).first) {
|
||||||
|
next_candidates.push_back({ tok.index, tok.code_points + 1, tok.partial_utf8 });
|
||||||
} else {
|
} else {
|
||||||
rejects.push_back(tok);
|
rejects.push_back(tok);
|
||||||
}
|
}
|
||||||
@ -2274,7 +2376,7 @@ static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_
|
|||||||
|
|
||||||
auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates);
|
auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates);
|
||||||
for (auto tok : next_rejects) {
|
for (auto tok : next_rejects) {
|
||||||
rejects.push_back({ tok.index, tok.code_points - 1 });
|
rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8 });
|
||||||
}
|
}
|
||||||
|
|
||||||
return rejects;
|
return rejects;
|
||||||
@ -2339,7 +2441,7 @@ struct llama_grammar * llama_grammar_init(
|
|||||||
}
|
}
|
||||||
} while (true);
|
} while (true);
|
||||||
|
|
||||||
return new llama_grammar{ std::move(vec_rules), std::move(stacks) };
|
return new llama_grammar{ std::move(vec_rules), std::move(stacks), {} };
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_grammar_free(struct llama_grammar * grammar) {
|
void llama_grammar_free(struct llama_grammar * grammar) {
|
||||||
@ -2645,8 +2747,8 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
|
|||||||
|
|
||||||
const llama_token eos = llama_token_eos();
|
const llama_token eos = llama_token_eos();
|
||||||
|
|
||||||
std::vector<std::vector<uint32_t>> candidates_decoded;
|
std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
|
||||||
std::vector<llama_grammar_candidate> candidates_grammar;
|
std::vector<llama_grammar_candidate> candidates_grammar;
|
||||||
|
|
||||||
for (size_t i = 0; i < candidates->size; ++i) {
|
for (size_t i = 0; i < candidates->size; ++i) {
|
||||||
const llama_token id = candidates->data[i].id;
|
const llama_token id = candidates->data[i].id;
|
||||||
@ -2658,8 +2760,10 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
|
|||||||
} else if (*str == 0) {
|
} else if (*str == 0) {
|
||||||
candidates->data[i].logit = -INFINITY;
|
candidates->data[i].logit = -INFINITY;
|
||||||
} else {
|
} else {
|
||||||
candidates_decoded.push_back(decode_utf8(str));
|
candidates_decoded.push_back(decode_utf8(str, grammar->partial_utf8));
|
||||||
candidates_grammar.push_back({ i, candidates_decoded.back().data() });
|
candidates_grammar.push_back({
|
||||||
|
i, candidates_decoded.back().first.data(), candidates_decoded.back().second
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2860,11 +2964,14 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
|
|||||||
}
|
}
|
||||||
|
|
||||||
const char * str = llama_token_to_str(ctx, token);
|
const char * str = llama_token_to_str(ctx, token);
|
||||||
|
|
||||||
// Note terminating 0 in decoded string
|
// Note terminating 0 in decoded string
|
||||||
auto code_points = decode_utf8(str);
|
const auto decoded = decode_utf8(str, grammar->partial_utf8);
|
||||||
|
const auto & code_points = decoded.first;
|
||||||
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
||||||
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
|
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
|
||||||
}
|
}
|
||||||
|
grammar->partial_utf8 = decoded.second;
|
||||||
LLAMA_ASSERT(!grammar->stacks.empty());
|
LLAMA_ASSERT(!grammar->stacks.empty());
|
||||||
|
|
||||||
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||||
|
@ -199,7 +199,7 @@ int main()
|
|||||||
uint32_t *cp = new uint32_t[2]; // dynamically allocate memory for code_point
|
uint32_t *cp = new uint32_t[2]; // dynamically allocate memory for code_point
|
||||||
cp[0] = 37 + i;
|
cp[0] = 37 + i;
|
||||||
cp[1] = 0;
|
cp[1] = 0;
|
||||||
next_candidates[i] = {i, cp};
|
next_candidates[i] = {i, cp, {}};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::vector<std::pair<uint32_t, uint16_t>>> expected_reject = {
|
std::vector<std::vector<std::pair<uint32_t, uint16_t>>> expected_reject = {
|
||||||
|
Loading…
Reference in New Issue
Block a user