llama-bench : add time-to-first-byte stat
Some checks failed
Python check requirements.txt / check-requirements (push) Has been cancelled
flake8 Lint / Lint (push) Has been cancelled
Python Type-Check / pyright type-check (push) Has been cancelled

This commit is contained in:
Georgi Gerganov 2024-09-19 09:15:29 +03:00
parent afd9909a64
commit bc82fc2ed8
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -929,6 +929,13 @@ struct test {
return ts; return ts;
} }
std::vector<double> get_ttfb() const {
int n_tokens = n_prompt + n_gen;
std::vector<double> ts;
std::transform(samples_ns.begin(), samples_ns.end(), std::back_inserter(ts), [n_tokens](uint64_t t) { return t/1e6; });
return ts;
}
double avg_ts() const { double avg_ts() const {
return ::avg(get_ts()); return ::avg(get_ts());
} }
@ -937,6 +944,14 @@ struct test {
return ::stdev(get_ts()); return ::stdev(get_ts());
} }
double avg_ttfb() const {
return ::avg(get_ttfb());
}
double stdev_ttfb() const {
return ::stdev(get_ttfb());
}
static std::string get_backend() { static std::string get_backend() {
if (cuda) { if (cuda) {
return GGML_CUDA_NAME; return GGML_CUDA_NAME;
@ -1187,6 +1202,9 @@ struct markdown_printer : public printer {
if (field == "model") { if (field == "model") {
return -30; return -30;
} }
if (field == "ttfb") {
return 30;
}
if (field == "t/s") { if (field == "t/s") {
return 20; return 20;
} }
@ -1314,6 +1332,7 @@ struct markdown_printer : public printer {
} }
fields.emplace_back("test"); fields.emplace_back("test");
fields.emplace_back("t/s"); fields.emplace_back("t/s");
fields.emplace_back("ttfb");
fprintf(fout, "|"); fprintf(fout, "|");
for (const auto & field : fields) { for (const auto & field : fields) {
@ -1368,6 +1387,9 @@ struct markdown_printer : public printer {
} else if (field == "t/s") { } else if (field == "t/s") {
snprintf(buf, sizeof(buf), "%.2f ± %.2f", t.avg_ts(), t.stdev_ts()); snprintf(buf, sizeof(buf), "%.2f ± %.2f", t.avg_ts(), t.stdev_ts());
value = buf; value = buf;
} else if (field == "ttfb") {
snprintf(buf, sizeof(buf), "%.2f ± %.2f", t.avg_ttfb(), t.stdev_ttfb());
value = buf;
} else if (vmap.find(field) != vmap.end()) { } else if (vmap.find(field) != vmap.end()) {
value = vmap.at(field); value = vmap.at(field);
} else { } else {
@ -1376,7 +1398,7 @@ struct markdown_printer : public printer {
} }
int width = get_field_width(field); int width = get_field_width(field);
if (field == "t/s") { if (field == "t/s" || field == "ttfb") {
// HACK: the utf-8 character is 2 bytes // HACK: the utf-8 character is 2 bytes
width += 1; width += 1;
} }