iq3_s: minor improvement on Metal

49.4 t/s -> 50.3 t/s
This commit is contained in:
Iwan Kawrakow 2024-03-01 17:46:33 +02:00
parent 9c5b594cde
commit 7b629c3b65

View File

@ -4773,8 +4773,10 @@ void kernel_mul_mv_iq3_s_f32_impl(
float2 sum = {0}; float2 sum = {0};
for (int l = 0; l < 4; ++l) { for (int l = 0; l < 4; ++l) {
const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256))); const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? values + 256 : values;
const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256))); const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? values + 256 : values;
const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]);
const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]);
for (int j = 0; j < 4; ++j) { for (int j = 0; j < 4; ++j) {
sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]); sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]);
sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]); sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]);