From 684780141a08200ec98eba3e982dbafd1d0b5000 Mon Sep 17 00:00:00 2001 From: Alexey Parfenov Date: Sun, 11 Feb 2024 13:38:14 +0000 Subject: [PATCH] server : allow to specify tokens as strings in logit_bias (#5003) * server: allow to specify tokens as strings in logit_bias * Apply suggestions from code review Co-authored-by: Georgi Gerganov --------- Co-authored-by: Georgi Gerganov --- examples/server/README.md | 2 +- examples/server/server.cpp | 32 +++++++++++++++++++++++++------- 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/examples/server/README.md b/examples/server/README.md index 1db7cdf21..0f7373ae8 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -185,7 +185,7 @@ node index.js `ignore_eos`: Ignore end of stream token and continue generating (default: false). - `logit_bias`: Modify the likelihood of a token appearing in the generated text completion. For example, use `"logit_bias": [[15043,1.0]]` to increase the likelihood of the token 'Hello', or `"logit_bias": [[15043,-1.0]]` to decrease its likelihood. Setting the value to false, `"logit_bias": [[15043,false]]` ensures that the token `Hello` is never produced (default: []). + `logit_bias`: Modify the likelihood of a token appearing in the generated text completion. For example, use `"logit_bias": [[15043,1.0]]` to increase the likelihood of the token 'Hello', or `"logit_bias": [[15043,-1.0]]` to decrease its likelihood. Setting the value to false, `"logit_bias": [[15043,false]]` ensures that the token `Hello` is never produced. The tokens can also be represented as strings, e.g. `[["Hello, World!",-0.5]]` will reduce the likelihood of all the individual tokens that represent the string `Hello, World!`, just like the `presence_penalty` does. (default: []). `n_probs`: If greater than 0, the response also contains the probabilities of top N tokens for each generated token (default: 0) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 4d212f1f0..1699eb76b 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -626,18 +626,36 @@ struct llama_server_context const int n_vocab = llama_n_vocab(model); for (const auto &el : *logit_bias) { - if (el.is_array() && el.size() == 2 && el[0].is_number_integer()) + if (el.is_array() && el.size() == 2) { - llama_token tok = el[0].get(); - if (tok >= 0 && tok < n_vocab) + float bias; + if (el[1].is_number()) { - if (el[1].is_number()) + bias = el[1].get(); + } + else if (el[1].is_boolean() && !el[1].get()) + { + bias = -INFINITY; + } + else + { + continue; + } + + if (el[0].is_number_integer()) + { + llama_token tok = el[0].get(); + if (tok >= 0 && tok < n_vocab) { - slot->sparams.logit_bias[tok] = el[1].get(); + slot->sparams.logit_bias[tok] = bias; } - else if (el[1].is_boolean() && !el[1].get()) + } + else if (el[0].is_string()) + { + auto toks = llama_tokenize(model, el[0].get(), false); + for (auto tok : toks) { - slot->sparams.logit_bias[tok] = -INFINITY; + slot->sparams.logit_bias[tok] = bias; } } }