llama.swiftui : improve bench

This commit is contained in:
Georgi Gerganov 2023-12-17 19:37:22 +02:00
parent 5c5bdba605
commit 865066621b
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
4 changed files with 110 additions and 47 deletions

View File

@ -27,6 +27,7 @@ actor LlamaContext {
private var context: OpaquePointer
private var batch: llama_batch
private var tokens_list: [llama_token]
/// This variable is used to store temporarily invalid cchars
private var temporary_invalid_cchars: [CChar]
@ -195,62 +196,100 @@ actor LlamaContext {
return new_token_str
}
func bench() -> String {
let pp = 512
let tg = 128
let pl = 1
func bench(pp: Int, tg: Int, pl: Int, nr: Int = 1) -> String {
var pp_avg: Double = 0
var tg_avg: Double = 0
// bench prompt processing
var pp_std: Double = 0
var tg_std: Double = 0
llama_batch_clear(&batch)
for r in 0..<nr {
// bench prompt processing
let n_tokens = pp
for i in 0..<n_tokens {
llama_batch_add(&batch, 0, Int32(i), [0], false)
}
batch.logits[Int(batch.n_tokens) - 1] = 1 // true
llama_kv_cache_clear(context)
let t_pp_start = ggml_time_us()
if llama_decode(context, batch) != 0 {
print("llama_decode() failed during prompt")
}
let t_pp_end = ggml_time_us()
// bench text generation
llama_kv_cache_clear(context)
let t_tg_start = ggml_time_us()
for i in 0..<tg {
llama_batch_clear(&batch)
for j in 0..<pl {
llama_batch_add(&batch, 0, Int32(i), [Int32(j)], true)
let n_tokens = pp
for i in 0..<n_tokens {
llama_batch_add(&batch, 0, Int32(i), [0], false)
}
batch.logits[Int(batch.n_tokens) - 1] = 1 // true
llama_kv_cache_clear(context)
let t_pp_start = ggml_time_us()
if llama_decode(context, batch) != 0 {
print("llama_decode() failed during text generation")
print("llama_decode() failed during prompt")
}
let t_pp_end = ggml_time_us()
// bench text generation
llama_kv_cache_clear(context)
let t_tg_start = ggml_time_us()
for i in 0..<tg {
llama_batch_clear(&batch)
for j in 0..<pl {
llama_batch_add(&batch, 0, Int32(i), [Int32(j)], true)
}
if llama_decode(context, batch) != 0 {
print("llama_decode() failed during text generation")
}
}
let t_tg_end = ggml_time_us()
llama_kv_cache_clear(context)
let t_pp = Double(t_pp_end - t_pp_start) / 1000000.0
let t_tg = Double(t_tg_end - t_tg_start) / 1000000.0
let speed_pp = Double(pp) / t_pp
let speed_tg = Double(pl*tg) / t_tg
pp_avg += speed_pp
tg_avg += speed_tg
pp_std += speed_pp * speed_pp
tg_std += speed_tg * speed_tg
print("pp \(speed_pp) t/s, tg \(speed_tg) t/s")
}
let t_tg_end = ggml_time_us()
pp_avg /= Double(nr)
tg_avg /= Double(nr)
llama_kv_cache_clear(context)
if nr > 1 {
pp_std = sqrt(pp_std / Double(nr - 1) - pp_avg * pp_avg * Double(nr) / Double(nr - 1))
tg_std = sqrt(tg_std / Double(nr - 1) - tg_avg * tg_avg * Double(nr) / Double(nr - 1))
} else {
pp_std = 0
tg_std = 0
}
let t_pp = Double(t_pp_end - t_pp_start) / 1000000.0
let t_tg = Double(t_tg_end - t_tg_start) / 1000000.0
let model_desc = model_info();
let model_size = String(format: "%.2f GiB", Double(llama_model_size(model)) / 1024.0 / 1024.0 / 1024.0);
let model_n_params = String(format: "%.2f B", Double(llama_model_n_params(model)) / 1e9);
let backend = "Metal";
let pp_avg_str = String(format: "%.2f", pp_avg);
let tg_avg_str = String(format: "%.2f", tg_avg);
let pp_std_str = String(format: "%.2f", pp_std);
let tg_std_str = String(format: "%.2f", tg_std);
let speed_pp = Double(pp) / t_pp
let speed_tg = Double(pl*tg) / t_tg
var result = ""
return String(format: "PP 512 speed: %7.2f t/s\n", speed_pp) +
String(format: "TG 128 speed: %7.2f t/s\n", speed_tg)
result += String("| model | size | params | backend | test | t/s |\n")
result += String("| --- | --- | --- | --- | --- | --- |\n")
result += String("| \(model_desc) | \(model_size) | \(model_n_params) | \(backend) | pp \(pp) | \(pp_avg_str) ± \(pp_std_str) |\n")
result += String("| \(model_desc) | \(model_size) | \(model_n_params) | \(backend) | tg \(tg) | \(tg_avg_str) ± \(tg_std_str) |\n")
return result;
}
func clear() {

View File

@ -7,11 +7,11 @@
objects = {
/* Begin PBXBuildFile section */
542376082B0D9BFB008E6A1C /* ggml-quants.c in Sources */ = {isa = PBXBuildFile; fileRef = 542376072B0D9BFB008E6A1C /* ggml-quants.c */; };
5423760B2B0D9C4B008E6A1C /* ggml-backend.c in Sources */ = {isa = PBXBuildFile; fileRef = 5423760A2B0D9C4B008E6A1C /* ggml-backend.c */; };
542376082B0D9BFB008E6A1C /* ggml-quants.c in Sources */ = {isa = PBXBuildFile; fileRef = 542376072B0D9BFB008E6A1C /* ggml-quants.c */; settings = {COMPILER_FLAGS = "-O3"; }; };
5423760B2B0D9C4B008E6A1C /* ggml-backend.c in Sources */ = {isa = PBXBuildFile; fileRef = 5423760A2B0D9C4B008E6A1C /* ggml-backend.c */; settings = {COMPILER_FLAGS = "-O3"; }; };
542378792ACE3F3500834A7B /* ggml-metal.metal in Resources */ = {isa = PBXBuildFile; fileRef = 549479C82AC9E10B00E0F78B /* ggml-metal.metal */; };
542EA09D2AC8723900A8AEE9 /* ggml.c in Sources */ = {isa = PBXBuildFile; fileRef = 542EA09B2AC8723900A8AEE9 /* ggml.c */; settings = {COMPILER_FLAGS = "-DGGML_USE_ACCELERATE -DGGML_USE_METAL -DGGML_USE_K_QUANTS -O3"; }; };
542EA0A02AC8725700A8AEE9 /* ggml-alloc.c in Sources */ = {isa = PBXBuildFile; fileRef = 542EA09F2AC8725700A8AEE9 /* ggml-alloc.c */; };
542EA0A02AC8725700A8AEE9 /* ggml-alloc.c in Sources */ = {isa = PBXBuildFile; fileRef = 542EA09F2AC8725700A8AEE9 /* ggml-alloc.c */; settings = {COMPILER_FLAGS = "-O3"; }; };
542EA0A32AC8729100A8AEE9 /* llama.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 542EA0A12AC8729100A8AEE9 /* llama.cpp */; settings = {COMPILER_FLAGS = "-DGGML_USE_K_QUANTS -DGGML_USE_METAL -O3"; }; };
549479CB2AC9E16000E0F78B /* Metal.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 549479CA2AC9E16000E0F78B /* Metal.framework */; };
549479CD2AC9E42A00E0F78B /* ggml-metal.m in Sources */ = {isa = PBXBuildFile; fileRef = 549479C52AC9E0F200E0F78B /* ggml-metal.m */; settings = {COMPILER_FLAGS = "-fno-objc-arc -DGGML_SWIFT -DGGML_USE_METAL -O3"; }; };

View File

@ -50,12 +50,28 @@ class LlamaState: ObservableObject {
return
}
messageLog += "\n"
messageLog += "Running benchmark...\n"
messageLog += "Model info: "
messageLog += await llamaContext.model_info() + "\n"
messageLog += "Running benchmark...\n"
await llamaContext.bench() // heat up
let result = await llamaContext.bench()
let t_start = DispatchTime.now().uptimeNanoseconds
await llamaContext.bench(pp: 8, tg: 4, pl: 1) // heat up
let t_end = DispatchTime.now().uptimeNanoseconds
let t_heat = Double(t_end - t_start) / 1_000_000_000.0
messageLog += "Heat up time: \(t_heat) seconds, please wait...\n"
// if more than 5 seconds, then we're probably running on a slow device
if t_heat > 5.0 {
messageLog += "Heat up time is too long, aborting benchmark\n"
return
}
let result = await llamaContext.bench(pp: 512, tg: 128, pl: 1, nr: 3)
messageLog += "\(result)"
messageLog += "\n"
}
func clear() async {

View File

@ -62,6 +62,14 @@ struct ContentView: View {
.background(Color.blue)
.foregroundColor(.white)
.cornerRadius(8)
Button("Copy") {
UIPasteboard.general.string = llamaState.messageLog
}
.padding(8)
.background(Color.blue)
.foregroundColor(.white)
.cornerRadius(8)
}
VStack {