mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 19:04:35 +00:00
llama-bench : use random tokens to improve accuracy with mixtral (#6069)
This commit is contained in:
parent
4755afd1cb
commit
b0bc9f4a9d
@ -8,6 +8,7 @@
|
|||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <ctime>
|
#include <ctime>
|
||||||
|
#include <cstdlib>
|
||||||
#include <iterator>
|
#include <iterator>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
@ -1123,15 +1124,19 @@ struct sql_printer : public printer {
|
|||||||
static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_batch, int n_threads) {
|
static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_batch, int n_threads) {
|
||||||
llama_set_n_threads(ctx, n_threads, n_threads);
|
llama_set_n_threads(ctx, n_threads, n_threads);
|
||||||
|
|
||||||
//std::vector<llama_token> tokens(n_prompt, llama_token_bos(llama_get_model(ctx)));
|
const llama_model * model = llama_get_model(ctx);
|
||||||
//llama_decode(ctx, llama_batch_get_one(tokens.data(), n_prompt, n_past, 0));
|
const int32_t n_vocab = llama_n_vocab(model);
|
||||||
//GGML_UNUSED(n_batch);
|
|
||||||
|
std::vector<llama_token> tokens(n_batch);
|
||||||
|
|
||||||
std::vector<llama_token> tokens(n_batch, llama_token_bos(llama_get_model(ctx)));
|
|
||||||
int n_processed = 0;
|
int n_processed = 0;
|
||||||
|
|
||||||
while (n_processed < n_prompt) {
|
while (n_processed < n_prompt) {
|
||||||
int n_tokens = std::min(n_prompt - n_processed, n_batch);
|
int n_tokens = std::min(n_prompt - n_processed, n_batch);
|
||||||
|
tokens[0] = n_processed == 0 && llama_add_bos_token(model) ? llama_token_bos(model) : std::rand() % n_vocab;
|
||||||
|
for (int i = 1; i < n_tokens; i++) {
|
||||||
|
tokens[i] = std::rand() % n_vocab;
|
||||||
|
}
|
||||||
llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens, n_past + n_processed, 0));
|
llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens, n_past + n_processed, 0));
|
||||||
n_processed += n_tokens;
|
n_processed += n_tokens;
|
||||||
}
|
}
|
||||||
@ -1142,11 +1147,15 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_bat
|
|||||||
static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) {
|
static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) {
|
||||||
llama_set_n_threads(ctx, n_threads, n_threads);
|
llama_set_n_threads(ctx, n_threads, n_threads);
|
||||||
|
|
||||||
llama_token token = llama_token_bos(llama_get_model(ctx));
|
const llama_model * model = llama_get_model(ctx);
|
||||||
|
const int32_t n_vocab = llama_n_vocab(model);
|
||||||
|
|
||||||
|
llama_token token = llama_add_bos_token(model) ? llama_token_bos(model) : std::rand() % n_vocab;
|
||||||
|
|
||||||
for (int i = 0; i < n_gen; i++) {
|
for (int i = 0; i < n_gen; i++) {
|
||||||
llama_decode(ctx, llama_batch_get_one(&token, 1, n_past + i, 0));
|
llama_decode(ctx, llama_batch_get_one(&token, 1, n_past + i, 0));
|
||||||
llama_synchronize(ctx);
|
llama_synchronize(ctx);
|
||||||
|
token = std::rand() % n_vocab;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user