mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-13 12:10:18 +00:00
swift : fix concatenation method to avoid invalid UTF8 stringfication (#4325)
This commit is contained in:
parent
5c9f90cba1
commit
d208995c6d
@ -11,6 +11,8 @@ actor LlamaContext {
|
|||||||
private var context: OpaquePointer
|
private var context: OpaquePointer
|
||||||
private var batch: llama_batch
|
private var batch: llama_batch
|
||||||
private var tokens_list: [llama_token]
|
private var tokens_list: [llama_token]
|
||||||
|
/// This variable is used to store temporarily invalid cchars
|
||||||
|
private var temporary_invalid_cchars: [CChar]
|
||||||
|
|
||||||
var n_len: Int32 = 512
|
var n_len: Int32 = 512
|
||||||
var n_cur: Int32 = 0
|
var n_cur: Int32 = 0
|
||||||
@ -21,6 +23,7 @@ actor LlamaContext {
|
|||||||
self.context = context
|
self.context = context
|
||||||
self.tokens_list = []
|
self.tokens_list = []
|
||||||
self.batch = llama_batch_init(512, 0, 1)
|
self.batch = llama_batch_init(512, 0, 1)
|
||||||
|
self.temporary_invalid_cchars = []
|
||||||
}
|
}
|
||||||
|
|
||||||
deinit {
|
deinit {
|
||||||
@ -61,6 +64,7 @@ actor LlamaContext {
|
|||||||
print("attempting to complete \"\(text)\"")
|
print("attempting to complete \"\(text)\"")
|
||||||
|
|
||||||
tokens_list = tokenize(text: text, add_bos: true)
|
tokens_list = tokenize(text: text, add_bos: true)
|
||||||
|
temporary_invalid_cchars = []
|
||||||
|
|
||||||
let n_ctx = llama_n_ctx(context)
|
let n_ctx = llama_n_ctx(context)
|
||||||
let n_kv_req = tokens_list.count + (Int(n_len) - tokens_list.count)
|
let n_kv_req = tokens_list.count + (Int(n_len) - tokens_list.count)
|
||||||
@ -72,7 +76,7 @@ actor LlamaContext {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for id in tokens_list {
|
for id in tokens_list {
|
||||||
print(token_to_piece(token: id))
|
print(String(cString: token_to_piece(token: id) + [0]))
|
||||||
}
|
}
|
||||||
|
|
||||||
// batch = llama_batch_init(512, 0) // done in init()
|
// batch = llama_batch_init(512, 0) // done in init()
|
||||||
@ -115,10 +119,25 @@ actor LlamaContext {
|
|||||||
|
|
||||||
if new_token_id == llama_token_eos(context) || n_cur == n_len {
|
if new_token_id == llama_token_eos(context) || n_cur == n_len {
|
||||||
print("\n")
|
print("\n")
|
||||||
return ""
|
let new_token_str = String(cString: temporary_invalid_cchars + [0])
|
||||||
|
temporary_invalid_cchars.removeAll()
|
||||||
|
return new_token_str
|
||||||
}
|
}
|
||||||
|
|
||||||
let new_token_str = token_to_piece(token: new_token_id)
|
let new_token_cchars = token_to_piece(token: new_token_id)
|
||||||
|
temporary_invalid_cchars.append(contentsOf: new_token_cchars)
|
||||||
|
let new_token_str: String
|
||||||
|
if let string = String(validatingUTF8: temporary_invalid_cchars + [0]) {
|
||||||
|
temporary_invalid_cchars.removeAll()
|
||||||
|
new_token_str = string
|
||||||
|
} else if (0 ..< temporary_invalid_cchars.count).contains(where: {$0 != 0 && String(validatingUTF8: Array(temporary_invalid_cchars.suffix($0)) + [0]) != nil}) {
|
||||||
|
// in this case, at least the suffix of the temporary_invalid_cchars can be interpreted as UTF8 string
|
||||||
|
let string = String(cString: temporary_invalid_cchars + [0])
|
||||||
|
temporary_invalid_cchars.removeAll()
|
||||||
|
new_token_str = string
|
||||||
|
} else {
|
||||||
|
new_token_str = ""
|
||||||
|
}
|
||||||
print(new_token_str)
|
print(new_token_str)
|
||||||
// tokens_list.append(new_token_id)
|
// tokens_list.append(new_token_id)
|
||||||
|
|
||||||
@ -144,6 +163,7 @@ actor LlamaContext {
|
|||||||
|
|
||||||
func clear() {
|
func clear() {
|
||||||
tokens_list.removeAll()
|
tokens_list.removeAll()
|
||||||
|
temporary_invalid_cchars.removeAll()
|
||||||
}
|
}
|
||||||
|
|
||||||
private func tokenize(text: String, add_bos: Bool) -> [llama_token] {
|
private func tokenize(text: String, add_bos: Bool) -> [llama_token] {
|
||||||
@ -162,7 +182,8 @@ actor LlamaContext {
|
|||||||
return swiftTokens
|
return swiftTokens
|
||||||
}
|
}
|
||||||
|
|
||||||
private func token_to_piece(token: llama_token) -> String {
|
/// - note: The result does not contain null-terminator
|
||||||
|
private func token_to_piece(token: llama_token) -> [CChar] {
|
||||||
let result = UnsafeMutablePointer<Int8>.allocate(capacity: 8)
|
let result = UnsafeMutablePointer<Int8>.allocate(capacity: 8)
|
||||||
result.initialize(repeating: Int8(0), count: 8)
|
result.initialize(repeating: Int8(0), count: 8)
|
||||||
defer {
|
defer {
|
||||||
@ -176,10 +197,12 @@ actor LlamaContext {
|
|||||||
defer {
|
defer {
|
||||||
newResult.deallocate()
|
newResult.deallocate()
|
||||||
}
|
}
|
||||||
_ = llama_token_to_piece(model, token, newResult, -nTokens)
|
let nNewTokens = llama_token_to_piece(model, token, newResult, -nTokens)
|
||||||
return String(cString: newResult)
|
let bufferPointer = UnsafeBufferPointer(start: newResult, count: Int(nNewTokens))
|
||||||
|
return Array(bufferPointer)
|
||||||
} else {
|
} else {
|
||||||
return String(cString: result)
|
let bufferPointer = UnsafeBufferPointer(start: result, count: Int(nTokens))
|
||||||
|
return Array(bufferPointer)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user