server : remove beam-search functionality

This commit is contained in:
Georgi Gerganov 2023-10-19 14:10:37 +03:00
parent 654e0a1fe0
commit a8c981b734
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -1296,12 +1296,6 @@ struct llama_server_context
} }
}; };
struct server_beam_search_callback_data
{
llama_context * ctx;
llama_client_slot * slot;
};
static void server_print_usage(const char *argv0, const gpt_params &params, static void server_print_usage(const char *argv0, const gpt_params &params,
const server_params &sparams) const server_params &sparams)
{ {
@ -2006,46 +2000,6 @@ static void log_server_request(const httplib::Request &req, const httplib::Respo
}); });
} }
static bool is_at_eob(const server_beam_search_callback_data & server_context, const llama_token *tokens, const size_t n_tokens) {
return n_tokens && tokens[n_tokens - 1] == llama_token_eos(server_context.ctx);
}
// Function matching type llama_beam_search_callback_fn_t.
// Custom callback example is called each time the beams lengths increase:
// * Show progress by printing ',' following by number of convergent beam tokens if any.
// * When all beams converge to a common prefix, they are made available in beams_state.beams[0].
// This is also called when the stop condition is met.
// Collect tokens into std::vector<llama_token> response which is pointed to by callback_data.
// NO TESTED after PR #3589
static void beam_search_callback(void *callback_data, llama_beams_state beams_state) {
auto & llama = *static_cast<server_beam_search_callback_data *>(callback_data);
// Mark beams as EOS as needed.
for (size_t i = 0 ; i < beams_state.n_beams ; ++i) {
llama_beam_view& beam_view = beams_state.beam_views[i];
if (!beam_view.eob && is_at_eob(llama, beam_view.tokens, beam_view.n_tokens)) {
beam_view.eob = true;
}
}
printf(","); // Show progress
if (const size_t n = beams_state.common_prefix_length) {
llama.slot->generated_token_probs.resize(llama.slot->generated_token_probs.size() + n);
assert(0u < beams_state.n_beams);
const llama_token * tokens = beams_state.beam_views[0].tokens;
const auto map = [](llama_token tok) { return completion_token_output{{},tok,""}; };
std::transform(tokens, tokens + n, llama.slot->generated_token_probs.end() - n, map);
printf("%zu", n);
}
fflush(stdout);
#if 0 // DEBUG: print current beams for this iteration
std::cout << "\n\nCurrent beams:\n";
for (size_t i=0 ; i < beams_state.n_beams ; ++i) {
std::cout << "beams["<<i<<"]: " << ostream_beam_view{state.ctx,beams_state.beam_views[i]} << std::endl;
}
#endif
}
struct token_translator struct token_translator
{ {
llama_context * ctx; llama_context * ctx;
@ -2176,20 +2130,8 @@ int main(int argc, char **argv)
} }
if (!slot->params.stream) { if (!slot->params.stream) {
std::string completion_text = ""; std::string completion_text;
if (llama.params.n_beams)
{
// Fill llama.generated_token_probs vector with final beam.
server_beam_search_callback_data data_beam;
data_beam.slot = slot;
data_beam.ctx = llama.ctx;
llama_beam_search(llama.ctx, beam_search_callback, &data_beam, llama.params.n_beams,
slot->n_past, llama.params.n_predict);
// Translate llama.generated_token_probs to llama.generated_text.
append_to_generated_text_from_generated_token_probs(llama, slot);
}
else
{
while (slot->is_processing()) while (slot->is_processing())
{ {
if (slot->has_new_token()) if (slot->has_new_token())
@ -2201,7 +2143,6 @@ int main(int argc, char **argv)
std::this_thread::sleep_for(std::chrono::microseconds(5)); std::this_thread::sleep_for(std::chrono::microseconds(5));
} }
} }
}
auto probs = slot->generated_token_probs; auto probs = slot->generated_token_probs;
if (slot->sparams.n_probs > 0 && slot->stopped_word) if (slot->sparams.n_probs > 0 && slot->stopped_word)