diff --git a/examples/llama.android/app/src/main/cpp/llama-android.cpp b/examples/llama.android/app/src/main/cpp/llama-android.cpp index 2beb1e0d5..ce8ab3b70 100644 --- a/examples/llama.android/app/src/main/cpp/llama-android.cpp +++ b/examples/llama.android/app/src/main/cpp/llama-android.cpp @@ -33,6 +33,45 @@ jclass la_int_var; jmethodID la_int_var_value; jmethodID la_int_var_inc; +std::string cached_token_chars; + +bool is_valid_utf8(const char * string) { + if (!string) { + return true; + } + + const unsigned char * bytes = (const unsigned char *)string; + int num; + + while (*bytes != 0x00) { + if ((*bytes & 0x80) == 0x00) { + // U+0000 to U+007F + num = 1; + } else if ((*bytes & 0xE0) == 0xC0) { + // U+0080 to U+07FF + num = 2; + } else if ((*bytes & 0xF0) == 0xE0) { + // U+0800 to U+FFFF + num = 3; + } else if ((*bytes & 0xF8) == 0xF0) { + // U+10000 to U+10FFFF + num = 4; + } else { + return false; + } + + bytes += 1; + for (int i = 1; i < num; ++i) { + if ((*bytes & 0xC0) != 0x80) { + return false; + } + bytes += 1; + } + } + + return true; +} + static void log_callback(ggml_log_level level, const char * fmt, void * data) { if (level == GGML_LOG_LEVEL_ERROR) __android_log_print(ANDROID_LOG_ERROR, TAG, fmt, data); else if (level == GGML_LOG_LEVEL_INFO) __android_log_print(ANDROID_LOG_INFO, TAG, fmt, data); @@ -295,6 +334,8 @@ Java_com_example_llama_Llm_completion_1init( jint n_len ) { + cached_token_chars.clear(); + const auto text = env->GetStringUTFChars(jtext, 0); const auto context = reinterpret_cast(context_pointer); const auto batch = reinterpret_cast(batch_pointer); @@ -372,8 +413,16 @@ Java_com_example_llama_Llm_completion_1loop( } auto new_token_chars = llama_token_to_piece(context, new_token_id); - LOGi("new_token_chars: `%s`", new_token_chars.c_str()); - auto new_token = env->NewStringUTF(new_token_chars.c_str()); + cached_token_chars += new_token_chars; + + jstring new_token = nullptr; + if (is_valid_utf8(cached_token_chars.c_str())) { + new_token = env->NewStringUTF(cached_token_chars.c_str()); + LOGi("cached: %s, new_token_chars: `%s`, id: %d", cached_token_chars.c_str(), new_token_chars.c_str(), new_token_id); + cached_token_chars.clear(); + } else { + new_token = env->NewStringUTF(""); + } llama_batch_clear(*batch); llama_batch_add(*batch, new_token_id, n_cur, { 0 }, true); diff --git a/examples/llama.android/app/src/main/java/com/example/llama/Llm.kt b/examples/llama.android/app/src/main/java/com/example/llama/Llm.kt index 5f3270372..d86afee37 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/Llm.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/Llm.kt @@ -71,7 +71,7 @@ class Llm { batch: Long, nLen: Int, ncur: IntVar - ): String + ): String? private external fun kv_cache_clear(context: Long) @@ -115,7 +115,7 @@ class Llm { val ncur = IntVar(completion_init(state.context, state.batch, message, nlen)) while (ncur.value <= nlen) { val str = completion_loop(state.context, state.batch, nlen, ncur) - if (str.isEmpty()) { + if (str == null) { break } emit(str)