From 4b223ec4329a24f3b932ea1a9c0456ef11b851ea Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Mon, 2 Oct 2023 09:04:02 -0400 Subject: [PATCH] Refactor getrows to use common code and get ready for q6_k. --- kompute/common.comp | 138 +++++++++++++++-------------------- kompute/op_getrows.comp | 25 +++++++ kompute/op_getrows_f16.comp | 10 ++- kompute/op_getrows_q4_0.comp | 38 +++------- kompute/op_getrows_q4_1.comp | 41 +++-------- 5 files changed, 111 insertions(+), 141 deletions(-) create mode 100644 kompute/op_getrows.comp diff --git a/kompute/common.comp b/kompute/common.comp index 2e843a878..040b87375 100644 --- a/kompute/common.comp +++ b/kompute/common.comp @@ -16,27 +16,12 @@ #extension GL_EXT_debug_printf : enable #define QK4_0 32 -#define QR4_0 2 #define QK4_1 32 #define GELU_COEF_A 0.044715 #define SQRT_2_OVER_PI 0.79788456080286535587989211986876 -#ifndef QK_K #define QK_K 256 -#endif - -#if QK_K == 256 -#define K_SCALE_SIZE 12 -#else -#define K_SCALE_SIZE 4 -#endif - -#define BM 128 -#define BN 128 -#define BK 8 -#define TM 8 -#define TN 8 #define u8BufToU16(buf, idx) (((uint16_t(buf[idx + 1]) << 8)) | buf[idx]) #define u8BufToFloat16(buf, idx) uint16BitsToHalf u8BufToU16(buf, idx) @@ -44,83 +29,76 @@ #define u8BufToFloat(buf, idx) uintBitsToFloat u8BufToU32(buf, idx) #define sizeof_block_q4_0 0x12 -#define sizeof_block_q4_1 0x14 struct block_q4_0 { float16_t d; uint8_t qs[QK4_0 / 2]; }; +mat4 dequantize_q4_0(const block_q4_0 xb, uint il) { + const float d1 = il != 0 ? (xb.d / 16.f) : xb.d; + const float d2 = d1 / 256.f; + const float md = -8.f * xb.d; + const uint16_t mask0 = il != 0 ? uint16_t(0x00F0) : uint16_t(0x000F); + const uint16_t mask1 = mask0 << 8; + + mat4 reg; + for (int i=0;i<8;i++) { + uint16_t b = (uint16_t(xb.qs[2 * i + 1]) << 8) | uint16_t(xb.qs[2 * i]); + reg[i/2][2*(i%2)+0] = d1 * (b & mask0) + md; + reg[i/2][2*(i%2)+1] = d2 * (b & mask1) + md; + } + return reg; +} + +#define sizeof_block_q4_1 0x14 struct block_q4_1 { float16_t d; float16_t m; uint8_t qs[QK4_1 / 2]; }; +mat4 dequantize_q4_1(const block_q4_1 xb, uint il) { + const float d1 = il != 0 ? (xb.d / 16.f) : xb.d; + const float d2 = d1 / 256.f; + const float m = xb.m; + const uint16_t mask0 = il != 0 ? uint16_t(0x00F0) : uint16_t(0x000F); + const uint16_t mask1 = mask0 << 8; -#ifndef QK_K -#define QK_K 256 -#endif + mat4 reg; + for (int i=0;i<8;i++) { + uint16_t b = (uint16_t(xb.qs[2 * i + 1]) << 8) | uint16_t(xb.qs[2 * i]); + reg[i/2][2*(i%2)+0] = ((b & mask0) * d1) + m; + reg[i/2][2*(i%2)+1] = ((b & mask1) * d2) + m; + } + return reg; +} -#if QK_K == 256 -#define K_SCALE_SIZE 12 -#else -#define K_SCALE_SIZE 4 -#endif - -struct block_q2_K { - uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits - uint8_t qs[QK_K/4]; // quants - float16_t d; // super-block scale for quantized scales - float16_t dmin; // super-block scale for quantized mins -}; -// 84 bytes / block - -struct block_q3_K { - uint8_t hmask[QK_K/8]; // quants - high bit - uint8_t qs[QK_K/4]; // quants - low 2 bits -#if QK_K == 64 - uint8_t scales[2]; -#else - uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits -#endif - float16_t d; // super-block scale -}; - -#if QK_K == 64 -typedef struct { - float16_t d[2]; // super-block scales/mins - uint8_t scales[2]; - uint8_t qs[QK_K/2]; // 4-bit quants -} block_q4_K; -#else -struct block_q4_K { - float16_t d; // super-block scale for quantized scales - float16_t dmin; // super-block scale for quantized mins - uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits - uint8_t qs[QK_K/2]; // 4--bit quants -}; -#endif - -#if QK_K == 64 -struct block_q5_K { - float16_t d; // super-block scales/mins - int8_t scales[QK_K/16]; // 8-bit block scales - uint8_t qh[QK_K/8]; // quants, high bit - uint8_t qs[QK_K/2]; // quants, low 4 bits -}; -#else -struct block_q5_K { - float16_t d; // super-block scale for quantized scales - float16_t dmin; // super-block scale for quantized mins - uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits - uint8_t qh[QK_K/8]; // quants, high bit - uint8_t qs[QK_K/2]; // quants, low 4 bits -}; -// 176 bytes / block -#endif - -struct block_q6_K { +#define sizeof_block_q6_k 210 +struct block_q6_k { uint8_t ql[QK_K/2]; // quants, lower 4 bits uint8_t qh[QK_K/4]; // quants, upper 2 bits int8_t scales[QK_K/16]; // scales, quantized with 8 bits - float16_t d; // super-block scale + float16_t d; // super-block scale }; -// 210 bytes / block +mat4 dequantize_q6_k(const block_q6_k xb, uint il) { + const float16_t d_all = xb.d; + uint8_t ql[QK_K/2]; + uint8_t qh[QK_K/4]; + int8_t scales[QK_K/16]; + + const uint qlIndex = 64*(il/8) + 32*((il/2)&1) + 16*(il&1); + const uint qhIndex = 32*(il/8) + 16*(il&1); + float16_t sc = xb.scales[(il%2) + 2 * ((il/2))]; + il = (il/2) & 3; + + const uint16_t kmask1 = il>1 ? uint16_t(il>2 ? 192 : 48) : uint16_t(il>0 ? 12 : 3); + const uint16_t kmask2 = il>1 ? uint8_t(0xF0) : uint8_t(0x0F); + const float16_t coef = il>1 ? float16_t(1.f/16.f) : float16_t(1.f); + const float16_t ml = float16_t(d_all * sc * 32.f); + const float16_t dl = float16_t(d_all * sc * coef); + mat4 reg; + for (int i = 0; i < 16; ++i) { + const float16_t q = (il&1) != 0 ? ((ql[qlIndex + i] & kmask2) | ((qh[qhIndex + i] & kmask1) << 2)) + : ((ql[qlIndex + i] & kmask2) | ((qh[qhIndex + i] & kmask1) << 4)); + reg[i/4][i%4] = dl * q - ml; + } + return reg; +} diff --git a/kompute/op_getrows.comp b/kompute/op_getrows.comp new file mode 100644 index 000000000..a4d8bb9a0 --- /dev/null +++ b/kompute/op_getrows.comp @@ -0,0 +1,25 @@ +/** + * Copyright (c) 2023 Nomic, Inc. All rights reserved. + * + * This software is licensed under the terms of the Software for Open Models License (SOM), + * version 1.0, as detailed in the LICENSE_SOM.txt file. A copy of this license should accompany + * this software. Except as expressly granted in the SOM license, all rights are reserved by Nomic, Inc. + */ + +void main() { + const uint i = gl_WorkGroupID.x; + const int r = inB[i + pcs.inBOff]; + + int z = 0; + for (uint ind = gl_LocalInvocationID.x; ind < pcs.ne00/16; ind += gl_WorkGroupSize.x) { + const uint inIndex = (r * pcs.nb01 + pcs.inAOff) + ind/NL * SIZE_OF_BLOCK; + const mat4 result = dequantize_block(inIndex, ind%NL); + for (uint j = 0; j < 4; ++j) { + for (uint k = 0; k < 4; ++k) { + const uint outIndex = i * pcs.nb1/BYTES_FOR_TYPE + pcs.outOff + z; + out_[outIndex] = result[j][k]; + ++z; + } + } + } +} diff --git a/kompute/op_getrows_f16.comp b/kompute/op_getrows_f16.comp index 17b478b5e..3f2b16724 100644 --- a/kompute/op_getrows_f16.comp +++ b/kompute/op_getrows_f16.comp @@ -25,11 +25,15 @@ layout (push_constant) uniform parameter { int nb1; } pcs; +void dequantize_row_f16(uint x /*Based from inA unaligned*/, uint y /*Based from out_*/, int k) { + for (int j = 0; j < k; j++) { + out_[y + j] = inA[x + j]; + } +} + void main() { const uint i = gl_WorkGroupID.x; const int r = inB[i + pcs.inBOff]; - for (int j = 0; j < pcs.ne00; j++) { - out_[i*pcs.nb1 + j + pcs.outOff] = inA[r*pcs.nb01/2+j + pcs.inAOff]; - } + dequantize_row_f16(r*pcs.nb01/2/*bytes for float16*/ + pcs.inAOff, i*pcs.nb1 + pcs.outOff, pcs.ne00); } diff --git a/kompute/op_getrows_q4_0.comp b/kompute/op_getrows_q4_0.comp index 590f218e6..0449b1987 100644 --- a/kompute/op_getrows_q4_0.comp +++ b/kompute/op_getrows_q4_0.comp @@ -10,6 +10,10 @@ #include "common.comp" +#define NL 2 +#define BYTES_FOR_TYPE 4 /*bytes for float*/ +#define SIZE_OF_BLOCK sizeof_block_q4_0 + layout(local_size_x = 1) in; layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; }; @@ -25,40 +29,18 @@ layout (push_constant) uniform parameter { int nb1; } pcs; -#define UNALIGNED_INPUT inA - block_q4_0 get_unaligned_block_q4_0(uint index) { block_q4_0 fres; - fres.d = u8BufToFloat16(UNALIGNED_INPUT, index); + fres.d = u8BufToFloat16(inA, index); [[unroll]] for (uint it = 0; it != QK4_0 / 2; it++) { - fres.qs[it] = UNALIGNED_INPUT[index+2+it]; + fres.qs[it] = inA[index+2+it]; } return fres; } -void dequantize_row_q4_0(uint x /*Based from inA unaligned*/, uint y /*Based from out_*/, int k) { - const uint qk = QK4_0; - - const uint nb = k / qk; - - for (uint i = 0; i < nb; i++) { - const block_q4_0 block = get_unaligned_block_q4_0(x + i*sizeof_block_q4_0); - - const float16_t d = block.d; - - for (uint j = 0; j < qk/2; ++j) { - const int x0 = (block.qs[j] & 0x0F) - 8; - const int x1 = (block.qs[j] >> 4) - 8; - - out_[y+i*qk + j + 0 ] = float(x0)*d; - out_[y+i*qk + j + qk/2] = float(x1)*d; - } - } +mat4 dequantize_block(uint index, uint il) { + const block_q4_0 block = get_unaligned_block_q4_0(index); + return dequantize_q4_0(block, il); } -void main() { - const uint i = gl_WorkGroupID.x; - const int r = inB[i + pcs.inBOff]; - - dequantize_row_q4_0(uint(r*pcs.nb01) + pcs.inAOff, uint(i*pcs.nb1/4) + pcs.outOff, pcs.ne00); -} +#include "op_getrows.comp" diff --git a/kompute/op_getrows_q4_1.comp b/kompute/op_getrows_q4_1.comp index 3d00928d3..64586cdc9 100644 --- a/kompute/op_getrows_q4_1.comp +++ b/kompute/op_getrows_q4_1.comp @@ -10,6 +10,10 @@ #include "common.comp" +#define NL 2 +#define BYTES_FOR_TYPE 4 /*bytes for float*/ +#define SIZE_OF_BLOCK sizeof_block_q4_1 + layout(local_size_x = 1) in; layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; }; @@ -25,42 +29,19 @@ layout (push_constant) uniform parameter { int nb1; } pcs; -#define UNALIGNED_INPUT inA - block_q4_1 get_unaligned_block_q4_1(uint index) { block_q4_1 fres; - fres.d = u8BufToFloat16(UNALIGNED_INPUT, index); - fres.m = u8BufToFloat16(UNALIGNED_INPUT, index+2); + fres.d = u8BufToFloat16(inA, index); + fres.m = u8BufToFloat16(inA, index+2); [[unroll]] for (uint it = 0; it != QK4_1 / 2; it++) { - fres.qs[it] = UNALIGNED_INPUT[index+4+it]; + fres.qs[it] = inA[index+4+it]; } return fres; } -void dequantize_row_q4_1(uint x /*Based from inA unaligned*/, uint y /*Based from out_*/, int k) { - const uint qk = QK4_1; - - const uint nb = k / qk; - - for (uint i = 0; i < nb; i++) { - const block_q4_1 block = get_unaligned_block_q4_1(x + i*sizeof_block_q4_1); - - const float16_t d = block.d; - const float16_t m = block.m; - - for (uint j = 0; j < qk/2; ++j) { - const int x0 = (block.qs[j] & 0x0F); - const int x1 = (block.qs[j] >> 4); - - out_[y+i*qk + j + 0 ] = float(x0)*d + m; - out_[y+i*qk + j + qk/2] = float(x1)*d + m; - } - } +mat4 dequantize_block(uint index, uint il) { + const block_q4_1 block = get_unaligned_block_q4_1(index); + return dequantize_q4_1(block, il); } -void main() { - const uint i = gl_WorkGroupID.x; - const int r = inB[i + pcs.inBOff]; - - dequantize_row_q4_1(uint(r*pcs.nb01) + pcs.inAOff, uint(i*pcs.nb1/4) + pcs.outOff, pcs.ne00); -} +#include "op_getrows.comp"