finetune : fix ggml_allocr lifetimes (tmp workaround) (#5033)

* Fix issue with alloc causing max_compute_size to be calculated

* remove ggml_allocr_free as suggested in issue #4791
This commit is contained in:
Uzo Nweke 2024-01-19 13:20:50 -05:00 committed by GitHub
parent a5cacb22b2
commit 381ee19572
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -263,7 +263,6 @@ static void init_model(struct my_llama_model * model) {
model->data.resize(size + tensor_alignment); model->data.resize(size + tensor_alignment);
alloc = ggml_allocr_new(model->data.data(), model->data.size(), tensor_alignment); alloc = ggml_allocr_new(model->data.data(), model->data.size(), tensor_alignment);
alloc_model(alloc, model); alloc_model(alloc, model);
ggml_allocr_free(alloc);
} }
static void randomize_model(struct my_llama_model * model, int seed, float mean, float std, float min, float max) { static void randomize_model(struct my_llama_model * model, int seed, float mean, float std, float min, float max) {
@ -1102,7 +1101,6 @@ int main(int argc, char ** argv) {
alloc = ggml_allocr_new(mem_input_data.data(), mem_input_data.size(), tensor_alignment); alloc = ggml_allocr_new(mem_input_data.data(), mem_input_data.size(), tensor_alignment);
ggml_allocr_alloc(alloc, tokens_input); ggml_allocr_alloc(alloc, tokens_input);
ggml_allocr_alloc(alloc, target_probs); ggml_allocr_alloc(alloc, target_probs);
ggml_allocr_free(alloc);
// context for compute tensors without their data // context for compute tensors without their data
const size_t estimated_compute_size_wo_data = ( const size_t estimated_compute_size_wo_data = (
@ -1149,7 +1147,6 @@ int main(int argc, char ** argv) {
best_compute_size = max_compute_size; best_compute_size = max_compute_size;
best_order = gf->order; best_order = gf->order;
} }
ggml_allocr_free(alloc);
ggml_free(ctx_compute); ggml_free(ctx_compute);
} }
size_t max_compute_size = best_compute_size; size_t max_compute_size = best_compute_size;
@ -1177,7 +1174,6 @@ int main(int argc, char ** argv) {
params.common.use_flash, params.common.use_flash,
params.common.use_checkpointing params.common.use_checkpointing
); );
ggml_allocr_free(alloc);
std::vector<llama_token> train_tokens; std::vector<llama_token> train_tokens;
std::vector<size_t> train_samples_begin; std::vector<size_t> train_samples_begin;