simple : fixes

This commit is contained in:
slaren 2023-09-26 23:19:36 +02:00
parent 8845160058
commit 72e7ef4e53

View File

@ -1,6 +1,7 @@
#include "common.h" #include "common.h"
#include "llama.h" #include "llama.h"
#include <algorithm>
#include <cmath> #include <cmath>
#include <cstdio> #include <cstdio>
#include <string> #include <string>
@ -42,7 +43,9 @@ int main(int argc, char ** argv) {
llama_context_params ctx_params = llama_context_default_params(); llama_context_params ctx_params = llama_context_default_params();
ctx_params.seed = 1234; ctx_params.seed = 1234;
ctx_params.n_ctx = 2048; ctx_params.n_ctx = n_len*n_parallel; // FIXME: use n_kv_req instead (tokenize with model after #3301)
ctx_params.n_batch = std::max(n_len, n_parallel);
// ctx_params.n_gpu_layers = 99; // offload all layers to the GPU
llama_model * model = llama_load_model_from_file(params.model.c_str(), ctx_params); llama_model * model = llama_load_model_from_file(params.model.c_str(), ctx_params);
@ -66,11 +69,11 @@ int main(int argc, char ** argv) {
const int n_ctx = llama_n_ctx(ctx); const int n_ctx = llama_n_ctx(ctx);
const int n_kv_req = tokens_list.size() + (n_len - tokens_list.size())*n_parallel; const int n_kv_req = tokens_list.size() + (n_len - tokens_list.size())*n_parallel;
LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_parallel = %d, n_kv_req = %d\n", __func__, n_len, n_ctx, n_parallel, n_kv_req); LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_batch = %d, n_parallel = %d, n_kv_req = %d\n", __func__, n_len, n_ctx, ctx_params.n_batch, n_parallel, n_kv_req);
// make sure the KV cache is big enough to hold all the prompt and generated tokens // make sure the KV cache is big enough to hold all the prompt and generated tokens
if (n_kv_req > n_ctx) { if (n_kv_req > n_ctx) {
LOG_TEE("%s: error: n_kv_req > n_ctx, the required KV cache size is not big enough\n", __func__); LOG_TEE("%s: error: n_kv_req (%d) > n_ctx, the required KV cache size is not big enough\n", __func__, n_kv_req);
LOG_TEE("%s: either reduce n_parallel or increase n_ctx\n", __func__); LOG_TEE("%s: either reduce n_parallel or increase n_ctx\n", __func__);
return 1; return 1;
} }
@ -88,7 +91,7 @@ int main(int argc, char ** argv) {
// create a llama_batch with size 512 // create a llama_batch with size 512
// we use this object to submit token data for decoding // we use this object to submit token data for decoding
llama_batch batch = llama_batch_init(512, 0); llama_batch batch = llama_batch_init(std::max(tokens_list.size(), (size_t)n_parallel), 0);
// evaluate the initial prompt // evaluate the initial prompt
batch.n_tokens = tokens_list.size(); batch.n_tokens = tokens_list.size();
@ -133,12 +136,6 @@ int main(int argc, char ** argv) {
const auto t_main_start = ggml_time_us(); const auto t_main_start = ggml_time_us();
while (n_cur <= n_len) { while (n_cur <= n_len) {
// evaluate the current batch with the transformer model
if (llama_decode(ctx, batch, params.n_threads)) {
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
return 1;
}
// prepare the next batch // prepare the next batch
batch.n_tokens = 0; batch.n_tokens = 0;
@ -150,7 +147,7 @@ int main(int argc, char ** argv) {
} }
auto n_vocab = llama_n_vocab(ctx); auto n_vocab = llama_n_vocab(ctx);
auto logits = llama_get_logits(ctx) + i_batch[i] * n_vocab; auto * logits = llama_get_logits(ctx) + i_batch[i] * n_vocab;
std::vector<llama_token_data> candidates; std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab); candidates.reserve(n_vocab);
@ -178,7 +175,7 @@ int main(int argc, char ** argv) {
i_batch[i] = -1; i_batch[i] = -1;
LOG_TEE("\n"); LOG_TEE("\n");
if (n_parallel > 1) { if (n_parallel > 1) {
LOG_TEE("%s: stream %d finished", __func__, i); LOG_TEE("%s: stream %d finished at n_cur = %d", __func__, i, n_cur);
} }
continue; continue;
@ -211,6 +208,12 @@ int main(int argc, char ** argv) {
} }
n_cur += 1; n_cur += 1;
// evaluate the current batch with the transformer model
if (llama_decode(ctx, batch, params.n_threads)) {
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
return 1;
}
} }
LOG_TEE("\n"); LOG_TEE("\n");