From 102cd98074f62aab00b1e591bab8ac48f01113b7 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 23 Apr 2023 14:44:36 +0300 Subject: [PATCH] ggml : Q4_3c using 2x "Full range" approach --- ggml-cuda.cu | 34 ++++-- ggml.c | 331 +++++++++++++++++++++++---------------------------- 2 files changed, 168 insertions(+), 197 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index fa511c1dc..2c2617626 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -31,8 +31,8 @@ static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2 #define QK4_3 16 typedef struct { - __half d; // delta - __half m; // min + __half d0; // delta + __half d1; // delta uint8_t qs[QK4_3 / 2]; // nibbles / quants } block_q4_3; static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding"); @@ -112,22 +112,32 @@ static __global__ void dequantize_block_q4_3(const void * vx, float * y) { const int i = blockIdx.x; - const float d = x[i].d; - const float m = x[i].m; + const float d0 = x[i].d0; + const float d1 = x[i].d1; const uint8_t * pp = x[i].qs; - for (int l = 0; l < QK4_3; l += 2) { - const uint8_t vi = pp[l/2]; + for (int l = 0; l < QK4_3/2; l += 2) { + const uint8_t vi0 = pp[l/2]; + const uint8_t vi1 = pp[l/2 + QK4_3/4]; - const int8_t vi0 = vi & 0xf; - const int8_t vi1 = vi >> 4; + const int8_t vi0_0 = vi0 & 0xf; + const int8_t vi0_1 = vi0 >> 4; - const float v0 = vi0*d + m; - const float v1 = vi1*d + m; + const int8_t vi1_0 = vi1 & 0xf; + const int8_t vi1_1 = vi1 >> 4; - y[i*QK4_3 + l + 0] = v0; - y[i*QK4_3 + l + 1] = v1; + const float v0_0 = (vi0_0 - 8)*d0; + const float v0_1 = (vi0_1 - 8)*d0; + + const float v1_0 = (vi1_0 - 8)*d1; + const float v1_1 = (vi1_1 - 8)*d1; + + y[i*QK4_3 + l + 0] = v0_0; + y[i*QK4_3 + l + 1] = v0_1; + + y[i*QK4_3 + l + 0 + QK4_3/2] = v1_0; + y[i*QK4_3 + l + 1 + QK4_3/2] = v1_1; } } diff --git a/ggml.c b/ggml.c index e73c098bc..d04cb4c28 100644 --- a/ggml.c +++ b/ggml.c @@ -655,8 +655,8 @@ static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2 #define QK4_3 16 typedef struct { - ggml_fp16_t d; // delta - ggml_fp16_t m; // min + ggml_fp16_t d0; // delta + ggml_fp16_t d1; // min uint8_t qs[QK4_3 / 2]; // nibbles / quants } block_q4_3; static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding"); @@ -1219,93 +1219,12 @@ static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * r } } -static inline int nearest_int(float fval) { - assert(fval <= 4194303.f); - float val = fval + 12582912.f; - int i; memcpy(&i, &val, sizeof(int)); - return (i & 0x007fffff) - 0x00400000; -} - -static float kquantize_q4_with_bounds(int n, int nmin, int nmax, const float * restrict X, int nCandidates, - const float * restrict candidates, int8_t * restrict L) { - assert (nmin >= INT8_MIN); - assert (nmax <= INT8_MAX); - float amax = 0; - for (int i=0; i sumlxM2*suml2P) { - if (sumlxP2 > best*suml2P) { - best = sumlxP2/suml2P; bestScale = iscale; - } - } else { - if (sumlxM2 > best*suml2M) { - best = sumlxM2/suml2M; bestScale = -iscale; - } - } - } - float sumlx = 0; int suml2 = 0; - for (int i=0; i max) max = v; + for (int l = 0; l < QK4_3/2; l++) { + const float v0 = x[i*QK4_3 + l]; + const float v1 = x[i*QK4_3 + l + QK4_3/2]; + + if (amax0 < fabsf(v0)) { + amax0 = fabsf(v0); + max0 = v0; + } + + if (amax1 < fabsf(v1)) { + amax1 = fabsf(v1); + max1 = v1; + } } - const float d = (max - min) / ((1 << 4) - 1); - const float id = d ? 1.0f/d : 0.0f; + const float d0 = max0 / -8; + const float d1 = max1 / -8; - y[i].d = GGML_FP32_TO_FP16(d); - y[i].m = GGML_FP32_TO_FP16(min); + const float id0 = d0 ? 1.0f/d0 : 0.0f; + const float id1 = d1 ? 1.0f/d1 : 0.0f; - for (int l = 0; l < QK4_3; l += 2) { - const float v0 = (x[i*QK4_3 + l + 0] - min)*id; - const float v1 = (x[i*QK4_3 + l + 1] - min)*id; + y[i].d0 = GGML_FP32_TO_FP16(d0); + y[i].d1 = GGML_FP32_TO_FP16(d1); - const uint8_t vi0 = (int) (v0 + 0.5f); - const uint8_t vi1 = (int) (v1 + 0.5f); + for (int l = 0; l < QK4_3/2; l += 2) { + const float v0_0 = x[i*QK4_3 + l + 0]*id0; + const float v0_1 = x[i*QK4_3 + l + 1]*id0; - assert(vi0 < 16); - assert(vi1 < 16); + const float v1_0 = x[i*QK4_3 + l + 0 + QK4_3/2]*id1; + const float v1_1 = x[i*QK4_3 + l + 1 + QK4_3/2]*id1; - y[i].qs[l/2] = vi0 | (vi1 << 4); + const uint8_t vi0_0 = MIN(15, (int8_t)roundf(v0_0) + 8); + const uint8_t vi0_1 = MIN(15, (int8_t)roundf(v0_1) + 8); + + const uint8_t vi1_0 = MIN(15, (int8_t)roundf(v1_0) + 8); + const uint8_t vi1_1 = MIN(15, (int8_t)roundf(v1_1) + 8); + + y[i].qs[l/2 ] = vi0_0 | (vi0_1 << 4); + y[i].qs[l/2 + QK4_3/4] = vi1_0 | (vi1_1 << 4); } } } @@ -1810,25 +1747,32 @@ static void dequantize_row_q4_3(const void * restrict vx, float * restrict y, in const block_q4_3 * restrict x = vx; for (int i = 0; i < nb; i++) { - const float d = GGML_FP16_TO_FP32(x[i].d); - const float m = GGML_FP16_TO_FP32(x[i].m); + const float d0 = GGML_FP16_TO_FP32(x[i].d0); + const float d1 = GGML_FP16_TO_FP32(x[i].d1); const uint8_t * restrict pp = x[i].qs; - for (int l = 0; l < QK4_3; l += 2) { - const uint8_t vi = pp[l/2]; + for (int l = 0; l < QK4_3/2; l += 2) { + const uint8_t vi0 = pp[l/2]; + const uint8_t vi1 = pp[l/2 + QK4_3/4]; - const int8_t vi0 = vi & 0xf; - const int8_t vi1 = vi >> 4; + const int8_t vi0_0 = vi0 & 0xf; + const int8_t vi0_1 = vi0 >> 4; - const float v0 = vi0*d + m; - const float v1 = vi1*d + m; + const int8_t vi1_0 = vi1 & 0xf; + const int8_t vi1_1 = vi1 >> 4; - y[i*QK4_3 + l + 0] = v0; - y[i*QK4_3 + l + 1] = v1; + const float v0_0 = (vi0_0 - 8)*d0; + const float v0_1 = (vi0_1 - 8)*d0; - assert(!isnan(y[i*QK4_3 + l + 0])); - assert(!isnan(y[i*QK4_3 + l + 1])); + const float v1_0 = (vi1_0 - 8)*d1; + const float v1_1 = (vi1_1 - 8)*d1; + + y[i*QK4_3 + l + 0] = v0_0; + y[i*QK4_3 + l + 1] = v0_1; + + y[i*QK4_3 + l + 0 + QK4_3/2] = v1_0; + y[i*QK4_3 + l + 1 + QK4_3/2] = v1_1; } } } @@ -2937,17 +2881,16 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * assert(n % QK8_0 == 0); assert(nb % 2 == 0); - assert(QK8_0 == 2*QK4_2); + assert(QK8_0 == 2*QK4_3); const block_q4_3 * restrict x = vx; const block_q8_0 * restrict y = vy; #if defined(__ARM_NEON) - float32x4_t sumv0 = vdupq_n_f32(0.0f); - float32x4_t sumv1 = vdupq_n_f32(0.0f); - - float summs0 = 0.0f; - float summs1 = 0.0f; + float32x2_t sumv0 = vdup_n_f32(0.0f); + float32x2_t sumv1 = vdup_n_f32(0.0f); + float32x2_t sumv2 = vdup_n_f32(0.0f); + float32x2_t sumv3 = vdup_n_f32(0.0f); for (int i = 0; i < nb; ++i) { const block_q4_3 * restrict x0_0 = &x[2*(i + 0) + 0]; @@ -2955,29 +2898,46 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * const block_q8_0 * restrict y0 = &y[i + 0]; - summs0 += GGML_FP16_TO_FP32(x0_0->m) * y0->s0; - summs1 += GGML_FP16_TO_FP32(x0_1->m) * y0->s1; - - const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs)); + const uint8x8_t v0_0 = vld1_u8(x0_0->qs); + const uint8x8_t v0_1 = vld1_u8(x0_1->qs); // 4-bit -> 8-bit - const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, vdupq_n_u8(0xf))); - const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + const int8x8_t v0_0l = vreinterpret_s8_u8(vand_u8 (v0_0, vdup_n_u8(0xf))); + const int8x8_t v0_0h = vreinterpret_s8_u8(vshr_n_u8(v0_0, 4)); + const int8x8_t v0_1l = vreinterpret_s8_u8(vand_u8 (v0_1, vdup_n_u8(0xf))); + const int8x8_t v0_1h = vreinterpret_s8_u8(vshr_n_u8(v0_1, 4)); + + // sub 8 + const int8x8_t v0_0ls = vsub_s8(v0_0l, vdup_n_s8(8)); + const int8x8_t v0_0hs = vsub_s8(v0_0h, vdup_n_s8(8)); + const int8x8_t v0_1ls = vsub_s8(v0_1l, vdup_n_s8(8)); + const int8x8_t v0_1hs = vsub_s8(v0_1h, vdup_n_s8(8)); // interleave - const int8x16_t v0_0lz = vzip1q_s8(v0_0l, v0_0h); - const int8x16_t v0_0hz = vzip2q_s8(v0_0l, v0_0h); + const int8x8_t v0_0lz = vzip1_s8(v0_0ls, v0_0hs); + const int8x8_t v0_0hz = vzip2_s8(v0_0ls, v0_0hs); + const int8x8_t v0_1lz = vzip1_s8(v0_1ls, v0_1hs); + const int8x8_t v0_1hz = vzip2_s8(v0_1ls, v0_1hs); // load y - const int8x16_t v1_0l = vld1q_s8(y0->qs); - const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); + const int8x8_t v1_0l = vld1_s8(y0->qs); + const int8x8_t v1_0h = vld1_s8(y0->qs + 8); + const int8x8_t v1_1l = vld1_s8(y0->qs + 16); + const int8x8_t v1_1h = vld1_s8(y0->qs + 24); - const float x0_0d = GGML_FP16_TO_FP32(x0_0->d); - const float x0_1d = GGML_FP16_TO_FP32(x0_1->d); + const float x0_0d = GGML_FP16_TO_FP32(x0_0->d0); + const float x0_1d = GGML_FP16_TO_FP32(x0_0->d1); + const float x1_0d = GGML_FP16_TO_FP32(x0_1->d0); + const float x1_1d = GGML_FP16_TO_FP32(x0_1->d1); #if defined(__ARM_FEATURE_DOTPROD) - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), x0_0d*y0->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), x0_1d*y0->d); + //sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), x0_0d*y0->d); + //sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), x0_1d*y0->d); + + sumv0 = vmla_n_f32(sumv0, vcvt_f32_s32(vdot_s32(vdup_n_s32(0), v0_0lz, v1_0l)), x0_0d*y0->d); + sumv1 = vmla_n_f32(sumv1, vcvt_f32_s32(vdot_s32(vdup_n_s32(0), v0_0hz, v1_0h)), x0_1d*y0->d); + sumv2 = vmla_n_f32(sumv2, vcvt_f32_s32(vdot_s32(vdup_n_s32(0), v0_1lz, v1_1l)), x1_0d*y0->d); + sumv3 = vmla_n_f32(sumv3, vcvt_f32_s32(vdot_s32(vdup_n_s32(0), v0_1hz, v1_1h)), x1_1d*y0->d); #else const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l)); const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l)); @@ -2992,77 +2952,79 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * #endif } - *s = vaddvq_f32(vaddq_f32(sumv0, sumv1)) + summs0 + summs1; + *s = vaddv_f32(vadd_f32(vadd_f32(sumv0, sumv1), vadd_f32(sumv2, sumv3))); #elif defined(__AVX2__) + GGML_ASSERT(false); // TODO // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); + //__m256 acc = _mm256_setzero_ps(); - // Main loop - for (int i = 0; i < nb; i++) { - const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d)); - const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d)); - const __m256 dx = _mm256_set_m128(d1, d0); + //// Main loop + //for (int i = 0; i < nb; i++) { + // const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d)); + // const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d)); + // const __m256 dx = _mm256_set_m128(d1, d0); - const __m128 m0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].m)); - const __m128 m1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].m)); - const __m256 mx = _mm256_set_m128(m1, m0); + // const __m128 m0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].m)); + // const __m128 m1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].m)); + // const __m256 mx = _mm256_set_m128(m1, m0); - const __m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs); - const __m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs); - const __m256i bx = _mm256_set_m128i(bx1, bx0); + // const __m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs); + // const __m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs); + // const __m256i bx = _mm256_set_m128i(bx1, bx0); - const __m256 dy = _mm256_broadcast_ss(&y[i].d); - const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); + // const __m256 dy = _mm256_broadcast_ss(&y[i].d); + // const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); - const __m256i syi = _mm256_maddubs_epi16(_mm256_set1_epi8(1), by); - const __m256 syf = sum_i16_pairs_float(syi); + // const __m256i syi = _mm256_maddubs_epi16(_mm256_set1_epi8(1), by); + // const __m256 syf = sum_i16_pairs_float(syi); - const __m256 q = mul_sum_i8_pairs_float(bx, by); + // const __m256 q = mul_sum_i8_pairs_float(bx, by); - const __m256 sxy = _mm256_fmadd_ps(q, dx, _mm256_mul_ps(mx, syf)); - acc = _mm256_fmadd_ps(sxy, dy, acc); - } + // const __m256 sxy = _mm256_fmadd_ps(q, dx, _mm256_mul_ps(mx, syf)); + // acc = _mm256_fmadd_ps(sxy, dy, acc); + //} - *s = hsum_float_8(acc); + //*s = hsum_float_8(acc); #else - // scalar - float sumf = 0.0; - for (int i = 0; i < nb; i++) { - const uint8_t * restrict x0 = x[2*i + 0].qs; - const uint8_t * restrict x1 = x[2*i + 1].qs; - const int8_t * restrict y0 = y[i].qs; + GGML_ASSERT(false); // TODO + //// scalar + //float sumf = 0.0; + //for (int i = 0; i < nb; i++) { + // const uint8_t * restrict x0 = x[2*i + 0].qs; + // const uint8_t * restrict x1 = x[2*i + 1].qs; + // const int8_t * restrict y0 = y[i].qs; - const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d); - const float m0 = GGML_FP16_TO_FP32(x[2*i + 0].m); - const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d); - const float m1 = GGML_FP16_TO_FP32(x[2*i + 1].m); + // const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d); + // const float m0 = GGML_FP16_TO_FP32(x[2*i + 0].m); + // const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d); + // const float m1 = GGML_FP16_TO_FP32(x[2*i + 1].m); - int sxy_0 = 0; - int sxy_1 = 0; + // int sxy_0 = 0; + // int sxy_1 = 0; - for (int j = 0; j < QK8_0/4; j++) { - const uint8_t v0 = x0[j]; - const uint8_t v1 = x1[j]; + // for (int j = 0; j < QK8_0/4; j++) { + // const uint8_t v0 = x0[j]; + // const uint8_t v1 = x1[j]; - const int x0_0 = v0 & 0xf; - const int x1_0 = v0 >> 4; + // const int x0_0 = v0 & 0xf; + // const int x1_0 = v0 >> 4; - const int x0_1 = v1 & 0xf; - const int x1_1 = v1 >> 4; + // const int x0_1 = v1 & 0xf; + // const int x1_1 = v1 >> 4; - const int y0_0 = y0[2*j + 0]; - const int y1_0 = y0[2*j + 1]; + // const int y0_0 = y0[2*j + 0]; + // const int y1_0 = y0[2*j + 1]; - const int y0_1 = y0[2*(j + QK8_0/4) + 0]; - const int y1_1 = y0[2*(j + QK8_0/4) + 1]; + // const int y0_1 = y0[2*(j + QK8_0/4) + 0]; + // const int y1_1 = y0[2*(j + QK8_0/4) + 1]; - sxy_0 += x0_0*y0_0 + x1_0*y1_0; - sxy_1 += x0_1*y0_1 + x1_1*y1_1; - } + // sxy_0 += x0_0*y0_0 + x1_0*y1_0; + // sxy_1 += x0_1*y0_1 + x1_1*y1_1; + // } - sumf += (d0*sxy_0 + d1*sxy_1)*y[i].d + m0*y[i].s0 + m1*y[i].s1; - } - *s = sumf; + // sumf += (d0*sxy_0 + d1*sxy_1)*y[i].d + m0*y[i].s0 + m1*y[i].s1; + //} + //*s = sumf; #endif } @@ -12189,7 +12151,6 @@ size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t * block_q4_2 * restrict y = (block_q4_2 *)dst + j/QK4_2; quantize_row_q4_2_reference(src + j, y, k); - //quantize_row_q4_2_rmse(src + j, y, k); for (int i = 0; i < nb; i++) { for (int l = 0; l < QK4_2; l += 2) {