Update tests/test-sampling.cpp
Some checks failed
flake8 Lint / Lint (push) Has been cancelled

Co-authored-by: slaren <slarengh@gmail.com>
This commit is contained in:
Georgi Gerganov 2024-09-23 17:18:12 +03:00 committed by GitHub
parent 3cb33a8e29
commit a5a11bfbc3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -248,25 +248,25 @@ static void test_sampler_queue(const size_t n_vocab, const std::string & sampler
samplers_sequence.c_str(), n_vocab, top_k, top_p, min_p); samplers_sequence.c_str(), n_vocab, top_k, top_p, min_p);
} }
#define BENCH(__cnstr, __data, __n_iter) do { \ static void bench(llama_sampler * cnstr, const char * cnstr_name, const std::vector<llama_token_data> & data, int n_iter) {
auto * cnstr = (__cnstr); \ std::vector<llama_token_data> cur(data.size());
std::vector<llama_token_data> cur((__data).size()); \ std::copy(data.begin(), data.end(), cur.begin());
std::copy((__data).begin(), (__data).end(), cur.begin()); \ llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; \ llama_sampler_apply(cnstr, &cur_p);
llama_sampler_apply(cnstr, &cur_p); \ llama_sampler_reset(cnstr);
llama_sampler_reset(cnstr); \ const int64_t t_start = ggml_time_us();
const int64_t t_start = ggml_time_us(); \ for (int i = 0; i < n_iter; i++) {
const int n_iter = (__n_iter); \ std::copy(data.begin(), data.end(), cur.begin());
for (int i = 0; i < n_iter; i++) { \ llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
std::copy((__data).begin(), (__data).end(), cur.begin()); \ llama_sampler_apply(cnstr, &cur_p);
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; \ llama_sampler_reset(cnstr);
llama_sampler_apply(cnstr, &cur_p); \ }
llama_sampler_reset(cnstr); \ const int64_t t_end = ggml_time_us();
} \ llama_sampler_free(cnstr);
const int64_t t_end = ggml_time_us(); \ printf("%-42s: %8.3f us/iter\n", cnstr_name, (t_end - t_start) / (float)n_iter);
llama_sampler_free(cnstr); \ }
printf("%-42s: %8.3f us/iter\n", #__cnstr, (t_end - t_start) / (float)n_iter); \
} while(0) #define BENCH(__cnstr, __data, __n_iter) bench((__cnstr), #__cnstr, (__data), (__n_iter))
static void test_perf() { static void test_perf() {
const int n_vocab = 1 << 17; const int n_vocab = 1 << 17;