mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 02:44:36 +00:00
llama2c : fix segfault and alloc-dealloc-mismatch (#2913)
* llama2c : fix segfault if vocab is not found * llama2c : fix mismatch between new[] and delete * llama2c : fix basename on Windows * llama2c : use a destructor to prevent memory leaks
This commit is contained in:
parent
e8d9158925
commit
18705a30ef
@ -75,7 +75,7 @@ typedef struct {
|
|||||||
int seq_len; // max sequence length
|
int seq_len; // max sequence length
|
||||||
} Config;
|
} Config;
|
||||||
|
|
||||||
typedef struct {
|
struct TransformerWeights {
|
||||||
// token embedding table
|
// token embedding table
|
||||||
float* token_embedding_table; // (vocab_size, dim)
|
float* token_embedding_table; // (vocab_size, dim)
|
||||||
// weights for rmsnorms
|
// weights for rmsnorms
|
||||||
@ -97,7 +97,22 @@ typedef struct {
|
|||||||
// float* freq_cis_imag; // (seq_len, dim/2)
|
// float* freq_cis_imag; // (seq_len, dim/2)
|
||||||
// (optional) classifier weights for the logits, on the last layer
|
// (optional) classifier weights for the logits, on the last layer
|
||||||
float* wcls;
|
float* wcls;
|
||||||
} TransformerWeights;
|
|
||||||
|
~TransformerWeights() {
|
||||||
|
delete[] token_embedding_table;
|
||||||
|
delete[] rms_att_weight;
|
||||||
|
delete[] rms_ffn_weight;
|
||||||
|
delete[] wq;
|
||||||
|
delete[] wk;
|
||||||
|
delete[] wv;
|
||||||
|
delete[] wo;
|
||||||
|
delete[] w1;
|
||||||
|
delete[] w2;
|
||||||
|
delete[] w3;
|
||||||
|
delete[] rms_final_weight;
|
||||||
|
delete[] wcls;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
void malloc_weights(TransformerWeights* w, Config* p, bool shared_weights) {
|
void malloc_weights(TransformerWeights* w, Config* p, bool shared_weights) {
|
||||||
// we calloc instead of malloc to keep valgrind happy
|
// we calloc instead of malloc to keep valgrind happy
|
||||||
@ -173,21 +188,6 @@ int checkpoint_init_weights(TransformerWeights *w, Config* p, FILE* f, bool shar
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
void free_weights(TransformerWeights* w) {
|
|
||||||
delete w->token_embedding_table;
|
|
||||||
delete w->rms_att_weight;
|
|
||||||
delete w->rms_ffn_weight;
|
|
||||||
delete w->wq;
|
|
||||||
delete w->wk;
|
|
||||||
delete w->wv;
|
|
||||||
delete w->wo;
|
|
||||||
delete w->w1;
|
|
||||||
delete w->w2;
|
|
||||||
delete w->w3;
|
|
||||||
delete w->rms_final_weight;
|
|
||||||
if (w->wcls) delete w->wcls;
|
|
||||||
}
|
|
||||||
|
|
||||||
void print_sample_weights(TransformerWeights *w){
|
void print_sample_weights(TransformerWeights *w){
|
||||||
printf("----- Quick print of first of the weight vales of all the variables\n");
|
printf("----- Quick print of first of the weight vales of all the variables\n");
|
||||||
printf("%f\n", w->token_embedding_table[0]);
|
printf("%f\n", w->token_embedding_table[0]);
|
||||||
@ -596,6 +596,10 @@ void load_vocab(const char *filename, Config *config, struct llama_vocab *vocab)
|
|||||||
// assume llama2.c vocabulary
|
// assume llama2.c vocabulary
|
||||||
printf("Assuming llama2.c vocabulary since %s is not a gguf file\n", filename);
|
printf("Assuming llama2.c vocabulary since %s is not a gguf file\n", filename);
|
||||||
llama_file file(filename, "rb");
|
llama_file file(filename, "rb");
|
||||||
|
if (!file.fp) {
|
||||||
|
fprintf(stderr, "error: %s: %s\n", strerror(errno), filename);
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
const int n_vocab = config->vocab_size;
|
const int n_vocab = config->vocab_size;
|
||||||
/* uint32_t max_token_length = */ file.read_u32(); // unused
|
/* uint32_t max_token_length = */ file.read_u32(); // unused
|
||||||
vocab->id_to_token.resize(n_vocab);
|
vocab->id_to_token.resize(n_vocab);
|
||||||
@ -898,7 +902,7 @@ bool params_parse(int argc, char ** argv, struct train_params * params) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::string basename(const std::string &path) {
|
std::string basename(const std::string &path) {
|
||||||
size_t pos = path.find_last_of("/");
|
size_t pos = path.find_last_of("/\\");
|
||||||
if (pos == std::string::npos) {
|
if (pos == std::string::npos) {
|
||||||
return path;
|
return path;
|
||||||
}
|
}
|
||||||
@ -911,7 +915,7 @@ int main(int argc, char ** argv) {
|
|||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
Config config;
|
Config config;
|
||||||
TransformerWeights weights;
|
TransformerWeights weights = {};
|
||||||
{
|
{
|
||||||
FILE *file = fopen(params.fn_llama2c_model, "rb");
|
FILE *file = fopen(params.fn_llama2c_model, "rb");
|
||||||
if (!file) { printf("Unable to open the checkpoint file %s!\n", params.fn_llama2c_model); return 1; }
|
if (!file) { printf("Unable to open the checkpoint file %s!\n", params.fn_llama2c_model); return 1; }
|
||||||
@ -953,6 +957,5 @@ int main(int argc, char ** argv) {
|
|||||||
printf("Saving llama.c model file %s in ggml format at %s\n", params.fn_llama2c_model, params.fn_llama2c_output_model);
|
printf("Saving llama.c model file %s in ggml format at %s\n", params.fn_llama2c_model, params.fn_llama2c_output_model);
|
||||||
|
|
||||||
ggml_free(model.ctx);
|
ggml_free(model.ctx);
|
||||||
free_weights(&weights);
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user