2024-01-29 20:50:50 +00:00
|
|
|
void main() {
|
|
|
|
// NB: hack to make compatible with AMD GPUs that have a subgroup size of 64
|
|
|
|
if (gl_SubgroupInvocationID > 31)
|
|
|
|
return;
|
|
|
|
|
|
|
|
const uint nb = uint(pcs.ne00/BLOCKS_IN_QUANT);
|
|
|
|
|
|
|
|
const uint r0 = gl_WorkGroupID.x;
|
|
|
|
const uint r1 = gl_WorkGroupID.y;
|
|
|
|
const uint im = gl_WorkGroupID.z;
|
|
|
|
|
|
|
|
const uint first_row = (r0 * gl_NumSubgroups + gl_SubgroupID) * N_ROWS;
|
|
|
|
|
|
|
|
const uint i12 = im%pcs.ne12;
|
|
|
|
const uint i13 = im/pcs.ne12;
|
|
|
|
|
2024-11-28 11:51:38 +00:00
|
|
|
// pointers to src0 rows
|
|
|
|
uint ax[N_ROWS];
|
|
|
|
for (int row = 0; row < N_ROWS; ++row) {
|
|
|
|
const uint offset0 = (first_row + row)*(pcs.nb01/SIZE_OF_BLOCK) + (i12/pcs.r2)*(pcs.nb02/SIZE_OF_BLOCK) + (i13/pcs.r3)*(pcs.nb03/SIZE_OF_BLOCK);
|
|
|
|
|
|
|
|
ax[row] = offset0 + pcs.inAOff;
|
|
|
|
}
|
2024-01-29 20:50:50 +00:00
|
|
|
|
2024-11-28 11:51:38 +00:00
|
|
|
const uint y = (r1*pcs.nb11 + i12*pcs.nb12 + i13*pcs.nb13) / 4 + pcs.inBOff;
|
2024-01-29 20:50:50 +00:00
|
|
|
|
|
|
|
float sumf[N_ROWS] = {0.0f, 0.0f, 0.0f, 0.0f};
|
|
|
|
|
|
|
|
const uint ix = gl_SubgroupInvocationID/2;
|
|
|
|
const uint il = (BLOCKS_IN_QUANT/4)*(gl_SubgroupInvocationID%2);
|
|
|
|
|
|
|
|
uint yb = y + ix * BLOCKS_IN_QUANT + il;
|
|
|
|
|
|
|
|
//debugPrintfEXT("gl_NumSubgroups=%d, gl_SubgroupID=%d, gl_SubgroupInvocationID=%d, glSubgroupSize=%d, gl_WorkGroupSize.x=%d, gl_WorkGroupSize.y=%d, gl_WorkGroupSize.z=%d\n",
|
|
|
|
// gl_NumSubgroups, gl_SubgroupID, gl_SubgroupInvocationID, gl_SubgroupSize,
|
|
|
|
// gl_WorkGroupSize.x, gl_WorkGroupSize.y, gl_WorkGroupSize.z);
|
|
|
|
|
|
|
|
for (uint ib = ix; ib < nb; ib += 16) {
|
|
|
|
for (int row = 0; row < N_ROWS; row++) {
|
2024-11-28 11:51:38 +00:00
|
|
|
sumf[row] += block_q_n_dot_y(ax[row] + ib, yb, il);
|
2024-01-29 20:50:50 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
yb += BLOCKS_IN_QUANT * 16;
|
|
|
|
}
|
|
|
|
|
|
|
|
for (int row = 0; row < N_ROWS; ++row) {
|
|
|
|
const float tot = subgroupAdd(sumf[row]);
|
|
|
|
if (first_row + row < pcs.ne01 && subgroupElect()) {
|
|
|
|
out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + first_row + row + pcs.outOff] = tot;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|