llama : more consistent names of count variables (#5994)

* llama : more consistent names of count variables

ggml-ci

* llama : n_parallel -> n_seq_max

* common : fix param name

* examples : fix param name
This commit is contained in:
Georgi Gerganov 2024-03-11 17:49:47 +02:00 committed by GitHub
parent 83796e62bc
commit 05b06210c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 35 additions and 34 deletions

View File

@ -10,7 +10,7 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others)
### Recent API changes ### Recent API changes
- [2024 Mar 8] `llama_kv_cache_seq_rm()` returns a `bool` instead of `void`, and new `llama_n_max_seq()` returns the upper limit of acceptable `seq_id` in batches (relevant when dealing with multiple sequences) https://github.com/ggerganov/llama.cpp/pull/5328 - [2024 Mar 8] `llama_kv_cache_seq_rm()` returns a `bool` instead of `void`, and new `llama_n_seq_max()` returns the upper limit of acceptable `seq_id` in batches (relevant when dealing with multiple sequences) https://github.com/ggerganov/llama.cpp/pull/5328
- [2024 Mar 4] Embeddings API updated https://github.com/ggerganov/llama.cpp/pull/5796 - [2024 Mar 4] Embeddings API updated https://github.com/ggerganov/llama.cpp/pull/5796
- [2024 Mar 3] `struct llama_context_params` https://github.com/ggerganov/llama.cpp/pull/5849 - [2024 Mar 3] `struct llama_context_params` https://github.com/ggerganov/llama.cpp/pull/5849

View File

