gguf : write metadata in gguf_file_saver

This commit is contained in:
M. Yusuf Sarıgöz 2023-08-11 20:07:43 +03:00
parent 781b9ec3f5
commit d09fd10713
3 changed files with 38 additions and 1 deletions

4
ggml.c
View File

@ -19035,6 +19035,10 @@ enum gguf_type gguf_get_kv_type(struct gguf_context * ctx, int i) {
return ctx->header.kv[i].type;
}
enum gguf_type gguf_get_arr_type(struct gguf_context * ctx, int i) {
return ctx->header.kv[i].value.arr.type;
}
const char * gguf_get_arr_str(struct gguf_context * ctx, int key_id, int i) {
struct gguf_kv * kv = &ctx->header.kv[key_id];
struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[i];

1
ggml.h
View File

@ -1748,6 +1748,7 @@ extern "C" {
GGML_API int gguf_find_key(struct gguf_context * ctx, const char * key);
GGML_API const char * gguf_get_key (struct gguf_context * ctx, int i);
GGML_API enum gguf_type gguf_get_kv_type (struct gguf_context * ctx, int i);
GGML_API enum gguf_type gguf_get_arr_type (struct gguf_context * ctx, int i);
GGML_API void gguf_get_val (struct gguf_context * ctx, int i, void * val);
GGML_API const char * gguf_get_arr_str(struct gguf_context * ctx, int key_id, int i);

View File

@ -647,6 +647,28 @@ struct gguf_file_saver {
file.write_i32(n_kv);
}
void write_hparam_arr_str(const std::string & key, enum gguf_type type, int i, int n_arr) {
std::vector<std::string> data(n_arr);
for (int j = 0; j < n_arr; ++j) {
std::string val = gguf_get_arr_str(any_file_loader->gguf_ctx, i, j);
data[j] = val;
}
file.write_arr<std::string>(key, type, data);
}
void write_hparam_arr_f32(const std::string & key, enum gguf_type type, int i, int n_arr) {
std::vector<float> data(n_arr);
for (int j = 0; j < n_arr; ++j) {
float val = gguf_get_arr_f32(any_file_loader->gguf_ctx, i, j);
data[j] = val;
}
file.write_arr<float>(key, type, data);
}
void write_hparams(enum llama_ftype new_ftype) {
const int32_t n_kv = gguf_get_n_kv(any_file_loader->gguf_ctx);
for (int i = 0; i < n_kv; ++i) {
@ -665,7 +687,8 @@ struct gguf_file_saver {
uint16_t u16_val;
uint32_t u32_val;
uint8_t u8_val;
gguf_type arr_type;
int n_arr;
switch(vtype) {
case GGUF_TYPE_BOOL:
@ -705,6 +728,15 @@ struct gguf_file_saver {
file.write_val<uint8_t>(key, GGUF_TYPE_UINT8, u8_val);
break;
case GGUF_TYPE_ARRAY:
arr_type = gguf_get_arr_type(any_file_loader->gguf_ctx, i);
n_arr = gguf_get_arr_n(any_file_loader->gguf_ctx, i);
if (arr_type == GGUF_TYPE_FLOAT32) {
write_hparam_arr_f32(key, arr_type, i, n_arr);
} else if (arr_type == GGUF_TYPE_STRING) {
write_hparam_arr_str(key, GGUF_TYPE_STRING, i, n_arr);
} else {
throw std::runtime_error("not implemented");
}
break;
default:
throw std::runtime_error(format("cannot recognize value type for key %s\n", key));