mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-31 22:04:35 +00:00
sampling : avoid expensive softmax during greedy sampling
ggml-ci
This commit is contained in:
parent
37f8c7b4c9
commit
8241bc71b5
@ -209,7 +209,10 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
|
|||||||
GGML_ASSERT(false && "unknown mirostat version");
|
GGML_ASSERT(false && "unknown mirostat version");
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
if (params.n_probs > 0) {
|
||||||
|
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k(params.n_probs));
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_softmax());
|
llama_sampler_chain_add(result->chain, llama_sampler_init_softmax());
|
||||||
|
}
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_greedy());
|
llama_sampler_chain_add(result->chain, llama_sampler_init_greedy());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1066,6 +1066,7 @@ extern "C" {
|
|||||||
LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed);
|
LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed);
|
||||||
|
|
||||||
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
|
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
|
||||||
|
/// NOTE: Avoid using on the full vocabulary as the sorting can become slow. For example, apply top-k or top-p sampling first.
|
||||||
LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void);
|
LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void);
|
||||||
|
|
||||||
/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
||||||
|
@ -3,13 +3,14 @@
|
|||||||
#include "llama-vocab.h"
|
#include "llama-vocab.h"
|
||||||
#include "llama-grammar.h"
|
#include "llama-grammar.h"
|
||||||
|
|
||||||
#include <cassert>
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cstring>
|
#include <cassert>
|
||||||
#include <ctime>
|
|
||||||
#include <cfloat>
|
#include <cfloat>
|
||||||
#include <chrono>
|
#include <chrono>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <cstring>
|
||||||
|
#include <ctime>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <random>
|
#include <random>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
#include "llama-sampling.h"
|
|
||||||
|
|
||||||
#ifdef NDEBUG
|
#ifdef NDEBUG
|
||||||
#undef NDEBUG
|
#undef NDEBUG
|
||||||
@ -249,6 +248,45 @@ static void test_sampler_queue(const size_t n_vocab, const std::string & sampler
|
|||||||
samplers_sequence.c_str(), n_vocab, top_k, top_p, min_p);
|
samplers_sequence.c_str(), n_vocab, top_k, top_p, min_p);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define BENCH(__cnstr, __data, __n_iter) do { \
|
||||||
|
auto * cnstr = (__cnstr); \
|
||||||
|
std::vector<llama_token_data> cur((__data).size()); \
|
||||||
|
std::copy((__data).begin(), (__data).end(), cur.begin()); \
|
||||||
|
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; \
|
||||||
|
llama_sampler_apply(cnstr, &cur_p); \
|
||||||
|
llama_sampler_reset(cnstr); \
|
||||||
|
const int64_t t_start = ggml_time_us(); \
|
||||||
|
const int n_iter = (__n_iter); \
|
||||||
|
for (int i = 0; i < n_iter; i++) { \
|
||||||
|
std::copy((__data).begin(), (__data).end(), cur.begin()); \
|
||||||
|
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; \
|
||||||
|
llama_sampler_apply(cnstr, &cur_p); \
|
||||||
|
llama_sampler_reset(cnstr); \
|
||||||
|
} \
|
||||||
|
const int64_t t_end = ggml_time_us(); \
|
||||||
|
llama_sampler_free(cnstr); \
|
||||||
|
printf("%-42s: %8.3f us/iter\n", #__cnstr, (t_end - t_start) / (float)n_iter); \
|
||||||
|
} while(0)
|
||||||
|
|
||||||
|
static void test_perf() {
|
||||||
|
const int n_vocab = 1 << 17;
|
||||||
|
|
||||||
|
std::vector<llama_token_data> data;
|
||||||
|
|
||||||
|
data.reserve(n_vocab);
|
||||||
|
for (int i = 0; i < n_vocab; i++) {
|
||||||
|
const float logit = 2.0f*((float)(rand())/RAND_MAX - 0.5f);
|
||||||
|
data.emplace_back(llama_token_data{i, logit, 0.0f});
|
||||||
|
}
|
||||||
|
|
||||||
|
BENCH(llama_sampler_init_top_k (40), data, 32);
|
||||||
|
BENCH(llama_sampler_init_top_p (0.8f, 1), data, 32);
|
||||||
|
BENCH(llama_sampler_init_min_p (0.2f, 1), data, 32);
|
||||||
|
BENCH(llama_sampler_init_tail_free(0.5f, 1), data, 32);
|
||||||
|
BENCH(llama_sampler_init_typical (0.5f, 1), data, 32);
|
||||||
|
BENCH(llama_sampler_init_softmax (), data, 32);
|
||||||
|
}
|
||||||
|
|
||||||
int main(void) {
|
int main(void) {
|
||||||
ggml_time_init();
|
ggml_time_init();
|
||||||
|
|
||||||
@ -316,5 +354,7 @@ int main(void) {
|
|||||||
|
|
||||||
printf("OK\n");
|
printf("OK\n");
|
||||||
|
|
||||||
|
test_perf();
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user