From 904109ed0d97c9b656a5e8bf612925f739bb8166 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Tue, 26 Nov 2024 09:45:05 -0600 Subject: [PATCH] vulkan: fix group_norm (#10496) Fix bad calculation of the end of the range. Add a backend test that covers the bad case (taken from stable diffusion). Fixes https://github.com/leejet/stable-diffusion.cpp/issues/439. --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 2 +- ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp | 2 +- tests/test-backend-ops.cpp | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 49527fdf4..da1cfd24e 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -7157,7 +7157,7 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { const int32_t max_period = tensor->op_params[1]; tensor_clone = ggml_timestep_embedding(ggml_ctx, src0_clone, dim, max_period); } else if (tensor->op == GGML_OP_POOL_2D) { - enum ggml_op_pool op = static_cast(dst->op_params[0]); + enum ggml_op_pool op = static_cast(tensor->op_params[0]); const int32_t k0 = tensor->op_params[1]; const int32_t k1 = tensor->op_params[2]; const int32_t s0 = tensor->op_params[3]; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp index 5ad9b28da..b6a0d5645 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp @@ -19,7 +19,7 @@ void main() { const uint tid = gl_LocalInvocationID.x; const uint start = gl_WorkGroupID.x * group_size + tid; - const uint end = start + group_size; + const uint end = (gl_WorkGroupID.x + 1) * group_size; tmp[tid] = 0.0f; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 6376b0e4c..da66ed856 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -3796,7 +3796,8 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_upscale()); test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, { 512, 512, 3, 1 }, 2, true)); test_cases.emplace_back(new test_upscale_ext()); - test_cases.emplace_back(new test_group_norm()); + test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {64, 64, 320, 1})); + test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {9, 9, 1280, 1})); test_cases.emplace_back(new test_acc()); test_cases.emplace_back(new test_pad()); test_cases.emplace_back(new test_arange());