Vulkan: Add VK_AMD_shader_core_properties2 support to read Compute Unit count for split_k logic
Some checks failed
Python Type-Check / pyright type-check (push) Has been cancelled

This commit is contained in:
0cc4m 2024-12-03 20:23:06 +00:00
parent 9622fbe373
commit 7002d6c71a

View File

@ -1701,6 +1701,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
bool maintenance4_support = false; bool maintenance4_support = false;
bool sm_builtins = false; bool sm_builtins = false;
bool amd_shader_core_properties2 = false;
// Check if maintenance4 is supported // Check if maintenance4 is supported
for (const auto& properties : ext_props) { for (const auto& properties : ext_props) {
@ -1708,6 +1709,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
maintenance4_support = true; maintenance4_support = true;
} else if (strcmp("VK_NV_shader_sm_builtins", properties.extensionName) == 0) { } else if (strcmp("VK_NV_shader_sm_builtins", properties.extensionName) == 0) {
sm_builtins = true; sm_builtins = true;
} else if (strcmp("VK_AMD_shader_core_properties2", properties.extensionName) == 0) {
amd_shader_core_properties2 = true;
} }
} }
@ -1716,6 +1719,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
vk::PhysicalDeviceMaintenance4Properties props4; vk::PhysicalDeviceMaintenance4Properties props4;
vk::PhysicalDeviceSubgroupProperties subgroup_props; vk::PhysicalDeviceSubgroupProperties subgroup_props;
vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props; vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props;
vk::PhysicalDeviceShaderCoreProperties2AMD amd_shader_core_properties2_props;
props2.pNext = &props3; props2.pNext = &props3;
props3.pNext = &subgroup_props; props3.pNext = &subgroup_props;
@ -1729,6 +1733,10 @@ static vk_device ggml_vk_get_device(size_t idx) {
last_struct->pNext = (VkBaseOutStructure *)&sm_props; last_struct->pNext = (VkBaseOutStructure *)&sm_props;
last_struct = (VkBaseOutStructure *)&sm_props; last_struct = (VkBaseOutStructure *)&sm_props;
} }
if (amd_shader_core_properties2) {
last_struct->pNext = (VkBaseOutStructure *)&amd_shader_core_properties2_props;
last_struct = (VkBaseOutStructure *)&amd_shader_core_properties2_props;
}
device->physical_device.getProperties2(&props2); device->physical_device.getProperties2(&props2);
device->properties = props2.properties; device->properties = props2.properties;
@ -1748,6 +1756,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu; device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
if (sm_builtins) { if (sm_builtins) {
device->shader_core_count = sm_props.shaderSMCount; device->shader_core_count = sm_props.shaderSMCount;
} else if (amd_shader_core_properties2) {
device->shader_core_count = amd_shader_core_properties2_props.activeComputeUnitCount;
} else { } else {
device->shader_core_count = 0; device->shader_core_count = 0;
} }
@ -1822,7 +1832,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
vk11_features.pNext = &vk12_features; vk11_features.pNext = &vk12_features;
// Pointer to the last chain element // Pointer to the last chain element
VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_features; last_struct = (VkBaseOutStructure *)&vk12_features;
VkPhysicalDevicePipelineRobustnessFeaturesEXT pl_robustness_features; VkPhysicalDevicePipelineRobustnessFeaturesEXT pl_robustness_features;
pl_robustness_features.pNext = nullptr; pl_robustness_features.pNext = nullptr;
@ -1952,6 +1962,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
// Shaders // Shaders
// Disable matmul tile sizes early if performance low or not supported // Disable matmul tile sizes early if performance low or not supported
switch (device->vendor_id) { switch (device->vendor_id) {
#ifndef GGML_VULKAN_RUN_TESTS
case VK_VENDOR_ID_AMD: case VK_VENDOR_ID_AMD:
case VK_VENDOR_ID_INTEL: case VK_VENDOR_ID_INTEL:
device->mul_mat_l = false; device->mul_mat_l = false;
@ -1969,6 +1980,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
device->mul_mat_id_m = true; device->mul_mat_id_m = true;
device->mul_mat_id_s = false; device->mul_mat_id_s = false;
break; break;
#endif
default: default:
device->mul_mat_l = true; device->mul_mat_l = true;
device->mul_mat_m = true; device->mul_mat_m = true;