Apply suggestions from the PR: employ CPU buffers to copy results, use correct ctx_size and add GGML_ASSERT to check v_output

This commit is contained in:
Lucas Nogueira 2024-11-16 01:13:46 -03:00
parent 5c1d1177d3
commit 1840df1b58
3 changed files with 290 additions and 273 deletions

View File

@ -2,7 +2,9 @@
#include "common.h"
#include "llama.h"
#include "ggml.h"
#include "vanilla_pca.hpp"
#include "mean.hpp"
#include "pca.hpp"
#ifdef GGML_USE_CUDA
#include "ggml-cuda.h"

View File

@ -2,7 +2,7 @@
#include "common.h"
#include "llama.h"
#include "ggml.h"
#include "../vanilla_pca.hpp"
#include "../pca.hpp"
#ifdef GGML_USE_CUDA
#include "ggml-cuda.h"
@ -15,28 +15,11 @@
#include <cstdio>
#include <cstring>
// Function to initialize ggml with optional GPU backend support
struct ggml_context *initialize_ggml_context() {
#ifdef GGML_USE_CUDA
struct ggml_init_params params = { .mem_size = 1024 * 1024, .mem_buffer = NULL, .use_gpu = true };
printf("Initializing with GPU backend...\n");
#else
struct ggml_init_params params = { .mem_size = 1024 * 1024, .mem_buffer = NULL };
printf("Initializing with CPU backend...\n");
#endif
return ggml_init(params);
}
// Helper function to create a tensor from a matrix
struct ggml_tensor *create_tensor(struct ggml_context *ctx, float *data, int rows, int cols) {
struct ggml_tensor *tensor = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, cols, rows);
memcpy(tensor->data, data, ggml_nbytes(tensor));
return tensor;
}
// Function to run PCA and print results
void run_pca_test(struct ggml_context *ctx, float *matrix, int rows, int cols) {
struct ggml_tensor *input_tensor = create_tensor(ctx, matrix, rows, cols);
static void run_pca_test(struct ggml_context *ctx, float *matrix, int rows, int cols) {
// struct ggml_tensor *input_tensor = create_tensor(ctx, matrix, rows, cols);
struct ggml_tensor *input_tensor = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, rows, cols);
memcpy(input_tensor->data, matrix, rows * cols * sizeof(float));
PCA::pca_params pca_params;
pca_params.n_threads = 8;
@ -44,20 +27,37 @@ void run_pca_test(struct ggml_context *ctx, float *matrix, int rows, int cols) {
pca_params.n_iterations = 1000;
pca_params.tolerance = 1e-5;
PCA::pca_result result;
PCA::pca_result result = {NULL, 0};
PCA::run_single_pca(pca_params, input_tensor, result);
printf("\nPrincipal components:\n");
float *b = (float *)result.principal_component->data;
for (int i = 0; i < result.principal_component->ne[0]; i++) {
printf("%f ", b[i]);
printf("Principal components:\n");
for (int i = 0; i < cols; i++) {
printf("%f ", result.principal_component[i]);
}
printf("\nEigenvalue: %f\n", result.explained_variance);
free(result.principal_component);
}
int main() {
// Initialize ggml context
struct ggml_context *ctx = initialize_ggml_context();
size_t ctx_size = 0;
ctx_size += 4 * 4 * ggml_type_size(GGML_TYPE_F32);
ctx_size += 10 * 10 * ggml_type_size(GGML_TYPE_F32);
ctx_size += 3 * 3 * ggml_type_size(GGML_TYPE_F32);
ctx_size += 3 * 3 * ggml_type_size(GGML_TYPE_F32);
ctx_size += 4 * ggml_tensor_overhead();
ctx_size += 1024;
// Step 2. Initialize GGML Context
struct ggml_init_params ctx_params {
ctx_size, // mem_size
NULL, // mem_buffer
false, // no_alloc
};
struct ggml_context * ctx = ggml_init(ctx_params);
if (ctx == NULL) {
printf("Failed to initialize ggml context\n");
return 1;

View File

@ -6,15 +6,15 @@
#include "ggml-cuda.h"
#endif
#ifdef GGML_USE_METAL
#include "ggml-metal.h"
#endif
#include <cstdio>
#include <ctime>
#include <random>
#include <string>
#include <tuple>
#include <vector>
#include <algorithm>
#include <iostream>
#include <fstream>
#define DEBUG_POS 5
@ -28,6 +28,7 @@ static void print_debug_tensor(struct ggml_tensor * t, bool with_data = true) {
printf(" ... ]\n");
}
// begin vanilla pca namespace
namespace PCA {
// input params for PCA computations
@ -36,280 +37,294 @@ struct pca_params {
int n_batch = 20; // number of iterations do to in one batch. larger the batch, more memory is used
int n_iterations = 1000;
float tolerance = 1e-7;
// for debugging
int i_layer = 0;
int n_layers = 0;
};
// result from each iteration
struct pca_result {
struct ggml_tensor * calculated_square = NULL;
std::vector<struct ggml_tensor *> eigenvectors;
std::vector<float> distances;
float * principal_component; // eigenvectors of the covariance matrix
float explained_variance; // eigenvalues of the covariance matrix
};
struct pca_model {
static void compute_covariance(struct pca_params &pca_params,
struct ggml_tensor * X,
float * covariance,
struct ggml_backend * backend) {
size_t ctx_size = 0;
ctx_size += 7 * X->ne[0] * X->ne[1] * ggml_type_size(GGML_TYPE_F32);
ctx_size += 7 * ggml_tensor_overhead();
ctx_size += ggml_graph_overhead();
ctx_size += 1024;
// Memory allocation
struct ggml_cgraph * gf = NULL;
struct ggml_context * ctx = NULL;
struct ggml_init_params ctx_params = {
ctx_size,
NULL,
true, // the tensors will be allocated later by ggml_gallocr_alloc_graph()
};
ctx = ggml_init(ctx_params);
gf = ggml_new_graph(ctx);
// Step 0: Transpose the input because of row-major
X = ggml_cont(ctx, ggml_transpose(ctx, X));
// Step 1: Compute the mean for each feature
struct ggml_tensor * mean = ggml_repeat(ctx, ggml_mean(ctx, X), X); // mean with trick to make it easier to sub
struct ggml_tensor * centered_data = ggml_sub(ctx, X, mean);
// Step 2: Compute the covariance matrix
struct ggml_tensor * cov = ggml_mul_mat(ctx, centered_data, centered_data); // C = X * X^T
cov = ggml_scale(ctx, cov, 1.0/(X->ne[0]-1));
ggml_build_forward_expand(gf, cov);
// Step 3: Create ggml_gallocr for graph computation
ggml_gallocr_t allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend));
ggml_gallocr_alloc_graph(allocr, gf);
// Step 4: Check if CPU and compute the result of the graph
if (ggml_backend_is_cpu(backend)) {
ggml_backend_cpu_set_n_threads(backend, pca_params.n_threads);
}
ggml_backend_graph_compute(backend, gf);
// Step 5: Store covariance matrix in the data pointer
struct ggml_tensor * result = ggml_graph_node(gf, ggml_graph_n_nodes(gf)-1);
ggml_backend_tensor_get(result, covariance, 0, ggml_nbytes(result));
// Step 6: Free memory
ggml_gallocr_free(allocr);
ggml_free(ctx);
}
static void compute_cross_covariance(struct pca_params &pca_params,
struct ggml_tensor * A,
struct ggml_tensor * B,
float * cross_covariance,
struct ggml_backend * backend) {
size_t ctx_size = 0;
ctx_size += 9 * A->ne[0] * B->ne[1] * ggml_type_size(GGML_TYPE_F32);
ctx_size += 9 * ggml_tensor_overhead();
ctx_size += ggml_graph_overhead();
ctx_size += 1024;
// Memory allocation
struct ggml_cgraph * gf = NULL;
struct ggml_context * ctx = NULL;
struct ggml_init_params ctx_params = {
ctx_size,
NULL,
true, // the tensors will be allocated later by ggml_gallocr_alloc_graph()
};
ctx = ggml_init(ctx_params);
gf = ggml_new_graph(ctx);
// Step 1: Compute matrices of cross_covariance
struct ggml_tensor * AT = ggml_cont(ctx, ggml_transpose(ctx, A));
struct ggml_tensor * BT = ggml_cont(ctx, ggml_transpose(ctx, B));
struct ggml_tensor * AT_B = ggml_mul_mat(ctx, AT, BT);
struct ggml_tensor * BT_A = ggml_cont(ctx, ggml_transpose(ctx, AT_B));
// Step 2: Compute the covariance matrix
struct ggml_tensor * cross_cov = ggml_add(ctx, AT_B, BT_A);
cross_cov = ggml_scale(ctx, cross_cov, 0.5);
ggml_build_forward_expand(gf, cross_cov);
// Step 3: Create ggml_gallocr for graph computation
ggml_gallocr_t allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend));
ggml_gallocr_alloc_graph(allocr, gf);
// Step 4: Check if CPU and compute the result of the graph
if (ggml_backend_is_cpu(backend)) {
ggml_backend_cpu_set_n_threads(backend, pca_params.n_threads);
}
ggml_backend_graph_compute(backend, gf);
// Step 5: Store covariance matrix in the data pointer
struct ggml_tensor * result = ggml_graph_node(gf, ggml_graph_n_nodes(gf)-1);
ggml_backend_tensor_get(result, cross_covariance, 0, ggml_nbytes(result));
// Step 6: Free memory
ggml_gallocr_free(allocr);
ggml_free(ctx);
}
// Find the dominant eigenvector of tensor M
static void power_iteration(struct pca_params &pca_params,
struct ggml_tensor * M,
struct pca_result &result,
struct ggml_backend * backend) {
int m = M->ne[1];
// Initialize random vector
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_real_distribution<float> dist(-1.0f, 1.0f);
float * b = result.principal_component;
for (int i = 0; i < m; i++) {
b[i] = dist(gen);
};
float eigenvalue = 0;
// Iterate
int n_rounds = pca_params.n_iterations / pca_params.n_batch;
for(int i = 0; i < n_rounds; i++) {
// Memory allocation
struct ggml_cgraph * gf = NULL;
struct ggml_context * ctx = NULL;
struct ggml_init_params params = {
ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(),
NULL,
true, // the tensors will be allocated later by ggml_gallocr_alloc_graph()
};
ctx = ggml_init(params);
gf = ggml_new_graph(ctx);
// Fill current eigen vector
struct ggml_tensor * e_curr = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, m);
struct ggml_tensor * e_prev = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, m);
ggml_backend_buffer_t buffer = ggml_backend_alloc_ctx_tensors(ctx, backend);
ggml_backend_tensor_set(e_curr, b, 0, ggml_nbytes(e_curr));
ggml_backend_tensor_set(e_prev, b, 0, ggml_nbytes(e_curr));
struct ggml_tensor * e_next = NULL;
struct ggml_tensor * e_norm = NULL;
for(int j = 0; j < pca_params.n_batch; j++) {
// Compute next candidate vector multiplying M with the current vector
e_next = ggml_mul_mat(ctx, M, e_curr);
// Compute the norm of the new vector (and normalize it)
// this will give us the next eigenvector and eigenvalue
e_norm = ggml_sqrt_inplace(ctx, ggml_sum_rows(ctx, ggml_sqr(ctx, e_next)));
e_curr = ggml_div_inplace(ctx, e_next, e_norm);
ggml_format_name(e_norm, "eigenvalue_%d", j);
ggml_format_name(e_curr, "eigenvector_%d", j);
// Update graph
ggml_build_forward_expand(gf, e_curr);
}
// Compute the similarity between the current eigenvector and the previous (dot product)
struct ggml_tensor * similarity = ggml_mul_mat(ctx, e_curr, e_prev);
ggml_build_forward_expand(gf, similarity);
// Create ggml_gallocr for graph computation
ggml_gallocr_t allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend));
ggml_gallocr_alloc_graph(allocr, gf);
// Check if CPU and compute the result of the graph
if (ggml_backend_is_cpu(backend)) {
ggml_backend_cpu_set_n_threads(backend, pca_params.n_threads);
}
ggml_status graph_status = ggml_backend_graph_compute(backend, gf);
// Get graph results (eigenvector and eigenvalue) and store it in b and eigenvalue
if(graph_status == GGML_STATUS_SUCCESS){
// Similarity is the last node in the graph
struct ggml_tensor * similarity_tensor = ggml_graph_node(gf, ggml_graph_n_nodes(gf)-1);
float similarity = (float)((float*) similarity_tensor->data)[0];
// Eigenvector is the second last node in the graph
// struct ggml_tensor * eigenvector_tensor = gf->nodes[gf->n_nodes-2];
struct ggml_tensor * eigenvector_tensor = ggml_graph_node(gf,ggml_graph_n_nodes(gf)-2);
ggml_backend_tensor_get(eigenvector_tensor, b, 0, ggml_nbytes(eigenvector_tensor));
// Eigenvalue computation is 1 operation before eigenvector computation
// struct ggml_tensor * eigenvalue_tensor = gf->nodes[gf->n_nodes-3];
struct ggml_tensor * eigenvalue_tensor = ggml_graph_node(gf, ggml_graph_n_nodes(gf)-3);
eigenvalue = (float)((float*) eigenvalue_tensor->data)[0];
// Check if the similarity is close enough to 1, if so we converged and should break
if(1 - similarity < pca_params.tolerance)
break;
}
// Free memory
ggml_backend_buffer_free(buffer);
ggml_gallocr_free(allocr);
ggml_free(ctx);
}
// Store result
result.principal_component = b;
result.explained_variance = eigenvalue;
return;
}
static void run_single_pca(struct pca_params &pca_params,
struct ggml_tensor * X,
struct pca_result &result
) {
ggml_set_name(X, "input_tensor");
int m = X->ne[1]; // Number of features
// Step 1. Initialize GGML Backend
ggml_backend_t backend = NULL;
ggml_backend_buffer_t buffer;
struct ggml_context * ctx; // context to compute graph on target device
struct ggml_context * ctx_host; // host context to store results
// tensors on target device
struct ggml_tensor * dev_input;
struct ggml_tensor * dev_square;
struct ggml_tensor * dev_eigenvector;
pca_model(struct ggml_tensor * t_input) {
#ifdef GGML_USE_CUDA
fprintf(stderr, "%s: using CUDA backend\n", __func__);
backend = ggml_backend_cuda_init(0); // init device 0
if (!backend) {
fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__);
}
if (!backend) { fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__); }
#endif
// TODO: enable Metal support when support for GGML_OP_SQRT is added
// #ifdef GGML_USE_METAL
// fprintf(stderr, "%s: using Metal backend\n", __func__);
// backend = ggml_backend_metal_init();
// if (!backend) {
// fprintf(stderr, "%s: ggml_backend_metal_init() failed\n", __func__);
// }
// #endif
// If there aren't GPU Backends fallback to CPU backend
if (!backend) { backend = ggml_backend_cpu_init(); }
// if there aren't GPU Backends fallback to CPU backend
if (!backend) {
backend = ggml_backend_cpu_init();
}
// Compute the context size needed
size_t ctx_size = 0;
ctx_size += m * m * ggml_type_size(GGML_TYPE_F32);
ctx_size += 1 * ggml_tensor_overhead();
const int num_tensors = 4;
struct ggml_init_params params {
/*.mem_size =*/ ggml_tensor_overhead() * num_tensors,
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true,
// Step 2. Initialize GGML Context
struct ggml_init_params ctx_params {
ctx_size, // mem_size
NULL, // mem_buffer
true, // no_alloc
};
ctx = ggml_init(params);
struct ggml_context * ctx = ggml_init(ctx_params);
auto n_samples = t_input->ne[0];
auto n_embd = t_input->ne[1];
// Step 3. Compute the data covariance matrix
// Using a CPU buffer to copy data from the backend
float * covariance = (float *) malloc(m * m * sizeof(float));
compute_covariance(pca_params, X, covariance, backend);
dev_input = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_samples, n_embd);
dev_square = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd);
dev_eigenvector = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
// Create covariance tensor on backend
struct ggml_tensor * covariance_tensor = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, m, m);
ggml_backend_buffer_t buffer = ggml_backend_alloc_ctx_tensors(ctx, backend);
ggml_backend_tensor_set(covariance_tensor, covariance, 0, ggml_nbytes(covariance_tensor));
ggml_set_name(dev_input, "dev_input");
ggml_set_name(dev_square, "dev_square");
ggml_set_name(dev_eigenvector, "dev_eigenvector");
buffer = ggml_backend_alloc_ctx_tensors(ctx, backend);
ggml_backend_tensor_set(dev_input, t_input->data, 0, ggml_nbytes(t_input));
// Step 4. Power iteration
result.principal_component = (float *) malloc(m * sizeof(float));
power_iteration(pca_params, covariance_tensor, result, backend);
// initialize eigenvector to random normalized vector
{
std::vector<float> random_vec(ggml_nelements(dev_eigenvector), 0.0);
std::default_random_engine generator(static_cast<unsigned int>(std::time(0)));
std::uniform_real_distribution<float> distribution(0.0, 1.0);
float sum_sqr = 0.0; // for normalizing random_vec
for (size_t i = 0; i < random_vec.size(); ++i) {
float f = distribution(generator);
sum_sqr += f * f;
random_vec[i] = f;
}
// normalize it
float random_vec_norm = std::sqrt(sum_sqr);
for (size_t i = 0; i < random_vec.size(); ++i) {
random_vec[i] /= random_vec_norm;
}
ggml_backend_tensor_set(dev_eigenvector, random_vec.data(), 0, ggml_nbytes(dev_eigenvector));
}
}
~pca_model() {
// Step 5. Free ggml ctx and backend
ggml_free(ctx);
ggml_backend_buffer_free(buffer);
ggml_backend_free(backend);
}
};
static struct ggml_cgraph * build_graph_piter(
const struct pca_params & params,
const pca_model & model,
bool calc_square = false) {
GGML_ASSERT(params.n_batch > 0);
// TODO: buf_size must be able to scale with params.n_batch
static size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead();
static std::vector<uint8_t> buf(buf_size);
struct ggml_init_params params0 = {
/*.mem_size =*/ buf_size,
/*.mem_buffer =*/ buf.data(),
/*.no_alloc =*/ true, // the tensors will be allocated later by ggml_allocr_alloc_graph()
};
// create a temporally context to build the graph
struct ggml_context * ctx0 = ggml_init(params0);
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
// turn v_diff_original into square matrix if needed
struct ggml_tensor * tmp_square;
if (calc_square) {
tmp_square = ggml_mul_mat(ctx0, model.dev_input, model.dev_input);
ggml_set_name(tmp_square, "tmp_square");
}
struct ggml_tensor * b_tensor;
struct ggml_tensor * distance;
struct ggml_tensor * old_eigen = model.dev_eigenvector;
struct ggml_tensor * input_square = calc_square ? tmp_square : model.dev_square;
for (int i = 0; i < params.n_batch; ++i) {
// b_tensor = square * eigenvector^T
b_tensor = ggml_mul_mat(ctx0, input_square, old_eigen);
ggml_set_name(b_tensor, "b_tensor");
// normalize
b_tensor = ggml_div_inplace(ctx0,
b_tensor,
ggml_sqrt_inplace(ctx0, ggml_sum_rows(ctx0, ggml_sqr(ctx0, b_tensor)))
);
ggml_format_name(b_tensor, "b_tensor_norm_%d", i);
// calculate distance(new eigenvector - old eigenvector)
// we don't use ggml_sub because it may not be implemented on GPU backend
struct ggml_tensor * new_sub_old = ggml_add(ctx0, old_eigen, ggml_scale(ctx0, b_tensor, -1));
distance = ggml_sqrt_inplace(ctx0,
ggml_sum_rows(ctx0, ggml_sqr_inplace(ctx0, new_sub_old)));
ggml_format_name(distance, "distance_%d", i);
old_eigen = b_tensor;
// build operations nodes
ggml_build_forward_expand(gf, distance);
}
// delete the temporally context used to build the graph
ggml_free(ctx0);
return gf;
}
static ggml_status compute_piter(
const struct pca_params & params,
const pca_model & model,
struct ggml_cgraph * gf,
ggml_gallocr_t allocr,
struct pca_result & result) {
// allocate tensors
ggml_gallocr_alloc_graph(allocr, gf);
if (ggml_backend_is_cpu(model.backend)) {
ggml_backend_cpu_set_n_threads(model.backend, params.n_threads);
}
ggml_status res = ggml_backend_graph_compute(model.backend, gf);
if (res == GGML_STATUS_SUCCESS) {
auto extract_i = [](std::string prefix, std::string str) -> int {
int i = -1;
if (str.rfind(prefix, 0) == 0) {
sscanf(str.c_str(), (prefix + "%d").c_str(), &i);
}
return i;
};
result.calculated_square = NULL;
result.eigenvectors.clear();
result.distances.clear();
result.eigenvectors.resize(params.n_batch);
result.distances.resize(params.n_batch);
// get output nodes
for (int i = 0; i < ggml_graph_n_nodes(gf); ++i) {
auto node = ggml_graph_node(gf, i);
int iter = -1;
// find b_tensor (without copying data from device)
if ((iter = extract_i("b_tensor_norm_", node->name)) > -1) {
result.eigenvectors[iter] = node;
}
// find distances, then copy data from device
if ((iter = extract_i("distance_", node->name)) > -1) {
float d;
ggml_backend_tensor_get(node, &d, 0, sizeof(float));
result.distances[iter] = d;
// std::cout << node->name << " = " << d << "\n";
}
// find tmp_square if it exists (without copying data from device)
if (std::string(node->name) == "tmp_square") {
result.calculated_square = node;
}
}
}
return res;
}
static void power_iteration(
const struct pca_params & params,
struct ggml_tensor * input, // shape of input: [n_samples, n_embd]
struct ggml_tensor * output) {
//printf("in power iteration\n");
struct pca_model model(input);
ggml_gallocr_t allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(model.backend));
struct pca_result result;
struct ggml_tensor * last_eigenvector = NULL;
int n_iters = params.n_iterations / params.n_batch; // more batch, fewer iterations
for (int iter = 0; iter < n_iters; ++iter) {
bool calc_square = (iter == 0); // only need to calculate square for first iteration
struct ggml_cgraph * gf = build_graph_piter(params, model, calc_square);
// ggml_graph_dump_dot(gf, nullptr, "/tmp/_cgraph.dot");
compute_piter(params, model, gf, allocr, result);
for (size_t k = 0; k < result.distances.size(); ++k) {
last_eigenvector = result.eigenvectors[k];
if (result.distances[k] < params.tolerance) {
break; // done
}
}
if (calc_square) {
// copy and store the square matrix if needed
GGML_ASSERT(result.calculated_square != NULL);
ggml_backend_tensor_copy(result.calculated_square, model.dev_square);
}
{
// copy last eigen vector and store as input for next iteration
GGML_ASSERT(last_eigenvector != NULL);
ggml_backend_tensor_copy(last_eigenvector, model.dev_eigenvector);
}
printf("%s: layer %d/%d, iteration: %d / total: %d (batch = %d) ...\n",
__func__, params.i_layer+1, params.n_layers, iter+1, n_iters, params.n_batch);
}
// get output tensor
GGML_ASSERT(last_eigenvector);
ggml_backend_tensor_get(last_eigenvector, output->data, 0, ggml_nbytes(last_eigenvector));
//print_debug_tensor(output);
ggml_gallocr_free(allocr);
// TODO @ngxson : The output vector is randomly inverted
// Solution: https://github.com/ggerganov/llama.cpp/pull/8069#issuecomment-2185328171
}
static void run_pca(
struct pca_params & params,
const std::vector<struct ggml_tensor *> & v_input, // shape of v_input[0]: [n_samples, n_embd]
const std::vector<struct ggml_tensor *> & v_output) {
printf("%s: Running PCA...\n", __func__);
for (size_t il = 0; il < v_input.size(); ++il) {
// prepare output vector
struct ggml_tensor * ctrl_out = v_output[il];
ggml_format_name(ctrl_out, "direction.%ld", il+1);
// run power_iteration
params.i_layer = il;
params.n_layers = v_input.size();
power_iteration(params, v_input[il], ctrl_out);
printf("%s: Done layer %d / %d\n", __func__, (int) il+1, (int) v_input.size());
for (size_t i = 0; i < v_input.size(); i++) {
// Check shape of tensor inside v_output
GGML_ASSERT(v_output[i]->ne[0] == v_input[i]->ne[1]);
struct pca_result result = {NULL, 0};
run_single_pca(params, v_input[i], result);
ggml_backend_tensor_set(v_output[i], result.principal_component, 0, ggml_nbytes(v_output[i]));
free(result.principal_component);
}
}
// end namesace
}