ggml : alternative fix for race condition bug in non-inplace ggml_compute_forward_diag_mask_f32 (#1454)

* fix race condition bug in non-inplace ggml_compute_forward_diag_mask_f32

memcpy needs to be synchronized across threads to avoid race conditions.
=> do it in INIT phase

* remove trailing whitespace

* Update ggml.c

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
xaedes 2023-05-14 17:55:02 +02:00 committed by GitHub
parent 13c351ad72
commit 79b2d5b69d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

40
ggml.c
View File

@ -10501,34 +10501,28 @@ static void ggml_compute_forward_diag_mask_f32(
assert(src1->type == GGML_TYPE_I32); assert(src1->type == GGML_TYPE_I32);
assert(ggml_nelements(src1) == 2); assert(ggml_nelements(src1) == 2);
const int n_past = ((int32_t *) src1->data)[0];
const bool inplace = (bool)((int32_t *) src1->data)[1];
if (params->type == GGML_TASK_INIT) {
// TODO: this hack is not good, need a better way to handle this
if (!inplace) {
// use the init task to copy src -> dst
struct ggml_compute_params params_cpy = *params;
params_cpy.ith = 0;
params_cpy.nth = 1;
params_cpy.type = GGML_TASK_COMPUTE;
ggml_compute_forward_dup_same_cont(&params_cpy, src0, dst);
}
return;
}
if (params->type == GGML_TASK_FINALIZE) {
return;
}
const int ith = params->ith; const int ith = params->ith;
const int nth = params->nth; const int nth = params->nth;
const int n_past = ((int32_t *) src1->data)[0];
const bool inplace = (bool)((int32_t *) src1->data)[1];
assert(n_past >= 0); assert(n_past >= 0);
if (!inplace && (params->type == GGML_TASK_INIT)) {
// memcpy needs to be synchronized across threads to avoid race conditions.
// => do it in INIT phase
GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
memcpy(
((char *) dst->data),
((char *) src0->data),
ggml_nbytes(dst));
}
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return;
}
// TODO: handle transposed/permuted matrices // TODO: handle transposed/permuted matrices
const int n = ggml_nrows(src0); const int n = ggml_nrows(src0);