From a934b2cb8a1cbe2aad1ca10a119df60bbcf8d5d1 Mon Sep 17 00:00:00 2001 From: Jared Van Bortel Date: Tue, 14 Nov 2023 11:59:58 -0500 Subject: [PATCH] vulkan : assert various kernel requirements --- ggml-vulkan.cpp | 47 ++++++++++++++++++++++++++++------------------- 1 file changed, 28 insertions(+), 19 deletions(-) diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index 74d9fceb6..d4d6d1b87 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -1416,27 +1416,34 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph case GGML_OP_SCALE: { const float scale = *(const float *) src1->data; - ggml_vk_scale(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst)/8, scale); + int64_t n = ggml_nelements(dst); + GGML_ASSERT(n % 8 == 0); + ggml_vk_scale(seq, id_src0, id_dst, off_src0, off_dst, n/8, scale); } break; case GGML_OP_UNARY: - switch (ggml_get_unary_op(gf->nodes[i])) { - case GGML_UNARY_OP_SILU: - { - ggml_vk_silu(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst)/4); - } break; - case GGML_UNARY_OP_RELU: - { - ggml_vk_relu(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst)/4); - } break; - case GGML_UNARY_OP_GELU: - { - ggml_vk_gelu(seq, id_src0, id_dst, off_src0, off_dst, ggml_nelements(dst)/8); - } break; - default: - { - fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op)); - GGML_ASSERT(false); - } + { + int64_t n = ggml_nelements(dst); + GGML_ASSERT(n % 4 == 0); + switch (ggml_get_unary_op(gf->nodes[i])) { + case GGML_UNARY_OP_SILU: + { + ggml_vk_silu(seq, id_src0, id_dst, off_src0, off_dst, n/4); + } break; + case GGML_UNARY_OP_RELU: + { + ggml_vk_relu(seq, id_src0, id_dst, off_src0, off_dst, n/4); + } break; + case GGML_UNARY_OP_GELU: + { + GGML_ASSERT(n % 8 == 0); + ggml_vk_gelu(seq, id_src0, id_dst, off_src0, off_dst, n/8); + } break; + default: + { + fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op)); + GGML_ASSERT(false); + } + } } break; case GGML_OP_SOFT_MAX: { @@ -1455,6 +1462,8 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph } break; case GGML_OP_RMS_NORM: { + GGML_ASSERT(ne00 % 4 == 0); + float eps; memcpy(&eps, dst->op_params, sizeof(float)); ggml_vk_rms_norm(seq, id_src0, id_dst, off_src0, off_dst, ne00, nb01, ggml_nrows(src0), eps);