diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index 265933832..b70b7ac45 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -1358,7 +1358,7 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph // src1 is a row ggml_vk_addrow(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst), ne00); } else { - ggml_vk_add(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst)); + ggml_vk_add(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst)/4); } } break; case GGML_OP_MUL: @@ -1367,7 +1367,7 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph // src1 is a row ggml_vk_mulrow(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst), ne00); } else { - ggml_vk_mul(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst)); + ggml_vk_mul(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ggml_nelements(dst)/4); } } break; case GGML_OP_SCALE: @@ -1379,15 +1379,15 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph switch (ggml_get_unary_op(gf->nodes[i])) { case GGML_UNARY_OP_SILU: { - ggml_vk_silu(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst)); + ggml_vk_silu(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst)/4); } break; case GGML_UNARY_OP_RELU: { - ggml_vk_relu(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst)); + ggml_vk_relu(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst)/4); } break; case GGML_UNARY_OP_GELU: { - ggml_vk_gelu(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst)); + ggml_vk_gelu(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst)/4); } break; default: { @@ -1427,9 +1427,9 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph ggml_is_transposed(src1)) { fprintf(stderr, "%s: %s: matmul on tranposed tensor not supported: %u/%u\n", __func__, ggml_op_name(dst->op), src0t, src1t); goto not_implemented; - } + } - switch (src0t) { + switch (src0t) { case GGML_TYPE_F32: ggml_vk_mul_mat_mat_f32(seq, id_src0, id_src1, id_dst, @@ -1459,7 +1459,7 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph goto not_implemented; } } - + } break; case GGML_OP_GET_ROWS: { diff --git a/kompute/op_add.comp b/kompute/op_add.comp index f242864dd..314116aac 100644 --- a/kompute/op_add.comp +++ b/kompute/op_add.comp @@ -23,7 +23,10 @@ layout(push_constant) uniform PushConstants { } pcs; void main() { - const uint i = gl_WorkGroupID.x; + const uint baseIndex = gl_WorkGroupID.x * 4; - out_[i + pcs.outOff] = inA[i + pcs.inAOff] + inB[(i) + pcs.inBOff]; -} \ No newline at end of file + for (uint x = 0; x < 4; x++) { + const uint i = baseIndex + x; + out_[i + pcs.outOff] = inA[i + pcs.inAOff] + inB[i + pcs.inBOff]; + } +} diff --git a/kompute/op_gelu.comp b/kompute/op_gelu.comp index c9f8ce3cf..f74a14f7e 100644 --- a/kompute/op_gelu.comp +++ b/kompute/op_gelu.comp @@ -20,8 +20,11 @@ layout(push_constant) uniform PushConstants { } pcs; void main() { - const uint i = gl_WorkGroupID.x; - const float x = in_[i + pcs.inOff]; + const uint baseIndex = gl_WorkGroupID.x * 4; - out_[i + pcs.outOff] = 0.5*x*(1.0 + tanh(SQRT_2_OVER_PI*x*(1.0 + GELU_COEF_A*x*x))); + for (uint x = 0; x < 4; x++) { + const uint i = baseIndex + x; + const float y = in_[i + pcs.inOff]; + out_[i + pcs.outOff] = 0.5*y*(1.0 + tanh(SQRT_2_OVER_PI*y*(1.0 + GELU_COEF_A*y*y))); + } } diff --git a/kompute/op_mul.comp b/kompute/op_mul.comp index 31849b941..662ea8177 100644 --- a/kompute/op_mul.comp +++ b/kompute/op_mul.comp @@ -23,7 +23,10 @@ layout(push_constant) uniform PushConstants { } pcs; void main() { - const uint i = gl_WorkGroupID.x; + const uint baseIndex = gl_WorkGroupID.x * 4; - out_[i + pcs.outOff] = inA[i + pcs.inAOff] * inB[(i) + pcs.inBOff]; + for (uint x = 0; x < 4; x++) { + const uint i = baseIndex + x; + out_[i + pcs.outOff] = inA[i + pcs.inAOff] * inB[(i) + pcs.inBOff]; + } } \ No newline at end of file diff --git a/kompute/op_relu.comp b/kompute/op_relu.comp index 41f46be96..c6ed044a3 100644 --- a/kompute/op_relu.comp +++ b/kompute/op_relu.comp @@ -20,7 +20,10 @@ layout(push_constant) uniform PushConstants { } pcs; void main() { - const uint i = gl_WorkGroupID.x; + const uint baseIndex = gl_WorkGroupID.x * 4; - out_[i + pcs.outOff] = max(0.0, in_[i + pcs.inOff]); + for (uint x = 0; x < 4; x++) { + const uint i = baseIndex + x; + out_[i + pcs.outOff] = max(0.0, in_[i + pcs.inOff]); + } } diff --git a/kompute/op_silu.comp b/kompute/op_silu.comp index c5acac281..8c7bfe321 100644 --- a/kompute/op_silu.comp +++ b/kompute/op_silu.comp @@ -19,8 +19,12 @@ layout(push_constant) uniform PushConstants { uint outOff; } pcs; void main() { - const uint i = gl_WorkGroupID.x; - const float x = in_[i + pcs.inOff]; - out_[i + pcs.outOff] = x / (1.0 + exp(-x)); + const uint baseIndex = gl_WorkGroupID.x * 4; + + for (uint x = 0; x < 4; x++) { + const uint i = baseIndex + x; + const float y = in_[i + pcs.inOff]; + out_[i + pcs.outOff] = y / (1.0 + exp(-y)); + } }