diff --git a/common/common.cpp b/common/common.cpp index 489462b5a..10ef11829 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1704,6 +1704,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l } fprintf(stream, "lora_base: %s\n", params.lora_base.c_str()); fprintf(stream, "main_gpu: %d # default: 0\n", params.main_gpu); + fprintf(stream, "min_keep: %d # default: 0 (disabled)\n", sparams.min_keep); fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", sparams.mirostat); fprintf(stream, "mirostat_ent: %f # default: 5.0\n", sparams.mirostat_tau); fprintf(stream, "mirostat_lr: %f # default: 0.1\n", sparams.mirostat_eta); diff --git a/common/sampling.cpp b/common/sampling.cpp index 611c327bb..de4331a11 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -248,7 +248,10 @@ static llama_token llama_sampling_sample_impl( llama_sample_temp(ctx_main, &cur_p, temp); id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu); } else { - sampler_queue(ctx_main, params, cur_p, 1); + // temperature sampling + size_t min_keep = std::max(1, params.min_keep); + + sampler_queue(ctx_main, params, cur_p, min_keep); id = llama_sample_token(ctx_main, &cur_p); diff --git a/common/sampling.h b/common/sampling.h index e1279a894..95d875394 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -22,6 +22,7 @@ enum class llama_sampler_type : char { typedef struct llama_sampling_params { int32_t n_prev = 64; // number of previous tokens to remember int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. + int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens int32_t top_k = 40; // <= 0 to use vocab size float top_p = 0.95f; // 1.0 = disabled float min_p = 0.05f; // 0.0 = disabled diff --git a/examples/server/README.md b/examples/server/README.md index ac5133d24..809e2d37c 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -199,6 +199,8 @@ node index.js `n_probs`: If greater than 0, the response also contains the probabilities of top N tokens for each generated token (default: 0) + `min_keep`: If greater than 0, force samplers to return N possible tokens at minimum (default: 0) + `image_data`: An array of objects to hold base64-encoded image `data` and its `id`s to be reference in `prompt`. You can determine the place of the image in the prompt as in the following: `USER:[img-12]Describe the image in detail.\nASSISTANT:`. In this case, `[img-12]` will be replaced by the embeddings of the image with id `12` in the following `image_data` array: `{..., "image_data": [{"data": "", "id": 12}]}`. Use `image_data` only with multimodal models, e.g., LLaVA. `slot_id`: Assign the completion task to an specific slot. If is -1 the task will be assigned to a Idle slot (default: -1) diff --git a/examples/server/public/index.html b/examples/server/public/index.html index b059c75f2..84038ddce 100644 --- a/examples/server/public/index.html +++ b/examples/server/public/index.html @@ -234,6 +234,7 @@ mirostat_eta: 0.1, // learning rate grammar: '', n_probs: 0, // no completion_probabilities, + min_keep: 0, // min probs from each sampler, image_data: [], cache_prompt: true, api_key: '' @@ -791,6 +792,9 @@
${IntField({ label: "Show Probabilities", max: 10, min: 0, name: "n_probs", value: params.value.n_probs })}
+
+ ${IntField({ label: "Min Probabilities from each Sampler", max: 10, min: 0, name: "min_keep", value: params.value.min_keep })} +
diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 4f2e9c898..22c344dd4 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -548,6 +548,7 @@ struct llama_server_context slot->params.seed = json_value(data, "seed", default_params.seed); slot->sparams.grammar = json_value(data, "grammar", default_sparams.grammar); slot->sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs); + slot->sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep); if (slot->n_predict > 0 && slot->params.n_predict > slot->n_predict) { // Might be better to reject the request with a 400 ? @@ -1093,6 +1094,7 @@ struct llama_server_context {"stream", slot.params.stream}, {"logit_bias", slot.sparams.logit_bias}, {"n_probs", slot.sparams.n_probs}, + {"min_keep", slot.sparams.min_keep}, {"grammar", slot.sparams.grammar}, {"samplers", samplers_sequence} };