#version 450 #include "common.comp" #include "op_mul_mv_q_n_pre.comp" #define SIZE_OF_D 2 #define N_DST 4 // each SIMD group works on 4 rows #define N_SIMDGROUP 2 // number of SIMD groups in a thread group #define N_SIMDWIDTH 32 // assuming SIMD group size is 32 #define NB_Q8_0 8 void main() { // NB: hack to make compatible with AMD GPUs that have a subgroup size of 64 if (gl_SubgroupInvocationID > 31) return; const int nr = N_DST; const int nsg = N_SIMDGROUP; const int nw = N_SIMDWIDTH; const int nb = pcs.ne00/QK8_0; const uint r0 = gl_WorkGroupID.x; const uint r1 = gl_WorkGroupID.y; const uint im = gl_WorkGroupID.z; const uint first_row = (r0 * nsg + gl_SubgroupID) * nr; const uint i12 = im%pcs.ne12; const uint i13 = im/pcs.ne12; const uint offset0 = first_row * nb + (i12/pcs.r2)*(nb*pcs.ne01) + (i13/pcs.r3)*(nb*pcs.ne01*pcs.ne02); const uint x = offset0*sizeof_block_q8_0 + pcs.inAOff; // Based from inA const uint y = r1*pcs.ne10 + im*pcs.ne00*pcs.ne1 + pcs.inBOff; // based from inB float yl[NB_Q8_0]; float sumf[N_DST]={0.f, 0.f, 0.f, 0.f}; const uint ix = gl_SubgroupInvocationID.x/4; const uint il = gl_SubgroupInvocationID.x%4; uint yb = y + ix * QK8_0 + NB_Q8_0*il; // each thread in a SIMD group deals with NB_Q8_0 quants at a time for (uint ib = ix; ib < nb; ib += nw/4) { for (int i = 0; i < NB_Q8_0; ++i) { yl[i] = inB[yb + i]; } for (int row = 0; row < nr; row++) { const uint block_offset = (ib+row*nb) * sizeof_block_q8_0; float sumq = 0.f; for (int iq = 0; iq < NB_Q8_0; ++iq) { const int8_t qs_iq = int8_t(inA[x + block_offset + SIZE_OF_D + NB_Q8_0*il + iq]); sumq += qs_iq * yl[iq]; } const float16_t d = u8BufToFloat16(inA, x + block_offset); sumf[row] += sumq*d; } yb += NB_Q8_0 * nw; } for (int row = 0; row < nr; ++row) { const float tot = subgroupAdd(sumf[row]); if (subgroupElect() && first_row + row < pcs.ne01) { out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + first_row + row] = tot; } } }