ggml : fix backward rope after YaRN (#3974)

* fix backward process of rope

rope backward process was broken after YaRN RoPE (#2268) implementation, due to missing changes in backward functions.

the code for the backward process is nearly identically to the forward process:
the only difference is the sign of the sin-values.

to avoid future regressions remove the near-duplicate backward functions and reuse the forward code:

for this a new function argument `bool forward` was added to `ggml_compute_forward_rope_f32` and `ggml_compute_forward_rope_f16`.
the sin-values will be negated when forward is false.

* fix finetune rope call to use correct default attn_factor of 1.0f

* remove unused `ggml_rope_xpos_back`

it is better to have only one `ggml_rope_back` function that accepts all rope parameters, so that `ggml_compute_backward` can propagate all parameters without having to switch between different rope_back variants.

* fix comments explaining the sinus sign in ggml_forward_rope

* add missing function arguments in declaration

* fix function argument type in declaration
This commit is contained in:
xaedes 2023-11-07 09:04:51 +01:00 committed by GitHub
parent 54b4df8886
commit e9c1cecb9d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 84 additions and 253 deletions

View File

@ -643,7 +643,7 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
return ggml_rope_custom(ctx, return ggml_rope_custom(ctx,
t, KQ_pos, n_rot, rope_mode, n_ctx, 0, t, KQ_pos, n_rot, rope_mode, n_ctx, 0,
rope_freq_base, rope_freq_scale, 0.0f, 0.0f, 0.0f, 0.0f rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f
); );
}; };

318
ggml.c
View File

