diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index 57813cb3d..f2320f3cc 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -810,12 +810,10 @@ void ggml_vk_norm_(const std::vector& spirv, kp::Sequence& seq, const std::shared_ptr& out, uint32_t inOff, uint32_t outOff, int32_t ne00, int32_t nb01, - int32_t nrows) { + int32_t nrows, float epsilon) { GGML_ASSERT(nb01%sizeof(float) == 0); GGML_ASSERT(ne00%sizeof(float) == 0); - const float epsilon = 1e-6f; // this is what ggml.c uses for rms norm - struct PushConstants { uint32_t inOff, outOff; uint32_t ne00, nb01; @@ -1559,11 +1557,15 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph } break; case GGML_OP_NORM: { - ggml_vk_norm(seq, id_src0, id_dst, off_src0, off_dst, ne00, nb01, ggml_nrows(src0)); + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + ggml_vk_norm(seq, id_src0, id_dst, off_src0, off_dst, ne00, nb01, ggml_nrows(src0), eps); } break; case GGML_OP_RMS_NORM: { - ggml_vk_rms_norm(seq, id_src0, id_dst, off_src0, off_dst, ne00, nb01, ggml_nrows(src0)); + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + ggml_vk_rms_norm(seq, id_src0, id_dst, off_src0, off_dst, ne00, nb01, ggml_nrows(src0), eps); } break; case GGML_OP_MUL_MAT: {