diff --git a/examples/tts/tts.cpp b/examples/tts/tts.cpp index 1fdf756d5..8eaf5a262 100644 --- a/examples/tts/tts.cpp +++ b/examples/tts/tts.cpp @@ -58,7 +58,7 @@ static void print_usage(int, char ** argv) { LOG("\n"); } -void fill_hann_window(int length, bool periodic, float * output) { +static void fill_hann_window(int length, bool periodic, double * output) { int offset = -1; if (periodic) { offset = 0; @@ -69,31 +69,31 @@ void fill_hann_window(int length, bool periodic, float * output) { } // very poor-man fft -void twiddle(float * real, float * imag, int k, int N) { - float angle = 2 * M_PI * k / N; +static void twiddle(double * real, double * imag, int k, int N) { + double angle = 2 * M_PI * k / N; *real = cos(angle); *imag = sin(angle); } -void irfft(int n, float * inp_cplx, float * out_real) { +static void irfft(int n, const double * inp_cplx, double * out_real) { int N = n / 2 + 1; - std::vector real_input(N); - std::vector imag_input(N); + std::vector real_input(N); + std::vector imag_input(N); for (int i = 0; i < N; ++i) { real_input[i] = inp_cplx[2 * i]; imag_input[i] = inp_cplx[2 * i + 1]; } - std::vector real_output(n); - std::vector imag_output(n); + std::vector real_output(n); + std::vector imag_output(n); for (int k = 0; k < n; ++k) { real_output[k] = 0.0f; imag_output[k] = 0.0f; for (int m = 0; m < N; ++m) { - float twiddle_real; - float twiddle_imag; + double twiddle_real; + double twiddle_imag; twiddle(&twiddle_real, &twiddle_imag, k * m, n); @@ -107,6 +107,38 @@ void irfft(int n, float * inp_cplx, float * out_real) { } } +static void fold( + const std::vector & data, + int64_t output_size, + int64_t win_length, + int64_t hop_length, + int64_t pad, + std::vector& output +) { + int64_t output_height = output_size; + int64_t kernel_w = win_length; + int64_t stride_w = hop_length; + + int64_t width = output_size; + + output.resize(width, 0.0f); + + int64_t col_idx = 0; + for (int64_t w_col = 0; w_col < width; ++w_col) { + int64_t start = w_col * stride_w - pad; + int64_t end = start + kernel_w; + + for (int64_t w_im = start; w_im < end; ++w_im) { + if (w_im >= 0 && w_im < output_height) { + output[w_im] += data[col_idx]; + } + col_idx++; + } + } + + output.resize(output_size - 2 * pad); +} + int main(int argc, char ** argv) { common_params params; @@ -226,16 +258,17 @@ int main(int argc, char ** argv) { const int n_hop = 320; const int n_win = 1280; const int n_pad = (n_win - n_hop)/2; + const int n_out = (n - 1)*n_hop + n_win; - std::vector hann(n_fft); + std::vector hann(n_fft); fill_hann_window(hann.size(), true, hann.data()); int n_spec = n_embd*n; - std::vector E (n_spec); - std::vector S (n_spec); - std::vector ST(n_spec); + std::vector E (n_spec); + std::vector S (n_spec); + std::vector ST(n_spec); for (int l = 0; l < n; ++l) { for (int k = 0; k < n_embd; ++k) { @@ -245,8 +278,8 @@ int main(int argc, char ** argv) { for (int k = 0; k < n_embd/2; ++k) { for (int l = 0; l < n; ++l) { - float mag = E[(k )*n + l]; - float phi = E[(k + n_embd/2)*n + l]; + double mag = E[(k )*n + l]; + double phi = E[(k + n_embd/2)*n + l]; mag = exp(mag); @@ -265,7 +298,8 @@ int main(int argc, char ** argv) { } } - std::vector res(n*n_fft); + std::vector res (n*n_fft); + std::vector hann2(n*n_fft); const int n_thread = std::thread::hardware_concurrency(); std::vector workers(n_thread); @@ -273,6 +307,10 @@ int main(int argc, char ** argv) { workers[i] = std::thread([&, i]() { for (int l = i; l < n; l += n_thread) { irfft(n_fft, ST.data() + l*n_embd, res.data() + l*n_fft); + for (int j = 0; j < n_fft; ++j) { + res [l*n_fft + j] *= hann[j]; + hann2[l*n_fft + j] = hann[j] * hann[j]; + } } }); } @@ -280,34 +318,55 @@ int main(int argc, char ** argv) { workers[i].join(); } - LOG("result (%d):\n", res.size()); - for (int i = 0; i < n_fft; ++i) { - LOG("%d - %8.5f\n", i, res[5*n_fft + i]); - } - LOG("\n"); - double sum = 0.0; - for (int i = 0; i < n_fft; ++i) { - sum += res[5*n_fft + i]; - } - LOG("sum: %f\n", sum); + //LOG("result (%d):\n", res.size()); + //for (int i = 0; i < n_fft; ++i) { + // LOG("%d - %8.5f\n", i, res[5*n_fft + i]); + //} + //LOG("\n"); + //double sum = 0.0; + //for (int i = 0; i < n_fft; ++i) { + // sum += res[5*n_fft + i]; + //} + //LOG("sum: %f\n", sum); - { - LOG("result:\n"); - for (int i = 0; i < 10; ++i) { - LOG("%8.3f ", S[i]); - } - LOG("\n"); - for (int i = n_spec - 10; i < n_spec; ++i) { - LOG("%8.3f ", S[i]); - } - LOG("\n"); - double sum = 0.0; - for (int i = 0; i < n_spec; ++i) { - sum += S[i]; - } - LOG("sum: %f\n", sum); + std::vector audio; + std::vector env; + + fold(res, n_out, n_win, n_hop, n_pad, audio); + fold(hann2, n_out, n_win, n_hop, n_pad, env); + + for (size_t i = 0; i < audio.size(); ++i) { + audio[i] /= env[i]; } + //LOG("audio (%d):\n", audio.size()); + //for (int i = 0; i < 1000; ++i) { + // LOG("%d: %8.5f\n", i, audio[i]); + //} + //LOG("\n"); + //double sum = 0.0; + //for (int i = 0; i < 1000; ++i) { + // sum += audio[i]; + //} + //LOG("sum: %f\n", sum); + + //{ + // LOG("result:\n"); + // for (int i = 0; i < 10; ++i) { + // LOG("%8.3f ", S[i]); + // } + // LOG("\n"); + // for (int i = n_spec - 10; i < n_spec; ++i) { + // LOG("%8.3f ", S[i]); + // } + // LOG("\n"); + // double sum = 0.0; + // for (int i = 0; i < n_spec; ++i) { + // sum += S[i]; + // } + // LOG("sum: %f\n", sum); + //} + fprintf(stderr, "\n"); llama_free(ctx_ttc);