diff --git a/examples/tts/tts.cpp b/examples/tts/tts.cpp index f402ba8a2..1fdf756d5 100644 --- a/examples/tts/tts.cpp +++ b/examples/tts/tts.cpp @@ -9,6 +9,7 @@ #include #include #include +#include // // Terminal utils @@ -67,6 +68,45 @@ void fill_hann_window(int length, bool periodic, float * output) { } } +// very poor-man fft +void twiddle(float * real, float * imag, int k, int N) { + float angle = 2 * M_PI * k / N; + *real = cos(angle); + *imag = sin(angle); +} + +void irfft(int n, float * inp_cplx, float * out_real) { + int N = n / 2 + 1; + + std::vector real_input(N); + std::vector imag_input(N); + for (int i = 0; i < N; ++i) { + real_input[i] = inp_cplx[2 * i]; + imag_input[i] = inp_cplx[2 * i + 1]; + } + + std::vector real_output(n); + std::vector imag_output(n); + + for (int k = 0; k < n; ++k) { + real_output[k] = 0.0f; + imag_output[k] = 0.0f; + for (int m = 0; m < N; ++m) { + float twiddle_real; + float twiddle_imag; + + twiddle(&twiddle_real, &twiddle_imag, k * m, n); + + real_output[k] += real_input[m] * twiddle_real - imag_input[m] * twiddle_imag; + imag_output[k] += real_input[m] * twiddle_imag + imag_input[m] * twiddle_real; + } + } + + for (int i = 0; i < n; ++i) { + out_real[i] = real_output[i] / N; + } +} + int main(int argc, char ** argv) { common_params params; @@ -181,28 +221,93 @@ int main(int argc, char ** argv) { const int n_embd = llama_n_embd(model_cts); const float * embd = llama_get_embeddings(ctx_cts); - const int w = 1280; - std::vector hann(w); + const int n = prompt_inp.size(); + const int n_fft = 1280; + const int n_hop = 320; + const int n_win = 1280; + const int n_pad = (n_win - n_hop)/2; + + std::vector hann(n_fft); + fill_hann_window(hann.size(), true, hann.data()); + int n_spec = n_embd*n; - int n = n_embd*261; + std::vector E (n_spec); + std::vector S (n_spec); + std::vector ST(n_spec); - LOG("result:\n"); - for (int i = 0; i < 10; ++i) { - LOG("%8.3f ", embd[i]); + for (int l = 0; l < n; ++l) { + for (int k = 0; k < n_embd; ++k) { + E[k*n + l] = embd[l*n_embd + k]; + } } - LOG("\n"); - for (int i = n - 10; i < n; ++i) { - LOG("%8.3f ", embd[i]); + + for (int k = 0; k < n_embd/2; ++k) { + for (int l = 0; l < n; ++l) { + float mag = E[(k )*n + l]; + float phi = E[(k + n_embd/2)*n + l]; + + mag = exp(mag); + + if (mag > 1e2) { + mag = 1e2; + } + S[2*(k*n + l) + 0] = mag*cosf(phi); + S[2*(k*n + l) + 1] = mag*sinf(phi); + } + } + + for (int l = 0; l < n; ++l) { + for (int k = 0; k < n_embd/2; ++k) { + ST[l*n_embd + 2*k + 0] = S[2*(k*n + l) + 0]; + ST[l*n_embd + 2*k + 1] = S[2*(k*n + l) + 1]; + } + } + + std::vector res(n*n_fft); + + const int n_thread = std::thread::hardware_concurrency(); + std::vector workers(n_thread); + for (int i = 0; i < n_thread; ++i) { + workers[i] = std::thread([&, i]() { + for (int l = i; l < n; l += n_thread) { + irfft(n_fft, ST.data() + l*n_embd, res.data() + l*n_fft); + } + }); + } + for (int i = 0; i < n_thread; ++i) { + workers[i].join(); + } + + LOG("result (%d):\n", res.size()); + for (int i = 0; i < n_fft; ++i) { + LOG("%d - %8.5f\n", i, res[5*n_fft + i]); } LOG("\n"); double sum = 0.0; - for (int i = 0; i < n; ++i) { - sum += embd[i]; + for (int i = 0; i < n_fft; ++i) { + sum += res[5*n_fft + i]; } LOG("sum: %f\n", sum); + { + LOG("result:\n"); + for (int i = 0; i < 10; ++i) { + LOG("%8.3f ", S[i]); + } + LOG("\n"); + for (int i = n_spec - 10; i < n_spec; ++i) { + LOG("%8.3f ", S[i]); + } + LOG("\n"); + double sum = 0.0; + for (int i = 0; i < n_spec; ++i) { + sum += S[i]; + } + LOG("sum: %f\n", sum); + } + fprintf(stderr, "\n"); llama_free(ctx_ttc);