diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index d6e5e0497..32eea7869 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -1123,15 +1124,19 @@ struct sql_printer : public printer { static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_batch, int n_threads) { llama_set_n_threads(ctx, n_threads, n_threads); - //std::vector tokens(n_prompt, llama_token_bos(llama_get_model(ctx))); - //llama_decode(ctx, llama_batch_get_one(tokens.data(), n_prompt, n_past, 0)); - //GGML_UNUSED(n_batch); + const llama_model * model = llama_get_model(ctx); + const int32_t n_vocab = llama_n_vocab(model); + + std::vector tokens(n_batch); - std::vector tokens(n_batch, llama_token_bos(llama_get_model(ctx))); int n_processed = 0; while (n_processed < n_prompt) { int n_tokens = std::min(n_prompt - n_processed, n_batch); + tokens[0] = n_processed == 0 && llama_add_bos_token(model) ? llama_token_bos(model) : std::rand() % n_vocab; + for (int i = 1; i < n_tokens; i++) { + tokens[i] = std::rand() % n_vocab; + } llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens, n_past + n_processed, 0)); n_processed += n_tokens; } @@ -1142,11 +1147,15 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_bat static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) { llama_set_n_threads(ctx, n_threads, n_threads); - llama_token token = llama_token_bos(llama_get_model(ctx)); + const llama_model * model = llama_get_model(ctx); + const int32_t n_vocab = llama_n_vocab(model); + + llama_token token = llama_add_bos_token(model) ? llama_token_bos(model) : std::rand() % n_vocab; for (int i = 0; i < n_gen; i++) { llama_decode(ctx, llama_batch_get_one(&token, 1, n_past + i, 0)); llama_synchronize(ctx); + token = std::rand() % n_vocab; } }