mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-29 04:44:34 +00:00
Fixes for subgroup size to bring AMD and NVIDIA inline with eachother for all kernels.
This commit is contained in:
parent
de589ced7c
commit
bc4b5ed1cb
@ -785,7 +785,8 @@ void ggml_vk_soft_max(kp::Sequence& seq,
|
||||
|
||||
std::shared_ptr<kp::Algorithm> s_algo = nullptr;
|
||||
if (!komputeManager()->hasAlgorithm(__func__)) {
|
||||
const uint32_t local_x = ggml_vk_current_device().subgroupSize;
|
||||
// FIXME: The softmax kernel needs to be fixed to use the subgroupsize which can vary by device
|
||||
const uint32_t local_x = 32;
|
||||
s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {in, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {local_x}, {pushConsts});
|
||||
} else {
|
||||
s_algo = komputeManager()->getAlgorithm(__func__);
|
||||
@ -981,8 +982,8 @@ void ggml_vk_mul_mat_q6_k(kp::Sequence& seq,
|
||||
|
||||
std::shared_ptr<kp::Algorithm> s_algo = nullptr;
|
||||
if (!komputeManager()->hasAlgorithm(__func__)) {
|
||||
// const uint32_t local_x = ggml_vk_current_device().subgroupSize * 2;
|
||||
s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)}, {2,32}, {pushConsts});
|
||||
const uint32_t local_x = ggml_vk_current_device().subgroupSize * 2;
|
||||
s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)}, {local_x}, {pushConsts});
|
||||
} else {
|
||||
s_algo = komputeManager()->getAlgorithm(__func__);
|
||||
s_algo->setTensors({inA, inB, out});
|
||||
|
@ -44,31 +44,38 @@ void main() {
|
||||
const uint r1 = gl_WorkGroupID.y;
|
||||
const uint r2 = gl_WorkGroupID.z;
|
||||
|
||||
const uint row = 2 * r0 + gl_SubgroupID;
|
||||
const uint row = (r0 * gl_NumSubgroups + gl_SubgroupID);
|
||||
const uint offset0 = r2/pcs.gqa*(nb*pcs.ne0);
|
||||
const uint x = row * nb + offset0; // Based from inA without base offset
|
||||
const uint yy = r1*pcs.ne10 + r2*pcs.ne00*pcs.ne1+pcs.inBOff; // Based from inB
|
||||
|
||||
float sumf = 0;
|
||||
|
||||
const uint tid = gl_SubgroupInvocationID/2;
|
||||
const uint ix = gl_SubgroupInvocationID%2;
|
||||
const uint ip = tid/8; // 0 or 1
|
||||
const uint il = tid%8;
|
||||
const uint n = 4;
|
||||
const uint l0 = n*il;
|
||||
const uint is = 8*ip + l0/16;
|
||||
// bits of invocation ID for gl_SubgroupSize=32:
|
||||
// x x x x x
|
||||
// 4 3 2 1 0
|
||||
// ( tid ) ix
|
||||
// ip ( il )
|
||||
|
||||
const uint block_stride = gl_SubgroupSize / 16; // number of blocks each subgroup processes
|
||||
const uint tid = gl_SubgroupInvocationID/block_stride; // first block_stride groups have tid=0
|
||||
const uint ix = gl_SubgroupInvocationID%block_stride; // first block is 0..block_stride-1
|
||||
const uint ip = tid/8; // first or second half of block (0 or 1)
|
||||
const uint il = tid%8; // each half has 8 parts, one per scale
|
||||
const uint n = 4; // 4 scales at a time (and 4 sums)
|
||||
const uint l0 = n*il; // offset into half-block, 0..28
|
||||
const uint is = 8*ip + l0/16; // 0, 1, 8, 9
|
||||
|
||||
const uint y_offset = 128*ip + l0;
|
||||
const uint q_offset_l = 64*ip + l0;
|
||||
const uint q_offset_h = 32*ip + l0;
|
||||
|
||||
for (uint i = ix; i < nb; i += 2) {
|
||||
for (uint i = ix; i < nb; i += block_stride) {
|
||||
|
||||
const uint baseIndex = (x + i) * SIZE_OF_BLOCK + pcs.inAOff;
|
||||
|
||||
const uint qlIndex = q_offset_l;
|
||||
const uint q2Index = qlIndex + 32;
|
||||
const uint q2Index = qlIndex + QK_K/8;
|
||||
const uint qhIndex = q_offset_h;
|
||||
const uint y = yy + i * QK_K + y_offset;
|
||||
|
||||
|
@ -7,6 +7,9 @@
|
||||
*/
|
||||
|
||||
void main() {
|
||||
if (gl_SubgroupInvocationID > 31)
|
||||
return;
|
||||
|
||||
const uint nb = uint(pcs.ne00/BLOCKS_IN_QUANT);
|
||||
const uint r0 = gl_WorkGroupID.x;
|
||||
const uint r1 = gl_WorkGroupID.y;
|
||||
@ -28,13 +31,13 @@ void main() {
|
||||
// gl_NumSubgroups, gl_SubgroupID, gl_SubgroupInvocationID, gl_SubgroupSize,
|
||||
// gl_WorkGroupSize.x, gl_WorkGroupSize.y, gl_WorkGroupSize.z);
|
||||
|
||||
for (uint ib = ix; ib < nb; ib += gl_SubgroupSize/2) {
|
||||
for (uint ib = ix; ib < nb; ib += 16) {
|
||||
for (int row = 0; row < N_ROWS; row++) {
|
||||
const uint block_index = x + ib + row * nb;
|
||||
sumf[row] += block_q_n_dot_y(block_index, yb, il);
|
||||
}
|
||||
|
||||
yb += BLOCKS_IN_QUANT * gl_SubgroupSize/2;
|
||||
yb += BLOCKS_IN_QUANT * 16;
|
||||
}
|
||||
|
||||
for (int row = 0; row < N_ROWS; ++row) {
|
||||
|
@ -24,6 +24,9 @@ layout(push_constant) uniform PushConstants {
|
||||
} pcs;
|
||||
|
||||
void main() {
|
||||
if (gl_SubgroupInvocationID > 31)
|
||||
return;
|
||||
|
||||
const uint i03 = gl_WorkGroupID.z;
|
||||
const uint i02 = gl_WorkGroupID.y;
|
||||
const uint i01 = gl_WorkGroupID.x;
|
||||
@ -34,21 +37,21 @@ void main() {
|
||||
|
||||
// parallel max
|
||||
float localMax = uintBitsToFloat(0xFF800000);
|
||||
for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += gl_SubgroupSize) {
|
||||
for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
|
||||
localMax = max(localMax, in_[psrc0 + i00]);
|
||||
}
|
||||
float max_ = subgroupMax(localMax);
|
||||
|
||||
// parallel sum
|
||||
float localSum = 0.0f;
|
||||
for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += gl_SubgroupSize) {
|
||||
for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
|
||||
const float exp_psrc0 = exp(in_[psrc0 + i00] - max_);
|
||||
localSum += exp_psrc0;
|
||||
out_[pdst + i00] = exp_psrc0;
|
||||
}
|
||||
|
||||
const float sum = subgroupAdd(localSum);
|
||||
for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += gl_SubgroupSize) {
|
||||
for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
|
||||
out_[pdst + i00] /= sum;
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user