Update tests/test-sampling.cpp
Some checks failed
flake8 Lint / Lint (push) Has been cancelled

Co-authored-by: slaren <slarengh@gmail.com>
This commit is contained in:
Georgi Gerganov 2024-09-23 17:18:12 +03:00 committed by GitHub
parent 3cb33a8e29
commit a5a11bfbc3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -248,25 +248,25 @@ static void test_sampler_queue(const size_t n_vocab, const std::string & sampler
samplers_sequence.c_str(), n_vocab, top_k, top_p, min_p);
}
#define BENCH(__cnstr, __data, __n_iter) do { \
auto * cnstr = (__cnstr); \
std::vector<llama_token_data> cur((__data).size()); \
std::copy((__data).begin(), (__data).end(), cur.begin()); \
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; \
llama_sampler_apply(cnstr, &cur_p); \
llama_sampler_reset(cnstr); \
const int64_t t_start = ggml_time_us(); \
const int n_iter = (__n_iter); \
for (int i = 0; i < n_iter; i++) { \
std::copy((__data).begin(), (__data).end(), cur.begin()); \
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; \
llama_sampler_apply(cnstr, &cur_p); \
llama_sampler_reset(cnstr); \
} \
const int64_t t_end = ggml_time_us(); \
llama_sampler_free(cnstr); \
printf("%-42s: %8.3f us/iter\n", #__cnstr, (t_end - t_start) / (float)n_iter); \
} while(0)
static void bench(llama_sampler * cnstr, const char * cnstr_name, const std::vector<llama_token_data> & data, int n_iter) {
std::vector<llama_token_data> cur(data.size());
std::copy(data.begin(), data.end(), cur.begin());
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
llama_sampler_apply(cnstr, &cur_p);
llama_sampler_reset(cnstr);
const int64_t t_start = ggml_time_us();
for (int i = 0; i < n_iter; i++) {
std::copy(data.begin(), data.end(), cur.begin());
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
llama_sampler_apply(cnstr, &cur_p);
llama_sampler_reset(cnstr);
}
const int64_t t_end = ggml_time_us();
llama_sampler_free(cnstr);
printf("%-42s: %8.3f us/iter\n", cnstr_name, (t_end - t_start) / (float)n_iter);
}
#define BENCH(__cnstr, __data, __n_iter) bench((__cnstr), #__cnstr, (__data), (__n_iter))
static void test_perf() {
const int n_vocab = 1 << 17;