diff --git a/common/common.cpp b/common/common.cpp index 9969cb97d..8fbff1da7 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -436,8 +436,6 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { params.use_mmap = false; } else if (arg == "--numa") { params.numa = true; - } else if (arg == "--export") { - params.export_cgraph = true; } else if (arg == "--verbose-prompt") { params.verbose_prompt = true; } else if (arg == "-r" || arg == "--reverse-prompt") { @@ -685,7 +683,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" Not recommended since this is both slower and uses more VRAM.\n"); #endif // GGML_USE_CUBLAS #endif - printf(" --export export the computation graph to 'llama.ggml'\n"); printf(" --verbose-prompt print prompt before generation\n"); fprintf(stderr, " --simple-io use basic IO for better compatibility in subprocesses and limited consoles\n"); printf(" --lora FNAME apply LoRA adapter (implies --no-mmap)\n"); @@ -782,7 +779,7 @@ std::tuple llama_init_from_gpt_par { LOG("warming up the model with an empty run\n"); - const std::vector tmp = { llama_token_bos(lctx), llama_token_eos(lctx), }; + std::vector tmp = { llama_token_bos(lctx), llama_token_eos(lctx), }; llama_eval(lctx, tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, params.n_threads); llama_reset_timings(lctx); } @@ -1182,7 +1179,6 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l fprintf(stream, "color: %s # default: false\n", params.use_color ? "true" : "false"); fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx); fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false"); - fprintf(stream, "export: %s # default: false\n", params.export_cgraph ? "true" : "false"); fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n"); fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", params.frequency_penalty); dump_string_yaml_multiline(stream, "grammar", params.grammar.c_str()); diff --git a/common/common.h b/common/common.h index 37d15415f..504e5944e 100644 --- a/common/common.h +++ b/common/common.h @@ -111,7 +111,6 @@ struct gpt_params { bool use_mmap = true; // use mmap for faster loads bool use_mlock = false; // use mlock to keep model in memory bool numa = false; // attempt optimizations that help on some NUMA systems - bool export_cgraph = false; // export the computation graph bool verbose_prompt = false; // print prompt tokens before generation }; diff --git a/examples/beam-search/beam-search.cpp b/examples/beam-search/beam-search.cpp index 6b31aea78..37c9f81a9 100644 --- a/examples/beam-search/beam-search.cpp +++ b/examples/beam-search/beam-search.cpp @@ -158,7 +158,8 @@ int main(int argc, char ** argv) } std::cout << std::flush; - int n_past = llama_get_kv_cache_token_count(ctx); + int n_past = 0; + if (llama_eval(ctx, tokens_list.data(), tokens_list.size(), n_past, params.n_threads)) { fprintf(stderr, "%s : failed to eval prompt.\n" , __func__ ); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index a8179f1bf..b9d26bc75 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -198,15 +198,6 @@ int main(int argc, char ** argv) { params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info()); } - // export the cgraph and exit - if (params.export_cgraph) { - llama_eval_export(ctx, "llama.ggml"); - llama_free(ctx); - llama_free_model(model); - - return 0; - } - std::string path_session = params.path_prompt_cache; std::vector session_tokens; diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 3a1c8c28d..2b7472dcc 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -400,7 +400,7 @@ results_perplexity perplexity(llama_context * ctx, const gpt_params & params) { return {tokens, ppl, logit_history, prob_history}; } -std::vector hellaswag_evaluate_tokens(llama_context * ctx, const std::vector& tokens, int n_past, int n_batch, +std::vector hellaswag_evaluate_tokens(llama_context * ctx, const std::vector & tokens, int n_past, int n_batch, int n_vocab, int n_thread) { std::vector result; result.reserve(tokens.size() * n_vocab); diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index ba5de0cc6..9f160376a 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -73,10 +73,12 @@ int main(int argc, char ** argv) { const int n_gen = std::min(32, max_context_size); - while (llama_get_kv_cache_token_count(ctx) < n_gen) { + int n_cur = 0; + + while (n_cur < n_gen) { // evaluate the transformer - if (llama_eval(ctx, tokens_list.data(), int(tokens_list.size()), llama_get_kv_cache_token_count(ctx), params.n_threads)) { + if (llama_eval(ctx, tokens_list.data(), int(tokens_list.size()), n_cur, params.n_threads)) { fprintf(stderr, "%s : failed to eval\n", __func__); return 1; } diff --git a/ggml.c b/ggml.c index e4faafee6..207561794 100644 --- a/ggml.c +++ b/ggml.c @@ -12462,13 +12462,11 @@ static void ggml_compute_forward_alibi_f16( return; } - const int n_past = ((int32_t *) dst->op_params)[0]; + //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_head = ((int32_t *) dst->op_params)[1]; float max_bias; memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); - assert(n_past >= 0); - const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1 const int ne1 = src0->ne[1]; // seq_len_without_past const int ne2 = src0->ne[2]; // n_head -> this is k @@ -12483,7 +12481,7 @@ static void ggml_compute_forward_alibi_f16( //const int nb3 = src0->nb[3]; GGML_ASSERT(nb0 == sizeof(ggml_fp16_t)); - GGML_ASSERT(ne1 + n_past == ne0); (void) n_past; + //GGML_ASSERT(ne1 + n_past == ne0); (void) n_past; GGML_ASSERT(n_head == ne2); // add alibi to src0 (KQ_scaled) diff --git a/llama.cpp b/llama.cpp index 9d41689f7..532937da8 100644 --- a/llama.cpp +++ b/llama.cpp @@ -71,6 +71,7 @@ #include #include #include +#include #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data @@ -975,7 +976,25 @@ struct llama_layer { struct ggml_tensor * w3; // ffn_up }; +struct llama_kv_cell { + llama_pos pos = -1; + + std::set seq_id; + + bool has_seq_id(const llama_seq_id & id) const { + return seq_id.find(id) != seq_id.end(); + } +}; + +// ring-buffer of cached KV data struct llama_kv_cache { + bool is_roped = false; + + uint32_t head = 0; + uint32_t size = 0; + + std::vector cells; + struct ggml_tensor * k = NULL; struct ggml_tensor * v = NULL; @@ -983,8 +1002,6 @@ struct llama_kv_cache { llama_buffer buf; - int n; // number of tokens currently in the cache - ~llama_kv_cache() { if (ctx) { ggml_free(ctx); @@ -1167,16 +1184,21 @@ static bool llama_kv_cache_init( const struct llama_hparams & hparams, struct llama_kv_cache & cache, ggml_type wtype, - int n_ctx, int n_gpu_layers) { - const int n_embd = hparams.n_embd_gqa(); - const int n_layer = hparams.n_layer; + const uint32_t n_embd = hparams.n_embd_gqa(); + const uint32_t n_layer = hparams.n_layer; + const uint32_t n_ctx = hparams.n_ctx; const int64_t n_mem = n_layer*n_ctx; const int64_t n_elements = n_embd*n_mem; + cache.head = 0; + cache.size = n_ctx; + + cache.cells.clear(); + cache.cells.resize(n_ctx); + cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*MB); - cache.n = 0; struct ggml_init_params params; params.mem_size = cache.buf.size; @@ -1208,6 +1230,68 @@ static bool llama_kv_cache_init( return true; } +// find an empty slot of size "n_tokens" in the cache +// updates the cache head +static bool llama_kv_cache_find_slot( + struct llama_kv_cache & cache, + struct llama_batch & batch) { + const uint32_t n_ctx = cache.size; + const uint32_t n_tokens = batch.n_tokens; + + if (n_tokens > n_ctx) { + LLAMA_LOG_ERROR("%s: n_tokens=%d > n_ctx=%d\n", __func__, n_tokens, n_ctx); + return false; + } + + uint32_t n_tested = 0; + + while (true) { + if (cache.head + n_tokens > n_ctx) { + cache.head = 0; + n_tested += n_ctx - cache.head; + continue; + } + + bool found = true; + for (uint32_t i = 0; i < n_tokens; i++) { + if (cache.cells[cache.head + i].pos >= 0) { + found = false; + cache.head += i + 1; + n_tested += i + 1; + break; + } + } + + if (found) { + break; + } + + if (n_tested >= n_ctx) { + LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); + return false; + } + } + + for (uint32_t i = 0; i < n_tokens; i++) { + cache.cells[cache.head + i].pos = batch.pos[i]; + cache.cells[cache.head + i].seq_id.insert(batch.seq_id[i]); + } + + return true; +} + +void llama_kv_cache_clear(struct llama_kv_cache & cache, int32_t p0, int32_t p1) { + cache.head = p0; + + if (p0 < 0) p0 = 0; + if (p1 < 0) p1 = cache.size; + + for (int32_t i = p0; i < p1; ++i) { + cache.cells[i].pos = -1; + cache.cells[i].seq_id.clear(); + } +} + // // model loading and saving // @@ -2308,15 +2392,7 @@ static bool llama_model_load( static struct ggml_cgraph * llm_build_llama( llama_context & lctx, - const llama_token * tokens, - const float * embd, - int n_tokens, - int n_past) { - - GGML_ASSERT((!tokens && embd) || (tokens && !embd)); // NOLINT - - const int N = n_tokens; - + llama_batch & batch) { const auto & model = lctx.model; const auto & hparams = model.hparams; @@ -2340,6 +2416,8 @@ static struct ggml_cgraph * llm_build_llama( const int n_gpu_layers = model.n_gpu_layers; + const int32_t n_tokens = batch.n_tokens; + auto & buf_compute = lctx.buf_compute; struct ggml_init_params params = { @@ -2357,12 +2435,12 @@ static struct ggml_cgraph * llm_build_llama( struct ggml_tensor * cur; struct ggml_tensor * inpL; - if (tokens) { - struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + if (batch.token) { + struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); ggml_allocr_alloc(lctx.alloc, inp_tokens); if (!ggml_allocr_is_measure(lctx.alloc)) { - memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens)); + memcpy(inp_tokens->data, batch.token, n_tokens*ggml_element_size(inp_tokens)); } ggml_set_name(inp_tokens, "inp_tokens"); @@ -2372,11 +2450,11 @@ static struct ggml_cgraph * llm_build_llama( GGML_ASSERT(false && "not implemented"); #endif - inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N); + inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens); ggml_allocr_alloc(lctx.alloc, inpL); if (!ggml_allocr_is_measure(lctx.alloc)) { - memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL)); + memcpy(inpL->data, batch.embd, n_tokens * n_embd * ggml_element_size(inpL)); } } @@ -2408,33 +2486,35 @@ static struct ggml_cgraph * llm_build_llama( struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); ggml_allocr_alloc(lctx.alloc, KQ_scale); if (!ggml_allocr_is_measure(lctx.alloc)) { - ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head)); + ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd_head))); } ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); - // KQ_mask - struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_past + N, N, 1); + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_ctx, n_tokens, 1); ggml_allocr_alloc(lctx.alloc, KQ_mask); if (!ggml_allocr_is_measure(lctx.alloc)) { float * data = (float *) KQ_mask->data; memset(data, 0, ggml_nbytes(KQ_mask)); for (int h = 0; h < 1; ++h) { - for (int j = 0; j < N; ++j) { - for (int i = n_past + j + 1; i < n_past + N; ++i) { - data[h*(n_past + N)*N + j*(n_past + N) + i] = -INFINITY; + for (int j = 0; j < n_tokens; ++j) { + for (int i = 0; i < n_ctx; ++i) { + if (!kv_self.cells[i].has_seq_id(batch.seq_id[j]) || kv_self.cells[i].pos > batch.pos[j]) { + data[h*(n_ctx*n_tokens) + j*n_ctx + i] = -INFINITY; + } } } } } // KQ_pos - contains the positions - struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); ggml_allocr_alloc(lctx.alloc, KQ_pos); if (!ggml_allocr_is_measure(lctx.alloc)) { int * data = (int *) KQ_pos->data; - for (int i = 0; i < N; ++i) { - data[i] = n_past + i; + for (int i = 0; i < n_tokens; ++i) { + data[i] = batch.pos[i]; } } @@ -2474,33 +2554,33 @@ static struct ggml_cgraph * llm_build_llama( offload_func_kq(tmpq); ggml_set_name(tmpq, "tmpq"); - struct ggml_tensor * Kcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), KQ_pos, n_embd_head, 0, 0, freq_base, freq_scale); + struct ggml_tensor * Kcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), KQ_pos, n_embd_head, 0, 0, freq_base, freq_scale); offload_func_kq(Kcur); ggml_set_name(Kcur, "Kcur"); - struct ggml_tensor * Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), KQ_pos, n_embd_head, 0, 0, freq_base, freq_scale); + struct ggml_tensor * Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens), KQ_pos, n_embd_head, 0, 0, freq_base, freq_scale); offload_func_kq(Qcur); ggml_set_name(Qcur, "Qcur"); // store key and value to memory { - // compute the transposed [N, n_embd] V matrix + // compute the transposed [n_tokens, n_embd] V matrix struct ggml_tensor * tmpv = ggml_mul_mat(ctx0, model.layers[il].wv, cur); offload_func_v(tmpv); ggml_set_name(tmpv, "tmpv"); - struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, N)); + struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, n_tokens)); offload_func_v(Vcur); ggml_set_name(Vcur, "Vcur"); - struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + n_past)); + struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_self.head)); offload_func_kq(k); ggml_set_name(k, "k"); - struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd_gqa, + struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa, ( n_ctx)*ggml_element_size(kv_self.v), - (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + n_past*ggml_element_size(kv_self.v)); + (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_self.head*ggml_element_size(kv_self.v)); offload_func_v(v); ggml_set_name(v, "v"); @@ -2515,7 +2595,7 @@ static struct ggml_cgraph * llm_build_llama( struct ggml_tensor * K = ggml_view_3d(ctx0, kv_self.k, - n_embd_head, n_past + N, n_head_kv, + n_embd_head, n_ctx, n_head_kv, ggml_element_size(kv_self.k)*n_embd_gqa, ggml_element_size(kv_self.k)*n_embd_head, ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il); @@ -2528,14 +2608,13 @@ static struct ggml_cgraph * llm_build_llama( ggml_set_name(KQ, "KQ"); // KQ_scaled = KQ / sqrt(n_embd_head) - // KQ_scaled shape [n_past + N, N, n_head, 1] + // KQ_scaled shape [n_ctx, n_tokens, n_head, 1] struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale); offload_func_kq(KQ_scaled); ggml_set_name(KQ_scaled, "KQ_scaled"); // KQ_masked = mask_past(KQ_scaled) struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ_scaled, KQ_mask); - //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); offload_func_kq(KQ_masked); ggml_set_name(KQ_masked, "KQ_masked"); @@ -2547,7 +2626,7 @@ static struct ggml_cgraph * llm_build_llama( // split cached V into n_head heads struct ggml_tensor * V = ggml_view_3d(ctx0, kv_self.v, - n_past + N, n_embd_head, n_head_kv, + n_ctx, n_embd_head, n_head_kv, ggml_element_size(kv_self.v)*n_ctx, ggml_element_size(kv_self.v)*n_ctx*n_embd_head, ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il); @@ -2562,7 +2641,7 @@ static struct ggml_cgraph * llm_build_llama( // make V contiguous in memory to speed up the matmul, however we waste time on the copy // on M1 this is faster for the perplexity computation, but ~5% slower for the single-token generation // is there a better way? - struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_embd_head, n_head)); + struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_ctx, n_embd_head, n_head)); struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_cont, KQ_soft_max); #endif @@ -2571,10 +2650,10 @@ static struct ggml_cgraph * llm_build_llama( offload_func_v(KQV_merged); ggml_set_name(KQV_merged, "KQV_merged"); - // cur = KQV_merged.contiguous().view(n_embd, N) + // cur = KQV_merged.contiguous().view(n_embd, n_tokens) cur = ggml_cpy(ctx0, KQV_merged, - ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens)); offload_func_v(cur); ggml_set_name(cur, "KQV_merged_contiguous"); @@ -2665,18 +2744,9 @@ static struct ggml_cgraph * llm_build_llama( return gf; } - static struct ggml_cgraph * llm_build_baichaun( llama_context & lctx, - const llama_token * tokens, - const float * embd, - int n_tokens, - int n_past) { - - GGML_ASSERT((!tokens && embd) || (tokens && !embd)); // NOLINT - - const int N = n_tokens; - + llama_batch & batch) { const auto & model = lctx.model; const auto & hparams = model.hparams; @@ -2700,6 +2770,8 @@ static struct ggml_cgraph * llm_build_baichaun( const int n_gpu_layers = model.n_gpu_layers; + const int32_t n_tokens = batch.n_tokens; + auto & buf_compute = lctx.buf_compute; struct ggml_init_params params = { @@ -2717,12 +2789,12 @@ static struct ggml_cgraph * llm_build_baichaun( struct ggml_tensor * cur; struct ggml_tensor * inpL; - if (tokens) { - struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + if (batch.token) { + struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); ggml_allocr_alloc(lctx.alloc, inp_tokens); if (!ggml_allocr_is_measure(lctx.alloc)) { - memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens)); + memcpy(inp_tokens->data, batch.token, n_tokens*ggml_element_size(inp_tokens)); } ggml_set_name(inp_tokens, "inp_tokens"); @@ -2732,11 +2804,11 @@ static struct ggml_cgraph * llm_build_baichaun( GGML_ASSERT(false && "not implemented"); #endif - inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N); + inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens); ggml_allocr_alloc(lctx.alloc, inpL); if (!ggml_allocr_is_measure(lctx.alloc)) { - memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL)); + memcpy(inpL->data, batch.embd, n_tokens * n_embd * ggml_element_size(inpL)); } } @@ -2772,29 +2844,31 @@ static struct ggml_cgraph * llm_build_baichaun( } ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); - // KQ_mask - struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_past + N, N, 1); + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_ctx, n_tokens, 1); ggml_allocr_alloc(lctx.alloc, KQ_mask); if (!ggml_allocr_is_measure(lctx.alloc)) { float * data = (float *) KQ_mask->data; memset(data, 0, ggml_nbytes(KQ_mask)); for (int h = 0; h < 1; ++h) { - for (int j = 0; j < N; ++j) { - for (int i = n_past + j + 1; i < n_past + N; ++i) { - data[h*(n_past + N)*N + j*(n_past + N) + i] = -INFINITY; + for (int j = 0; j < n_tokens; ++j) { + for (int i = 0; i < n_ctx; ++i) { + if (!kv_self.cells[i].has_seq_id(batch.seq_id[j]) || kv_self.cells[i].pos > batch.pos[j]) { + data[h*(n_ctx*n_tokens) + j*n_ctx + i] = -INFINITY; + } } } } } // KQ_pos - contains the positions - struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); ggml_allocr_alloc(lctx.alloc, KQ_pos); if (!ggml_allocr_is_measure(lctx.alloc)) { int * data = (int *) KQ_pos->data; - for (int i = 0; i < N; ++i) { - data[i] = n_past + i; + for (int i = 0; i < n_tokens; ++i) { + data[i] = batch.pos[i]; } } @@ -2838,12 +2912,12 @@ static struct ggml_cgraph * llm_build_baichaun( struct ggml_tensor * Qcur; switch (model.type) { case MODEL_7B: - Kcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), KQ_pos, n_embd_head, 0, 0, freq_base, freq_scale); - Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), KQ_pos, n_embd_head, 0, 0, freq_base, freq_scale); + Kcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), KQ_pos, n_embd_head, 0, 0, freq_base, freq_scale); + Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens), KQ_pos, n_embd_head, 0, 0, freq_base, freq_scale); break; case MODEL_13B: - Kcur = ggml_reshape_3d(ctx0, tmpk, n_embd/n_head, n_head, N); - Qcur = ggml_reshape_3d(ctx0, tmpq, n_embd/n_head, n_head, N); + Kcur = ggml_reshape_3d(ctx0, tmpk, n_embd/n_head, n_head, n_tokens); + Qcur = ggml_reshape_3d(ctx0, tmpq, n_embd/n_head, n_head, n_tokens); break; default: GGML_ASSERT(false); @@ -2857,23 +2931,23 @@ static struct ggml_cgraph * llm_build_baichaun( // store key and value to memory { - // compute the transposed [N, n_embd] V matrix + // compute the transposed [n_tokens, n_embd] V matrix struct ggml_tensor * tmpv = ggml_mul_mat(ctx0, model.layers[il].wv, cur); offload_func_v(tmpv); ggml_set_name(tmpv, "tmpv"); - struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, N)); + struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, n_tokens)); offload_func_v(Vcur); ggml_set_name(Vcur, "Vcur"); - struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + n_past)); + struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_self.head)); offload_func_kq(k); ggml_set_name(k, "k"); - struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd_gqa, + struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa, ( n_ctx)*ggml_element_size(kv_self.v), - (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + n_past*ggml_element_size(kv_self.v)); + (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_self.head*ggml_element_size(kv_self.v)); offload_func_v(v); ggml_set_name(v, "v"); @@ -2888,7 +2962,7 @@ static struct ggml_cgraph * llm_build_baichaun( struct ggml_tensor * K = ggml_view_3d(ctx0, kv_self.k, - n_embd_head, n_past + N, n_head_kv, + n_embd_head, n_ctx, n_head_kv, ggml_element_size(kv_self.k)*n_embd_gqa, ggml_element_size(kv_self.k)*n_embd_head, ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il); @@ -2901,7 +2975,7 @@ static struct ggml_cgraph * llm_build_baichaun( ggml_set_name(KQ, "KQ"); // KQ_scaled = KQ / sqrt(n_embd_head) - // KQ_scaled shape [n_past + N, N, n_head, 1] + // KQ_scaled shape [n_past + n_tokens, n_tokens, n_head, 1] struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale); offload_func_kq(KQ_scaled); ggml_set_name(KQ_scaled, "KQ_scaled"); @@ -2912,22 +2986,16 @@ static struct ggml_cgraph * llm_build_baichaun( switch (model.type) { case MODEL_7B: KQ_masked = ggml_add(ctx0, KQ_scaled, KQ_mask); - //KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past); break; case MODEL_13B: - KQ_scaled_alibi =ggml_alibi(ctx0, KQ_scaled, n_past, n_head, 8); + // TODO: replace with ggml_add() + KQ_scaled_alibi = ggml_alibi(ctx0, KQ_scaled, /*n_past*/ 0, n_head, 8); ggml_set_name(KQ_scaled_alibi, "KQ_scaled_alibi"); KQ_masked = ggml_add(ctx0, KQ_scaled_alibi, KQ_mask); - //KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled_alibi, n_past); break; default: GGML_ASSERT(false); } - // KQ_masked = mask_past(KQ_scaled) - // struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past); - // struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled_alibi, n_past); - // offload_func_kq(KQ_masked); - // ggml_set_name(KQ_masked, "KQ_masked"); // KQ = soft_max(KQ_masked) struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); @@ -2937,34 +3005,26 @@ static struct ggml_cgraph * llm_build_baichaun( // split cached V into n_head heads struct ggml_tensor * V = ggml_view_3d(ctx0, kv_self.v, - n_past + N, n_embd_head, n_head_kv, + n_ctx, n_embd_head, n_head_kv, ggml_element_size(kv_self.v)*n_ctx, ggml_element_size(kv_self.v)*n_ctx*n_embd_head, ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il); offload_func_v(V); ggml_set_name(V, "V"); -#if 1 struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); offload_func_v(KQV); ggml_set_name(KQV, "KQV"); -#else - // make V contiguous in memory to speed up the matmul, however we waste time on the copy - // on M1 this is faster for the perplexity computation, but ~5% slower for the single-token generation - // is there a better way? - struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_embd_head, n_head)); - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_cont, KQ_soft_max); -#endif // KQV_merged = KQV.permute(0, 2, 1, 3) struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); offload_func_v(KQV_merged); ggml_set_name(KQV_merged, "KQV_merged"); - // cur = KQV_merged.contiguous().view(n_embd, N) + // cur = KQV_merged.contiguous().view(n_embd, n_tokens) cur = ggml_cpy(ctx0, KQV_merged, - ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens)); offload_func_v(cur); ggml_set_name(cur, "KQV_merged_contiguous"); @@ -3057,15 +3117,7 @@ static struct ggml_cgraph * llm_build_baichaun( static struct ggml_cgraph * llm_build_falcon( llama_context & lctx, - const llama_token * tokens, - const float * embd, - int n_tokens, - int n_past) { - - GGML_ASSERT((!tokens && embd) || (tokens && !embd)); // NOLINT - - const int N = n_tokens; - + llama_batch & batch) { const auto & model = lctx.model; const auto & hparams = model.hparams; @@ -3089,6 +3141,8 @@ static struct ggml_cgraph * llm_build_falcon( const int n_gpu_layers = model.n_gpu_layers; + const int32_t n_tokens = batch.n_tokens; + auto & buf_compute = lctx.buf_compute; struct ggml_init_params params = { @@ -3106,12 +3160,12 @@ static struct ggml_cgraph * llm_build_falcon( struct ggml_tensor * cur; struct ggml_tensor * inpL; - if (tokens) { - struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + if (batch.token) { + struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); ggml_allocr_alloc(lctx.alloc, inp_tokens); if (!ggml_allocr_is_measure(lctx.alloc)) { - memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens)); + memcpy(inp_tokens->data, batch.token, n_tokens*ggml_element_size(inp_tokens)); } ggml_set_name(inp_tokens, "inp_tokens"); @@ -3121,11 +3175,11 @@ static struct ggml_cgraph * llm_build_falcon( GGML_ASSERT(false && "not implemented"); #endif - inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N); + inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens); ggml_allocr_alloc(lctx.alloc, inpL); if (!ggml_allocr_is_measure(lctx.alloc)) { - memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL)); + memcpy(inpL->data, batch.embd, n_tokens * n_embd * ggml_element_size(inpL)); } } @@ -3161,29 +3215,31 @@ static struct ggml_cgraph * llm_build_falcon( } ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); - // KQ_mask - struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_past + N, N, 1); + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_ctx, n_tokens, 1); ggml_allocr_alloc(lctx.alloc, KQ_mask); if (!ggml_allocr_is_measure(lctx.alloc)) { float * data = (float *) KQ_mask->data; memset(data, 0, ggml_nbytes(KQ_mask)); for (int h = 0; h < 1; ++h) { - for (int j = 0; j < N; ++j) { - for (int i = n_past + j + 1; i < n_past + N; ++i) { - data[h*(n_past + N)*N + j*(n_past + N) + i] = -INFINITY; + for (int j = 0; j < n_tokens; ++j) { + for (int i = 0; i < n_ctx; ++i) { + if (!kv_self.cells[i].has_seq_id(batch.seq_id[j]) || kv_self.cells[i].pos > batch.pos[j]) { + data[h*(n_ctx*n_tokens) + j*n_ctx + i] = -INFINITY; + } } } } } // KQ_pos - contains the positions - struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); ggml_allocr_alloc(lctx.alloc, KQ_pos); if (!ggml_allocr_is_measure(lctx.alloc)) { int * data = (int *) KQ_pos->data; - for (int i = 0; i < N; ++i) { - data[i] = n_past + i; + for (int i = 0; i < n_tokens; ++i) { + data[i] = batch.pos[i]; } } @@ -3242,21 +3298,21 @@ static struct ggml_cgraph * llm_build_falcon( // TODO: these 2 ggml_conts are technically not needed, but we add them until CUDA support for // non-contiguous views is added for the rope operator struct ggml_tensor * tmpq = ggml_cont(ctx0, ggml_view_3d( - ctx0, cur, n_embd_head, n_head, N, + ctx0, cur, n_embd_head, n_head, n_tokens, wsize * n_embd_head, wsize * n_embd_head * (n_head + 2 * n_head_kv), 0)); offload_func_kq(tmpq); struct ggml_tensor * tmpk = ggml_cont(ctx0, ggml_view_3d( - ctx0, cur, n_embd_head, n_head_kv, N, + ctx0, cur, n_embd_head, n_head_kv, n_tokens, wsize * n_embd_head, wsize * n_embd_head * (n_head + 2 * n_head_kv), wsize * n_embd_head * n_head)); offload_func_kq(tmpk); struct ggml_tensor * tmpv = ggml_view_3d( - ctx0, cur, n_embd_head, n_head_kv, N, + ctx0, cur, n_embd_head, n_head_kv, n_tokens, wsize * n_embd_head, wsize * n_embd_head * (n_head + 2 * n_head_kv), wsize * n_embd_head * (n_head + n_head_kv)); @@ -3269,18 +3325,18 @@ static struct ggml_cgraph * llm_build_falcon( offload_func_kq(Kcur); { - struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_cont(ctx0, tmpv), n_embd_gqa, N)); + struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_cont(ctx0, tmpv), n_embd_gqa, n_tokens)); offload_func_v(Vcur); offload_func_v(Vcur->src[0]->src[0]); ggml_set_name(Vcur, "Vcur"); - struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + n_past)); + struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_self.head)); offload_func_kq(k); ggml_set_name(k, "k"); - struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd_gqa, + struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa, ( n_ctx)*ggml_element_size(kv_self.v), - (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + n_past*ggml_element_size(kv_self.v)); + (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_self.head*ggml_element_size(kv_self.v)); offload_func_v(v); ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); @@ -3293,7 +3349,7 @@ static struct ggml_cgraph * llm_build_falcon( struct ggml_tensor * K = ggml_view_3d(ctx0, kv_self.k, - n_embd_head, n_past + N, n_head_kv, + n_embd_head, n_ctx, n_head_kv, ggml_element_size(kv_self.k)*n_embd_gqa, ggml_element_size(kv_self.k)*n_embd_head, ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il); @@ -3308,7 +3364,7 @@ static struct ggml_cgraph * llm_build_falcon( offload_func_kq(KQ_scaled); ggml_set_name(KQ_scaled, "KQ_scaled"); - struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past); + struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ_scaled, KQ_mask); offload_func_kq(KQ_masked); ggml_set_name(KQ_masked, "KQ_masked"); @@ -3318,7 +3374,7 @@ static struct ggml_cgraph * llm_build_falcon( struct ggml_tensor * V = ggml_view_3d(ctx0, kv_self.v, - n_past + N, n_embd_head, n_head_kv, + n_ctx, n_embd_head, n_head_kv, ggml_element_size(kv_self.v)*n_ctx, ggml_element_size(kv_self.v)*n_ctx*n_embd_head, ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il); @@ -3333,7 +3389,7 @@ static struct ggml_cgraph * llm_build_falcon( offload_func_v(KQV_merged); ggml_set_name(KQV_merged, "KQV_merged"); - cur = ggml_cpy(ctx0, KQV_merged, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); + cur = ggml_cpy(ctx0, KQV_merged, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens)); offload_func_v(cur); ggml_set_name(cur, "KQV_merged_contiguous"); @@ -3391,10 +3447,7 @@ static struct ggml_cgraph * llm_build_falcon( static struct ggml_cgraph * llama_build_graph( llama_context & lctx, - const llama_token * tokens, - const float * embd, - int n_tokens, - int n_past) { + llama_batch & batch) { const auto & model = lctx.model; struct ggml_cgraph * result = NULL; @@ -3402,15 +3455,15 @@ static struct ggml_cgraph * llama_build_graph( switch (model.arch) { case LLM_ARCH_LLAMA: { - result = llm_build_llama(lctx, tokens, embd, n_tokens, n_past); + result = llm_build_llama(lctx, batch); } break; case LLM_ARCH_BAICHUAN: { - result = llm_build_baichaun(lctx, tokens, embd, n_tokens, n_past); + result = llm_build_baichaun(lctx, batch); } break; case LLM_ARCH_FALCON: { - result = llm_build_falcon(lctx, tokens, embd, n_tokens, n_past); + result = llm_build_falcon(lctx, batch); } break; default: GGML_ASSERT(false); @@ -3422,52 +3475,48 @@ static struct ggml_cgraph * llama_build_graph( // evaluate the transformer // // - lctx: llama context -// - tokens: new batch of tokens to process -// - embd embeddings input -// - n_tokens number of tokens -// - n_past: the context size so far +// - batch: batch to evaluate // - n_threads: number of threads to use // static bool llama_eval_internal( llama_context & lctx, - const llama_token * tokens, - const float * embd, - int n_tokens, - int n_past, - int n_threads, - const char * cgraph_fname) { + llama_batch & batch, + int n_threads) { + const uint32_t n_tokens = batch.n_tokens; - GGML_ASSERT((!tokens && embd) || (tokens && !embd)); // NOLINT + if (n_tokens == 0) { + LLAMA_LOG_ERROR("%s: n_tokens == 0", __func__); + return false; + } - GGML_ASSERT(n_tokens > 0); - GGML_ASSERT(n_past >= 0); - // TODO: keep the values of n_batch and n_ctx - // GGML_ASSERT(n_tokens <= n_batch); - // GGML_ASSERT(n_past + n_tokens <= n_ctx); + GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT const int64_t t_start_us = ggml_time_us(); #ifdef GGML_USE_MPI + // TODO: needs fix after #3228 ggml_mpi_eval_init(lctx.ctx_mpi, &n_tokens, &n_past, &n_threads); #endif GGML_ASSERT(n_threads > 0); - const int N = n_tokens; - const auto & model = lctx.model; const auto & hparams = model.hparams; - const auto & kv_self = lctx.kv_self; + auto & kv_self = lctx.kv_self; GGML_ASSERT(!!kv_self.ctx); const int64_t n_embd = hparams.n_embd; const int64_t n_vocab = hparams.n_vocab; + if (!llama_kv_cache_find_slot(kv_self, batch)) { + return false; + } + ggml_allocr_reset(lctx.alloc); - ggml_cgraph * gf = llama_build_graph(lctx, tokens, embd, n_tokens, n_past); + ggml_cgraph * gf = llama_build_graph(lctx, batch); ggml_allocr_alloc_graph(lctx.alloc, gf); @@ -3494,7 +3543,7 @@ static bool llama_eval_internal( // TODO: this is mostly important for Apple Silicon where CBLAS is still performing very well // we still need some threads to process all non-mul_mat ops, but not too much to avoid interfering // with the BLAS calls. need a better solution - if (N >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas()) { + if (n_tokens >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas()) { n_threads = std::min(4, n_threads); } @@ -3524,12 +3573,8 @@ static bool llama_eval_internal( ggml_mpi_graph_compute_post(lctx.ctx_mpi, gf, n_layer); #endif - // update kv token count - lctx.kv_self.n = n_past + N; - - if (cgraph_fname) { - ggml_graph_export(gf, cgraph_fname); - } + // update the kv ring buffer head + lctx.kv_self.head += n_tokens; #ifdef GGML_PERF // print timing information per ggml operation (for debugging purposes) @@ -3547,12 +3592,12 @@ static bool llama_eval_internal( auto & logits_out = lctx.logits; if (lctx.logits_all) { - logits_out.resize(n_vocab * N); - memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*N); + logits_out.resize(n_vocab * n_tokens); + memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*n_tokens); } else { // return result for just the last token logits_out.resize(n_vocab); - memcpy(logits_out.data(), (float *) ggml_get_data(res) + (n_vocab*(N-1)), sizeof(float)*n_vocab); + memcpy(logits_out.data(), (float *) ggml_get_data(res) + (n_vocab*(n_tokens - 1)), sizeof(float)*n_vocab); } } @@ -3561,17 +3606,17 @@ static bool llama_eval_internal( auto & embedding_out = lctx.embedding; embedding_out.resize(n_embd); - memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(N - 1)), sizeof(float)*n_embd); + memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(n_tokens - 1)), sizeof(float)*n_embd); } // measure the performance only for the single-token evals - if (N == 1) { + if (n_tokens == 1) { lctx.t_eval_us += ggml_time_us() - t_start_us; lctx.n_eval++; } - else if (N > 1) { + else if (n_tokens > 1) { lctx.t_p_eval_us += ggml_time_us() - t_start_us; - lctx.n_p_eval += N; + lctx.n_p_eval += n_tokens; } return true; @@ -6043,12 +6088,16 @@ struct llama_context * llama_new_context_with_model( // reserve memory for context buffers if (!params.vocab_only) { - if (!llama_kv_cache_init(ctx->model.hparams, ctx->kv_self, memory_type, ctx->model.hparams.n_ctx, params.n_gpu_layers)) { + if (!llama_kv_cache_init(ctx->model.hparams, ctx->kv_self, memory_type, params.n_gpu_layers)) { LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__); llama_free(ctx); return nullptr; } + if (model->arch == LLM_ARCH_LLAMA) { + ctx->kv_self.is_roped = true; + } + { const size_t memory_size = ggml_nbytes(ctx->kv_self.k) + ggml_nbytes(ctx->kv_self.v); LLAMA_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); @@ -6076,10 +6125,11 @@ struct llama_context * llama_new_context_with_model( ctx->alloc = ggml_allocr_new_measure(tensor_alignment); // build worst-case graph - int n_tokens = std::min((int)hparams.n_ctx, params.n_batch); - int n_past = hparams.n_ctx - n_tokens; + uint32_t n_tokens = std::min((int)hparams.n_ctx, params.n_batch); llama_token token = llama_token_bos(ctx); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph - ggml_cgraph * gf = llama_build_graph(*ctx, &token, NULL, n_tokens, n_past); + llama_batch batch = { n_tokens, &token, nullptr, nullptr, nullptr }; + ggml_cgraph * gf = llama_build_graph(*ctx, batch); + #ifdef GGML_USE_METAL if (params.n_gpu_layers > 0) { ctx->ctx_metal = ggml_metal_init(1); @@ -6279,7 +6329,11 @@ int llama_model_apply_lora_from_file(const struct llama_model * model, const cha } int llama_get_kv_cache_token_count(const struct llama_context * ctx) { - return ctx->kv_self.n; + return ctx->kv_self.head; +} + +void llama_kv_clear(struct llama_context * ctx, int32_t p0, int32_t p1) { + llama_kv_cache_clear(ctx->kv_self, p0, p1); } #define LLAMA_MAX_RNG_STATE (64*1024) @@ -6376,6 +6430,16 @@ struct llama_data_file_context : llama_data_context { * */ void llama_copy_state_data_internal(struct llama_context * ctx, llama_data_context * data_ctx) { + // TODO: does not support multi-sequence states + { + const auto & kv_self = ctx->kv_self; + for (uint32_t i = 0; i < kv_self.head; ++i) { + GGML_ASSERT(kv_self.cells[i].pos == (int32_t) i); + GGML_ASSERT(kv_self.cells[i].seq_id.size() == 1); + GGML_ASSERT(kv_self.cells[i].has_seq_id(0)); + } + } + // copy rng { std::stringstream rng_ss; @@ -6431,7 +6495,7 @@ void llama_copy_state_data_internal(struct llama_context * ctx, llama_data_conte const int n_ctx = hparams.n_ctx; const size_t kv_size = kv_self.buf.size; - const int kv_ntok = llama_get_kv_cache_token_count(ctx); + const int kv_ntok = kv_self.head; data_ctx->write(&kv_size, sizeof(kv_size)); data_ctx->write(&kv_ntok, sizeof(kv_ntok)); @@ -6575,7 +6639,7 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { ggml_free(cpy_ctx); } - ctx->kv_self.n = kv_ntok; + ctx->kv_self.head = kv_ntok; } const size_t nread = inp - src; @@ -6671,10 +6735,24 @@ bool llama_save_session_file(struct llama_context * ctx, const char * path_sessi int llama_eval( struct llama_context * ctx, const llama_token * tokens, - int n_tokens, + uint32_t n_tokens, int n_past, int n_threads) { - if (!llama_eval_internal(*ctx, tokens, nullptr, n_tokens, n_past, n_threads, nullptr)) { + std::vector pos(n_tokens); + for (uint32_t i = 0; i < n_tokens; i++) { + pos[i] = n_past + i; + } + + std::vector seq_id(n_tokens); + for (uint32_t i = 0; i < n_tokens; i++) { + seq_id[i] = 0; + } + + llama_batch batch = { n_tokens, tokens, nullptr, pos.data(), seq_id.data(), }; + + llama_kv_cache_clear(ctx->kv_self, n_past, -1); + + if (!llama_eval_internal(*ctx, batch, n_threads)) { LLAMA_LOG_ERROR("%s: failed to eval\n", __func__); return 1; } @@ -6692,10 +6770,22 @@ int llama_eval( int llama_eval_embd( struct llama_context * ctx, const float * embd, - int n_tokens, + uint32_t n_tokens, int n_past, int n_threads) { - if (!llama_eval_internal(*ctx, nullptr, embd, n_tokens, n_past, n_threads, nullptr)) { + std::vector pos(n_tokens); + for (uint32_t i = 0; i < n_tokens; i++) { + pos[i] = n_past + i; + } + + std::vector seq_id(n_tokens); + for (uint32_t i = 0; i < n_tokens; i++) { + seq_id[i] = 0; + } + + llama_batch batch = { n_tokens, nullptr, embd, pos.data(), seq_id.data(), }; + + if (!llama_eval_internal(*ctx, batch, n_threads)) { LLAMA_LOG_ERROR("%s: failed to eval\n", __func__); return 1; } @@ -6710,20 +6800,6 @@ int llama_eval_embd( return 0; } -int llama_eval_export(struct llama_context * ctx, const char * fname) { - const int n_batch = 1; - const int n_ctx = 512 - n_batch; - - const std::vector tmp(n_batch, llama_token_bos(ctx)); - - if (!llama_eval_internal(*ctx, tmp.data(), nullptr, tmp.size(), n_ctx, 1, fname)) { - LLAMA_LOG_ERROR("%s: failed to eval\n", __func__); - return 1; - } - - return 0; -} - float * llama_get_logits(struct llama_context * ctx) { return ctx->logits.data(); } diff --git a/llama.h b/llama.h index 37975bebe..043b62e10 100644 --- a/llama.h +++ b/llama.h @@ -60,7 +60,20 @@ extern "C" { struct llama_model; struct llama_context; - typedef int llama_token; + typedef int32_t llama_pos; + typedef int32_t llama_token; + typedef int32_t llama_seq_id; + + // data used for batch inference + typedef struct llama_batch { + uint32_t n_tokens; + + // TODO: not sure about these consts - might just get in the way all the time with no benefit + const llama_token * token; + const float * embd; + const llama_pos * pos; + const llama_seq_id * seq_id; + } llama_seq; enum llama_log_level { LLAMA_LOG_LEVEL_ERROR = 2, @@ -289,8 +302,15 @@ extern "C" { const char * path_base_model, int n_threads); + // + // KV cache API + // + // Returns the number of tokens in the KV cache - LLAMA_API int llama_get_kv_cache_token_count(const struct llama_context * ctx); + LLAMA_API DEPRECATED(int llama_get_kv_cache_token_count(const struct llama_context * ctx), + "avoid using this, it will be removed in the future"); + + LLAMA_API void llama_kv_clear(struct llama_context * ctx, int32_t p0, int32_t p1); // Sets the current rng seed. LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed); @@ -319,7 +339,7 @@ extern "C" { LLAMA_API int llama_eval( struct llama_context * ctx, const llama_token * tokens, - int n_tokens, + uint32_t n_tokens, int n_past, int n_threads); @@ -327,16 +347,10 @@ extern "C" { LLAMA_API int llama_eval_embd( struct llama_context * ctx, const float * embd, - int n_tokens, + uint32_t n_tokens, int n_past, int n_threads); - // Export a static computation graph for context of 511 and batch size of 1 - // NOTE: since this functionality is mostly for debugging and demonstration purposes, we hardcode these - // parameters here to keep things simple - // IMPORTANT: do not use for anything else other than debugging and testing! - LLAMA_API int llama_eval_export(struct llama_context * ctx, const char * fname); - // Token logits obtained from the last call to llama_eval() // The logits for the last token are stored in the last row // Can be mutated in order to change the probabilities of the next token diff --git a/tests/test-tokenizer-1-llama.cpp b/tests/test-tokenizer-1-llama.cpp index ab3d822f2..24c420372 100644 --- a/tests/test-tokenizer-1-llama.cpp +++ b/tests/test-tokenizer-1-llama.cpp @@ -87,10 +87,11 @@ int main(int argc, char **argv) { std::vector tokens = llama_tokenize(ctx, str, false); std::string check = llama_detokenize_spm(ctx, tokens); if (check != str) { - fprintf(stderr, "%s : error: token %d detokenizes to >%s<(%llu) but tokenization of this detokenizes to >%s<(%llu)\n", - __func__, i, str.c_str(), str.length(), check.c_str(), check.length()); - if(i != 3) + fprintf(stderr, "%s : error: token %d detokenizes to >%s<(%d) but tokenization of this detokenizes to >%s<(%d)\n", + __func__, i, str.c_str(), (int) str.length(), check.c_str(), (int) check.length()); + if (i != 3) { return 2; + } } } @@ -100,10 +101,11 @@ int main(int argc, char **argv) { std::vector tokens = llama_tokenize(ctx, str, false); std::string check = llama_detokenize_spm(ctx, tokens); if (str != check) { - fprintf(stderr, "%s : error: codepoint %d detokenizes to >%s<(%llu) instead of >%s<(%llu)\n", - __func__, cp, check.c_str(), check.length(), str.c_str(), str.length()); - if(cp != 0 && cp != 9601) + fprintf(stderr, "%s : error: codepoint %d detokenizes to >%s<(%d) instead of >%s<(%d)\n", + __func__, cp, check.c_str(), (int) check.length(), str.c_str(), (int) str.length()); + if (cp != 0 && cp != 9601) { return 3; + } } } } @@ -112,8 +114,8 @@ int main(int argc, char **argv) { std::vector tokens = llama_tokenize(ctx, str, false); std::string check = llama_detokenize_spm(ctx, tokens); if (str != check) { - fprintf(stderr, "%s : error: codepoint %d detokenizes to >%s<(%llu) instead of >%s<(%llu)\n", - __func__, cp, check.c_str(), check.length(), str.c_str(), str.length()); + fprintf(stderr, "%s : error: codepoint %d detokenizes to >%s<(%d) instead of >%s<(%d)\n", + __func__, cp, check.c_str(), (int) check.length(), str.c_str(), (int) str.length()); return 4; } }