ggml : fix ggml_flash_attn to use op_params (#2387)

* ggml : fix ggml_flash_attn to use op_params
This commit is contained in:
slaren 2023-07-25 16:20:12 +02:00 committed by GitHub
parent fce48caf9a
commit 07aaa0f63f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

18
ggml.c
View File

@ -7030,14 +7030,16 @@ struct ggml_tensor * ggml_flash_attn(
} }
//struct ggml_tensor * result = ggml_dup_tensor(ctx, q); //struct ggml_tensor * result = ggml_dup_tensor(ctx, q);
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, q->ne); struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, q->n_dims, q->ne);
int32_t t = masked ? 1 : 0;
ggml_set_op_params(result, &t, sizeof(t));
result->op = GGML_OP_FLASH_ATTN; result->op = GGML_OP_FLASH_ATTN;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = q; result->src[0] = q;
result->src[1] = k; result->src[1] = k;
result->src[2] = v; result->src[2] = v;
result->src[3] = ggml_new_i32(ctx, masked ? 1 : 0);
return result; return result;
} }
@ -7061,7 +7063,7 @@ struct ggml_tensor * ggml_flash_ff(
} }
//struct ggml_tensor * result = ggml_dup_tensor(ctx, a); //struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, a->ne); struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, a->n_dims, a->ne);
result->op = GGML_OP_FLASH_FF; result->op = GGML_OP_FLASH_FF;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@ -7127,13 +7129,15 @@ struct ggml_tensor * ggml_flash_attn_back(
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
int32_t masked_i = masked ? 1 : 0;
ggml_set_op_params(result, &masked_i, sizeof(masked_i));
result->op = GGML_OP_FLASH_ATTN_BACK; result->op = GGML_OP_FLASH_ATTN_BACK;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = q; result->src[0] = q;
result->src[1] = k; result->src[1] = k;
result->src[2] = v; result->src[2] = v;
result->src[3] = d; result->src[3] = d;
result->src[4] = ggml_new_i32(ctx, masked ? 1 : 0);
return result; return result;
} }
@ -14773,7 +14777,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
} break; } break;
case GGML_OP_FLASH_ATTN: case GGML_OP_FLASH_ATTN:
{ {
const int32_t t = ggml_get_i32_1d(tensor->src[3], 0); const int32_t t = ggml_get_op_params_i32(tensor, 0);
GGML_ASSERT(t == 0 || t == 1); GGML_ASSERT(t == 0 || t == 1);
const bool masked = t != 0; const bool masked = t != 0;
ggml_compute_forward_flash_attn(params, tensor->src[0], tensor->src[1], tensor->src[2], masked, tensor); ggml_compute_forward_flash_attn(params, tensor->src[0], tensor->src[1], tensor->src[2], masked, tensor);
@ -14784,7 +14788,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
} break; } break;
case GGML_OP_FLASH_ATTN_BACK: case GGML_OP_FLASH_ATTN_BACK:
{ {
int32_t t = ggml_get_i32_1d(tensor->src[4], 0); int32_t t = ggml_get_op_params_i32(tensor, 0);
GGML_ASSERT(t == 0 || t == 1); GGML_ASSERT(t == 0 || t == 1);
bool masked = t != 0; bool masked = t != 0;
ggml_compute_forward_flash_attn_back(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], masked, tensor); ggml_compute_forward_flash_attn_back(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], masked, tensor);
@ -15402,7 +15406,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{ {
struct ggml_tensor * flash_grad = NULL; struct ggml_tensor * flash_grad = NULL;
if (src0->grad || src1->grad || tensor->src[2]->grad) { if (src0->grad || src1->grad || tensor->src[2]->grad) {
int32_t t = ggml_get_i32_1d(tensor->src[3], 0); int32_t t = ggml_get_op_params_i32(tensor, 0);
GGML_ASSERT(t == 0 || t == 1); GGML_ASSERT(t == 0 || t == 1);
bool masked = t != 0; bool masked = t != 0;
flash_grad = flash_grad =