Control vector loading fixes (#8137)

* Fixed leak in llama_control_vector_load_one() and allow llama_control_vector_load() to grow

* refactored `llama_control_vector_load_one()`

* allow multiple directions for same layer in same file

* llama_control_vector_load_one() and llama_control_vector_load() now break on error

* removed unnecessary ggml_free() call
This commit is contained in:
jukofyork 2024-06-27 15:48:07 +01:00 committed by GitHub
parent 387952651a
commit 97877eb10b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2804,125 +2804,87 @@ float llama_embd_similarity_cos(const float * embd1, const float * embd2, int n)
// //
static llama_control_vector_data llama_control_vector_load_one(const llama_control_vector_load_info & load_info) { static llama_control_vector_data llama_control_vector_load_one(const llama_control_vector_load_info & load_info) {
int32_t n_tensors;
size_t n_bytes = 0;
uint32_t max_direction_layer = 0;
llama_control_vector_data result = { -1, {} }; llama_control_vector_data result = { -1, {} };
// calculate size of ctx needed for tensors, ensure tensors are f32, and find max layer ggml_context * ctx = nullptr;
{ struct gguf_init_params meta_gguf_params = {
struct ggml_init_params meta_params = { /* .no_alloc = */ false,
/* .mem_size = */ ggml_tensor_overhead() * 128 + ggml_graph_overhead(), /* .ctx = */ &ctx,
/* .mem_buffer = */ nullptr, };
/* .no_alloc = */ true, struct gguf_context * ctx_gguf = gguf_init_from_file(load_info.fname.c_str(), meta_gguf_params);
}; if (!ctx_gguf) {
ggml_context * meta_ctx = ggml_init(meta_params); fprintf(stderr, "%s: failed to load control vector file from %s\n", __func__, load_info.fname.c_str());
struct gguf_init_params meta_gguf_params = { return result;
/* .no_alloc = */ true,
/* .ctx = */ &meta_ctx,
};
struct gguf_context * meta_ctx_gguf = gguf_init_from_file(load_info.fname.c_str(), meta_gguf_params);
if (!meta_ctx_gguf) {
fprintf(stderr, "%s: failed to load control vector from %s\n", __func__, load_info.fname.c_str());
ggml_free(meta_ctx);
return result;
}
n_tensors = gguf_get_n_tensors(meta_ctx_gguf);
for (int i = 0; i < n_tensors; i++) {
std::string name = gguf_get_tensor_name(meta_ctx_gguf, i);
// split on '.'
size_t dotpos = name.find('.');
if (dotpos != std::string::npos && name.substr(0, dotpos) == "direction") {
try {
uint32_t layer = std::stoi(name.substr(dotpos + 1));
if (layer == 0) {
fprintf(stderr, "%s: direction tensor invalid in %s\n", __func__, load_info.fname.c_str());
ggml_free(meta_ctx);
gguf_free(meta_ctx_gguf);
return result;
}
if (layer > max_direction_layer) {
max_direction_layer = layer;
}
} catch (...) {
fprintf(stderr, "%s: direction tensor invalid in %s\n", __func__, load_info.fname.c_str());
ggml_free(meta_ctx);
gguf_free(meta_ctx_gguf);
return result;
}
}
struct ggml_tensor * tensor_meta = ggml_get_tensor(meta_ctx, name.c_str());
if (tensor_meta->type != GGML_TYPE_F32 || ggml_n_dims(tensor_meta) != 1) {
fprintf(stderr, "%s: direction tensor invalid in %s\n", __func__, load_info.fname.c_str());
ggml_free(meta_ctx);
gguf_free(meta_ctx_gguf);
return result;
}
if (result.n_embd == -1) {
result.n_embd = ggml_nelements(tensor_meta);
} else if (ggml_nelements(tensor_meta) != result.n_embd) {
fprintf(stderr, "%s: direction tensor sizes mismatched in %s\n", __func__, load_info.fname.c_str());
ggml_free(meta_ctx);
gguf_free(meta_ctx_gguf);
return result;
}
n_bytes += ggml_nbytes(tensor_meta);
}
ggml_free(meta_ctx);
gguf_free(meta_ctx_gguf);
} }
int32_t n_tensors = gguf_get_n_tensors(ctx_gguf);
if (n_tensors == 0) { if (n_tensors == 0) {
fprintf(stderr, "%s: no direction tensors found in %s\n", __func__, load_info.fname.c_str()); fprintf(stderr, "%s: no direction tensors found in %s\n", __func__, load_info.fname.c_str());
return result;
} }
// load and scale tensors into final control vector context for (int i = 0; i < n_tensors; i++) {
struct ggml_init_params ggml_params = { std::string name = gguf_get_tensor_name(ctx_gguf, i);
/* .mem_size = */ ggml_tensor_overhead() * n_tensors + n_bytes,
/* .mem_buffer = */ nullptr,
/* .no_alloc = */ false,
};
struct ggml_context * ctx = ggml_init(ggml_params);
struct gguf_init_params params = { int layer_idx = -1;
/*.no_alloc = */ false,
/*.ctx = */ &ctx,
};
struct gguf_context * ctx_gguf = gguf_init_from_file(load_info.fname.c_str(), params);
if (!ctx_gguf) {
fprintf(stderr, "%s: failed to load control vector from %s\n", __func__, load_info.fname.c_str());
ggml_free(ctx);
return result;
}
// do not store data for layer 0 (it's not used) // split on '.'
result.data.resize(result.n_embd * max_direction_layer); size_t dotpos = name.find('.');
if (dotpos != std::string::npos && name.substr(0, dotpos) == "direction") {
for (uint32_t il = 1; il <= max_direction_layer; il++) { try {
const std::string name = "direction." + std::to_string(il); layer_idx = std::stoi(name.substr(dotpos + 1));
const ggml_tensor * tensor = ggml_get_tensor(ctx, name.c_str()); } catch (...) {
layer_idx = -1;
float * dst = result.data.data() + result.n_embd * (il - 1);
if (tensor) {
const float * src = (const float *) tensor->data;
for (int j = 0; j < result.n_embd; j++) {
dst[j] = src[j] * load_info.strength;
}
} else {
for (int j = 0; j < result.n_embd; j++) {
dst[j] = 0.0f;
} }
} }
if (layer_idx < 0) {
fprintf(stderr, "%s: invalid/unparsable direction tensor layer index in %s\n", __func__, load_info.fname.c_str());
result.n_embd = -1;
break;
} else if (layer_idx == 0) {
fprintf(stderr, "%s: invalid (zero) direction tensor layer index in %s\n", __func__, load_info.fname.c_str());
result.n_embd = -1;
break;
}
struct ggml_tensor * tensor = ggml_get_tensor(ctx, name.c_str());
if (tensor->type != GGML_TYPE_F32) {
fprintf(stderr, "%s: invalid (non-F32) direction tensor type in %s\n", __func__, load_info.fname.c_str());
result.n_embd = -1;
break;
}
if (ggml_n_dims(tensor) != 1) {
fprintf(stderr, "%s: invalid (non-1D) direction tensor shape in %s\n", __func__, load_info.fname.c_str());
result.n_embd = -1;
break;
}
if (result.n_embd == -1) {
result.n_embd = ggml_nelements(tensor);
} else if (ggml_nelements(tensor) != result.n_embd) {
fprintf(stderr, "%s: direction tensor in %s does not match previous dimensions\n", __func__, load_info.fname.c_str());
result.n_embd = -1;
break;
}
// extend if necessary - do not store data for layer 0 (it's not used)
result.data.resize(std::max(result.data.size(), static_cast<size_t>(result.n_embd * layer_idx)), 0.0f);
const float * src = (const float *) tensor->data;
float * dst = result.data.data() + result.n_embd * (layer_idx - 1); // layer 1 at [0]
for (int j = 0; j < result.n_embd; j++) {
dst[j] += src[j] * load_info.strength; // allows multiple directions for same layer in same file
}
} }
if (result.n_embd == -1) {
fprintf(stderr, "%s: skipping %s due to invalid direction tensors\n", __func__, load_info.fname.c_str());
result.data.clear();
}
gguf_free(ctx_gguf);
ggml_free(ctx);
return result; return result;
} }
@ -2933,16 +2895,19 @@ llama_control_vector_data llama_control_vector_load(const std::vector<llama_cont
auto cur = llama_control_vector_load_one(info); auto cur = llama_control_vector_load_one(info);
if (cur.n_embd == -1) { if (cur.n_embd == -1) {
return result; result.n_embd = -1;
break;
} }
if (result.n_embd != -1 && (result.n_embd != cur.n_embd || result.data.size() != cur.data.size())) { if (result.n_embd != -1 && result.n_embd != cur.n_embd) {
fprintf(stderr, "%s: control vector in %s does not match previous vector dimensions\n", __func__, info.fname.c_str()); fprintf(stderr, "%s: control vectors in %s does not match previous dimensions\n", __func__, info.fname.c_str());
return result; result.n_embd = -1;
break;
} }
if (result.n_embd == -1) { if (result.n_embd == -1) {
result = std::move(cur); result = std::move(cur);
} else { } else {
result.data.resize(std::max(result.data.size(), cur.data.size()), 0.0f); // extend if necessary
for (size_t i = 0; i < cur.data.size(); i++) { for (size_t i = 0; i < cur.data.size(); i++) {
result.data[i] += cur.data[i]; result.data[i] += cur.data[i];
} }
@ -2950,7 +2915,8 @@ llama_control_vector_data llama_control_vector_load(const std::vector<llama_cont
} }
if (result.n_embd == -1) { if (result.n_embd == -1) {
fprintf(stderr, "%s: no vectors passed\n", __func__); fprintf(stderr, "%s: no valid control vector files passed\n", __func__);
result.data.clear();
} }
return result; return result;