Control vector loading fixes (#8137)

* Fixed leak in llama_control_vector_load_one() and allow llama_control_vector_load() to grow

* refactored `llama_control_vector_load_one()`

* allow multiple directions for same layer in same file

* llama_control_vector_load_one() and llama_control_vector_load() now break on error

* removed unnecessary ggml_free() call
This commit is contained in:
jukofyork 2024-06-27 15:48:07 +01:00 committed by GitHub
parent 387952651a
commit 97877eb10b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2804,125 +2804,87 @@ float llama_embd_similarity_cos(const float * embd1, const float * embd2, int n)
// //
static llama_control_vector_data llama_control_vector_load_one(const llama_control_vector_load_info & load_info) { static llama_control_vector_data llama_control_vector_load_one(const llama_control_vector_load_info & load_info) {
int32_t n_tensors;
size_t n_bytes = 0;
uint32_t max_direction_layer = 0;
llama_control_vector_data result = { -1, {} }; llama_control_vector_data result = { -1, {} };
// calculate size of ctx needed for tensors, ensure tensors are f32, and find max layer ggml_context * ctx = nullptr;
{
struct ggml_init_params meta_params = {
/* .mem_size = */ ggml_tensor_overhead() * 128 + ggml_graph_overhead(),
/* .mem_buffer = */ nullptr,
/* .no_alloc = */ true,
};
ggml_context * meta_ctx = ggml_init(meta_params);
struct gguf_init_params meta_gguf_params = { struct gguf_init_params meta_gguf_params = {
/* .no_alloc = */ true, /* .no_alloc = */ false,
/* .ctx = */ &meta_ctx, /* .ctx = */ &ctx,
}; };
struct gguf_context * meta_ctx_gguf = gguf_init_from_file(load_info.fname.c_str(), meta_gguf_params); struct gguf_context * ctx_gguf = gguf_init_from_file(load_info.fname.c_str(), meta_gguf_params);
if (!meta_ctx_gguf) { if (!ctx_gguf) {
fprintf(stderr, "%s: failed to load control vector from %s\n", __func__, load_info.fname.c_str()); fprintf(stderr, "%s: failed to load control vector file from %s\n", __func__, load_info.fname.c_str());
ggml_free(meta_ctx);
return result; return result;
} }
n_tensors = gguf_get_n_tensors(meta_ctx_gguf); int32_t n_tensors = gguf_get_n_tensors(ctx_gguf);
if (n_tensors == 0) {
fprintf(stderr, "%s: no direction tensors found in %s\n", __func__, load_info.fname.c_str());
}
for (int i = 0; i < n_tensors; i++) { for (int i = 0; i < n_tensors; i++) {
std::string name = gguf_get_tensor_name(meta_ctx_gguf, i); std::string name = gguf_get_tensor_name(ctx_gguf, i);
int layer_idx = -1;
// split on '.' // split on '.'
size_t dotpos = name.find('.'); size_t dotpos = name.find('.');
if (dotpos != std::string::npos && name.substr(0, dotpos) == "direction") { if (dotpos != std::string::npos && name.substr(0, dotpos) == "direction") {
try { try {
uint32_t layer = std::stoi(name.substr(dotpos + 1)); layer_idx = std::stoi(name.substr(dotpos + 1));
if (layer == 0) {
fprintf(stderr, "%s: direction tensor invalid in %s\n", __func__, load_info.fname.c_str());
ggml_free(meta_ctx);
gguf_free(meta_ctx_gguf);
return result;
}
if (layer > max_direction_layer) {
max_direction_layer = layer;
}
} catch (...) { } catch (...) {
fprintf(stderr, "%s: direction tensor invalid in %s\n", __func__, load_info.fname.c_str()); layer_idx = -1;
ggml_free(meta_ctx);
gguf_free(meta_ctx_gguf);
return result;
} }
} }
if (layer_idx < 0) {
fprintf(stderr, "%s: invalid/unparsable direction tensor layer index in %s\n", __func__, load_info.fname.c_str());
result.n_embd = -1;
break;
} else if (layer_idx == 0) {
fprintf(stderr, "%s: invalid (zero) direction tensor layer index in %s\n", __func__, load_info.fname.c_str());
result.n_embd = -1;
break;
}
struct ggml_tensor * tensor_meta = ggml_get_tensor(meta_ctx, name.c_str()); struct ggml_tensor * tensor = ggml_get_tensor(ctx, name.c_str());
if (tensor_meta->type != GGML_TYPE_F32 || ggml_n_dims(tensor_meta) != 1) { if (tensor->type != GGML_TYPE_F32) {
fprintf(stderr, "%s: direction tensor invalid in %s\n", __func__, load_info.fname.c_str()); fprintf(stderr, "%s: invalid (non-F32) direction tensor type in %s\n", __func__, load_info.fname.c_str());
ggml_free(meta_ctx); result.n_embd = -1;
gguf_free(meta_ctx_gguf); break;
return result;
} }
if (ggml_n_dims(tensor) != 1) {
fprintf(stderr, "%s: invalid (non-1D) direction tensor shape in %s\n", __func__, load_info.fname.c_str());
result.n_embd = -1;
break;
}
if (result.n_embd == -1) { if (result.n_embd == -1) {
result.n_embd = ggml_nelements(tensor_meta); result.n_embd = ggml_nelements(tensor);
} else if (ggml_nelements(tensor_meta) != result.n_embd) { } else if (ggml_nelements(tensor) != result.n_embd) {
fprintf(stderr, "%s: direction tensor sizes mismatched in %s\n", __func__, load_info.fname.c_str()); fprintf(stderr, "%s: direction tensor in %s does not match previous dimensions\n", __func__, load_info.fname.c_str());
ggml_free(meta_ctx); result.n_embd = -1;
gguf_free(meta_ctx_gguf); break;
return result;
}
n_bytes += ggml_nbytes(tensor_meta);
}
ggml_free(meta_ctx);
gguf_free(meta_ctx_gguf);
} }
if (n_tensors == 0) { // extend if necessary - do not store data for layer 0 (it's not used)
fprintf(stderr, "%s: no direction tensors found in %s\n", __func__, load_info.fname.c_str()); result.data.resize(std::max(result.data.size(), static_cast<size_t>(result.n_embd * layer_idx)), 0.0f);
return result;
}
// load and scale tensors into final control vector context
struct ggml_init_params ggml_params = {
/* .mem_size = */ ggml_tensor_overhead() * n_tensors + n_bytes,
/* .mem_buffer = */ nullptr,
/* .no_alloc = */ false,
};
struct ggml_context * ctx = ggml_init(ggml_params);
struct gguf_init_params params = {
/*.no_alloc = */ false,
/*.ctx = */ &ctx,
};
struct gguf_context * ctx_gguf = gguf_init_from_file(load_info.fname.c_str(), params);
if (!ctx_gguf) {
fprintf(stderr, "%s: failed to load control vector from %s\n", __func__, load_info.fname.c_str());
ggml_free(ctx);
return result;
}
// do not store data for layer 0 (it's not used)
result.data.resize(result.n_embd * max_direction_layer);
for (uint32_t il = 1; il <= max_direction_layer; il++) {
const std::string name = "direction." + std::to_string(il);
const ggml_tensor * tensor = ggml_get_tensor(ctx, name.c_str());
float * dst = result.data.data() + result.n_embd * (il - 1);
if (tensor) {
const float * src = (const float *) tensor->data; const float * src = (const float *) tensor->data;
float * dst = result.data.data() + result.n_embd * (layer_idx - 1); // layer 1 at [0]
for (int j = 0; j < result.n_embd; j++) { for (int j = 0; j < result.n_embd; j++) {
dst[j] = src[j] * load_info.strength; dst[j] += src[j] * load_info.strength; // allows multiple directions for same layer in same file
}
} else {
for (int j = 0; j < result.n_embd; j++) {
dst[j] = 0.0f;
} }
} }
if (result.n_embd == -1) {
fprintf(stderr, "%s: skipping %s due to invalid direction tensors\n", __func__, load_info.fname.c_str());
result.data.clear();
} }
gguf_free(ctx_gguf);
ggml_free(ctx);
return result; return result;
} }
@ -2933,16 +2895,19 @@ llama_control_vector_data llama_control_vector_load(const std::vector<llama_cont
auto cur = llama_control_vector_load_one(info); auto cur = llama_control_vector_load_one(info);
if (cur.n_embd == -1) { if (cur.n_embd == -1) {
return result; result.n_embd = -1;
break;
} }
if (result.n_embd != -1 && (result.n_embd != cur.n_embd || result.data.size() != cur.data.size())) { if (result.n_embd != -1 && result.n_embd != cur.n_embd) {
fprintf(stderr, "%s: control vector in %s does not match previous vector dimensions\n", __func__, info.fname.c_str()); fprintf(stderr, "%s: control vectors in %s does not match previous dimensions\n", __func__, info.fname.c_str());
return result; result.n_embd = -1;
break;
} }
if (result.n_embd == -1) { if (result.n_embd == -1) {
result = std::move(cur); result = std::move(cur);
} else { } else {
result.data.resize(std::max(result.data.size(), cur.data.size()), 0.0f); // extend if necessary
for (size_t i = 0; i < cur.data.size(); i++) { for (size_t i = 0; i < cur.data.size(); i++) {
result.data[i] += cur.data[i]; result.data[i] += cur.data[i];
} }
@ -2950,7 +2915,8 @@ llama_control_vector_data llama_control_vector_load(const std::vector<llama_cont
} }
if (result.n_embd == -1) { if (result.n_embd == -1) {
fprintf(stderr, "%s: no vectors passed\n", __func__); fprintf(stderr, "%s: no valid control vector files passed\n", __func__);
result.data.clear();
} }
return result; return result;