From 55e2cae83f52a4568964b852ba767ef1641abc8e Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Tue, 9 Jan 2024 18:22:20 +0100 Subject: [PATCH] iq2_xs: Metal now works --- ggml-metal.metal | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index f20b0b024..399debf4b 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -3867,9 +3867,9 @@ void kernel_mul_mv_iq2_xs_f32_impl( const int nb32 = nb * (QK_K / 32); threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values; - threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256); + threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 512); { - int nval = 4; + int nval = 8; int pos = (32*sgitg + tiisg)*nval; for (int i = 0; i < nval; ++i) values[pos + i] = iq2xs_grid[pos + i]; nval = 2; @@ -3922,9 +3922,9 @@ void kernel_mul_mv_iq2_xs_f32_impl( } sumf[row] += d1 * sum1 + d2 * sum2; - dh += nb*sizeof(block_iq2_xxs)/2; - q2 += nb*sizeof(block_iq2_xxs)/2; - sc += nb*sizeof(block_iq2_xxs); + dh += nb*sizeof(block_iq2_xs)/2; + q2 += nb*sizeof(block_iq2_xs)/2; + sc += nb*sizeof(block_iq2_xs); } y4 += 32 * 32; @@ -4275,12 +4275,12 @@ void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4 // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 device const uint16_t * q2 = xb->qs + 4*ib32; const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f; - constant uint8_t * grid = (constant uint8_t *)(iq2xxs_grid + (q2[2*il+0] & 511)); + constant uint8_t * grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+0] & 511)); uint8_t signs = ksigns_iq2xs[q2[2*il+0] >> 9]; for (int i = 0; i < 8; ++i) { reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f); } - grid = (constant uint8_t *)(iq2xxs_grid + (q2[2*il+1] & 511)); + grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+1] & 511)); signs = ksigns_iq2xs[q2[2*il+1] >> 9]; for (int i = 0; i < 8; ++i) { reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);