@ -1288,7 +1288,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.n_ctx = params.n_ctx; cparams.n_ctx = params.n_ctx;
cparams.n_batch = params.n_batch; cparams.n_batch = params.n_batch;
cparams.n_parallel = params.n_parallel; cparams.n_seq_max = params.n_parallel;
cparams.n_threads = params.n_threads; cparams.n_threads = params.n_threads;
cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch; cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
cparams.seed = params.seed; cparams.seed = params.seed;
@ -1786,17 +1786,17 @@ void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size) {
static const char slot_chars[] = ".123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz+"; static const char slot_chars[] = ".123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz+";
printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, largest empty slot=%d @ %d", printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, largest empty slot=%d @ %d",
view.n_cells, view.n_max_seq, view.used_cells, view.token_count, view.max_contiguous, view.max_contiguous_idx); view.n_cells, view.n_seq_max, view.used_cells, view.token_count, view.max_contiguous, view.max_contiguous_idx);
llama_kv_cache_view_cell * c_curr = view.cells; llama_kv_cache_view_cell * c_curr = view.cells;
llama_seq_id * cs_curr = view.cells_sequences; llama_seq_id * cs_curr = view.cells_sequences;
for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_max_seq) { for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_seq_max) {
if (i % row_size == 0) { if (i % row_size == 0) {
printf("\n%5d: ", i); printf("\n%5d: ", i);
} }
int seq_count = 0; int seq_count = 0;
for (int j = 0; j < view.n_max_seq; j++) { for (int j = 0; j < view.n_seq_max; j++) {
if (cs_curr[j] >= 0) { seq_count++; } if (cs_curr[j] >= 0) { seq_count++; }
} }
putchar(slot_chars[std::min(sizeof(slot_chars) - 2, size_t(seq_count))]); putchar(slot_chars[std::min(sizeof(slot_chars) - 2, size_t(seq_count))]);
@ -1809,14 +1809,14 @@ void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size) {
static const char slot_chars[] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; static const char slot_chars[] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, largest empty slot=%d @ %d\n", printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, largest empty slot=%d @ %d\n",
view.n_cells, view.n_max_seq, view.used_cells, view.token_count, view.max_contiguous, view.max_contiguous_idx); view.n_cells, view.n_seq_max, view.used_cells, view.token_count, view.max_contiguous, view.max_contiguous_idx);
std::unordered_map<llama_seq_id, size_t> seqs; std::unordered_map<llama_seq_id, size_t> seqs;
llama_kv_cache_view_cell * c_curr = view.cells; llama_kv_cache_view_cell * c_curr = view.cells;
llama_seq_id * cs_curr = view.cells_sequences; llama_seq_id * cs_curr = view.cells_sequences;
for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_max_seq) { for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_seq_max) {
for (int j = 0; j < view.n_max_seq; j++) { for (int j = 0; j < view.n_seq_max; j++) {
if (cs_curr[j] < 0) { continue; } if (cs_curr[j] < 0) { continue; }
if (seqs.find(cs_curr[j]) == seqs.end()) { if (seqs.find(cs_curr[j]) == seqs.end()) {
if (seqs.size() + 1 >= sizeof(slot_chars)) { break; } if (seqs.size() + 1 >= sizeof(slot_chars)) { break; }
@ -1835,11 +1835,11 @@ void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size) {
c_curr = view.cells; c_curr = view.cells;
cs_curr = view.cells_sequences; cs_curr = view.cells_sequences;
for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_max_seq) { for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_seq_max) {
if (i % row_size == 0) { if (i % row_size == 0) {
printf("\n%5d: ", i); printf("\n%5d: ", i);
} }
for (int j = 0; j < view.n_max_seq; j++) { for (int j = 0; j < view.n_seq_max; j++) {
if (cs_curr[j] >= 0) { if (cs_curr[j] >= 0) {
const auto & it = seqs.find(cs_curr[j]); const auto & it = seqs.find(cs_curr[j]);
putchar(it != seqs.end() ? int(slot_chars[it->second]) : '+'); putchar(it != seqs.end() ? int(slot_chars[it->second]) : '+');

View File

@ -106,7 +106,7 @@ int main(int argc, char ** argv) {
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch; ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
// ensure enough sequences are available // ensure enough sequences are available
ctx_params.n_parallel = *std::max_element(n_pl.begin(), n_pl.end()); ctx_params.n_seq_max = *std::max_element(n_pl.begin(), n_pl.end());
llama_context * ctx = llama_new_context_with_model(model, ctx_params); llama_context * ctx = llama_new_context_with_model(model, ctx_params);

View File

@ -80,7 +80,7 @@ int main(int argc, char ** argv) {
ctx_params.seed = 1234; ctx_params.seed = 1234;
ctx_params.n_ctx = n_kv_req; ctx_params.n_ctx = n_kv_req;
ctx_params.n_batch = std::max(n_len, n_parallel); ctx_params.n_batch = std::max(n_len, n_parallel);
ctx_params.n_parallel = n_parallel; ctx_params.n_seq_max = n_parallel;
ctx_params.n_threads = params.n_threads; ctx_params.n_threads = params.n_threads;
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch; ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;

View File

@ -878,6 +878,7 @@ int main(int argc, char ** argv) {
const auto line_pfx = ::llama_tokenize(ctx, params.input_prefix, false, true); const auto line_pfx = ::llama_tokenize(ctx, params.input_prefix, false, true);
const auto line_inp = ::llama_tokenize(ctx, buffer, false, false); const auto line_inp = ::llama_tokenize(ctx, buffer, false, false);
const auto line_sfx = ::llama_tokenize(ctx, params.input_suffix, false, true); const auto line_sfx = ::llama_tokenize(ctx, params.input_suffix, false, true);
LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp).c_str()); LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp).c_str());
embd_inp.insert(embd_inp.end(), line_pfx.begin(), line_pfx.end()); embd_inp.insert(embd_inp.end(), line_pfx.begin(), line_pfx.end());

View File

@ -841,7 +841,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
const int n_batch = params.n_batch; const int n_batch = params.n_batch;
const int max_tasks_per_batch = 32; const int max_tasks_per_batch = 32;
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_max_seq(ctx)); const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
llama_batch batch = llama_batch_init(n_ctx, 0, max_seq); llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
@ -1118,7 +1118,7 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
const int n_batch = params.n_batch; const int n_batch = params.n_batch;
const int max_tasks_per_batch = 128; const int max_tasks_per_batch = 128;
const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_max_seq(ctx)); const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
llama_batch batch = llama_batch_init(n_ctx, 0, max_seq); llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
@ -1470,7 +1470,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
const int n_batch = params.n_batch; const int n_batch = params.n_batch;
const int max_tasks_per_batch = 32; const int max_tasks_per_batch = 32;
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_max_seq(ctx)); const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
llama_batch batch = llama_batch_init(n_ctx, 0, max_seq); llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);

View File

@ -12538,7 +12538,7 @@ struct llama_context_params llama_context_default_params() {
/*.seed =*/ LLAMA_DEFAULT_SEED, /*.seed =*/ LLAMA_DEFAULT_SEED,
/*.n_ctx =*/ 512, /*.n_ctx =*/ 512,
/*.n_batch =*/ 512, /*.n_batch =*/ 512,
/*.n_parallel =*/ 1, /*.n_seq_max =*/ 1,
/*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default /*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default
/*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS, /*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS,
/*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED, /*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
@ -12700,7 +12700,7 @@ struct llama_context * llama_new_context_with_model(
auto & cparams = ctx->cparams; auto & cparams = ctx->cparams;
cparams.n_batch = params.n_batch; cparams.n_batch = params.n_batch;
// TODO: maybe add n_parallel here too // TODO: maybe add n_seq_max here too
cparams.n_threads = params.n_threads; cparams.n_threads = params.n_threads;
cparams.n_threads_batch = params.n_threads_batch; cparams.n_threads_batch = params.n_threads_batch;
cparams.yarn_ext_factor = params.yarn_ext_factor; cparams.yarn_ext_factor = params.yarn_ext_factor;
@ -12767,7 +12767,7 @@ struct llama_context * llama_new_context_with_model(
// Mamba only needs a constant number of KV cache cells per sequence // Mamba only needs a constant number of KV cache cells per sequence
if (model->arch == LLM_ARCH_MAMBA) { if (model->arch == LLM_ARCH_MAMBA) {
// Mamba needs at least as many KV cells as there are sequences kept at any time // Mamba needs at least as many KV cells as there are sequences kept at any time
kv_size = std::max((uint32_t) 1, params.n_parallel); kv_size = std::max((uint32_t) 1, params.n_seq_max);
// it's probably best to keep as much precision as possible for the states // it's probably best to keep as much precision as possible for the states
type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states
type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states
@ -13024,7 +13024,7 @@ uint32_t llama_n_batch(const struct llama_context * ctx) {
return ctx->cparams.n_batch; return ctx->cparams.n_batch;
} }
uint32_t llama_n_max_seq(const struct llama_context * ctx) { uint32_t llama_n_seq_max(const struct llama_context * ctx) {
return ctx->kv_self.size; return ctx->kv_self.size;
} }
@ -13188,10 +13188,10 @@ int32_t llama_model_apply_lora_from_file(const struct llama_model * model, const
} }
} }
struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_max_seq) { struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_seq_max) {
struct llama_kv_cache_view result = { struct llama_kv_cache_view result = {
/*.n_cells = */ 0, /*.n_cells = */ 0,
/*.n_max_seq = */ n_max_seq, /*.n_seq_max = */ n_seq_max,
/*.token_count = */ 0, /*.token_count = */ 0,
/*.used_cells = */ llama_get_kv_cache_used_cells(ctx), /*.used_cells = */ llama_get_kv_cache_used_cells(ctx),
/*.max_contiguous = */ 0, /*.max_contiguous = */ 0,
@ -13219,7 +13219,7 @@ void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_k
void * p = realloc(view->cells, sizeof(struct llama_kv_cache_view_cell) * view->n_cells); void * p = realloc(view->cells, sizeof(struct llama_kv_cache_view_cell) * view->n_cells);
GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells"); GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells");
view->cells = (struct llama_kv_cache_view_cell *)p; view->cells = (struct llama_kv_cache_view_cell *)p;
p = realloc(view->cells_sequences, sizeof(llama_seq_id) * view->n_max_seq * view->n_cells); p = realloc(view->cells_sequences, sizeof(llama_seq_id) * view->n_seq_max * view->n_cells);
GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells sequences"); GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells sequences");
view->cells_sequences = (llama_seq_id *)p; view->cells_sequences = (llama_seq_id *)p;
} }
@ -13233,7 +13233,7 @@ void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_k
uint32_t max_contig = 0; uint32_t max_contig = 0;
int32_t max_contig_idx = -1; int32_t max_contig_idx = -1;
for (int32_t i = 0; i < int32_t(ctx->kv_self.size); i++, c_curr++, cs_curr += view->n_max_seq) { for (int32_t i = 0; i < int32_t(ctx->kv_self.size); i++, c_curr++, cs_curr += view->n_seq_max) {
const size_t curr_size = kv_cells[i].seq_id.size(); const size_t curr_size = kv_cells[i].seq_id.size();
token_count += curr_size; token_count += curr_size;
c_curr->pos = kv_cells[i].pos + kv_cells[i].delta; c_curr->pos = kv_cells[i].pos + kv_cells[i].delta;
@ -13250,7 +13250,7 @@ void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_k
int seq_idx = 0; int seq_idx = 0;
for (const llama_seq_id it : kv_cells[i].seq_id) { for (const llama_seq_id it : kv_cells[i].seq_id) {
if (seq_idx >= view->n_max_seq) { if (seq_idx >= view->n_seq_max) {
break; break;
} }
cs_curr[seq_idx] = it; cs_curr[seq_idx] = it;
@ -13259,7 +13259,7 @@ void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_k
if (seq_idx != 0) { if (seq_idx != 0) {
used_cells++; used_cells++;
} }
for (; seq_idx < view->n_max_seq; seq_idx++) { for (; seq_idx < view->n_seq_max; seq_idx++) {
cs_curr[seq_idx] = -1; cs_curr[seq_idx] = -1;
} }
} }
@ -13921,12 +13921,12 @@ int32_t llama_tokenize(
const char * text, const char * text,
int32_t text_len, int32_t text_len,
llama_token * tokens, llama_token * tokens,
int32_t n_max_tokens, int32_t n_tokens_max,
bool add_bos, bool add_bos,
bool special) { bool special) {
auto res = llama_tokenize_internal(model->vocab, std::string(text, text_len), add_bos, special); auto res = llama_tokenize_internal(model->vocab, std::string(text, text_len), add_bos, special);
if (n_max_tokens < (int) res.size()) { if (n_tokens_max < (int) res.size()) {
// LLAMA_LOG_ERROR("%s: too many tokens\n", __func__); // LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);
return -((int) res.size()); return -((int) res.size());
} }

14
llama.h
View File

@ -235,7 +235,7 @@ extern "C" {
uint32_t seed; // RNG seed, -1 for random uint32_t seed; // RNG seed, -1 for random
uint32_t n_ctx; // text context, 0 = from model uint32_t n_ctx; // text context, 0 = from model
uint32_t n_batch; // prompt processing maximum batch size uint32_t n_batch; // prompt processing maximum batch size
uint32_t n_parallel; // number of parallel sequences (i.e. distinct states for recurrent models) uint32_t n_seq_max; // max number of sequences (i.e. distinct states for recurrent models)
uint32_t n_threads; // number of threads to use for generation uint32_t n_threads; // number of threads to use for generation
uint32_t n_threads_batch; // number of threads to use for batch processing uint32_t n_threads_batch; // number of threads to use for batch processing
@ -377,7 +377,7 @@ extern "C" {
LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_max_seq (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model); LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model);
LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model); LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
@ -456,7 +456,7 @@ extern "C" {
// Maximum number of sequences that can exist in a cell. It's not an error // Maximum number of sequences that can exist in a cell. It's not an error
// if there are more sequences in a cell than this value, however they will // if there are more sequences in a cell than this value, however they will
// not be visible in the view cells_sequences. // not be visible in the view cells_sequences.
int32_t n_max_seq; int32_t n_seq_max;
// Number of tokens in the cache. For example, if there are two populated // Number of tokens in the cache. For example, if there are two populated
// cells, the first with 1 sequence id in it and the second with 2 sequence // cells, the first with 1 sequence id in it and the second with 2 sequence
@ -476,12 +476,12 @@ extern "C" {
// Information for an individual cell. // Information for an individual cell.
struct llama_kv_cache_view_cell * cells; struct llama_kv_cache_view_cell * cells;
// The sequences for each cell. There will be n_max_seq items per cell. // The sequences for each cell. There will be n_seq_max items per cell.
llama_seq_id * cells_sequences; llama_seq_id * cells_sequences;
}; };
// Create an empty KV cache view. (use only for debugging purposes) // Create an empty KV cache view. (use only for debugging purposes)
LLAMA_API struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_max_seq); LLAMA_API struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_seq_max);
// Free a KV cache view. (use only for debugging purposes) // Free a KV cache view. (use only for debugging purposes)
LLAMA_API void llama_kv_cache_view_free(struct llama_kv_cache_view * view); LLAMA_API void llama_kv_cache_view_free(struct llama_kv_cache_view * view);
@ -708,7 +708,7 @@ extern "C" {
/// @details Convert the provided text into tokens. /// @details Convert the provided text into tokens.
/// @param tokens The tokens pointer must be large enough to hold the resulting tokens. /// @param tokens The tokens pointer must be large enough to hold the resulting tokens.
/// @return Returns the number of tokens on success, no more than n_max_tokens /// @return Returns the number of tokens on success, no more than n_tokens_max
/// @return Returns a negative number on failure - the number of tokens that would have been returned /// @return Returns a negative number on failure - the number of tokens that would have been returned
/// @param special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. /// @param special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext.
/// Does not insert a leading space. /// Does not insert a leading space.
@ -717,7 +717,7 @@ extern "C" {
const char * text, const char * text,
int32_t text_len, int32_t text_len,
llama_token * tokens, llama_token * tokens,
int32_t n_max_tokens, int32_t n_tokens_max,
bool add_bos, bool add_bos,
bool special); bool special);