mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-13 12:10:18 +00:00
tts : text pre-processing
This commit is contained in:
parent
8f34d0dd8b
commit
f1b5b6b5a1
@ -7,12 +7,14 @@
|
|||||||
#define _USE_MATH_DEFINES // For M_PI on MSVC
|
#define _USE_MATH_DEFINES // For M_PI on MSVC
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cstdio>
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <string>
|
#include <cstdio>
|
||||||
#include <vector>
|
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
|
#include <map>
|
||||||
|
#include <regex>
|
||||||
|
#include <string>
|
||||||
#include <thread>
|
#include <thread>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
//
|
//
|
||||||
// Terminal utils
|
// Terminal utils
|
||||||
@ -267,6 +269,143 @@ static std::vector<double> embd_to_audio(
|
|||||||
return audio;
|
return audio;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static const std::map<int, std::string> ones = {
|
||||||
|
{0, "zero"}, {1, "one"}, {2, "two"}, {3, "three"}, {4, "four"},
|
||||||
|
{5, "five"}, {6, "six"}, {7, "seven"}, {8, "eight"}, {9, "nine"},
|
||||||
|
{10, "ten"}, {11, "eleven"}, {12, "twelve"}, {13, "thirteen"}, {14, "fourteen"},
|
||||||
|
{15, "fifteen"}, {16, "sixteen"}, {17, "seventeen"}, {18, "eighteen"}, {19, "nineteen"}
|
||||||
|
};
|
||||||
|
|
||||||
|
static const std::map<int, std::string> tens = {
|
||||||
|
{2, "twenty"}, {3, "thirty"}, {4, "forty"}, {5, "fifty"},
|
||||||
|
{6, "sixty"}, {7, "seventy"}, {8, "eighty"}, {9, "ninety"}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Convert a number less than 1000 to words
|
||||||
|
static std::string convert_less_than_thousand(int num) {
|
||||||
|
std::string result;
|
||||||
|
|
||||||
|
if (num >= 100) {
|
||||||
|
result += ones.at(num / 100) + " hundred ";
|
||||||
|
num %= 100;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (num >= 20) {
|
||||||
|
result += tens.at(num / 10);
|
||||||
|
if (num % 10 > 0) {
|
||||||
|
result += "-" + ones.at(num % 10);
|
||||||
|
}
|
||||||
|
} else if (num > 0) {
|
||||||
|
result += ones.at(num);
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::string number_to_words(const std::string & number_str) {
|
||||||
|
try {
|
||||||
|
size_t decimal_pos = number_str.find('.');
|
||||||
|
std::string integer_part = number_str.substr(0, decimal_pos);
|
||||||
|
|
||||||
|
int int_number = std::stoi(integer_part);
|
||||||
|
std::string result;
|
||||||
|
|
||||||
|
if (int_number == 0) {
|
||||||
|
result = "zero";
|
||||||
|
} else {
|
||||||
|
if (int_number >= 1000000000) {
|
||||||
|
int billions = int_number / 1000000000;
|
||||||
|
result += convert_less_than_thousand(billions) + " billion ";
|
||||||
|
int_number %= 1000000000;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (int_number >= 1000000) {
|
||||||
|
int millions = int_number / 1000000;
|
||||||
|
result += convert_less_than_thousand(millions) + " million ";
|
||||||
|
int_number %= 1000000;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (int_number >= 1000) {
|
||||||
|
int thousands = int_number / 1000;
|
||||||
|
result += convert_less_than_thousand(thousands) + " thousand ";
|
||||||
|
int_number %= 1000;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (int_number > 0) {
|
||||||
|
result += convert_less_than_thousand(int_number);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle decimal part
|
||||||
|
if (decimal_pos != std::string::npos) {
|
||||||
|
result += " point";
|
||||||
|
std::string decimal_part = number_str.substr(decimal_pos + 1);
|
||||||
|
for (char digit : decimal_part) {
|
||||||
|
result += " " + ones.at(digit - '0');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
} catch (const std::exception& e) {
|
||||||
|
// Skip if fails
|
||||||
|
return " ";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::string replace_numbers_with_words(const std::string & input_text) {
|
||||||
|
std::regex number_pattern(R"(\d+(\.\d+)?)");
|
||||||
|
std::string result;
|
||||||
|
auto it = std::sregex_iterator(input_text.begin(), input_text.end(), number_pattern);
|
||||||
|
auto end = std::sregex_iterator();
|
||||||
|
|
||||||
|
size_t last_pos = 0;
|
||||||
|
for (std::sregex_iterator i = it; i != end; ++i) {
|
||||||
|
const std::smatch& match = *i;
|
||||||
|
result.append(input_text, last_pos, match.position() - last_pos);
|
||||||
|
result.append(number_to_words(match.str()));
|
||||||
|
last_pos = match.position() + match.length();
|
||||||
|
}
|
||||||
|
result.append(input_text, last_pos);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Based on: https://github.com/edwko/OuteTTS/blob/a613e79c489d8256dd657ea9168d78de75895d82/outetts/version/v1/prompt_processor.py#L39
|
||||||
|
static std::string process_text(const std::string & text) {
|
||||||
|
|
||||||
|
// For now I skipped text romanization as I am unsure how to handle
|
||||||
|
// uroman and MeCab implementations in C++
|
||||||
|
// maybe something like https://github.com/anyascii/anyascii/ could work.
|
||||||
|
// currently only English would be supported in this function
|
||||||
|
|
||||||
|
std::string processed_text = replace_numbers_with_words(text);
|
||||||
|
|
||||||
|
std::transform(processed_text.begin(), processed_text.end(),
|
||||||
|
processed_text.begin(), ::tolower);
|
||||||
|
|
||||||
|
std::regex special_chars(R"([-_/,\.\\])");
|
||||||
|
processed_text = std::regex_replace(processed_text, special_chars, " ");
|
||||||
|
|
||||||
|
std::regex non_alpha(R"([^a-z\s])");
|
||||||
|
processed_text = std::regex_replace(processed_text, non_alpha, "");
|
||||||
|
|
||||||
|
std::regex multiple_spaces(R"(\s+)");
|
||||||
|
processed_text = std::regex_replace(processed_text, multiple_spaces, " ");
|
||||||
|
|
||||||
|
processed_text = std::regex_replace(processed_text, std::regex(R"(^\s+|\s+$)"), "");
|
||||||
|
|
||||||
|
/*
|
||||||
|
Replace spaces with the separator token same as in line 365
|
||||||
|
|
||||||
|
for (auto & c : prompt_user) {
|
||||||
|
if (c == ' ') {
|
||||||
|
prompt_clean += "<|text_sep|>";
|
||||||
|
*/
|
||||||
|
processed_text = std::regex_replace(processed_text, std::regex(R"(\s)"), "<|text_sep|>");
|
||||||
|
|
||||||
|
return processed_text;
|
||||||
|
}
|
||||||
|
|
||||||
static void prompt_add(llama_tokens & prompt, llama_token token) {
|
static void prompt_add(llama_tokens & prompt, llama_token token) {
|
||||||
prompt.push_back(token);
|
prompt.push_back(token);
|
||||||
}
|
}
|
||||||
@ -353,23 +492,11 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
prompt_add(prompt_inp, model_ttc, "<|text_start|>the<|text_sep|>overall<|text_sep|>package<|text_sep|>from<|text_sep|>just<|text_sep|>two<|text_sep|>people<|text_sep|>is<|text_sep|>pretty<|text_sep|>remarkable<|text_sep|>sure<|text_sep|>i<|text_sep|>have<|text_sep|>some<|text_sep|>critiques<|text_sep|>about<|text_sep|>some<|text_sep|>of<|text_sep|>the<|text_sep|>gameplay<|text_sep|>aspects<|text_sep|>but<|text_sep|>its<|text_sep|>still<|text_sep|>really<|text_sep|>enjoyable<|text_sep|>and<|text_sep|>it<|text_sep|>looks<|text_sep|>lovely<|text_sep|>", false, true);
|
prompt_add(prompt_inp, model_ttc, "<|text_start|>the<|text_sep|>overall<|text_sep|>package<|text_sep|>from<|text_sep|>just<|text_sep|>two<|text_sep|>people<|text_sep|>is<|text_sep|>pretty<|text_sep|>remarkable<|text_sep|>sure<|text_sep|>i<|text_sep|>have<|text_sep|>some<|text_sep|>critiques<|text_sep|>about<|text_sep|>some<|text_sep|>of<|text_sep|>the<|text_sep|>gameplay<|text_sep|>aspects<|text_sep|>but<|text_sep|>its<|text_sep|>still<|text_sep|>really<|text_sep|>enjoyable<|text_sep|>and<|text_sep|>it<|text_sep|>looks<|text_sep|>lovely<|text_sep|>", false, true);
|
||||||
|
|
||||||
// TODO: not sure if this is correct
|
// convert the input text into the necessary format expected by OuteTTS
|
||||||
{
|
{
|
||||||
std::string prompt_clean;
|
std::string prompt_clean = process_text(params.prompt);
|
||||||
std::string prompt_user = params.prompt;
|
|
||||||
|
|
||||||
for (auto & c : prompt_user) {
|
LOG_INF("%s: prompt: '%s'\n", __func__, prompt_clean.c_str());
|
||||||
if (c == ' ') {
|
|
||||||
prompt_clean += "<|text_sep|>";
|
|
||||||
} else {
|
|
||||||
if (isalpha(c) || isdigit(c)) {
|
|
||||||
c = tolower(c);
|
|
||||||
} else {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
prompt_clean += c;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
prompt_add(prompt_inp, model_ttc, prompt_clean, false, true);
|
prompt_add(prompt_inp, model_ttc, prompt_clean, false, true);
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user