diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp index 8d0a57913..2ec1af5c7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp @@ -52,13 +52,16 @@ void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) { #endif #ifndef MUL_MAT_ID - const uint i13 = batch_idx / p.ne12; - const uint i12 = batch_idx % p.ne12; + uint batch_idx_a = 0; + if (batch_idx != 0) { + const uint i13 = batch_idx / p.ne12; + const uint i12 = batch_idx % p.ne12; - const uint i03 = i13 / p.broadcast3; - const uint i02 = i12 / p.broadcast2; + const uint i03 = i13 / p.broadcast3; + const uint i02 = i12 / p.broadcast2; - const uint batch_idx_a = i03 * p.ne02 + i02; + batch_idx_a = i03 * p.ne02 + i02; + } #else const uint expert_id = data_ids[expert_idx]; #endif