swift : fix concatenation method to avoid invalid UTF8 stringfication (#4325)

This commit is contained in:
Miwa / Ensan 2023-12-05 01:03:49 +09:00 committed by GitHub
parent 5c9f90cba1
commit d208995c6d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -11,6 +11,8 @@ actor LlamaContext {
private var context: OpaquePointer private var context: OpaquePointer
private var batch: llama_batch private var batch: llama_batch
private var tokens_list: [llama_token] private var tokens_list: [llama_token]
/// This variable is used to store temporarily invalid cchars
private var temporary_invalid_cchars: [CChar]
var n_len: Int32 = 512 var n_len: Int32 = 512
var n_cur: Int32 = 0 var n_cur: Int32 = 0
@ -21,6 +23,7 @@ actor LlamaContext {
self.context = context self.context = context
self.tokens_list = [] self.tokens_list = []
self.batch = llama_batch_init(512, 0, 1) self.batch = llama_batch_init(512, 0, 1)
self.temporary_invalid_cchars = []
} }
deinit { deinit {
@ -61,6 +64,7 @@ actor LlamaContext {
print("attempting to complete \"\(text)\"") print("attempting to complete \"\(text)\"")
tokens_list = tokenize(text: text, add_bos: true) tokens_list = tokenize(text: text, add_bos: true)
temporary_invalid_cchars = []
let n_ctx = llama_n_ctx(context) let n_ctx = llama_n_ctx(context)
let n_kv_req = tokens_list.count + (Int(n_len) - tokens_list.count) let n_kv_req = tokens_list.count + (Int(n_len) - tokens_list.count)
@ -72,7 +76,7 @@ actor LlamaContext {
} }
for id in tokens_list { for id in tokens_list {
print(token_to_piece(token: id)) print(String(cString: token_to_piece(token: id) + [0]))
} }
// batch = llama_batch_init(512, 0) // done in init() // batch = llama_batch_init(512, 0) // done in init()
@ -115,10 +119,25 @@ actor LlamaContext {
if new_token_id == llama_token_eos(context) || n_cur == n_len { if new_token_id == llama_token_eos(context) || n_cur == n_len {
print("\n") print("\n")
return "" let new_token_str = String(cString: temporary_invalid_cchars + [0])
temporary_invalid_cchars.removeAll()
return new_token_str
} }
let new_token_str = token_to_piece(token: new_token_id) let new_token_cchars = token_to_piece(token: new_token_id)
temporary_invalid_cchars.append(contentsOf: new_token_cchars)
let new_token_str: String
if let string = String(validatingUTF8: temporary_invalid_cchars + [0]) {
temporary_invalid_cchars.removeAll()
new_token_str = string
} else if (0 ..< temporary_invalid_cchars.count).contains(where: {$0 != 0 && String(validatingUTF8: Array(temporary_invalid_cchars.suffix($0)) + [0]) != nil}) {
// in this case, at least the suffix of the temporary_invalid_cchars can be interpreted as UTF8 string
let string = String(cString: temporary_invalid_cchars + [0])
temporary_invalid_cchars.removeAll()
new_token_str = string
} else {
new_token_str = ""
}
print(new_token_str) print(new_token_str)
// tokens_list.append(new_token_id) // tokens_list.append(new_token_id)
@ -144,6 +163,7 @@ actor LlamaContext {
func clear() { func clear() {
tokens_list.removeAll() tokens_list.removeAll()
temporary_invalid_cchars.removeAll()
} }
private func tokenize(text: String, add_bos: Bool) -> [llama_token] { private func tokenize(text: String, add_bos: Bool) -> [llama_token] {
@ -162,7 +182,8 @@ actor LlamaContext {
return swiftTokens return swiftTokens
} }
private func token_to_piece(token: llama_token) -> String { /// - note: The result does not contain null-terminator
private func token_to_piece(token: llama_token) -> [CChar] {
let result = UnsafeMutablePointer<Int8>.allocate(capacity: 8) let result = UnsafeMutablePointer<Int8>.allocate(capacity: 8)
result.initialize(repeating: Int8(0), count: 8) result.initialize(repeating: Int8(0), count: 8)
defer { defer {
@ -176,10 +197,12 @@ actor LlamaContext {
defer { defer {
newResult.deallocate() newResult.deallocate()
} }
_ = llama_token_to_piece(model, token, newResult, -nTokens) let nNewTokens = llama_token_to_piece(model, token, newResult, -nTokens)
return String(cString: newResult) let bufferPointer = UnsafeBufferPointer(start: newResult, count: Int(nNewTokens))
return Array(bufferPointer)
} else { } else {
return String(cString: result) let bufferPointer = UnsafeBufferPointer(start: result, count: Int(nTokens))
return Array(bufferPointer)
} }
} }
} }