From 72e7ef4e53a88f3c4a582f2881b307d2964ff831 Mon Sep 17 00:00:00 2001 From: slaren Date: Tue, 26 Sep 2023 23:19:36 +0200 Subject: [PATCH] simple : fixes --- examples/simple/simple.cpp | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index cf48ce0c0..85f19d25f 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -1,6 +1,7 @@ #include "common.h" #include "llama.h" +#include #include #include #include @@ -42,7 +43,9 @@ int main(int argc, char ** argv) { llama_context_params ctx_params = llama_context_default_params(); ctx_params.seed = 1234; - ctx_params.n_ctx = 2048; + ctx_params.n_ctx = n_len*n_parallel; // FIXME: use n_kv_req instead (tokenize with model after #3301) + ctx_params.n_batch = std::max(n_len, n_parallel); + // ctx_params.n_gpu_layers = 99; // offload all layers to the GPU llama_model * model = llama_load_model_from_file(params.model.c_str(), ctx_params); @@ -66,11 +69,11 @@ int main(int argc, char ** argv) { const int n_ctx = llama_n_ctx(ctx); const int n_kv_req = tokens_list.size() + (n_len - tokens_list.size())*n_parallel; - LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_parallel = %d, n_kv_req = %d\n", __func__, n_len, n_ctx, n_parallel, n_kv_req); + LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_batch = %d, n_parallel = %d, n_kv_req = %d\n", __func__, n_len, n_ctx, ctx_params.n_batch, n_parallel, n_kv_req); // make sure the KV cache is big enough to hold all the prompt and generated tokens if (n_kv_req > n_ctx) { - LOG_TEE("%s: error: n_kv_req > n_ctx, the required KV cache size is not big enough\n", __func__); + LOG_TEE("%s: error: n_kv_req (%d) > n_ctx, the required KV cache size is not big enough\n", __func__, n_kv_req); LOG_TEE("%s: either reduce n_parallel or increase n_ctx\n", __func__); return 1; } @@ -88,7 +91,7 @@ int main(int argc, char ** argv) { // create a llama_batch with size 512 // we use this object to submit token data for decoding - llama_batch batch = llama_batch_init(512, 0); + llama_batch batch = llama_batch_init(std::max(tokens_list.size(), (size_t)n_parallel), 0); // evaluate the initial prompt batch.n_tokens = tokens_list.size(); @@ -133,12 +136,6 @@ int main(int argc, char ** argv) { const auto t_main_start = ggml_time_us(); while (n_cur <= n_len) { - // evaluate the current batch with the transformer model - if (llama_decode(ctx, batch, params.n_threads)) { - fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1); - return 1; - } - // prepare the next batch batch.n_tokens = 0; @@ -149,8 +146,8 @@ int main(int argc, char ** argv) { continue; } - auto n_vocab = llama_n_vocab(ctx); - auto logits = llama_get_logits(ctx) + i_batch[i] * n_vocab; + auto n_vocab = llama_n_vocab(ctx); + auto * logits = llama_get_logits(ctx) + i_batch[i] * n_vocab; std::vector candidates; candidates.reserve(n_vocab); @@ -178,7 +175,7 @@ int main(int argc, char ** argv) { i_batch[i] = -1; LOG_TEE("\n"); if (n_parallel > 1) { - LOG_TEE("%s: stream %d finished", __func__, i); + LOG_TEE("%s: stream %d finished at n_cur = %d", __func__, i, n_cur); } continue; @@ -211,6 +208,12 @@ int main(int argc, char ** argv) { } n_cur += 1; + + // evaluate the current batch with the transformer model + if (llama_decode(ctx, batch, params.n_threads)) { + fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1); + return 1; + } } LOG_TEE("\n");