/** * Copyright (c) 2023 Nomic, Inc. All rights reserved. * * This software is licensed under the terms of the Software for Open Models License (SOM), * version 1.0, as detailed in the LICENSE_SOM.txt file. A copy of this license should accompany * this software. Except as expressly granted in the SOM license, all rights are reserved by Nomic, Inc. */ #extension GL_KHR_shader_subgroup_arithmetic : require #extension GL_EXT_debug_printf : enable void main() { const uint nb = uint(pcs.ne00/BLOCKS_IN_QUANT); const uint r0 = gl_WorkGroupID.x; const uint r1 = gl_WorkGroupID.y; const uint im = gl_WorkGroupID.z; const uint first_row = (r0 * gl_NumSubgroups + gl_SubgroupID) * N_ROWS; const uint offset0 = first_row * nb + im/pcs.gqa*(nb*pcs.ne0); const uint x = offset0; // Based from inA without base offset const uint y = r1*uint(pcs.ne10)+im*pcs.ne00*pcs.ne1+pcs.inBOff; // Based from inB float sumf[N_ROWS] = {0.0f, 0.0f, 0.0f, 0.0f}; const uint ix = gl_SubgroupInvocationID/2; const uint il = (BLOCKS_IN_QUANT/4)*(gl_SubgroupInvocationID%2); uint yb = y + ix * BLOCKS_IN_QUANT + il; debugPrintfEXT("gl_NumSubgroups=%d, gl_SubgroupID=%d, gl_SubgroupInvocationID=%d, glSubgroupSize=%d, gl_WorkGroupSize.x=%d, gl_WorkGroupSize.y=%d, gl_WorkGroupSize.z=%d\n", gl_NumSubgroups, gl_SubgroupID, gl_SubgroupInvocationID, gl_SubgroupSize, gl_WorkGroupSize.x, gl_WorkGroupSize.y, gl_WorkGroupSize.z); for (uint ib = ix; ib < nb; ib += gl_SubgroupSize/2) { for (int row = 0; row < N_ROWS; row++) { const uint block_index = x + ib + row * nb; sumf[row] += block_q_n_dot_y(block_index, yb, il); } yb += BLOCKS_IN_QUANT * gl_SubgroupSize/2; } for (int row = 0; row < N_ROWS; ++row) { const float tot = subgroupAdd(sumf[row]); if (first_row + row < pcs.ne01 && subgroupElect()) { out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + first_row + row + pcs.outOff] = tot; } } }