@ -4970,8 +4970,13 @@ struct ggml_tensor * ggml_rope_back(
int n_dims, int n_dims,
int mode, int mode,
int n_ctx, int n_ctx,
int n_orig_ctx,
float freq_base, float freq_base,
float freq_scale, float freq_scale,
float ext_factor,
float attn_factor,
float beta_fast,
float beta_slow,
float xpos_base, float xpos_base,
bool xpos_down) { bool xpos_down) {
GGML_ASSERT(ggml_is_vector(b)); GGML_ASSERT(ggml_is_vector(b));
@ -4988,11 +4993,15 @@ struct ggml_tensor * ggml_rope_back(
struct ggml_tensor * result = ggml_dup_tensor(ctx, a); struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
int32_t params[8] = { /*n_past*/ 0, n_dims, mode, n_ctx }; int32_t params[13] = { /*n_past*/ 0, n_dims, mode, n_ctx, n_orig_ctx };
memcpy(params + 4, &freq_base, sizeof(float)); memcpy(params + 5, &freq_base, sizeof(float));
memcpy(params + 5, &freq_scale, sizeof(float)); memcpy(params + 6, &freq_scale, sizeof(float));
memcpy(params + 6, &xpos_base, sizeof(float)); memcpy(params + 7, &ext_factor, sizeof(float));
memcpy(params + 7, &xpos_down, sizeof(bool)); memcpy(params + 8, &attn_factor, sizeof(float));
memcpy(params + 9, &beta_fast, sizeof(float));
memcpy(params + 10, &beta_slow, sizeof(float));
memcpy(params + 11, &xpos_base, sizeof(float));
memcpy(params + 12, &xpos_down, sizeof(bool));
ggml_set_op_params(result, params, sizeof(params)); ggml_set_op_params(result, params, sizeof(params));
result->op = GGML_OP_ROPE_BACK; result->op = GGML_OP_ROPE_BACK;
@ -10974,7 +10983,8 @@ static void ggml_compute_forward_rope_f32(
const struct ggml_compute_params * params, const struct ggml_compute_params * params,
const struct ggml_tensor * src0, const struct ggml_tensor * src0,
const struct ggml_tensor * src1, const struct ggml_tensor * src1,
struct ggml_tensor * dst) { struct ggml_tensor * dst,
const bool forward) {
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return; return;
} }
@ -11033,6 +11043,11 @@ static void ggml_compute_forward_rope_f32(
const bool is_neox = mode & 2; const bool is_neox = mode & 2;
const bool is_glm = mode & 4; const bool is_glm = mode & 4;
// backward process uses inverse rotation by cos and sin.
// cos and sin build a rotation matrix, where the inverse is the transpose.
// this essentially just switches the sign of sin.
const float sin_sign = forward ? 1.0f : -1.0f;
const int32_t * pos = (const int32_t *) src1->data; const int32_t * pos = (const int32_t *) src1->data;
for (int64_t i3 = 0; i3 < ne3; i3++) { for (int64_t i3 = 0; i3 < ne3; i3++) {
@ -11049,9 +11064,9 @@ static void ggml_compute_forward_rope_f32(
float block_theta = MAX(p - (n_ctx - 2), 0); float block_theta = MAX(p - (n_ctx - 2), 0);
for (int64_t i0 = 0; i0 < ne0 / 4; i0++) { for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
const float cos_theta = cosf(theta_base); const float cos_theta = cosf(theta_base);
const float sin_theta = sinf(theta_base); const float sin_theta = sinf(theta_base) * sin_sign;
const float cos_block_theta = cosf(block_theta); const float cos_block_theta = cosf(block_theta);
const float sin_block_theta = sinf(block_theta); const float sin_block_theta = sinf(block_theta) * sin_sign;
theta_base *= theta_scale; theta_base *= theta_scale;
block_theta *= theta_scale; block_theta *= theta_scale;
@ -11075,6 +11090,7 @@ static void ggml_compute_forward_rope_f32(
rope_yarn( rope_yarn(
theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta
); );
sin_theta *= sin_sign;
// zeta scaling for xPos only: // zeta scaling for xPos only:
float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f; float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
@ -11105,6 +11121,7 @@ static void ggml_compute_forward_rope_f32(
theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
&cos_theta, &sin_theta &cos_theta, &sin_theta
); );
sin_theta *= sin_sign;
theta_base *= theta_scale; theta_base *= theta_scale;
@ -11130,7 +11147,8 @@ static void ggml_compute_forward_rope_f16(
const struct ggml_compute_params * params, const struct ggml_compute_params * params,
const struct ggml_tensor * src0, const struct ggml_tensor * src0,
const struct ggml_tensor * src1, const struct ggml_tensor * src1,
struct ggml_tensor * dst) { struct ggml_tensor * dst,
const bool forward) {
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return; return;
} }
@ -11182,6 +11200,11 @@ static void ggml_compute_forward_rope_f16(
const bool is_neox = mode & 2; const bool is_neox = mode & 2;
const bool is_glm = mode & 4; const bool is_glm = mode & 4;
// backward process uses inverse rotation by cos and sin.
// cos and sin build a rotation matrix, where the inverse is the transpose.
// this essentially just switches the sign of sin.
const float sin_sign = forward ? 1.0f : -1.0f;
const int32_t * pos = (const int32_t *) src1->data; const int32_t * pos = (const int32_t *) src1->data;
for (int64_t i3 = 0; i3 < ne3; i3++) { for (int64_t i3 = 0; i3 < ne3; i3++) {
@ -11198,9 +11221,9 @@ static void ggml_compute_forward_rope_f16(
float block_theta = MAX(p - (n_ctx - 2), 0); float block_theta = MAX(p - (n_ctx - 2), 0);
for (int64_t i0 = 0; i0 < ne0 / 4; i0++) { for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
const float cos_theta = cosf(theta_base); const float cos_theta = cosf(theta_base);
const float sin_theta = sinf(theta_base); const float sin_theta = sinf(theta_base) * sin_sign;
const float cos_block_theta = cosf(block_theta); const float cos_block_theta = cosf(block_theta);
const float sin_block_theta = sinf(block_theta); const float sin_block_theta = sinf(block_theta) * sin_sign;
theta_base *= theta_scale; theta_base *= theta_scale;
block_theta *= theta_scale; block_theta *= theta_scale;
@ -11224,6 +11247,7 @@ static void ggml_compute_forward_rope_f16(
rope_yarn( rope_yarn(
theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta
); );
sin_theta *= sin_sign;
theta_base *= theta_scale; theta_base *= theta_scale;
@ -11250,6 +11274,7 @@ static void ggml_compute_forward_rope_f16(
theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
&cos_theta, &sin_theta &cos_theta, &sin_theta
); );
sin_theta *= sin_sign;
theta_base *= theta_scale; theta_base *= theta_scale;
@ -11279,11 +11304,11 @@ static void ggml_compute_forward_rope(
switch (src0->type) { switch (src0->type) {
case GGML_TYPE_F16: case GGML_TYPE_F16:
{ {
ggml_compute_forward_rope_f16(params, src0, src1, dst); ggml_compute_forward_rope_f16(params, src0, src1, dst, true);
} break; } break;
case GGML_TYPE_F32: case GGML_TYPE_F32:
{ {
ggml_compute_forward_rope_f32(params, src0, src1, dst); ggml_compute_forward_rope_f32(params, src0, src1, dst, true);
} break; } break;
default: default:
{ {
@ -11294,216 +11319,6 @@ static void ggml_compute_forward_rope(
// ggml_compute_forward_rope_back // ggml_compute_forward_rope_back
static void ggml_compute_forward_rope_back_f32(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return;
}
// y = rope(x, src1)
// dx = rope_back(dy, src1)
// src0 is dy, src1 contains options
float freq_base;
float freq_scale;
// these two only relevant for xPos RoPE:
float xpos_base;
bool xpos_down;
//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
const int n_ctx = ((int32_t *) dst->op_params)[3]; UNUSED(n_ctx);
memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
memcpy(&xpos_base, (int32_t *) dst->op_params + 6, sizeof(float));
memcpy(&xpos_down, (int32_t *) dst->op_params + 7, sizeof(bool));
GGML_TENSOR_UNARY_OP_LOCALS
//printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
//printf("n_past = %d, ne2 = %d\n", n_past, ne2);
assert(nb0 == sizeof(float));
const int ith = params->ith;
const int nth = params->nth;
const int nr = ggml_nrows(dst);
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
// row index used to determine which thread to use
int ir = 0;
const float theta_scale = powf(freq_base, -2.0f/n_dims);
const bool is_neox = mode & 2;
const int32_t * pos = (const int32_t *) src1->data;
for (int64_t i3 = 0; i3 < ne3; i3++) {
for (int64_t i2 = 0; i2 < ne2; i2++) {
const int64_t p = pos[i2];
for (int64_t i1 = 0; i1 < ne1; i1++) {
if (ir++ < ir0) continue;
if (ir > ir1) break;
float theta_base = freq_scale * (float)p;
if (!is_neox) {
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
const float cos_theta = cosf(theta_base);
const float sin_theta = sinf(theta_base);
// zeta scaling for xPos only:
float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
if (xpos_down) zeta = 1.0f / zeta;
theta_base *= theta_scale;
const float * const dy = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
float * dx = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
const float dy0 = dy[0];
const float dy1 = dy[1];
dx[0] = dy0*cos_theta*zeta + dy1*sin_theta*zeta;
dx[1] = - dy0*sin_theta*zeta + dy1*cos_theta*zeta;
}
} else {
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
for (int64_t ic = 0; ic < n_dims; ic += 2) {
const float cos_theta = cosf(theta_base);
const float sin_theta = sinf(theta_base);
theta_base *= theta_scale;
const int64_t i0 = ib*n_dims + ic/2;
const float * const dy = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
float * dx = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
const float dy0 = dy[0];
const float dy1 = dy[n_dims/2];
dx[0] = dy0*cos_theta + dy1*sin_theta;
dx[n_dims/2] = - dy0*sin_theta + dy1*cos_theta;
}
}
}
}
}
}
}
static void ggml_compute_forward_rope_back_f16(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return;
}
// y = rope(x, src1)
// dx = rope_back(dy, src1)
// src0 is dy, src1 contains options
//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
GGML_TENSOR_UNARY_OP_LOCALS
//printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
//printf("n_past = %d, ne2 = %d\n", n_past, ne2);
assert(nb0 == sizeof(ggml_fp16_t));
const int ith = params->ith;
const int nth = params->nth;
const int nr = ggml_nrows(dst);
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
// row index used to determine which thread to use
int ir = 0;
const float theta_scale = powf(10000.0, -2.0f/n_dims);
const bool is_neox = mode & 2;
const int32_t * pos = (const int32_t *) src1->data;
for (int64_t i3 = 0; i3 < ne3; i3++) {
for (int64_t i2 = 0; i2 < ne2; i2++) {
const int64_t p = pos[i2];
for (int64_t i1 = 0; i1 < ne1; i1++) {
if (ir++ < ir0) continue;
if (ir > ir1) break;
float theta_base = (float)p;
if (!is_neox) {
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
const float cos_theta = cosf(theta_base);
const float sin_theta = sinf(theta_base);
theta_base *= theta_scale;
const ggml_fp16_t * const dy = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
ggml_fp16_t * dx = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
const float dy0 = GGML_FP16_TO_FP32(dy[0]);
const float dy1 = GGML_FP16_TO_FP32(dy[1]);
dx[0] = GGML_FP32_TO_FP16( dy0*cos_theta + dy1*sin_theta);
dx[1] = GGML_FP32_TO_FP16(-dy0*sin_theta + dy1*cos_theta);
}
} else {
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
for (int64_t ic = 0; ic < n_dims; ic += 2) {
const float cos_theta = cosf(theta_base);
const float sin_theta = sinf(theta_base);
theta_base *= theta_scale;
const int64_t i0 = ib*n_dims + ic/2;
const ggml_fp16_t * const dy = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
ggml_fp16_t * dx = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
const float dy0 = GGML_FP16_TO_FP32(dy[0]);
const float dy1 = GGML_FP16_TO_FP32(dy[n_dims/2]);
dx[0] = GGML_FP32_TO_FP16( dy0*cos_theta + dy1*sin_theta);
dx[n_dims/2] = GGML_FP32_TO_FP16(-dy0*sin_theta + dy1*cos_theta);
}
}
}
}
}
}
}
static void ggml_compute_forward_rope_back( static void ggml_compute_forward_rope_back(
const struct ggml_compute_params * params, const struct ggml_compute_params * params,
const struct ggml_tensor * src0, const struct ggml_tensor * src0,
@ -11512,11 +11327,11 @@ static void ggml_compute_forward_rope_back(
switch (src0->type) { switch (src0->type) {
case GGML_TYPE_F16: case GGML_TYPE_F16:
{ {
ggml_compute_forward_rope_back_f16(params, src0, src1, dst); ggml_compute_forward_rope_f16(params, src0, src1, dst, false);
} break; } break;
case GGML_TYPE_F32: case GGML_TYPE_F32:
{ {
ggml_compute_forward_rope_back_f32(params, src0, src1, dst); ggml_compute_forward_rope_f32(params, src0, src1, dst, false);
} break; } break;
default: default:
{ {
@ -15562,14 +15377,17 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
const int n_dims = ((int32_t *) tensor->op_params)[1]; const int n_dims = ((int32_t *) tensor->op_params)[1];
const int mode = ((int32_t *) tensor->op_params)[2]; const int mode = ((int32_t *) tensor->op_params)[2];
const int n_ctx = ((int32_t *) tensor->op_params)[3]; const int n_ctx = ((int32_t *) tensor->op_params)[3];
float freq_base; const int n_orig_ctx = ((int32_t *) tensor->op_params)[4];
float freq_scale; float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, xpos_base, xpos_down;
float xpos_base;
bool xpos_down; memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float));
memcpy(&freq_base, (int32_t *) tensor->op_params + 4, sizeof(float)); memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float));
memcpy(&freq_scale, (int32_t *) tensor->op_params + 5, sizeof(float)); memcpy(&ext_factor, (int32_t *) tensor->op_params + 7, sizeof(float));
memcpy(&xpos_base, (int32_t *) tensor->op_params + 6, sizeof(float)); memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float));
memcpy(&xpos_down, (int32_t *) tensor->op_params + 7, sizeof(bool)); memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float));
memcpy(&xpos_base, (int32_t *) tensor->op_params + 11, sizeof(float));
memcpy(&xpos_down, (int32_t *) tensor->op_params + 12, sizeof(bool));
src0->grad = ggml_add_or_set(ctx, src0->grad = ggml_add_or_set(ctx,
src0->grad, src0->grad,
@ -15579,8 +15397,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
n_dims, n_dims,
mode, mode,
n_ctx, n_ctx,
n_orig_ctx,
freq_base, freq_base,
freq_scale, freq_scale,
ext_factor,
attn_factor,
beta_fast,
beta_slow,
xpos_base, xpos_base,
xpos_down), xpos_down),
zero_table); zero_table);
@ -15593,14 +15416,17 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
const int n_dims = ((int32_t *) tensor->op_params)[1]; const int n_dims = ((int32_t *) tensor->op_params)[1];
const int mode = ((int32_t *) tensor->op_params)[2]; const int mode = ((int32_t *) tensor->op_params)[2];
const int n_ctx = ((int32_t *) tensor->op_params)[3]; const int n_ctx = ((int32_t *) tensor->op_params)[3];
float freq_base; const int n_orig_ctx = ((int32_t *) tensor->op_params)[4];
float freq_scale; float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, xpos_base, xpos_down;
float xpos_base;
bool xpos_down; memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float));
memcpy(&freq_base, (int32_t *) tensor->op_params + 4, sizeof(float)); memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float));
memcpy(&freq_scale, (int32_t *) tensor->op_params + 5, sizeof(float)); memcpy(&ext_factor, (int32_t *) tensor->op_params + 7, sizeof(float));
memcpy(&xpos_base, (int32_t *) tensor->op_params + 6, sizeof(float)); memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float));
memcpy(&xpos_down, (int32_t *) tensor->op_params + 7, sizeof(bool)); memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float));
memcpy(&xpos_base, (int32_t *) tensor->op_params + 11, sizeof(float));
memcpy(&xpos_down, (int32_t *) tensor->op_params + 12, sizeof(bool));
src0->grad = ggml_add_or_set(ctx, src0->grad = ggml_add_or_set(ctx,
src0->grad, src0->grad,
@ -15609,14 +15435,14 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
src1, src1,
n_dims, n_dims,
mode, mode,
0,
n_ctx, n_ctx,
n_orig_ctx,
freq_base, freq_base,
freq_scale, freq_scale,
0.0f, ext_factor,
1.0f, attn_factor,
0.0f, beta_fast,
0.0f, beta_slow,
xpos_base, xpos_base,
xpos_down, xpos_down,
false), false),

5
ggml.h
View File

@ -1372,8 +1372,13 @@ extern "C" {
int n_dims, int n_dims,
int mode, int mode,
int n_ctx, int n_ctx,
int n_orig_ctx,
float freq_base, float freq_base,
float freq_scale, float freq_scale,
float ext_factor,
float attn_factor,
float beta_fast,
float beta_slow,
float xpos_base, float xpos_base,
bool xpos_down); bool xpos_down);