From 02c75452c138ffed3994bea994467717036b8e68 Mon Sep 17 00:00:00 2001 From: rhjdvsgsgks <26178113+rhjdvsgsgks@users.noreply.github.com> Date: Fri, 2 Aug 2024 18:46:31 +0000 Subject: [PATCH] vulkan: fix storageBuffer16BitAccess detection on some adreno driver --- ggml/src/ggml-vulkan.cpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml-vulkan.cpp b/ggml/src/ggml-vulkan.cpp index fa68360b9..0d17d0e47 100644 --- a/ggml/src/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan.cpp @@ -1743,16 +1743,21 @@ static vk_device ggml_vk_get_device(size_t idx) { vk11_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES; device_features2.pNext = &vk11_features; + VkPhysicalDevice16BitStorageFeatures storage_16bit; + storage_16bit.pNext = nullptr; + storage_16bit.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_16BIT_STORAGE_FEATURES; + vk11_features.pNext = &storage_16bit; + VkPhysicalDeviceVulkan12Features vk12_features; vk12_features.pNext = nullptr; vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES; - vk11_features.pNext = &vk12_features; + storage_16bit.pNext = &vk12_features; vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2); device->fp16 = device->fp16 && vk12_features.shaderFloat16; - if (!vk11_features.storageBuffer16BitAccess) { + if (!(vk11_features.storageBuffer16BitAccess && storage_16bit.storageBuffer16BitAccess)) { std::cerr << "ggml_vulkan: device " << GGML_VK_NAME << idx << " does not support 16-bit storage." << std::endl; throw std::runtime_error("Unsupported device"); } @@ -1860,15 +1865,10 @@ static void ggml_vk_print_gpu_info(size_t idx) { device_features2.pNext = nullptr; device_features2.features = (VkPhysicalDeviceFeatures)device_features; - VkPhysicalDeviceVulkan11Features vk11_features; - vk11_features.pNext = nullptr; - vk11_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES; - device_features2.pNext = &vk11_features; - VkPhysicalDeviceVulkan12Features vk12_features; vk12_features.pNext = nullptr; vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES; - vk11_features.pNext = &vk12_features; + device_features2.pNext = &vk12_features; vkGetPhysicalDeviceFeatures2(physical_device, &device_features2);