diff --git a/common/sampling.cpp b/common/sampling.cpp index 422292175..6f0af3c4a 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -66,6 +66,24 @@ void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * ds dst->prev = src->prev; } +llama_token llama_sampling_last(llama_sampling_context * ctx) { + return ctx->prev.back(); +} + +std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n) { + const int size = ctx_sampling->prev.size(); + + n = std::min(n, size); + + std::string result; + + for (int i = size - n; i < size; i++) { + result += llama_token_to_piece(ctx_main, ctx_sampling->prev[i]); + } + + return result; +} + std::string llama_sampling_print(const llama_sampling_params & params) { char result[1024]; @@ -193,11 +211,12 @@ llama_token llama_sampling_sample( void llama_sampling_accept( struct llama_sampling_context * ctx_sampling, struct llama_context * ctx_main, - llama_token id) { + llama_token id, + bool apply_grammar) { ctx_sampling->prev.erase(ctx_sampling->prev.begin()); ctx_sampling->prev.push_back(id); - if (ctx_sampling->grammar != NULL) { + if (ctx_sampling->grammar != NULL && apply_grammar) { llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id); } } diff --git a/common/sampling.h b/common/sampling.h index d8ee5126a..62ea6d4cf 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -70,6 +70,12 @@ void llama_sampling_reset(llama_sampling_context * ctx); // Copy the sampler context void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst); +// Get the last sampled token +llama_token llama_sampling_last(llama_sampling_context * ctx); + +// Get a string representation of the last sampled tokens +std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n); + // Print sampling parameters into a string std::string llama_sampling_print(const llama_sampling_params & params); @@ -99,4 +105,5 @@ llama_token llama_sampling_sample( void llama_sampling_accept( struct llama_sampling_context * ctx_sampling, struct llama_context * ctx_main, - llama_token id); + llama_token id, + bool apply_grammar); diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index b6bafb8ba..75b8df676 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -12,25 +12,26 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}) if (EMSCRIPTEN) else() - add_subdirectory(main) - add_subdirectory(quantize) - add_subdirectory(quantize-stats) - add_subdirectory(perplexity) - add_subdirectory(embedding) - add_subdirectory(save-load-state) - add_subdirectory(benchmark) add_subdirectory(baby-llama) - add_subdirectory(train-text-from-scratch) - add_subdirectory(finetune) - add_subdirectory(convert-llama2c-to-ggml) - add_subdirectory(simple) add_subdirectory(batched) add_subdirectory(batched-bench) - add_subdirectory(speculative) - add_subdirectory(parallel) - add_subdirectory(llava) - add_subdirectory(llama-bench) add_subdirectory(beam-search) + add_subdirectory(benchmark) + add_subdirectory(convert-llama2c-to-ggml) + add_subdirectory(embedding) + add_subdirectory(finetune) + add_subdirectory(infill) + add_subdirectory(llama-bench) + add_subdirectory(llava) + add_subdirectory(main) + add_subdirectory(parallel) + add_subdirectory(perplexity) + add_subdirectory(quantize) + add_subdirectory(quantize-stats) + add_subdirectory(save-load-state) + add_subdirectory(simple) + add_subdirectory(speculative) + add_subdirectory(train-text-from-scratch) if (LLAMA_METAL) add_subdirectory(metal) endif() diff --git a/examples/infill/CMakeLists.txt b/examples/infill/CMakeLists.txt index 046f9b1e7..57d01cb0b 100644 --- a/examples/infill/CMakeLists.txt +++ b/examples/infill/CMakeLists.txt @@ -4,5 +4,5 @@ install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) target_compile_features(${TARGET} PRIVATE cxx_std_11) if(TARGET BUILD_INFO) - add_dependencies(${TARGET} BUILD_INFO) + add_dependencies(${TARGET} BUILD_INFO) endif() diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index 8f520e38e..6331335e3 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -523,7 +523,7 @@ int main(int argc, char ** argv) { const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance); - llama_sampling_accept(ctx_sampling, ctx, id); + llama_sampling_accept(ctx_sampling, ctx, id, true); LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str()); @@ -541,8 +541,11 @@ int main(int argc, char ** argv) { LOG("embd_inp.size(): %d, n_consumed: %d\n", (int) embd_inp.size(), n_consumed); while ((int) embd_inp.size() > n_consumed) { embd.push_back(embd_inp[n_consumed]); - ctx_sampling->prev.erase(ctx_sampling->prev.begin()); - ctx_sampling->prev.push_back(embd_inp[n_consumed]); + + // push the prompt in the sampling context in order to apply repetition penalties later + // for the prompt, we don't apply grammar rules + llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], false); + ++n_consumed; if ((int) embd.size() >= params.n_batch) { break; @@ -574,7 +577,7 @@ int main(int argc, char ** argv) { if ((int) embd_inp.size() <= n_consumed) { // deal with eot token in infill mode - if ((ctx_sampling->prev.back() == llama_token_eot(ctx) || is_interacting) && params.interactive){ + if ((llama_sampling_last(ctx_sampling) == llama_token_eot(ctx) || is_interacting) && params.interactive){ if(is_interacting && !params.interactive_first) { // print an eot token printf("%s", llama_token_to_piece(ctx, llama_token_eot(ctx)).c_str()); @@ -591,7 +594,7 @@ int main(int argc, char ** argv) { buffer += line; } while (another_line); // check if we got an empty line, if so we use the old input - if(!buffer.empty() && !(buffer.length() == 1 && buffer[0] == '\n')) { + if (!buffer.empty() && !(buffer.length() == 1 && buffer[0] == '\n')) { params.input_prefix = buffer; } buffer.clear(); @@ -601,7 +604,7 @@ int main(int argc, char ** argv) { buffer += line; } while (another_line); // check if we got an empty line - if(!buffer.empty() && !(buffer.length() == 1 && buffer[0] == '\n')) { + if (!buffer.empty() && !(buffer.length() == 1 && buffer[0] == '\n')) { params.input_suffix = buffer; } buffer.clear(); @@ -614,7 +617,7 @@ int main(int argc, char ** argv) { process_escapes(params.input_suffix); } suff_rm_leading_spc = params.escape; - if (suff_rm_leading_spc && params.input_suffix.find_first_of(" ") == 0 && params.input_suffix.size() > 1) { + if (suff_rm_leading_spc && params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) { params.input_suffix.erase(0, 1); suff_rm_leading_spc = false; } @@ -641,7 +644,7 @@ int main(int argc, char ** argv) { is_interacting = false; } // deal with end of text token in interactive mode - else if (ctx_sampling->prev.back() == llama_token_eos(ctx)) { + else if (llama_sampling_last(ctx_sampling) == llama_token_eos(ctx)) { LOG("found EOS token\n"); if (params.interactive) { diff --git a/examples/main/main.cpp b/examples/main/main.cpp index d36e2a43b..db5309afe 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -611,7 +611,7 @@ int main(int argc, char ** argv) { const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance); - llama_sampling_accept(ctx_sampling, ctx, id); + llama_sampling_accept(ctx_sampling, ctx, id, true); LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str()); @@ -630,12 +630,9 @@ int main(int argc, char ** argv) { while ((int) embd_inp.size() > n_consumed) { embd.push_back(embd_inp[n_consumed]); - // GG: I'm not sure it's a good idea to push the prompt tokens into the sampling context - // Most likely will remove this in the future to avoid exposing "prev" - // Same thing is done in "server". If we stop pushing the prompt tokens, then the repetition - // penalty will be applied only based on the tokens generated by the model. - ctx_sampling->prev.erase(ctx_sampling->prev.begin()); - ctx_sampling->prev.push_back(embd_inp[n_consumed]); + // push the prompt in the sampling context in order to apply repetition penalties later + // for the prompt, we don't apply grammar rules + llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], false); ++n_consumed; if ((int) embd.size() >= params.n_batch) { @@ -666,12 +663,10 @@ int main(int argc, char ** argv) { // if not currently processing queued inputs; if ((int) embd_inp.size() <= n_consumed) { - // check for reverse prompt + // check for reverse prompt in the last n_prev tokens if (!params.antiprompt.empty()) { - std::string last_output; - for (auto id : ctx_sampling->prev) { - last_output += llama_token_to_piece(ctx, id); - } + const int n_prev = 32; + const std::string last_output = llama_sampling_prev_str(ctx_sampling, ctx, n_prev); is_antiprompt = false; // Check if each of the reverse prompts appears at the end of the output. @@ -698,7 +693,7 @@ int main(int argc, char ** argv) { } // deal with end of text token in interactive mode - if (ctx_sampling->prev.back() == llama_token_eos(ctx)) { + if (llama_sampling_last(ctx_sampling) == llama_token_eos(ctx)) { LOG("found EOS token\n"); if (params.interactive) { diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 3af36ed58..eb64adef8 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -330,7 +330,7 @@ int main(int argc, char ** argv) { const llama_token id = llama_sampling_sample(client.ctx_sampling, ctx, NULL, client.i_batch - i); - llama_sampling_accept(client.ctx_sampling, ctx, id); + llama_sampling_accept(client.ctx_sampling, ctx, id, true); if (client.n_decoded == 1) { // start measuring generation time after the first token to make sure all concurrent clients diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 50fdb2d3a..b5ad3cc99 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -195,10 +195,12 @@ struct llama_server_context json prompt; std::vector embd; + gpt_params params; + llama_model *model = nullptr; llama_context *ctx = nullptr; - gpt_params params; llama_sampling_context *ctx_sampling = nullptr; + int n_ctx; bool truncated = false; @@ -246,7 +248,10 @@ struct llama_server_context multibyte_pending = 0; n_remain = 0; n_past = 0; + params.sparams.n_prev = n_ctx; + } + void initSampling() { if (ctx_sampling != nullptr) { llama_sampling_free(ctx_sampling); } @@ -311,16 +316,32 @@ struct llama_server_context return prompt_tokens; } - bool loadGrammar() - { - ctx_sampling = llama_sampling_init(params.sparams); - return true; + void truncatePrompt(std::vector &prompt_tokens) { + const int n_left = n_ctx - params.n_keep; + const int n_block_size = n_left / 2; + const int erased_blocks = (prompt_tokens.size() - params.n_keep - n_block_size) / n_block_size; + + // Keep n_keep tokens at start of prompt (at most n_ctx - 4) + std::vector new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep); + + new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_block_size, prompt_tokens.end()); + + LOG_VERBOSE("input truncated", { + {"n_ctx", n_ctx}, + {"n_keep", params.n_keep}, + {"n_left", n_left}, + {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())}, + {"num_prompt_tokens", new_tokens.size()} + }); + + truncated = true; + prompt_tokens = new_tokens; } void loadInfill() { bool suff_rm_leading_spc = true; - if (params.input_suffix.find_first_of(" ") == 0 && params.input_suffix.size() > 1) { + if (params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) { params.input_suffix.erase(0, 1); suff_rm_leading_spc = false; } @@ -336,6 +357,7 @@ struct llama_server_context prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(ctx)); prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end()); prefix_tokens.push_back(llama_token_middle(ctx)); + auto prompt_tokens = prefix_tokens; num_prompt_tokens = prompt_tokens.size(); @@ -347,31 +369,18 @@ struct llama_server_context params.n_keep = std::min(params.n_ctx - 4, params.n_keep); // if input prompt is too big, truncate like normal - if (num_prompt_tokens >= (size_t)params.n_ctx) + if (num_prompt_tokens >= (size_t) n_ctx) { - printf("Input prompt is too big, truncating. Can only take %d tokens but got %zu\n", params.n_ctx, num_prompt_tokens); - // todo we probably want to cut from both sides - const int n_left = (params.n_ctx - params.n_keep) / 2; - std::vector new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep); - const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left; - new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end()); - std::copy(prompt_tokens.end() - params.n_ctx, prompt_tokens.end(), ctx_sampling->prev.begin()); + truncatePrompt(prompt_tokens); + num_prompt_tokens = prompt_tokens.size(); - LOG_VERBOSE("input truncated", { - {"n_ctx", params.n_ctx}, - {"n_keep", params.n_keep}, - {"n_left", n_left}, - {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())}, - }); - - truncated = true; - prompt_tokens = new_tokens; + GGML_ASSERT(num_prompt_tokens < (size_t)n_ctx); } - else + + // push the prompt into the sampling context (do not apply grammar) + for (auto & token : prompt_tokens) { - const size_t ps = num_prompt_tokens; - std::fill(ctx_sampling->prev.begin(), ctx_sampling->prev.end() - ps, 0); - std::copy(prompt_tokens.begin(), prompt_tokens.end(), ctx_sampling->prev.end() - ps); + llama_sampling_accept(ctx_sampling, ctx, token, false); } // compare the evaluated prompt with the new prompt @@ -409,29 +418,18 @@ struct llama_server_context params.n_keep = std::min(n_ctx - 4, params.n_keep); // if input prompt is too big, truncate like normal - if (num_prompt_tokens >= (size_t)n_ctx) + if (num_prompt_tokens >= (size_t) n_ctx) { - const int n_left = (n_ctx - params.n_keep) / 2; - std::vector new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep); - const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left; - new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end()); - std::copy(prompt_tokens.end() - n_ctx, prompt_tokens.end(), ctx_sampling->prev.begin()); + truncatePrompt(prompt_tokens); + num_prompt_tokens = prompt_tokens.size(); - LOG_VERBOSE("input truncated", { - {"n_ctx", n_ctx}, - {"n_keep", params.n_keep}, - {"n_left", n_left}, - {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())}, - }); - - truncated = true; - prompt_tokens = new_tokens; + GGML_ASSERT(num_prompt_tokens < (size_t)n_ctx); } - else + + // push the prompt into the sampling context (do not apply grammar) + for (auto & token : prompt_tokens) { - const size_t ps = num_prompt_tokens; - std::fill(ctx_sampling->prev.begin(), ctx_sampling->prev.end() - ps, 0); - std::copy(prompt_tokens.begin(), prompt_tokens.end(), ctx_sampling->prev.end() - ps); + llama_sampling_accept(ctx_sampling, ctx, token, false); } // compare the evaluated prompt with the new prompt @@ -542,7 +540,7 @@ struct llama_server_context result.probs.push_back({cur_p.data[i].id, cur_p.data[i].p}); } - llama_sampling_accept(ctx_sampling, ctx, result.tok); + llama_sampling_accept(ctx_sampling, ctx, result.tok, true); if (tg) { num_tokens_predicted++; @@ -1206,8 +1204,6 @@ static void parse_options_completion(const json &body, llama_server_context &lla } } - llama.ctx_sampling = llama_sampling_init(llama.params.sparams); - LOG_VERBOSE("completion parameters parsed", format_generation_settings(llama)); } @@ -1376,15 +1372,9 @@ int main(int argc, char **argv) llama.rewind(); llama_reset_timings(llama.ctx); - parse_options_completion(json::parse(req.body), llama); - if (!llama.loadGrammar()) - { - res.status = 400; - return; - } - + llama.initSampling(); llama.loadPrompt(); llama.beginCompletion(); @@ -1539,14 +1529,9 @@ int main(int argc, char **argv) llama.rewind(); llama_reset_timings(llama.ctx); - parse_options_infill(json::parse(req.body), llama); - if (!llama.loadGrammar()) - { - res.status = 400; - return; - } + llama.initSampling(); llama.loadInfill(); llama.beginCompletion(); const auto chunked_content_provider = [&](size_t, DataSink & sink) { @@ -1696,7 +1681,9 @@ int main(int argc, char **argv) const json body = json::parse(req.body); llama.rewind(); + llama_reset_timings(llama.ctx); + if (body.count("content") != 0) { llama.prompt = body["content"]; @@ -1706,6 +1693,8 @@ int main(int argc, char **argv) llama.prompt = ""; } llama.params.n_predict = 0; + + llama.initSampling(); llama.loadPrompt(); llama.beginCompletion(); llama.doCompletion(); diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 92fb9a43b..894321ce9 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -154,7 +154,7 @@ int main(int argc, char ** argv) { // sample from the target model llama_token id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]); - llama_sampling_accept(ctx_sampling, ctx_tgt, id); + llama_sampling_accept(ctx_sampling, ctx_tgt, id, true); //LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str()); @@ -328,7 +328,7 @@ int main(int argc, char ** argv) { const int s = sa[is]; - llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id); + llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id, true); drafts[s].tokens.push_back(id);