mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-28 12:24:35 +00:00
fft
This commit is contained in:
parent
5aaf4a8aa6
commit
55c9e328d9
@ -9,6 +9,7 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
|
#include <thread>
|
||||||
|
|
||||||
//
|
//
|
||||||
// Terminal utils
|
// Terminal utils
|
||||||
@ -67,6 +68,45 @@ void fill_hann_window(int length, bool periodic, float * output) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// very poor-man fft
|
||||||
|
void twiddle(float * real, float * imag, int k, int N) {
|
||||||
|
float angle = 2 * M_PI * k / N;
|
||||||
|
*real = cos(angle);
|
||||||
|
*imag = sin(angle);
|
||||||
|
}
|
||||||
|
|
||||||
|
void irfft(int n, float * inp_cplx, float * out_real) {
|
||||||
|
int N = n / 2 + 1;
|
||||||
|
|
||||||
|
std::vector<float> real_input(N);
|
||||||
|
std::vector<float> imag_input(N);
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
real_input[i] = inp_cplx[2 * i];
|
||||||
|
imag_input[i] = inp_cplx[2 * i + 1];
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<float> real_output(n);
|
||||||
|
std::vector<float> imag_output(n);
|
||||||
|
|
||||||
|
for (int k = 0; k < n; ++k) {
|
||||||
|
real_output[k] = 0.0f;
|
||||||
|
imag_output[k] = 0.0f;
|
||||||
|
for (int m = 0; m < N; ++m) {
|
||||||
|
float twiddle_real;
|
||||||
|
float twiddle_imag;
|
||||||
|
|
||||||
|
twiddle(&twiddle_real, &twiddle_imag, k * m, n);
|
||||||
|
|
||||||
|
real_output[k] += real_input[m] * twiddle_real - imag_input[m] * twiddle_imag;
|
||||||
|
imag_output[k] += real_input[m] * twiddle_imag + imag_input[m] * twiddle_real;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < n; ++i) {
|
||||||
|
out_real[i] = real_output[i] / N;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
int main(int argc, char ** argv) {
|
int main(int argc, char ** argv) {
|
||||||
common_params params;
|
common_params params;
|
||||||
|
|
||||||
@ -181,28 +221,93 @@ int main(int argc, char ** argv) {
|
|||||||
const int n_embd = llama_n_embd(model_cts);
|
const int n_embd = llama_n_embd(model_cts);
|
||||||
const float * embd = llama_get_embeddings(ctx_cts);
|
const float * embd = llama_get_embeddings(ctx_cts);
|
||||||
|
|
||||||
const int w = 1280;
|
const int n = prompt_inp.size();
|
||||||
std::vector<float> hann(w);
|
const int n_fft = 1280;
|
||||||
|
const int n_hop = 320;
|
||||||
|
const int n_win = 1280;
|
||||||
|
const int n_pad = (n_win - n_hop)/2;
|
||||||
|
|
||||||
|
std::vector<float> hann(n_fft);
|
||||||
|
|
||||||
fill_hann_window(hann.size(), true, hann.data());
|
fill_hann_window(hann.size(), true, hann.data());
|
||||||
|
|
||||||
|
int n_spec = n_embd*n;
|
||||||
|
|
||||||
int n = n_embd*261;
|
std::vector<float> E (n_spec);
|
||||||
|
std::vector<float> S (n_spec);
|
||||||
|
std::vector<float> ST(n_spec);
|
||||||
|
|
||||||
LOG("result:\n");
|
for (int l = 0; l < n; ++l) {
|
||||||
for (int i = 0; i < 10; ++i) {
|
for (int k = 0; k < n_embd; ++k) {
|
||||||
LOG("%8.3f ", embd[i]);
|
E[k*n + l] = embd[l*n_embd + k];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
LOG("\n");
|
|
||||||
for (int i = n - 10; i < n; ++i) {
|
for (int k = 0; k < n_embd/2; ++k) {
|
||||||
LOG("%8.3f ", embd[i]);
|
for (int l = 0; l < n; ++l) {
|
||||||
|
float mag = E[(k )*n + l];
|
||||||
|
float phi = E[(k + n_embd/2)*n + l];
|
||||||
|
|
||||||
|
mag = exp(mag);
|
||||||
|
|
||||||
|
if (mag > 1e2) {
|
||||||
|
mag = 1e2;
|
||||||
|
}
|
||||||
|
S[2*(k*n + l) + 0] = mag*cosf(phi);
|
||||||
|
S[2*(k*n + l) + 1] = mag*sinf(phi);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int l = 0; l < n; ++l) {
|
||||||
|
for (int k = 0; k < n_embd/2; ++k) {
|
||||||
|
ST[l*n_embd + 2*k + 0] = S[2*(k*n + l) + 0];
|
||||||
|
ST[l*n_embd + 2*k + 1] = S[2*(k*n + l) + 1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<float> res(n*n_fft);
|
||||||
|
|
||||||
|
const int n_thread = std::thread::hardware_concurrency();
|
||||||
|
std::vector<std::thread> workers(n_thread);
|
||||||
|
for (int i = 0; i < n_thread; ++i) {
|
||||||
|
workers[i] = std::thread([&, i]() {
|
||||||
|
for (int l = i; l < n; l += n_thread) {
|
||||||
|
irfft(n_fft, ST.data() + l*n_embd, res.data() + l*n_fft);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
for (int i = 0; i < n_thread; ++i) {
|
||||||
|
workers[i].join();
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG("result (%d):\n", res.size());
|
||||||
|
for (int i = 0; i < n_fft; ++i) {
|
||||||
|
LOG("%d - %8.5f\n", i, res[5*n_fft + i]);
|
||||||
}
|
}
|
||||||
LOG("\n");
|
LOG("\n");
|
||||||
double sum = 0.0;
|
double sum = 0.0;
|
||||||
for (int i = 0; i < n; ++i) {
|
for (int i = 0; i < n_fft; ++i) {
|
||||||
sum += embd[i];
|
sum += res[5*n_fft + i];
|
||||||
}
|
}
|
||||||
LOG("sum: %f\n", sum);
|
LOG("sum: %f\n", sum);
|
||||||
|
|
||||||
|
{
|
||||||
|
LOG("result:\n");
|
||||||
|
for (int i = 0; i < 10; ++i) {
|
||||||
|
LOG("%8.3f ", S[i]);
|
||||||
|
}
|
||||||
|
LOG("\n");
|
||||||
|
for (int i = n_spec - 10; i < n_spec; ++i) {
|
||||||
|
LOG("%8.3f ", S[i]);
|
||||||
|
}
|
||||||
|
LOG("\n");
|
||||||
|
double sum = 0.0;
|
||||||
|
for (int i = 0; i < n_spec; ++i) {
|
||||||
|
sum += S[i];
|
||||||
|
}
|
||||||
|
LOG("sum: %f\n", sum);
|
||||||
|
}
|
||||||
|
|
||||||
fprintf(stderr, "\n");
|
fprintf(stderr, "\n");
|
||||||
|
|
||||||
llama_free(ctx_ttc);
|
llama_free(ctx_ttc);
|
||||||
|
Loading…
Reference in New Issue
Block a user