mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-13 04:00:16 +00:00
Merge c9d1eb3a06
into 9a483999a6
This commit is contained in:
commit
49382ff2fa
@ -2214,6 +2214,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||||||
params.vocoder.model = value;
|
params.vocoder.model = value;
|
||||||
}
|
}
|
||||||
).set_examples({LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_SERVER}));
|
).set_examples({LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_SERVER}));
|
||||||
|
add_opt(common_arg(
|
||||||
|
{"--tts-use-guide-tokens"},
|
||||||
|
"Use guide tokens to improve TTS word recall",
|
||||||
|
[](common_params & params) {
|
||||||
|
params.vocoder.use_guide_tokens = true;
|
||||||
|
}
|
||||||
|
).set_examples({LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_SERVER}));
|
||||||
|
|
||||||
// model-specific
|
// model-specific
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
|
@ -178,6 +178,8 @@ struct common_params_vocoder {
|
|||||||
|
|
||||||
std::string model = ""; // model path // NOLINT
|
std::string model = ""; // model path // NOLINT
|
||||||
std::string model_url = ""; // model url to download // NOLINT
|
std::string model_url = ""; // model url to download // NOLINT
|
||||||
|
|
||||||
|
bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT
|
||||||
};
|
};
|
||||||
|
|
||||||
struct common_params {
|
struct common_params {
|
||||||
|
@ -425,6 +425,29 @@ static void prompt_init(llama_tokens & prompt, const llama_vocab * vocab) {
|
|||||||
prompt_add(prompt, vocab, "<|im_start|>\n", true, true);
|
prompt_add(prompt, vocab, "<|im_start|>\n", true, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static std::vector<llama_token> prepare_guide_tokens(const llama_model * model, const std::string& str)
|
||||||
|
{
|
||||||
|
const std::string& delimiter = "<|text_sep|>";
|
||||||
|
|
||||||
|
std::vector<llama_token> result;
|
||||||
|
size_t start = 0;
|
||||||
|
size_t end = str.find(delimiter);
|
||||||
|
|
||||||
|
while (end != std::string::npos) {
|
||||||
|
std::string current_word = str.substr(start, end - start);
|
||||||
|
auto tmp = common_tokenize(model, current_word, false, true);
|
||||||
|
result.push_back(tmp[0]);
|
||||||
|
start = end + delimiter.length();
|
||||||
|
end = str.find(delimiter, start);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the last part
|
||||||
|
std::string current_word = str.substr(start);
|
||||||
|
auto tmp = common_tokenize(model, current_word, false, true);
|
||||||
|
result.push_back(tmp[0]);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
int main(int argc, char ** argv) {
|
int main(int argc, char ** argv) {
|
||||||
common_params params;
|
common_params params;
|
||||||
|
|
||||||
@ -494,6 +517,7 @@ int main(int argc, char ** argv) {
|
|||||||
const auto t_main_start = ggml_time_us();
|
const auto t_main_start = ggml_time_us();
|
||||||
|
|
||||||
std::vector<llama_token> codes;
|
std::vector<llama_token> codes;
|
||||||
|
std::vector<llama_token> guide_tokens;
|
||||||
|
|
||||||
// process prompt and generate voice codes
|
// process prompt and generate voice codes
|
||||||
{
|
{
|
||||||
@ -508,6 +532,10 @@ int main(int argc, char ** argv) {
|
|||||||
// convert the input text into the necessary format expected by OuteTTS
|
// convert the input text into the necessary format expected by OuteTTS
|
||||||
{
|
{
|
||||||
std::string prompt_clean = process_text(params.prompt);
|
std::string prompt_clean = process_text(params.prompt);
|
||||||
|
if(params.vocoder.use_guide_tokens)
|
||||||
|
{
|
||||||
|
guide_tokens = prepare_guide_tokens(model_ttc,prompt_clean);
|
||||||
|
}
|
||||||
|
|
||||||
LOG_INF("%s: prompt: '%s'\n", __func__, prompt_clean.c_str());
|
LOG_INF("%s: prompt: '%s'\n", __func__, prompt_clean.c_str());
|
||||||
|
|
||||||
@ -717,6 +745,8 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
|
|||||||
int n_past = batch.n_tokens;
|
int n_past = batch.n_tokens;
|
||||||
int n_decode = 0;
|
int n_decode = 0;
|
||||||
|
|
||||||
|
bool next_token_uses_guide_token = true;
|
||||||
|
|
||||||
while (n_decode <= n_predict) {
|
while (n_decode <= n_predict) {
|
||||||
// prepare the next batch
|
// prepare the next batch
|
||||||
common_batch_clear(batch);
|
common_batch_clear(batch);
|
||||||
@ -728,7 +758,18 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
const llama_token new_token_id = common_sampler_sample(smpl[i], ctx_ttc, i_batch[i]);
|
llama_token new_token_id = common_sampler_sample(smpl[i], ctx_ttc, i_batch[i]);
|
||||||
|
|
||||||
|
//guide tokens help prevent hallucinations by forcing the TTS to use the correct word
|
||||||
|
if(!guide_tokens.empty() && next_token_uses_guide_token && !llama_token_is_control(model_ttc, new_token_id) && !llama_token_is_eog(model_ttc, new_token_id))
|
||||||
|
{
|
||||||
|
llama_token guide_token = guide_tokens[0];
|
||||||
|
guide_tokens.erase(guide_tokens.begin());
|
||||||
|
new_token_id = guide_token; //ensure correct word fragment is used
|
||||||
|
}
|
||||||
|
|
||||||
|
//this is the token id that always precedes a new word
|
||||||
|
next_token_uses_guide_token = (new_token_id == 198);
|
||||||
|
|
||||||
common_sampler_accept(smpl[i], new_token_id, true);
|
common_sampler_accept(smpl[i], new_token_id, true);
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user