diff --git a/ggml.c b/ggml.c index e00f09fa4..c8fa60328 100644 --- a/ggml.c +++ b/ggml.c @@ -19035,6 +19035,10 @@ enum gguf_type gguf_get_kv_type(struct gguf_context * ctx, int i) { return ctx->header.kv[i].type; } +enum gguf_type gguf_get_arr_type(struct gguf_context * ctx, int i) { + return ctx->header.kv[i].value.arr.type; +} + const char * gguf_get_arr_str(struct gguf_context * ctx, int key_id, int i) { struct gguf_kv * kv = &ctx->header.kv[key_id]; struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[i]; diff --git a/ggml.h b/ggml.h index 9a266e175..fb3db10e2 100644 --- a/ggml.h +++ b/ggml.h @@ -1748,6 +1748,7 @@ extern "C" { GGML_API int gguf_find_key(struct gguf_context * ctx, const char * key); GGML_API const char * gguf_get_key (struct gguf_context * ctx, int i); GGML_API enum gguf_type gguf_get_kv_type (struct gguf_context * ctx, int i); + GGML_API enum gguf_type gguf_get_arr_type (struct gguf_context * ctx, int i); GGML_API void gguf_get_val (struct gguf_context * ctx, int i, void * val); GGML_API const char * gguf_get_arr_str(struct gguf_context * ctx, int key_id, int i); diff --git a/gguf-llama.cpp b/gguf-llama.cpp index e70cae44c..27e0b5d43 100644 --- a/gguf-llama.cpp +++ b/gguf-llama.cpp @@ -647,6 +647,28 @@ struct gguf_file_saver { file.write_i32(n_kv); } + void write_hparam_arr_str(const std::string & key, enum gguf_type type, int i, int n_arr) { + std::vector data(n_arr); + + for (int j = 0; j < n_arr; ++j) { + std::string val = gguf_get_arr_str(any_file_loader->gguf_ctx, i, j); + data[j] = val; + } + + file.write_arr(key, type, data); + } + + void write_hparam_arr_f32(const std::string & key, enum gguf_type type, int i, int n_arr) { + std::vector data(n_arr); + + for (int j = 0; j < n_arr; ++j) { + float val = gguf_get_arr_f32(any_file_loader->gguf_ctx, i, j); + data[j] = val; + } + + file.write_arr(key, type, data); + } + void write_hparams(enum llama_ftype new_ftype) { const int32_t n_kv = gguf_get_n_kv(any_file_loader->gguf_ctx); for (int i = 0; i < n_kv; ++i) { @@ -665,7 +687,8 @@ struct gguf_file_saver { uint16_t u16_val; uint32_t u32_val; uint8_t u8_val; - + gguf_type arr_type; + int n_arr; switch(vtype) { case GGUF_TYPE_BOOL: @@ -705,6 +728,15 @@ struct gguf_file_saver { file.write_val(key, GGUF_TYPE_UINT8, u8_val); break; case GGUF_TYPE_ARRAY: + arr_type = gguf_get_arr_type(any_file_loader->gguf_ctx, i); + n_arr = gguf_get_arr_n(any_file_loader->gguf_ctx, i); + if (arr_type == GGUF_TYPE_FLOAT32) { + write_hparam_arr_f32(key, arr_type, i, n_arr); + } else if (arr_type == GGUF_TYPE_STRING) { + write_hparam_arr_str(key, GGUF_TYPE_STRING, i, n_arr); + } else { + throw std::runtime_error("not implemented"); + } break; default: throw std::runtime_error(format("cannot recognize value type for key %s\n", key));