metal : Q3_K speedup (#2995)

* Slightly faster Q3_K and Q5_K on metal

* Another Q3_K speedup on metal

Combined with previous commit, we are now +9.6% for TG.
PP is not affected as this happens via the matrix multiplication
templates.

* Slowly progressing on Q3_K on metal

We are now 13% faster than master

* nother small improvement for Q3_K on metal

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow 2023-09-08 18:01:04 +02:00 committed by GitHub
parent e64f5b5578
commit ba7ffbb251
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1123,31 +1123,40 @@ kernel void kernel_mul_mat_q3_K_f32(
device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0; device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
float yl[16]; float yl[32];
const uint16_t kmask1 = 0x0303; const uint16_t kmask1 = 0x3030;
const uint16_t kmask2 = 0x0f0f; const uint16_t kmask2 = 0x0f0f;
const int tid = tiisg/2; const int tid = tiisg/4;
const int ix = tiisg%2; const int ix = tiisg%4;
const int ip = tid/8; // 0 or 1 const int ip = tid/4; // 0 or 1
const int il = tid/2 - 4*ip; // 0...3 const int il = 2*((tid%4)/2); // 0 or 2
const int ir = tid%2; const int ir = tid%2;
const int n = 8; const int n = 8;
const int l0 = n*ir; const int l0 = n*ir;
const uint16_t m1 = 1 << (4*ip + il); // One would think that the Metal compiler would figure out that ip and il can only have
const uint16_t m2 = m1 << 8; // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it
// with these two tales.
//
// Possible masks for the high bit
const ushort4 mm[4] = {{0x0001, 0x0100, 0x0002, 0x0200}, // ip = 0, il = 0
{0x0004, 0x0400, 0x0008, 0x0800}, // ip = 0, il = 2
{0x0010, 0x1000, 0x0020, 0x2000}, // ip = 1, il = 0
{0x0040, 0x4000, 0x0080, 0x8000}}; // ip = 1, il = 2
// Possible masks for the low 2 bits
const int4 qm[2] = {{0x0003, 0x0300, 0x000c, 0x0c00}, {0x0030, 0x3000, 0x00c0, 0xc000}};
const ushort4 hm = mm[2*ip + il/2];
const int shift = 2*il; const int shift = 2*il;
const uint16_t qm1 = 0x0003 << shift; const float v1 = il == 0 ? 4.f : 64.f;
const uint16_t qm2 = 0x0300 << shift; const float v2 = 4.f * v1;
const int32_t v1 = 4 << shift;
const int32_t v2 = 1024 << shift;
const uint16_t s_shift1 = 4*ip; const uint16_t s_shift1 = 4*ip;
const uint16_t s_shift2 = s_shift1 + 2*(il/2); const uint16_t s_shift2 = s_shift1 + il;
const int ik = 4 + (il%2);
const int q_offset = 32*ip + l0; const int q_offset = 32*ip + l0;
const int y_offset = 128*ip + 32*il + l0; const int y_offset = 128*ip + 32*il + l0;
@ -1156,12 +1165,19 @@ kernel void kernel_mul_mat_q3_K_f32(
device const float * y1 = yy + ix*QK_K + y_offset; device const float * y1 = yy + ix*QK_K + y_offset;
float sumf1[2] = {0.f}, sumf2[2] = {0.f}; uint32_t scales32, aux32;
for (int i = ix; i < nb; i += 2) { thread uint16_t * scales16 = (thread uint16_t *)&scales32;
thread const int8_t * scales = (thread const int8_t *)&scales32;
float sumf1[2] = {0.f};
float sumf2[2] = {0.f};
for (int i = ix; i < nb; i += 4) {
for (int l = 0; l < 8; ++l) { for (int l = 0; l < 8; ++l) {
yl[l+ 0] = y1[l+ 0]; yl[l+ 0] = y1[l+ 0];
yl[l+ 8] = y1[l+16]; yl[l+ 8] = y1[l+16];
yl[l+16] = y1[l+32];
yl[l+24] = y1[l+48];
} }
device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset); device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
@ -1172,27 +1188,43 @@ kernel void kernel_mul_mat_q3_K_f32(
for (int row = 0; row < 2; ++row) { for (int row = 0; row < 2; ++row) {
const float d_all = (float)dh[0]; const float d_all = (float)dh[0];
const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
float s1 = 0, s2 = 0; scales16[0] = a[4];
for (int l = 0; l < n; l += 2) { scales16[1] = a[5];
const uint16_t qs = q[l/2]; aux32 = ((scales32 >> s_shift2) << 4) & 0x30303030;
s1 += yl[l+0] * ((int32_t)(qs & qm1) - ((h[l/2] & m1) ? 0 : v1)); scales16[0] = a[il+0];
s2 += yl[l+1] * ((int32_t)(qs & qm2) - ((h[l/2] & m2) ? 0 : v2)); scales16[1] = a[il+1];
} scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32;
float d = d_all * (s1 + 1.f/256.f * s2);
sumf1[row] += d * scales[0];
sumf2[row] += d;
s1 = s2 = 0; float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0;
for (int l = 0; l < n; l += 2) { for (int l = 0; l < n; l += 2) {
const uint16_t qs = q[l/2+8]; const int32_t qs = q[l/2];
s1 += yl[l+8] * ((int32_t)(qs & qm1) - ((h[l/2+8] & m1) ? 0 : v1)); s1 += yl[l+0] * (qs & qm[il/2][0]);
s2 += yl[l+9] * ((int32_t)(qs & qm2) - ((h[l/2+8] & m2) ? 0 : v2)); s2 += yl[l+1] * (qs & qm[il/2][1]);
s3 += ((h[l/2] & hm[0]) ? 0.f : yl[l+0]) + ((h[l/2] & hm[1]) ? 0.f : yl[l+1]);
s4 += yl[l+16] * (qs & qm[il/2][2]);
s5 += yl[l+17] * (qs & qm[il/2][3]);
s6 += ((h[l/2] & hm[2]) ? 0.f : yl[l+16]) + ((h[l/2] & hm[3]) ? 0.f : yl[l+17]);
} }
d = d_all * (s1 + 1.f/256.f * s2); float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
sumf1[row] += d * scales[1]; float d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
sumf2[row] += d; sumf1[row] += d1 * (scales[0] - 32);
sumf2[row] += d2 * (scales[2] - 32);
s1 = s2 = s3 = s4 = s5 = s6 = 0;
for (int l = 0; l < n; l += 2) {
const int32_t qs = q[l/2+8];
s1 += yl[l+8] * (qs & qm[il/2][0]);
s2 += yl[l+9] * (qs & qm[il/2][1]);
s3 += ((h[l/2+8] & hm[0]) ? 0.f : yl[l+8]) + ((h[l/2+8] & hm[1]) ? 0.f : yl[l+9]);
s4 += yl[l+24] * (qs & qm[il/2][2]);
s5 += yl[l+25] * (qs & qm[il/2][3]);
s6 += ((h[l/2+8] & hm[2]) ? 0.f : yl[l+24]) + ((h[l/2+8] & hm[3]) ? 0.f : yl[l+25]);
}
d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
sumf1[row] += d1 * (scales[1] - 32);
sumf2[row] += d2 * (scales[3] - 32);
q += step; q += step;
h += step; h += step;
@ -1201,17 +1233,20 @@ kernel void kernel_mul_mat_q3_K_f32(
} }
y1 += 2 * QK_K; y1 += 4 * QK_K;
} }
for (int row = 0; row < 2; ++row) { for (int row = 0; row < 2; ++row) {
const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift); const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift);
const float tot = simd_sum(sumf); sumf1[row] = simd_sum(sumf);
}
if (tiisg == 0) { if (tiisg == 0) {
dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot; for (int row = 0; row < 2; ++row) {
dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = sumf1[row];
} }
} }
} }
#else #else
kernel void kernel_mul_mat_q3_K_f32( kernel void kernel_mul_mat_q3_K_f32(
@ -1564,17 +1599,25 @@ kernel void kernel_mul_mat_q5_K_f32(
sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2); sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2); sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
float4 acc = {0.f, 0.f, 0.f, 0.f}; float4 acc1 = {0.f};
float4 acc2 = {0.f};
for (int l = 0; l < n; ++l) { for (int l = 0; l < n; ++l) {
uint8_t h = qh[l]; uint8_t h = qh[l];
acc[0] += yl[l+0] * ((uint16_t)(q1[l] & 0x0F) + (h & hm1 ? 16 : 0)); acc1[0] += yl[l+0] * (q1[l] & 0x0F);
acc[1] += yl[l+8] * ((uint16_t)(q1[l] & 0xF0) + (h & hm2 ? 256 : 0)); acc1[1] += yl[l+8] * (q1[l] & 0xF0);
acc[2] += yh[l+0] * ((uint16_t)(q2[l] & 0x0F) + (h & hm3 ? 16 : 0)); acc1[2] += yh[l+0] * (q2[l] & 0x0F);
acc[3] += yh[l+8] * ((uint16_t)(q2[l] & 0xF0) + (h & hm4 ? 256 : 0)); acc1[3] += yh[l+8] * (q2[l] & 0xF0);
acc2[0] += h & hm1 ? yl[l+0] : 0.f;
acc2[1] += h & hm2 ? yl[l+8] : 0.f;
acc2[2] += h & hm3 ? yh[l+0] : 0.f;
acc2[3] += h & hm4 ? yh[l+8] : 0.f;
} }
const float dall = dh[0]; const float dall = dh[0];
const float dmin = dh[1]; const float dmin = dh[1];
sumf[row] += dall * (acc[0] * sc8[0] + acc[1] * sc8[1] * 1.f/16.f + acc[2] * sc8[4] + acc[3] * sc8[5] * 1.f/16.f) - sumf[row] += dall * (sc8[0] * (acc1[0] + 16.f*acc2[0]) +
sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) +
sc8[4] * (acc1[2] + 16.f*acc2[2]) +
sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
q1 += step; q1 += step;