diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index 239f913f5..01d70d1a6 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -579,29 +579,48 @@ uint32_t safe_divide(uint32_t a, uint32_t b) { return a / b; } -void ggml_vk_add(kp::Sequence& seq, - const std::shared_ptr& inA, - const std::shared_ptr& inB, - const std::shared_ptr& out, - uint32_t inAOff, uint32_t inBOff, uint32_t outOff, - uint32_t size) { +void ggml_vk_add( + kp::Sequence& seq, + const std::shared_ptr& inA, + const std::shared_ptr& inB, + const std::shared_ptr& out, + uint32_t inAOff, uint32_t inBOff, uint32_t outOff, + int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne03, + int32_t nb00, int32_t nb01, int32_t nb02, int32_t nb03, + int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13, + int32_t nb10, int32_t nb11, int32_t nb12, int32_t nb13, + int32_t ne0, + int32_t nb0, int32_t nb1, int32_t nb2, int32_t nb3 +) { const static auto spirv = getSpirvShader(kp::shader_data::op_add_comp_spv, kp::shader_data::op_add_comp_spv_len); struct PushConstants { uint32_t inAOff, inBOff, outOff; + int32_t ne00; + int32_t nb00, nb01, nb02, nb03; + int32_t ne10, ne11, ne12, ne13; + int32_t nb10, nb11, nb12, nb13; + int32_t ne0; + int32_t nb0, nb1, nb2, nb3; } const pushConsts { - safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4) + safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4), + ne00, + nb00, nb01, nb02, nb03, + ne10, ne11, ne12, ne13, + nb10, nb11, nb12, nb13, + ne0, + nb0, nb1, nb2, nb3 }; std::shared_ptr s_algo = nullptr; - if (!komputeManager()->hasAlgorithm(__func__)) - s_algo = komputeManager()->algorithm(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {size}, {}, {pushConsts}); - else { + if (!komputeManager()->hasAlgorithm(__func__)) { + s_algo = komputeManager()->algorithm(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts}); + } else { s_algo = komputeManager()->getAlgorithm(__func__); s_algo->setTensors({inA, inB, out}); - s_algo->setWorkgroup({size}); + s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)}); s_algo->setPushConstants({pushConsts}); s_algo->updateDescriptors(s_kompute_context->pool.get()); } @@ -1315,12 +1334,12 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph const int32_t ne10 = src1 ? src1->ne[0] : 0; const int32_t ne11 = src1 ? src1->ne[1] : 0; const int32_t ne12 = src1 ? src1->ne[2] : 0; -// const int32_t ne13 = src1 ? src1->ne[3] : 0; + const int32_t ne13 = src1 ? src1->ne[3] : 0; -// const uint32_t nb10 = src1 ? src1->nb[0] : 0; + const uint32_t nb10 = src1 ? src1->nb[0] : 0; const uint32_t nb11 = src1 ? src1->nb[1] : 0; const uint32_t nb12 = src1 ? src1->nb[2] : 0; -// const uint32_t nb13 = src1 ? src1->nb[3] : 0; + const uint32_t nb13 = src1 ? src1->nb[3] : 0; const int32_t ne0 = dst ? dst->ne[0] : 0; const int32_t ne1 = dst ? dst->ne[1] : 0; @@ -1354,11 +1373,19 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph } break; case GGML_OP_ADD: { - if (ggml_nelements(src1) == ne10) { + if (ggml_nelements(src1) == ne10 && ne00 % 4 == 0) { // src1 is a row ggml_vk_addrow(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst)/4, ne00); } else { - ggml_vk_add(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst)/4); + ggml_vk_add( + seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, + ne00, ne01, ne02, ne03, + nb00, nb01, nb02, nb03, + ne10, ne11, ne12, ne13, + nb10, nb11, nb12, nb13, + ne0, + nb0, nb1, nb2, nb3 + ); } } break; case GGML_OP_MUL: diff --git a/kompute/op_add.comp b/kompute/op_add.comp index 314116aac..df3fdc59c 100644 --- a/kompute/op_add.comp +++ b/kompute/op_add.comp @@ -10,7 +10,7 @@ #include "common.comp" -layout(local_size_x = 1) in; +layout(local_size_x = 1024) in; layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; }; layout(binding = 1) buffer restrict readonly tensorInB { float inB[]; }; @@ -20,13 +20,47 @@ layout(push_constant) uniform PushConstants { uint inAOff; uint inBOff; uint outOff; + int ne00; + int nb00; + int nb01; + int nb02; + int nb03; + int ne10; + int ne11; + int ne12; + int ne13; + int nb10; + int nb11; + int nb12; + int nb13; + int ne0; + int nb0; + int nb1; + int nb2; + int nb3; } pcs; +// general-purpose kernel for addition of two tensors +// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3 +// cons: not very efficient void main() { - const uint baseIndex = gl_WorkGroupID.x * 4; + const uint i03 = gl_WorkGroupID.z; + const uint i02 = gl_WorkGroupID.y; + const uint i01 = gl_WorkGroupID.x; - for (uint x = 0; x < 4; x++) { - const uint i = baseIndex + x; - out_[i + pcs.outOff] = inA[i + pcs.inAOff] + inB[i + pcs.inBOff]; + const uint i13 = i03 % pcs.ne13; + const uint i12 = i02 % pcs.ne12; + const uint i11 = i01 % pcs.ne11; + + uint src0_off = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01 + gl_SubgroupInvocationID.x*pcs.nb00) / 4); + uint src1_off = uint((i13*pcs.nb13 + i12*pcs.nb12 + i11*pcs.nb11 + gl_SubgroupInvocationID.x*pcs.nb10) / 4); + uint dst_off = uint((i03*pcs.nb3 + i02*pcs.nb2 + i01*pcs.nb1 + gl_SubgroupInvocationID.x*pcs.nb0 ) / 4); + + for (uint i0 = gl_LocalInvocationID.x; i0 < pcs.ne0; i0 += gl_WorkGroupSize.x) { + out_[pcs.outOff + dst_off] = inA[pcs.inAOff + src0_off] + inB[pcs.inBOff + src1_off]; + + src0_off += gl_WorkGroupSize.x*pcs.ne00; + src1_off += gl_WorkGroupSize.x*pcs.ne10; + dst_off += gl_WorkGroupSize.x*pcs.ne0; } }