diff --git a/ggml-kompute.cpp b/ggml-kompute.cpp index 720a66986..520cd1fd7 100644 --- a/ggml-kompute.cpp +++ b/ggml-kompute.cpp @@ -1467,7 +1467,7 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph switch (dst->op) { case GGML_OP_ADD: { - if (ggml_nelements(src1) == ne10 && ne00 % 4 == 0) { + if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) { // src1 is a row ggml_vk_addrow(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst)/4, ne00); } else { diff --git a/kompute-shaders/op_add.comp b/kompute-shaders/op_add.comp index c86673452..b7b76a79d 100644 --- a/kompute-shaders/op_add.comp +++ b/kompute-shaders/op_add.comp @@ -30,6 +30,7 @@ layout(push_constant) uniform PushConstants { int nb1; int nb2; int nb3; + //int offs; // TODO: needed for GGML_OP_ACC, see metal code } pcs; // general-purpose kernel for addition of two tensors @@ -44,15 +45,14 @@ void main() { const uint i12 = i02 % pcs.ne12; const uint i11 = i01 % pcs.ne11; - uint src0_off = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01 + gl_SubgroupInvocationID.x*pcs.nb00) / 4); - uint src1_off = uint((i13*pcs.nb13 + i12*pcs.nb12 + i11*pcs.nb11 + gl_SubgroupInvocationID.x*pcs.nb10) / 4); - uint dst_off = uint((i03*pcs.nb3 + i02*pcs.nb2 + i01*pcs.nb1 + gl_SubgroupInvocationID.x*pcs.nb0 ) / 4); + int offs = 0; // TMP (see above) + + uint src0_off = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01 + offs) / 4); + uint src1_off = uint((i13*pcs.nb13 + i12*pcs.nb12 + i11*pcs.nb11 ) / 4); + uint dst_off = uint((i03*pcs.nb3 + i02*pcs.nb2 + i01*pcs.nb1 + offs) / 4); for (uint i0 = gl_LocalInvocationID.x; i0 < pcs.ne0; i0 += gl_WorkGroupSize.x) { - out_[pcs.outOff + dst_off] = inA[pcs.inAOff + src0_off] + inB[pcs.inBOff + src1_off]; - - src0_off += gl_WorkGroupSize.x*pcs.ne00; - src1_off += gl_WorkGroupSize.x*pcs.ne10; - dst_off += gl_WorkGroupSize.x*pcs.ne0; + const uint i10 = i0 % pcs.ne10; + out_[pcs.outOff + dst_off + i0] = inA[pcs.inAOff + src0_off + i0] + inB[pcs.inBOff + src1_off + i10]; } }