cuda : less diff in the rope_neox kernel

This commit is contained in:
Georgi Gerganov 2023-12-17 09:14:29 +02:00
parent f703ca8a3c
commit 42e9525884
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -4998,29 +4998,31 @@ static __global__ void rope_neox(
const int ib = col / n_dims; const int ib = col / n_dims;
const int ic = col % n_dims; const int ic = col % n_dims;
if (ib == 0) { if (ib > 0) {
const int i = row*ncols + ib*n_dims + ic/2;
const int i2 = row/p_delta_rows;
float cur_rot = inv_ndims * ic - ib;
const int p = has_pos ? pos[i2] : 0;
const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f);
float cos_theta, sin_theta;
rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
const float x0 = x[i + 0];
const float x1 = x[i + n_dims/2];
dst[i + 0] = x0*cos_theta - x1*sin_theta;
dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
} else {
const int i = row*ncols + ib*n_dims + ic; const int i = row*ncols + ib*n_dims + ic;
dst[i + 0] = x[i + 0]; dst[i + 0] = x[i + 0];
dst[i + 1] = x[i + 1]; dst[i + 1] = x[i + 1];
return;
} }
const int i = row*ncols + ib*n_dims + ic/2;
const int i2 = row/p_delta_rows;
float cur_rot = inv_ndims * ic - ib;
const int p = has_pos ? pos[i2] : 0;
const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f);
float cos_theta, sin_theta;
rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
const float x0 = x[i + 0];
const float x1 = x[i + n_dims/2];
dst[i + 0] = x0*cos_theta - x1*sin_theta;
dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
} }
static __global__ void rope_glm_f32( static __global__ void rope_glm_f32(