lookup: add print for drafting performance (#5450)

This commit is contained in:
Johannes Gäßler 2024-02-11 12:44:51 +01:00 committed by GitHub
parent 907e08c110
commit e4640d8fdf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,7 +1,9 @@
#include "common.h" #include "common.h"
#include "ggml.h"
#include "llama.h" #include "llama.h"
#include <cmath> #include <cmath>
#include <cstdint>
#include <cstdio> #include <cstdio>
#include <string> #include <string>
#include <vector> #include <vector>
@ -73,6 +75,8 @@ int main(int argc, char ** argv){
int n_drafted = 0; int n_drafted = 0;
int n_accept = 0; int n_accept = 0;
int64_t t_draft_us = 0;
int n_past = inp.size(); int n_past = inp.size();
bool has_eos = false; bool has_eos = false;
@ -160,7 +164,7 @@ int main(int argc, char ** argv){
// generate n_pred tokens through prompt lookup // generate n_pred tokens through prompt lookup
auto prompt_lookup = [&]() -> void { auto prompt_lookup = [&]() -> void {
int inp_size = inp.size(); const int inp_size = inp.size();
for (int ngram_size = ngram_max ; ngram_size > ngram_min; --ngram_size){ for (int ngram_size = ngram_max ; ngram_size > ngram_min; --ngram_size){
const llama_token * ngram = &inp[inp_size - ngram_size]; const llama_token * ngram = &inp[inp_size - ngram_size];
@ -191,8 +195,12 @@ int main(int argc, char ** argv){
return; return;
}; };
const int64_t t_start_draft_us = ggml_time_us();
prompt_lookup(); prompt_lookup();
t_draft_us += ggml_time_us() - t_start_draft_us;
llama_decode(ctx, batch_tgt); llama_decode(ctx, batch_tgt);
++n_past; ++n_past;
@ -210,6 +218,8 @@ int main(int argc, char ** argv){
LOG_TEE("n_draft = %d\n", n_draft); LOG_TEE("n_draft = %d\n", n_draft);
LOG_TEE("n_predict = %d\n", n_predict); LOG_TEE("n_predict = %d\n", n_predict);
LOG_TEE("n_drafted = %d\n", n_drafted); LOG_TEE("n_drafted = %d\n", n_drafted);
LOG_TEE("t_draft = %.2f ms, %.2f us per token, %.2f tokens per second\n",
t_draft_us*1e-3, 1.0f*t_draft_us/n_drafted, n_drafted/(1e-6*t_draft_us));
LOG_TEE("n_accept = %d\n", n_accept); LOG_TEE("n_accept = %d\n", n_accept);
LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted); LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted);