android : fix utf8 decoding error (#5935)

* examples: fix utf8 decoding error

some models have a tokenizer that decodes an id into an incomplete utf8 sequence, need to validate and wait for next token
one example would be: https://huggingface.co/Qwen/Qwen1.5-1.8B-Chat-GGUF/resolve/main/qwen1_5-1_8b-chat-q4_0.gguf and and an example of the token is 18137

* android : minor

---------

Co-authored-by: zhangfuwen <zhangfuwen@foxmail.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
Dean 2024-03-11 04:03:17 +08:00 committed by GitHub
parent d9f65c97c3
commit 7ab7b733bb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 53 additions and 4 deletions

View File

@ -33,6 +33,45 @@ jclass la_int_var;
jmethodID la_int_var_value; jmethodID la_int_var_value;
jmethodID la_int_var_inc; jmethodID la_int_var_inc;
std::string cached_token_chars;
bool is_valid_utf8(const char * string) {
if (!string) {
return true;
}
const unsigned char * bytes = (const unsigned char *)string;
int num;
while (*bytes != 0x00) {
if ((*bytes & 0x80) == 0x00) {
// U+0000 to U+007F
num = 1;
} else if ((*bytes & 0xE0) == 0xC0) {
// U+0080 to U+07FF
num = 2;
} else if ((*bytes & 0xF0) == 0xE0) {
// U+0800 to U+FFFF
num = 3;
} else if ((*bytes & 0xF8) == 0xF0) {
// U+10000 to U+10FFFF
num = 4;
} else {
return false;
}
bytes += 1;
for (int i = 1; i < num; ++i) {
if ((*bytes & 0xC0) != 0x80) {
return false;
}
bytes += 1;
}
}
return true;
}
static void log_callback(ggml_log_level level, const char * fmt, void * data) { static void log_callback(ggml_log_level level, const char * fmt, void * data) {
if (level == GGML_LOG_LEVEL_ERROR) __android_log_print(ANDROID_LOG_ERROR, TAG, fmt, data); if (level == GGML_LOG_LEVEL_ERROR) __android_log_print(ANDROID_LOG_ERROR, TAG, fmt, data);
else if (level == GGML_LOG_LEVEL_INFO) __android_log_print(ANDROID_LOG_INFO, TAG, fmt, data); else if (level == GGML_LOG_LEVEL_INFO) __android_log_print(ANDROID_LOG_INFO, TAG, fmt, data);
@ -295,6 +334,8 @@ Java_com_example_llama_Llm_completion_1init(
jint n_len jint n_len
) { ) {
cached_token_chars.clear();
const auto text = env->GetStringUTFChars(jtext, 0); const auto text = env->GetStringUTFChars(jtext, 0);
const auto context = reinterpret_cast<llama_context *>(context_pointer); const auto context = reinterpret_cast<llama_context *>(context_pointer);
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer); const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
@ -372,8 +413,16 @@ Java_com_example_llama_Llm_completion_1loop(
} }
auto new_token_chars = llama_token_to_piece(context, new_token_id); auto new_token_chars = llama_token_to_piece(context, new_token_id);
LOGi("new_token_chars: `%s`", new_token_chars.c_str()); cached_token_chars += new_token_chars;
auto new_token = env->NewStringUTF(new_token_chars.c_str());
jstring new_token = nullptr;
if (is_valid_utf8(cached_token_chars.c_str())) {
new_token = env->NewStringUTF(cached_token_chars.c_str());
LOGi("cached: %s, new_token_chars: `%s`, id: %d", cached_token_chars.c_str(), new_token_chars.c_str(), new_token_id);
cached_token_chars.clear();
} else {
new_token = env->NewStringUTF("");
}
llama_batch_clear(*batch); llama_batch_clear(*batch);
llama_batch_add(*batch, new_token_id, n_cur, { 0 }, true); llama_batch_add(*batch, new_token_id, n_cur, { 0 }, true);

View File

@ -71,7 +71,7 @@ class Llm {
batch: Long, batch: Long,
nLen: Int, nLen: Int,
ncur: IntVar ncur: IntVar
): String ): String?
private external fun kv_cache_clear(context: Long) private external fun kv_cache_clear(context: Long)
@ -115,7 +115,7 @@ class Llm {
val ncur = IntVar(completion_init(state.context, state.batch, message, nlen)) val ncur = IntVar(completion_init(state.context, state.batch, message, nlen))
while (ncur.value <= nlen) { while (ncur.value <= nlen) {
val str = completion_loop(state.context, state.batch, nlen, ncur) val str = completion_loop(state.context, state.batch, nlen, ncur)
if (str.isEmpty()) { if (str == null) {
break break
} }
emit(str) emit(str)