mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-10 02:31:46 +00:00
llama : deprecate softmax sampler + fix dist sampler
ggml-ci
This commit is contained in:
parent
3752217ed5
commit
e31c8790ff
@ -203,7 +203,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||||||
GGML_ASSERT(false && "unknown sampler type");
|
GGML_ASSERT(false && "unknown sampler type");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_softmax());
|
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
|
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
|
||||||
} else if (params.mirostat == 1) {
|
} else if (params.mirostat == 1) {
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
|
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
|
||||||
@ -222,7 +221,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||||||
// the following will not produce exactly the same probs as applyging softmax to the full vocabulary, but
|
// the following will not produce exactly the same probs as applyging softmax to the full vocabulary, but
|
||||||
// it is much faster, since we avoid sorting all tokens and should give a good approximation
|
// it is much faster, since we avoid sorting all tokens and should give a good approximation
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k(params.n_probs));
|
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k(params.n_probs));
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_softmax());
|
|
||||||
}
|
}
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_greedy());
|
llama_sampler_chain_add(result->chain, llama_sampler_init_greedy());
|
||||||
}
|
}
|
||||||
|
@ -46,7 +46,6 @@ actor LlamaContext {
|
|||||||
let sparams = llama_sampler_chain_default_params()
|
let sparams = llama_sampler_chain_default_params()
|
||||||
self.sampling = llama_sampler_chain_init(sparams)
|
self.sampling = llama_sampler_chain_init(sparams)
|
||||||
llama_sampler_chain_add(self.sampling, llama_sampler_init_temp(0.4))
|
llama_sampler_chain_add(self.sampling, llama_sampler_init_temp(0.4))
|
||||||
llama_sampler_chain_add(self.sampling, llama_sampler_init_softmax())
|
|
||||||
llama_sampler_chain_add(self.sampling, llama_sampler_init_dist(1234))
|
llama_sampler_chain_add(self.sampling, llama_sampler_init_dist(1234))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -42,7 +42,6 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
llama_sampler * smpl = llama_sampler_chain_init(sparams);
|
llama_sampler * smpl = llama_sampler_chain_init(sparams);
|
||||||
|
|
||||||
llama_sampler_chain_add(smpl, llama_sampler_init_softmax());
|
|
||||||
llama_sampler_chain_add(smpl, llama_sampler_init_dist(params.sparams.seed));
|
llama_sampler_chain_add(smpl, llama_sampler_init_dist(params.sparams.seed));
|
||||||
|
|
||||||
// tokenize prompt
|
// tokenize prompt
|
||||||
@ -96,7 +95,6 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
llama_sampler * smpl2 = llama_sampler_chain_init(sparams);
|
llama_sampler * smpl2 = llama_sampler_chain_init(sparams);
|
||||||
|
|
||||||
llama_sampler_chain_add(smpl2, llama_sampler_init_softmax());
|
|
||||||
llama_sampler_chain_add(smpl2, llama_sampler_init_dist(params.sparams.seed));
|
llama_sampler_chain_add(smpl2, llama_sampler_init_dist(params.sparams.seed));
|
||||||
|
|
||||||
printf("\nsecond run: %s", params.prompt.c_str());
|
printf("\nsecond run: %s", params.prompt.c_str());
|
||||||
@ -156,7 +154,6 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
llama_sampler * smpl3 = llama_sampler_chain_init(sparams);
|
llama_sampler * smpl3 = llama_sampler_chain_init(sparams);
|
||||||
|
|
||||||
llama_sampler_chain_add(smpl3, llama_sampler_init_softmax());
|
|
||||||
llama_sampler_chain_add(smpl3, llama_sampler_init_dist(params.sparams.seed));
|
llama_sampler_chain_add(smpl3, llama_sampler_init_dist(params.sparams.seed));
|
||||||
|
|
||||||
printf("\nsingle seq run: %s", params.prompt.c_str());
|
printf("\nsingle seq run: %s", params.prompt.c_str());
|
||||||
|
@ -180,8 +180,6 @@ int main(int argc, char ** argv) {
|
|||||||
// target model sampling context (reuse the llama_context's sampling instance)
|
// target model sampling context (reuse the llama_context's sampling instance)
|
||||||
struct common_sampler * smpl = common_sampler_init(model_tgt, params.sparams);
|
struct common_sampler * smpl = common_sampler_init(model_tgt, params.sparams);
|
||||||
|
|
||||||
struct llama_sampler * softmax = llama_sampler_init_softmax();
|
|
||||||
|
|
||||||
// draft sequence data
|
// draft sequence data
|
||||||
std::vector<seq_draft> drafts(n_seq_dft);
|
std::vector<seq_draft> drafts(n_seq_dft);
|
||||||
|
|
||||||
@ -624,7 +622,6 @@ int main(int argc, char ** argv) {
|
|||||||
common_sampler_free(drafts[s].smpl);
|
common_sampler_free(drafts[s].smpl);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_sampler_free(softmax);
|
|
||||||
llama_batch_free(batch_dft);
|
llama_batch_free(batch_dft);
|
||||||
|
|
||||||
llama_free(ctx_tgt);
|
llama_free(ctx_tgt);
|
||||||
|
@ -217,6 +217,7 @@ extern "C" {
|
|||||||
|
|
||||||
typedef struct llama_token_data_array {
|
typedef struct llama_token_data_array {
|
||||||
// TODO: consider SoA
|
// TODO: consider SoA
|
||||||
|
// NOTE: this pointer can be modified by the samplers
|
||||||
llama_token_data * data;
|
llama_token_data * data;
|
||||||
size_t size;
|
size_t size;
|
||||||
int64_t selected; // this is the index in the data array (i.e. not the token id)
|
int64_t selected; // this is the index in the data array (i.e. not the token id)
|
||||||
@ -1086,7 +1087,8 @@ extern "C" {
|
|||||||
|
|
||||||
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
|
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
|
||||||
/// NOTE: Avoid using on the full vocabulary as the sorting can become slow. For example, apply top-k or top-p sampling first.
|
/// NOTE: Avoid using on the full vocabulary as the sorting can become slow. For example, apply top-k or top-p sampling first.
|
||||||
LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void);
|
DEPRECATED(LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void),
|
||||||
|
"will be removed in the future (see https://github.com/ggerganov/llama.cpp/pull/9896#discussion_r1800920915)");
|
||||||
|
|
||||||
/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
||||||
LLAMA_API struct llama_sampler * llama_sampler_init_top_k (int32_t k);
|
LLAMA_API struct llama_sampler * llama_sampler_init_top_k (int32_t k);
|
||||||
|
@ -427,6 +427,9 @@ static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*
|
|||||||
|
|
||||||
static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||||
auto * ctx = (llama_sampler_dist *) smpl->ctx;
|
auto * ctx = (llama_sampler_dist *) smpl->ctx;
|
||||||
|
|
||||||
|
llama_sampler_softmax_impl(cur_p);
|
||||||
|
|
||||||
cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
|
cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -24,20 +24,22 @@ static void dump(const llama_token_data_array * cur_p) {
|
|||||||
llama_sampler_free(cnstr); \
|
llama_sampler_free(cnstr); \
|
||||||
} while(0)
|
} while(0)
|
||||||
|
|
||||||
|
#define CUR_P_FROM_PROBS() \
|
||||||
|
const size_t n_vocab = probs.size(); \
|
||||||
|
std::vector<llama_token_data> cur; \
|
||||||
|
cur.reserve(n_vocab); \
|
||||||
|
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { \
|
||||||
|
const float logit = logf(probs[token_id]); \
|
||||||
|
cur.emplace_back(llama_token_data{token_id, logit, 0.0f}); \
|
||||||
|
} \
|
||||||
|
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }
|
||||||
|
|
||||||
static void test_top_k(const std::vector<float> & probs, const std::vector<float> & expected_probs, int k) {
|
static void test_top_k(const std::vector<float> & probs, const std::vector<float> & expected_probs, int k) {
|
||||||
const size_t n_vocab = probs.size();
|
CUR_P_FROM_PROBS();
|
||||||
|
|
||||||
std::vector<llama_token_data> cur;
|
|
||||||
cur.reserve(n_vocab);
|
|
||||||
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
|
|
||||||
const float logit = logf(probs[token_id]);
|
|
||||||
cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
|
|
||||||
}
|
|
||||||
|
|
||||||
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
|
|
||||||
APPLY(llama_sampler_init_softmax(), &cur_p);
|
|
||||||
DUMP(&cur_p);
|
DUMP(&cur_p);
|
||||||
APPLY(llama_sampler_init_top_k(k), &cur_p);
|
APPLY(llama_sampler_init_top_k(k), &cur_p);
|
||||||
|
APPLY(llama_sampler_init_dist (0), &cur_p);
|
||||||
DUMP(&cur_p);
|
DUMP(&cur_p);
|
||||||
|
|
||||||
GGML_ASSERT(cur_p.size == expected_probs.size());
|
GGML_ASSERT(cur_p.size == expected_probs.size());
|
||||||
@ -47,19 +49,12 @@ static void test_top_k(const std::vector<float> & probs, const std::vector<float
|
|||||||
}
|
}
|
||||||
|
|
||||||
static void test_top_p(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
|
static void test_top_p(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
|
||||||
const size_t n_vocab = probs.size();
|
CUR_P_FROM_PROBS();
|
||||||
|
|
||||||
std::vector<llama_token_data> cur;
|
|
||||||
cur.reserve(n_vocab);
|
|
||||||
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
|
|
||||||
const float logit = logf(probs[token_id]);
|
|
||||||
cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
|
|
||||||
}
|
|
||||||
|
|
||||||
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
|
|
||||||
APPLY(llama_sampler_init_softmax(), &cur_p);
|
|
||||||
DUMP(&cur_p);
|
DUMP(&cur_p);
|
||||||
APPLY(llama_sampler_init_top_p(p, 1), &cur_p);
|
APPLY(llama_sampler_init_top_p(p, 1), &cur_p);
|
||||||
|
APPLY(llama_sampler_init_dist (0), &cur_p);
|
||||||
|
DUMP(&cur_p);
|
||||||
DUMP(&cur_p);
|
DUMP(&cur_p);
|
||||||
|
|
||||||
GGML_ASSERT(cur_p.size == expected_probs.size());
|
GGML_ASSERT(cur_p.size == expected_probs.size());
|
||||||
@ -69,16 +64,8 @@ static void test_top_p(const std::vector<float> & probs, const std::vector<float
|
|||||||
}
|
}
|
||||||
|
|
||||||
static void test_tfs(const std::vector<float> & probs, const std::vector<float> & expected_probs, float z) {
|
static void test_tfs(const std::vector<float> & probs, const std::vector<float> & expected_probs, float z) {
|
||||||
const size_t n_vocab = probs.size();
|
CUR_P_FROM_PROBS();
|
||||||
|
|
||||||
std::vector<llama_token_data> cur;
|
|
||||||
cur.reserve(n_vocab);
|
|
||||||
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
|
|
||||||
const float logit = logf(probs[token_id]);
|
|
||||||
cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
|
|
||||||
}
|
|
||||||
|
|
||||||
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
|
|
||||||
DUMP(&cur_p);
|
DUMP(&cur_p);
|
||||||
APPLY(llama_sampler_init_tail_free(z, 1), &cur_p);
|
APPLY(llama_sampler_init_tail_free(z, 1), &cur_p);
|
||||||
DUMP(&cur_p);
|
DUMP(&cur_p);
|
||||||
@ -90,20 +77,12 @@ static void test_tfs(const std::vector<float> & probs, const std::vector<float>
|
|||||||
}
|
}
|
||||||
|
|
||||||
static void test_min_p(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
|
static void test_min_p(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
|
||||||
const size_t n_vocab = probs.size();
|
CUR_P_FROM_PROBS();
|
||||||
|
|
||||||
std::vector<llama_token_data> cur;
|
|
||||||
cur.reserve(n_vocab);
|
|
||||||
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
|
|
||||||
const float logit = logf(probs[token_id]);
|
|
||||||
cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
|
|
||||||
}
|
|
||||||
|
|
||||||
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
|
|
||||||
DUMP(&cur_p);
|
DUMP(&cur_p);
|
||||||
APPLY(llama_sampler_init_min_p(p, 1), &cur_p);
|
APPLY(llama_sampler_init_min_p(p, 1), &cur_p);
|
||||||
|
APPLY(llama_sampler_init_dist (0), &cur_p);
|
||||||
DUMP(&cur_p);
|
DUMP(&cur_p);
|
||||||
APPLY(llama_sampler_init_softmax(), &cur_p);
|
|
||||||
|
|
||||||
GGML_ASSERT(cur_p.size == expected_probs.size());
|
GGML_ASSERT(cur_p.size == expected_probs.size());
|
||||||
for (size_t i = 0; i < cur_p.size; i++) {
|
for (size_t i = 0; i < cur_p.size; i++) {
|
||||||
@ -112,17 +91,8 @@ static void test_min_p(const std::vector<float> & probs, const std::vector<float
|
|||||||
}
|
}
|
||||||
|
|
||||||
static void test_xtc(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p, float t) {
|
static void test_xtc(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p, float t) {
|
||||||
const size_t n_vocab = probs.size();
|
CUR_P_FROM_PROBS();
|
||||||
|
|
||||||
std::vector<llama_token_data> cur;
|
|
||||||
cur.reserve(n_vocab);
|
|
||||||
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
|
|
||||||
const float logit = logf(probs[token_id]);
|
|
||||||
cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
|
|
||||||
}
|
|
||||||
|
|
||||||
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
|
|
||||||
APPLY(llama_sampler_init_softmax(), &cur_p);
|
|
||||||
DUMP(&cur_p);
|
DUMP(&cur_p);
|
||||||
APPLY(llama_sampler_init_xtc(p, t, 0, 0), &cur_p);
|
APPLY(llama_sampler_init_xtc(p, t, 0, 0), &cur_p);
|
||||||
DUMP(&cur_p);
|
DUMP(&cur_p);
|
||||||
@ -134,16 +104,8 @@ static void test_xtc(const std::vector<float> & probs, const std::vector<float>
|
|||||||
}
|
}
|
||||||
|
|
||||||
static void test_typical(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
|
static void test_typical(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
|
||||||
const size_t n_vocab = probs.size();
|
CUR_P_FROM_PROBS();
|
||||||
|
|
||||||
std::vector<llama_token_data> cur;
|
|
||||||
cur.reserve(n_vocab);
|
|
||||||
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
|
|
||||||
const float logit = logf(probs[token_id]);
|
|
||||||
cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
|
|
||||||
}
|
|
||||||
|
|
||||||
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
|
|
||||||
DUMP(&cur_p);
|
DUMP(&cur_p);
|
||||||
APPLY(llama_sampler_init_typical(p, 1), &cur_p);
|
APPLY(llama_sampler_init_typical(p, 1), &cur_p);
|
||||||
DUMP(&cur_p);
|
DUMP(&cur_p);
|
||||||
@ -160,16 +122,7 @@ static void test_penalties(
|
|||||||
) {
|
) {
|
||||||
GGML_ASSERT(probs.size() == expected_probs.size());
|
GGML_ASSERT(probs.size() == expected_probs.size());
|
||||||
|
|
||||||
const size_t n_vocab = probs.size();
|
CUR_P_FROM_PROBS();
|
||||||
|
|
||||||
std::vector<llama_token_data> cur;
|
|
||||||
cur.reserve(n_vocab);
|
|
||||||
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
|
|
||||||
const float logit = logf(probs[token_id]);
|
|
||||||
cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
|
|
||||||
}
|
|
||||||
|
|
||||||
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
|
|
||||||
|
|
||||||
auto * sampler = llama_sampler_init_penalties(n_vocab, LLAMA_TOKEN_NULL, LLAMA_TOKEN_NULL, last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence, false, false);
|
auto * sampler = llama_sampler_init_penalties(n_vocab, LLAMA_TOKEN_NULL, LLAMA_TOKEN_NULL, last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence, false, false);
|
||||||
|
|
||||||
@ -177,10 +130,9 @@ static void test_penalties(
|
|||||||
llama_sampler_accept(sampler, last_tokens[i]);
|
llama_sampler_accept(sampler, last_tokens[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
APPLY(llama_sampler_init_softmax(), &cur_p);
|
|
||||||
DUMP(&cur_p);
|
DUMP(&cur_p);
|
||||||
APPLY(sampler, &cur_p);
|
APPLY(sampler, &cur_p);
|
||||||
APPLY(llama_sampler_init_softmax(), &cur_p);
|
APPLY(llama_sampler_init_dist(0), &cur_p);
|
||||||
DUMP(&cur_p);
|
DUMP(&cur_p);
|
||||||
|
|
||||||
GGML_ASSERT(cur_p.size == expected_probs.size());
|
GGML_ASSERT(cur_p.size == expected_probs.size());
|
||||||
@ -214,7 +166,7 @@ static void test_sampler_queue(const size_t n_vocab, const std::string & sampler
|
|||||||
default : GGML_ABORT("Unknown sampler");
|
default : GGML_ABORT("Unknown sampler");
|
||||||
}
|
}
|
||||||
|
|
||||||
APPLY(llama_sampler_init_softmax(), &cur_p); // make sure tokens are sorted for tests
|
APPLY(llama_sampler_init_dist(0), &cur_p);
|
||||||
|
|
||||||
const int size = cur_p.size;
|
const int size = cur_p.size;
|
||||||
|
|
||||||
@ -307,21 +259,20 @@ static void test_perf() {
|
|||||||
BENCH(llama_sampler_init_tail_free(0.5f, 1), data, 32);
|
BENCH(llama_sampler_init_tail_free(0.5f, 1), data, 32);
|
||||||
BENCH(llama_sampler_init_typical (0.5f, 1), data, 32);
|
BENCH(llama_sampler_init_typical (0.5f, 1), data, 32);
|
||||||
BENCH(llama_sampler_init_xtc (1.0f, 0.1f, 1, 1), data, 32);
|
BENCH(llama_sampler_init_xtc (1.0f, 0.1f, 1, 1), data, 32);
|
||||||
BENCH(llama_sampler_init_softmax (), data, 32);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int main(void) {
|
int main(void) {
|
||||||
ggml_time_init();
|
ggml_time_init();
|
||||||
|
|
||||||
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f}, 1);
|
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f}, 1);
|
||||||
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f}, 3);
|
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.44444f, 0.33333f, 0.22222f}, 3);
|
||||||
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 4);
|
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 4);
|
||||||
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0);
|
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0);
|
||||||
|
|
||||||
test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f}, 0);
|
test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f}, 0);
|
||||||
test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f}, 0.7f);
|
test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.571429f, 0.428571f}, 0.7f);
|
||||||
test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f}, 0.8f);
|
test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.44444f, 0.33333f, 0.22222f}, 0.8f);
|
||||||
test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1);
|
test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f);
|
||||||
|
|
||||||
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/1.0f, 0.3f/1.0f, 0.2f/1.0f, 0.1f/1.0f}, 0.00f);
|
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/1.0f, 0.3f/1.0f, 0.2f/1.0f, 0.1f/1.0f}, 0.00f);
|
||||||
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/1.0f, 0.3f/1.0f, 0.2f/1.0f, 0.1f/1.0f}, 0.24f);
|
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/1.0f, 0.3f/1.0f, 0.2f/1.0f, 0.1f/1.0f}, 0.24f);
|
||||||
|
Loading…
Reference in New Issue
Block a user