Fixes for subgroup size to bring AMD and NVIDIA inline with eachother for all kernels.

This commit is contained in:
Adam Treat 2023-10-04 14:24:35 -04:00 committed by cebtenzzre
parent de589ced7c
commit bc4b5ed1cb
4 changed files with 32 additions and 18 deletions

View File

@ -785,7 +785,8 @@ void ggml_vk_soft_max(kp::Sequence& seq,
std::shared_ptr<kp::Algorithm> s_algo = nullptr; std::shared_ptr<kp::Algorithm> s_algo = nullptr;
if (!komputeManager()->hasAlgorithm(__func__)) { if (!komputeManager()->hasAlgorithm(__func__)) {
const uint32_t local_x = ggml_vk_current_device().subgroupSize; // FIXME: The softmax kernel needs to be fixed to use the subgroupsize which can vary by device
const uint32_t local_x = 32;
s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {in, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {local_x}, {pushConsts}); s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {in, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {local_x}, {pushConsts});
} else { } else {
s_algo = komputeManager()->getAlgorithm(__func__); s_algo = komputeManager()->getAlgorithm(__func__);
@ -981,8 +982,8 @@ void ggml_vk_mul_mat_q6_k(kp::Sequence& seq,
std::shared_ptr<kp::Algorithm> s_algo = nullptr; std::shared_ptr<kp::Algorithm> s_algo = nullptr;
if (!komputeManager()->hasAlgorithm(__func__)) { if (!komputeManager()->hasAlgorithm(__func__)) {
// const uint32_t local_x = ggml_vk_current_device().subgroupSize * 2; const uint32_t local_x = ggml_vk_current_device().subgroupSize * 2;
s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)}, {2,32}, {pushConsts}); s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)}, {local_x}, {pushConsts});
} else { } else {
s_algo = komputeManager()->getAlgorithm(__func__); s_algo = komputeManager()->getAlgorithm(__func__);
s_algo->setTensors({inA, inB, out}); s_algo->setTensors({inA, inB, out});

View File

@ -44,31 +44,38 @@ void main() {
const uint r1 = gl_WorkGroupID.y; const uint r1 = gl_WorkGroupID.y;
const uint r2 = gl_WorkGroupID.z; const uint r2 = gl_WorkGroupID.z;
const uint row = 2 * r0 + gl_SubgroupID; const uint row = (r0 * gl_NumSubgroups + gl_SubgroupID);
const uint offset0 = r2/pcs.gqa*(nb*pcs.ne0); const uint offset0 = r2/pcs.gqa*(nb*pcs.ne0);
const uint x = row * nb + offset0; // Based from inA without base offset const uint x = row * nb + offset0; // Based from inA without base offset
const uint yy = r1*pcs.ne10 + r2*pcs.ne00*pcs.ne1+pcs.inBOff; // Based from inB const uint yy = r1*pcs.ne10 + r2*pcs.ne00*pcs.ne1+pcs.inBOff; // Based from inB
float sumf = 0; float sumf = 0;
const uint tid = gl_SubgroupInvocationID/2; // bits of invocation ID for gl_SubgroupSize=32:
const uint ix = gl_SubgroupInvocationID%2; // x x x x x
const uint ip = tid/8; // 0 or 1 // 4 3 2 1 0
const uint il = tid%8; // ( tid ) ix
const uint n = 4; // ip ( il )
const uint l0 = n*il;
const uint is = 8*ip + l0/16; const uint block_stride = gl_SubgroupSize / 16; // number of blocks each subgroup processes
const uint tid = gl_SubgroupInvocationID/block_stride; // first block_stride groups have tid=0
const uint ix = gl_SubgroupInvocationID%block_stride; // first block is 0..block_stride-1
const uint ip = tid/8; // first or second half of block (0 or 1)
const uint il = tid%8; // each half has 8 parts, one per scale
const uint n = 4; // 4 scales at a time (and 4 sums)
const uint l0 = n*il; // offset into half-block, 0..28
const uint is = 8*ip + l0/16; // 0, 1, 8, 9
const uint y_offset = 128*ip + l0; const uint y_offset = 128*ip + l0;
const uint q_offset_l = 64*ip + l0; const uint q_offset_l = 64*ip + l0;
const uint q_offset_h = 32*ip + l0; const uint q_offset_h = 32*ip + l0;
for (uint i = ix; i < nb; i += 2) { for (uint i = ix; i < nb; i += block_stride) {
const uint baseIndex = (x + i) * SIZE_OF_BLOCK + pcs.inAOff; const uint baseIndex = (x + i) * SIZE_OF_BLOCK + pcs.inAOff;
const uint qlIndex = q_offset_l; const uint qlIndex = q_offset_l;
const uint q2Index = qlIndex + 32; const uint q2Index = qlIndex + QK_K/8;
const uint qhIndex = q_offset_h; const uint qhIndex = q_offset_h;
const uint y = yy + i * QK_K + y_offset; const uint y = yy + i * QK_K + y_offset;

View File

@ -7,6 +7,9 @@
*/ */
void main() { void main() {
if (gl_SubgroupInvocationID > 31)
return;
const uint nb = uint(pcs.ne00/BLOCKS_IN_QUANT); const uint nb = uint(pcs.ne00/BLOCKS_IN_QUANT);
const uint r0 = gl_WorkGroupID.x; const uint r0 = gl_WorkGroupID.x;
const uint r1 = gl_WorkGroupID.y; const uint r1 = gl_WorkGroupID.y;
@ -28,13 +31,13 @@ void main() {
// gl_NumSubgroups, gl_SubgroupID, gl_SubgroupInvocationID, gl_SubgroupSize, // gl_NumSubgroups, gl_SubgroupID, gl_SubgroupInvocationID, gl_SubgroupSize,
// gl_WorkGroupSize.x, gl_WorkGroupSize.y, gl_WorkGroupSize.z); // gl_WorkGroupSize.x, gl_WorkGroupSize.y, gl_WorkGroupSize.z);
for (uint ib = ix; ib < nb; ib += gl_SubgroupSize/2) { for (uint ib = ix; ib < nb; ib += 16) {
for (int row = 0; row < N_ROWS; row++) { for (int row = 0; row < N_ROWS; row++) {
const uint block_index = x + ib + row * nb; const uint block_index = x + ib + row * nb;
sumf[row] += block_q_n_dot_y(block_index, yb, il); sumf[row] += block_q_n_dot_y(block_index, yb, il);
} }
yb += BLOCKS_IN_QUANT * gl_SubgroupSize/2; yb += BLOCKS_IN_QUANT * 16;
} }
for (int row = 0; row < N_ROWS; ++row) { for (int row = 0; row < N_ROWS; ++row) {

View File

@ -24,6 +24,9 @@ layout(push_constant) uniform PushConstants {
} pcs; } pcs;
void main() { void main() {
if (gl_SubgroupInvocationID > 31)
return;
const uint i03 = gl_WorkGroupID.z; const uint i03 = gl_WorkGroupID.z;
const uint i02 = gl_WorkGroupID.y; const uint i02 = gl_WorkGroupID.y;
const uint i01 = gl_WorkGroupID.x; const uint i01 = gl_WorkGroupID.x;
@ -34,21 +37,21 @@ void main() {
// parallel max // parallel max
float localMax = uintBitsToFloat(0xFF800000); float localMax = uintBitsToFloat(0xFF800000);
for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += gl_SubgroupSize) { for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
localMax = max(localMax, in_[psrc0 + i00]); localMax = max(localMax, in_[psrc0 + i00]);
} }
float max_ = subgroupMax(localMax); float max_ = subgroupMax(localMax);
// parallel sum // parallel sum
float localSum = 0.0f; float localSum = 0.0f;
for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += gl_SubgroupSize) { for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
const float exp_psrc0 = exp(in_[psrc0 + i00] - max_); const float exp_psrc0 = exp(in_[psrc0 + i00] - max_);
localSum += exp_psrc0; localSum += exp_psrc0;
out_[pdst + i00] = exp_psrc0; out_[pdst + i00] = exp_psrc0;
} }
const float sum = subgroupAdd(localSum); const float sum = subgroupAdd(localSum);
for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += gl_SubgroupSize) { for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
out_[pdst + i00] /= sum; out_[pdst + i00] /= sum;
} }
} }