#include "common.h" #include "llama.h" #include #include #include #include int main(int argc, char ** argv) { gpt_params params; if (argc == 1 || argv[1][0] == '-') { printf("usage: %s MODEL_PATH [PROMPT] [PARALLEL]\n" , argv[0]); return 1 ; } int n_parallel = 1; if (argc >= 2) { params.model = argv[1]; } if (argc >= 3) { params.prompt = argv[2]; } if (argc >= 4) { n_parallel = std::atoi(argv[3]); } if (params.prompt.empty()) { params.prompt = "Hello my name is"; } // total length of the sequences including the prompt const int n_len = 32; // init LLM llama_backend_init(params.numa); llama_context_params ctx_params = llama_context_default_params(); ctx_params.seed = 1234; ctx_params.n_ctx = 2048; llama_model * model = llama_load_model_from_file(params.model.c_str(), ctx_params); if (model == NULL) { fprintf(stderr , "%s: error: unable to load model\n" , __func__); return 1; } llama_context * ctx = llama_new_context_with_model(model, ctx_params); if (ctx == NULL) { fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); return 1; } // tokenize the prompt std::vector tokens_list; tokens_list = ::llama_tokenize(ctx, params.prompt, true); const int n_ctx = llama_n_ctx(ctx); const int n_kv_req = tokens_list.size() + (n_len - tokens_list.size())*n_parallel; LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_parallel = %d, n_kv_req = %d\n", __func__, n_len, n_ctx, n_parallel, n_kv_req); // make sure wi if (n_kv_req > n_ctx) { LOG_TEE("%s: error: n_kv_req > n_ctx, the required KV cache size is not big enough\n", __func__); LOG_TEE("%s: either reduce n_parallel or increase n_ctx\n", __func__); return 1; } fprintf(stderr, "\n"); for (auto id : tokens_list) { fprintf(stderr, "%s", llama_token_to_piece(ctx, id).c_str()); } fflush(stderr); // create a llama_batch with size 512 // we use this object to submit token data for decoding llama_batch batch = llama_batch_init(512, 0); // evaluate the initial prompt batch.n_tokens = tokens_list.size(); for (int32_t i = 0; i < batch.n_tokens; i++) { batch.token[i] = tokens_list[i]; batch.pos[i] = i; batch.seq_id[i] = 0; batch.logits[i] = false; } // llama_decode will output logits only for the last token of the prompt batch.logits[batch.n_tokens - 1] = true; if (llama_decode(ctx, batch, params.n_threads) != 0) { LOG_TEE("%s: llama_decode() failed\n", __func__); return 1; } // assign the system KV cache to all parallel sequences for (int32_t i = 1; i < n_parallel; ++i) { llama_kv_cache_seq_cp(ctx, 0, i, 0, batch.n_tokens); } if (n_parallel > 1) { LOG_TEE("\n\n%s: generating %d sequences ...\n", __func__, n_parallel); } // main loop // we will store the parallel decoded sequences in this vector std::vector streams(n_parallel); // remember the batch index of the last tokenn for each parallel sequence // we will use this to know which logits to sample from std::vector i_batch(n_parallel, batch.n_tokens - 1); int n_cur = batch.n_tokens; int n_decode = 0; const auto t_main_start = ggml_time_us(); while (n_cur <= n_len) { // evaluate the current batch with the transformer model if (llama_decode(ctx, batch, params.n_threads)) { fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1); return 1; } // prepare the next batch batch.n_tokens = 0; // sample the next token for each parallel sequence / stream for (int32_t i = 0; i < n_parallel; ++i) { if (i_batch[i] < 0) { // the stream has already finished continue; } auto n_vocab = llama_n_vocab(ctx); auto logits = llama_get_logits(ctx) + i_batch[i] * n_vocab; std::vector candidates; candidates.reserve(n_vocab); for (llama_token token_id = 0; token_id < n_vocab; token_id++) { candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f }); } llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; const int top_k = 40; const float top_p = 0.9f; const float temp = 0.4f; llama_sample_top_k(ctx, &candidates_p, top_k, 1); llama_sample_top_p(ctx, &candidates_p, top_p, 1); llama_sample_temp (ctx, &candidates_p, temp); const llama_token new_token_id = llama_sample_token(ctx, &candidates_p); //const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p); // is it an end of stream ? // mark this stream as finished if (new_token_id == llama_token_eos(ctx) || n_cur == n_len) { i_batch[i] = -1; LOG_TEE("\n"); if (n_parallel > 1) { LOG_TEE("%s: stream %d finished", __func__, i); } continue; } if (n_parallel == 1) { // print the new token : LOG_TEE("%s", llama_token_to_piece(ctx, new_token_id).c_str()); fflush(stdout); } streams[i] += llama_token_to_piece(ctx, new_token_id); // push this new token for next evaluation batch.token [batch.n_tokens] = new_token_id; batch.pos [batch.n_tokens] = n_cur; batch.seq_id[batch.n_tokens] = i; batch.logits[batch.n_tokens] = true; i_batch[i] = batch.n_tokens; batch.n_tokens += 1; n_decode += 1; } if (batch.n_tokens == 0) { // all streams are finished break; } n_cur += 1; } LOG_TEE("\n"); if (n_parallel > 1) { LOG_TEE("\n"); for (int32_t i = 0; i < n_parallel; ++i) { LOG_TEE("sequence %d:\n\n%s%s\n\n", i, params.prompt.c_str(), streams[i].c_str()); } } const auto t_main_end = ggml_time_us(); LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n", __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f)); llama_print_timings(ctx); fprintf(stderr, "\n"); llama_free(ctx); llama_free_model(model); llama_backend_free(); return 0; }