ggml : sync ggml (add GPT-NeoX RoPE implementation)

This commit is contained in:
Georgi Gerganov 2023-04-20 23:32:59 +03:00
parent 9ff334f3c9
commit 12b5900dbc
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
3 changed files with 49 additions and 17 deletions

34
ggml.c
View File

@ -8653,9 +8653,11 @@ static void ggml_compute_forward_rope_f32(
const float theta_scale = powf(10000.0, -2.0f/n_dims); const float theta_scale = powf(10000.0, -2.0f/n_dims);
const bool is_neox = mode & 2;
for (int64_t i3 = 0; i3 < ne3; i3++) { for (int64_t i3 = 0; i3 < ne3; i3++) {
for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) { for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
const int p = (mode == 0 ? n_past + i2 : i2); const int p = ((mode & 1) == 0 ? n_past + i2 : i2);
for (int64_t i1 = 0; i1 < ne1; i1++) { for (int64_t i1 = 0; i1 < ne1; i1++) {
if (ir++ < ir0) continue; if (ir++ < ir0) continue;
if (ir > ir1) break; if (ir > ir1) break;
@ -8668,6 +8670,7 @@ static void ggml_compute_forward_rope_f32(
theta *= theta_scale; theta *= theta_scale;
if (!is_neox) {
const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
@ -8676,6 +8679,16 @@ static void ggml_compute_forward_rope_f32(
dst_data[0] = x0*cos_theta - x1*sin_theta; dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[1] = x0*sin_theta + x1*cos_theta; dst_data[1] = x0*sin_theta + x1*cos_theta;
} else {
const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
const float x0 = src[0];
const float x1 = src[n_dims/2];
dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
}
} }
} }
} }
@ -8730,9 +8743,11 @@ static void ggml_compute_forward_rope_f16(
const float theta_scale = powf(10000.0, -2.0f/n_dims); const float theta_scale = powf(10000.0, -2.0f/n_dims);
const bool is_neox = mode & 2;
for (int64_t i3 = 0; i3 < ne3; i3++) { for (int64_t i3 = 0; i3 < ne3; i3++) {
for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) { for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
const int p = (mode == 0 ? n_past + i2 : i2); const int p = ((mode & 1) == 0 ? n_past + i2 : i2);
for (int64_t i1 = 0; i1 < ne1; i1++) { for (int64_t i1 = 0; i1 < ne1; i1++) {
if (ir++ < ir0) continue; if (ir++ < ir0) continue;
if (ir > ir1) break; if (ir > ir1) break;
@ -8745,6 +8760,7 @@ static void ggml_compute_forward_rope_f16(
theta *= theta_scale; theta *= theta_scale;
if (!is_neox) {
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
@ -8753,6 +8769,16 @@ static void ggml_compute_forward_rope_f16(
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
} else {
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
const float x0 = GGML_FP16_TO_FP32(src[0]);
const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
}
} }
} }
} }

3
ggml.h
View File

@ -630,7 +630,8 @@ struct ggml_tensor * ggml_soft_max(
// rotary position embedding // rotary position embedding
// in-place, returns view(a) // in-place, returns view(a)
// if mode == 1, skip n_past elements // if mode & 1 == 1, skip n_past elements
// if mode & 2 == 1, GPT-NeoX style
// TODO: avoid creating a new tensor every time // TODO: avoid creating a new tensor every time
struct ggml_tensor * ggml_rope( struct ggml_tensor * ggml_rope(
struct ggml_context * ctx, struct ggml_context * ctx,

View File

@ -1618,6 +1618,11 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
// quantize only 2D tensors // quantize only 2D tensors
quantize &= (tensor.ne.size() == 2); quantize &= (tensor.ne.size() == 2);
// GG: uncomment this to keep the output layer in FP16
//if (tensor.name.rfind("output")) {
// quantize = false;
//}
enum ggml_type new_type; enum ggml_type new_type;
void * new_data; void * new_data;
size_t new_size; size_t new_size;