iq2_xs: Metal now works

This commit is contained in:
Iwan Kawrakow 2024-01-09 18:22:20 +01:00
parent 0aacd55159
commit 55e2cae83f

View File

@ -3867,9 +3867,9 @@ void kernel_mul_mv_iq2_xs_f32_impl(
const int nb32 = nb * (QK_K / 32);
threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256);
threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 512);
{
int nval = 4;
int nval = 8;
int pos = (32*sgitg + tiisg)*nval;
for (int i = 0; i < nval; ++i) values[pos + i] = iq2xs_grid[pos + i];
nval = 2;
@ -3922,9 +3922,9 @@ void kernel_mul_mv_iq2_xs_f32_impl(
}
sumf[row] += d1 * sum1 + d2 * sum2;
dh += nb*sizeof(block_iq2_xxs)/2;
q2 += nb*sizeof(block_iq2_xxs)/2;
sc += nb*sizeof(block_iq2_xxs);
dh += nb*sizeof(block_iq2_xs)/2;
q2 += nb*sizeof(block_iq2_xs)/2;
sc += nb*sizeof(block_iq2_xs);
}
y4 += 32 * 32;
@ -4275,12 +4275,12 @@ void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
device const uint16_t * q2 = xb->qs + 4*ib32;
const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
constant uint8_t * grid = (constant uint8_t *)(iq2xxs_grid + (q2[2*il+0] & 511));
constant uint8_t * grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+0] & 511));
uint8_t signs = ksigns_iq2xs[q2[2*il+0] >> 9];
for (int i = 0; i < 8; ++i) {
reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
}
grid = (constant uint8_t *)(iq2xxs_grid + (q2[2*il+1] & 511));
grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+1] & 511));
signs = ksigns_iq2xs[q2[2*il+1] >> 9];
for (int i = 0; i < 8; ++i) {
reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);