Vulkan: Implement VK_KHR_cooperative_matrix support in the matrix matrix multiplication shader

This commit is contained in:
0cc4m 2024-11-18 06:01:31 +00:00
parent cc98896db8
commit 2455bbc8bd
3 changed files with 332 additions and 68 deletions

View File

@ -1,7 +1,8 @@
#include "ggml-vulkan.h" #include "ggml-vulkan.h"
#include <vulkan/vulkan_core.h> #include <vulkan/vulkan_core.h>
#if defined(GGML_VULKAN_RUN_TESTS) || defined(GGML_VULKAN_PERF) #if defined(GGML_VULKAN_RUN_TESTS) || defined(GGML_VULKAN_PERF) || defined(GGML_VULKAN_CHECK_RESULTS)
#include <chrono> #include <chrono>
#include "ggml-cpu.h"
#endif #endif
#include <vulkan/vulkan.hpp> #include <vulkan/vulkan.hpp>
@ -168,6 +169,11 @@ struct vk_device_struct {
uint32_t shader_core_count; uint32_t shader_core_count;
bool uma; bool uma;
bool coopmat_support;
uint32_t coop_mat_m;
uint32_t coop_mat_n;
uint32_t coop_mat_k;
size_t idx; size_t idx;
vk_matmul_pipeline pipeline_matmul_f32; vk_matmul_pipeline pipeline_matmul_f32;
@ -1236,19 +1242,30 @@ static void ggml_vk_load_shaders(vk_device& device) {
const uint32_t subgroup_size_16 = std::max(device->subgroup_size, 16u); const uint32_t subgroup_size_16 = std::max(device->subgroup_size, 16u);
// mulmat // mulmat
// Matrix cores require different warp group sizes
const uint32_t tm_l = device->coopmat_support ? device->coop_mat_m : 4;
const uint32_t tm_m = device->coopmat_support ? device->coop_mat_m : 4;
const uint32_t tm_s = device->coopmat_support ? device->coop_mat_m : 2;
const uint32_t tn_l = device->coopmat_support ? device->coop_mat_n : 4;
const uint32_t tn_m = device->coopmat_support ? device->coop_mat_n : 2;
const uint32_t tn_s = device->coopmat_support ? device->coop_mat_n : 2;
const uint32_t tk_l = device->coopmat_support ? device->coop_mat_k : 1;
const uint32_t tk_m = device->coopmat_support ? device->coop_mat_k : 1;
const uint32_t tk_s = device->coopmat_support ? device->coop_mat_k : 1;
std::vector<uint32_t> l_warptile, m_warptile, s_warptile, std::vector<uint32_t> l_warptile, m_warptile, s_warptile,
l_warptile_mmq, m_warptile_mmq, s_warptile_mmq; l_warptile_mmq, m_warptile_mmq, s_warptile_mmq;
std::array<uint32_t, 3> l_wg_denoms, m_wg_denoms, s_wg_denoms, std::array<uint32_t, 3> l_wg_denoms, m_wg_denoms, s_wg_denoms,
l_mmq_wg_denoms, m_mmq_wg_denoms, s_mmq_wg_denoms; l_mmq_wg_denoms, m_mmq_wg_denoms, s_mmq_wg_denoms;
uint32_t l_align, m_align, s_align; uint32_t l_align, m_align, s_align;
l_warptile = { 128, 128, 128, 16, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size }; l_warptile = { 128, 128, 128, 16, device->subgroup_size * 2, 64, 2, tm_l, tn_l, tk_l, device->subgroup_size };
m_warptile = { 128, 64, 64, 16, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size }; m_warptile = { 128, 64, 64, 16, device->subgroup_size, 32, 2, tm_m, tn_m, tk_m, device->subgroup_size };
s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, 2, 2, device->subgroup_size }; s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, device->subgroup_size };
l_warptile_mmq = { 128, 128, 128, 32, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size }; l_warptile_mmq = { 128, 128, 128, 32, device->subgroup_size * 2, 64, 2, tm_l, tn_l, tk_l, device->subgroup_size };
m_warptile_mmq = { 128, 64, 64, 32, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size }; m_warptile_mmq = { 128, 64, 64, 32, device->subgroup_size, 32, 2, tm_m, tn_m, tk_m, device->subgroup_size };
s_warptile_mmq = { subgroup_size_16, 32, 32, 32, 32, 32, 2, 2, 2, device->subgroup_size }; s_warptile_mmq = { subgroup_size_16, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, device->subgroup_size };
l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 }; l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
m_mmq_wg_denoms = m_wg_denoms = { 64, 64, 1 }; m_mmq_wg_denoms = m_wg_denoms = { 64, 64, 1 };
@ -1324,7 +1341,52 @@ static void ggml_vk_load_shaders(vk_device& device) {
compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), name, spv_size, spv_data, entrypoint, parameter_count, push_constant_size, wg_denoms, specialization_constants, align, disable_robustness)); compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), name, spv_size, spv_data, entrypoint, parameter_count, push_constant_size, wg_denoms, specialization_constants, align, disable_robustness));
}; };
if (device->fp16) { if (device->coopmat_support) {
// Create 6 variants, {s,m,l}x{unaligned,aligned}
#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3);
CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3);
CREATE_MM(pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3);
CREATE_MM(pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3);
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4);
CREATE_MM(pipeline_matmul_id_f16, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4);
CREATE_MM(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
#undef CREATE_MM
} else if (device->fp16) {
// Create 6 variants, {s,m,l}x{unaligned,aligned} // Create 6 variants, {s,m,l}x{unaligned,aligned}
#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ #define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
@ -1647,7 +1709,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
const char* GGML_VK_FORCE_MAX_ALLOCATION_SIZE = getenv("GGML_VK_FORCE_MAX_ALLOCATION_SIZE"); const char* GGML_VK_FORCE_MAX_ALLOCATION_SIZE = getenv("GGML_VK_FORCE_MAX_ALLOCATION_SIZE");
if (GGML_VK_FORCE_MAX_ALLOCATION_SIZE != nullptr) { if (GGML_VK_FORCE_MAX_ALLOCATION_SIZE != nullptr) {
device->max_memory_allocation_size = std::stoi(GGML_VK_FORCE_MAX_ALLOCATION_SIZE); device->max_memory_allocation_size = std::stoul(GGML_VK_FORCE_MAX_ALLOCATION_SIZE);
} else if (maintenance4_support) { } else if (maintenance4_support) {
device->max_memory_allocation_size = std::min(props3.maxMemoryAllocationSize, props4.maxBufferSize); device->max_memory_allocation_size = std::min(props3.maxMemoryAllocationSize, props4.maxBufferSize);
} else { } else {
@ -1666,6 +1728,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
bool fp16_storage = false; bool fp16_storage = false;
bool fp16_compute = false; bool fp16_compute = false;
bool pipeline_robustness = false; bool pipeline_robustness = false;
device->coopmat_support = false;
for (const auto& properties : ext_props) { for (const auto& properties : ext_props) {
if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) { if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
@ -1674,6 +1737,11 @@ static vk_device ggml_vk_get_device(size_t idx) {
fp16_compute = true; fp16_compute = true;
} else if (strcmp("VK_EXT_pipeline_robustness", properties.extensionName) == 0) { } else if (strcmp("VK_EXT_pipeline_robustness", properties.extensionName) == 0) {
pipeline_robustness = true; pipeline_robustness = true;
} else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0) {
device->coopmat_support = true;
device->coop_mat_m = 0;
device->coop_mat_n = 0;
device->coop_mat_k = 0;
} }
} }
@ -1719,14 +1787,28 @@ static vk_device ggml_vk_get_device(size_t idx) {
vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES; vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES;
vk11_features.pNext = &vk12_features; vk11_features.pNext = &vk12_features;
// Pointer to the last chain element
VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_features;
VkPhysicalDevicePipelineRobustnessFeaturesEXT pl_robustness_features; VkPhysicalDevicePipelineRobustnessFeaturesEXT pl_robustness_features;
pl_robustness_features.pNext = nullptr; pl_robustness_features.pNext = nullptr;
pl_robustness_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PIPELINE_ROBUSTNESS_FEATURES_EXT; pl_robustness_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PIPELINE_ROBUSTNESS_FEATURES_EXT;
pl_robustness_features.pipelineRobustness = VK_FALSE; pl_robustness_features.pipelineRobustness = VK_FALSE;
if (pipeline_robustness) { if (pipeline_robustness) {
vk12_features.pNext = &pl_robustness_features; last_struct->pNext = (VkBaseOutStructure *)&pl_robustness_features;
device_extensions.push_back("VK_EXT_pipeline_robustness"); device_extensions.push_back("VK_EXT_pipeline_robustness");
last_struct = (VkBaseOutStructure *)&pl_robustness_features;
}
VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features;
coopmat_features.pNext = nullptr;
coopmat_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR;
coopmat_features.cooperativeMatrix = VK_FALSE;
if (device->coopmat_support) {
last_struct->pNext = (VkBaseOutStructure *)&coopmat_features;
last_struct = (VkBaseOutStructure *)&coopmat_features;
} }
vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2); vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
@ -1735,6 +1817,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
device->pipeline_robustness = pl_robustness_features.pipelineRobustness; device->pipeline_robustness = pl_robustness_features.pipelineRobustness;
device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix;
if (!vk11_features.storageBuffer16BitAccess) { if (!vk11_features.storageBuffer16BitAccess) {
std::cerr << "ggml_vulkan: device " << GGML_VK_NAME << idx << " does not support 16-bit storage." << std::endl; std::cerr << "ggml_vulkan: device " << GGML_VK_NAME << idx << " does not support 16-bit storage." << std::endl;
throw std::runtime_error("Unsupported device"); throw std::runtime_error("Unsupported device");
@ -1749,6 +1833,54 @@ static vk_device ggml_vk_get_device(size_t idx) {
if (device->fp16) { if (device->fp16) {
device_extensions.push_back("VK_KHR_shader_float16_int8"); device_extensions.push_back("VK_KHR_shader_float16_int8");
} }
if (device->coopmat_support) {
// Query supported shapes
std::vector<VkCooperativeMatrixPropertiesKHR> cm_props;
PFN_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR =
(PFN_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR)vkGetInstanceProcAddr(vk_instance.instance, "vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR");
uint32_t cm_props_num;
pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR(device->physical_device, &cm_props_num, nullptr);
cm_props.resize(cm_props_num);
for (auto& prop : cm_props) {
prop.sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_PROPERTIES_KHR;
}
pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR(device->physical_device, &cm_props_num, cm_props.data());
VK_LOG_DEBUG("ggml_vulkan: Cooperative Matrix Shapes: " << cm_props.size());
for (auto& prop : cm_props) {
VK_LOG_DEBUG("ggml_vulkan: M: " << prop.MSize << " N: " << prop.NSize << " K: " << prop.KSize << " A: " << vk::to_string((vk::ComponentTypeKHR)prop.AType) << " B: " << vk::to_string((vk::ComponentTypeKHR)prop.BType) << " C: " << vk::to_string((vk::ComponentTypeKHR)prop.CType) << " Result: " << vk::to_string((vk::ComponentTypeKHR)prop.ResultType) << " saturatingAccumulation: " << prop.saturatingAccumulation << " scope: " << vk::to_string((vk::ScopeKHR)prop.scope));
if ((vk::ComponentTypeKHR)prop.AType == vk::ComponentTypeKHR::eFloat16 &&
(vk::ComponentTypeKHR)prop.BType == vk::ComponentTypeKHR::eFloat16 &&
(vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat32 &&
(vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat32
) {
device->coop_mat_m = prop.MSize;
device->coop_mat_n = prop.NSize;
device->coop_mat_k = prop.KSize;
break;
}
}
if (device->coop_mat_m == 0) {
// No suitable matmul mode found
GGML_LOG_DEBUG("ggml_vulkan: WARNING: No suitable matrix core mode found. Disabling matrix cores.\n");
device->coopmat_support = false;
}
}
if (device->coopmat_support) {
device_extensions.push_back("VK_KHR_cooperative_matrix");
}
device->name = GGML_VK_NAME + std::to_string(idx); device->name = GGML_VK_NAME + std::to_string(idx);
device_create_info = { device_create_info = {
@ -1821,12 +1953,15 @@ static void ggml_vk_print_gpu_info(size_t idx) {
bool fp16_storage = false; bool fp16_storage = false;
bool fp16_compute = false; bool fp16_compute = false;
bool coopmat_support = false;
for (auto properties : ext_props) { for (auto properties : ext_props) {
if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) { if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
fp16_storage = true; fp16_storage = true;
} else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) { } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {
fp16_compute = true; fp16_compute = true;
} else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0) {
coopmat_support = true;
} }
} }
@ -1857,8 +1992,8 @@ static void ggml_vk_print_gpu_info(size_t idx) {
fp16 = fp16 && vk12_features.shaderFloat16; fp16 = fp16 && vk12_features.shaderFloat16;
std::string device_name = props2.properties.deviceName.data(); std::string device_name = props2.properties.deviceName.data();
GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu\n", GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | matrix cores: %d\n",
idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size); idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size, coopmat_support);
if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) { if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) {
GGML_LOG_DEBUG("ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want.\n"); GGML_LOG_DEBUG("ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want.\n");
@ -2809,7 +2944,7 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx,
if (m <= 32 || n <= 32) { if (m <= 32 || n <= 32) {
return aligned ? mmp->a_s : mmp->s; return aligned ? mmp->a_s : mmp->s;
} }
if (m <= 64 || n <= 64) { if (m <= 64 || n <= 64 || ctx->device->coopmat_support) {
return aligned ? mmp->a_m : mmp->m; return aligned ? mmp->a_m : mmp->m;
} }
return aligned ? mmp->a_l : mmp->l; return aligned ? mmp->a_l : mmp->l;
@ -4981,19 +5116,27 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
for (size_t i = 0; i < x_ne; i++) { for (size_t i = 0; i < x_ne; i++) {
if (std::is_same<float, X_TYPE>()) { if (std::is_same<float, X_TYPE>()) {
x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f; x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
// x[i] = 1.0f;
// x[i] = i + 1;
// x[i] = (i % k == i / k) ? 1.0f : 0.0f;
} else if (std::is_same<ggml_fp16_t, X_TYPE>()) { } else if (std::is_same<ggml_fp16_t, X_TYPE>()) {
x[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f); x[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f);
// x[i] = ggml_fp32_to_fp16(1.0f);
// x[i] = ggml_fp32_to_fp16(i + 1);
// x[i] = ggml_fp32_to_fp16((i % k == i / k) ? 1.0f : 0.0f);
} else { } else {
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
} }
} }
for (size_t i = 0; i < y_ne; i++) { for (size_t i = 0; i < y_ne; i++) {
if (std::is_same<float, Y_TYPE>()) { if (std::is_same<float, Y_TYPE>()) {
// y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f; y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
y[i] = (i % k == i / k) ? 1.0f : 0.0f; // y[i] = (i % k == i / k) ? 1.0f : 0.0f;
// y[i] = i + 1;
} else if (std::is_same<ggml_fp16_t, Y_TYPE>()) { } else if (std::is_same<ggml_fp16_t, Y_TYPE>()) {
// y[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f); y[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f);
y[i] = ggml_fp32_to_fp16((i % k == i / k) ? 1.0f : 0.0f); // y[i] = ggml_fp32_to_fp16((i % k == i / k) ? 1.0f : 0.0f);
// y[i] = ggml_fp32_to_fp16(i + 1);
} else { } else {
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
} }
@ -5077,7 +5220,7 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
double err = std::fabs(d[i] - d_chk[i]); double err = std::fabs(d[i] - d_chk[i]);
avg_err += err; avg_err += err;
if (err > 0.05f && first_err_n == -1) { if ((err > 0.05f || std::isnan(err)) && first_err_n == -1) {
first_err_b = i / (m * n); first_err_b = i / (m * n);
first_err_n = (i % (m * n)) / m; first_err_n = (i % (m * n)) / m;
first_err_m = (i % (m * n)) % m; first_err_m = (i % (m * n)) % m;
@ -5090,12 +5233,10 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
std::cerr << "TEST " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl; std::cerr << "TEST " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl;
if (avg_err > 0.1) { if (avg_err > 0.1 || std::isnan(avg_err)) {
std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl; std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl;
std::cerr << "Actual result: " << std::endl << std::endl; std::cerr << "Actual result: " << std::endl << std::endl;
ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
std::cerr << std::endl;
ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n + 15, first_err_b);
std::cerr << "Expected result: " << std::endl << std::endl; std::cerr << "Expected result: " << std::endl << std::endl;
ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
@ -5472,6 +5613,10 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q6_K); ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q6_K);
ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_IQ4_NL); ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_IQ4_NL);
ggml_vk_test_matmul<float, float>(ctx, 4, 4, 4, 1, 1, 1, 0);
ggml_vk_test_matmul<float, float>(ctx, 16, 16, 16, 1, 1, 1, 0);
ggml_vk_test_matmul<float, float>(ctx, 32, 32, 16, 1, 1, 1, 0);
ggml_vk_test_matmul<ggml_fp16_t, ggml_fp16_t>(ctx, 512, 512, 100, 32, 100, 1, 2); ggml_vk_test_matmul<ggml_fp16_t, ggml_fp16_t>(ctx, 512, 512, 100, 32, 100, 1, 2);
ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 1, 0); ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 1, 0);

View File

@ -7,6 +7,12 @@
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#endif #endif
#ifdef COOPMAT
#extension GL_KHR_cooperative_matrix : enable
#extension GL_KHR_memory_scope_semantics : enable
#extension GL_KHR_shader_subgroup_basic : enable
#endif
#ifdef MUL_MAT_ID #ifdef MUL_MAT_ID
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
#endif #endif
@ -57,6 +63,7 @@ layout (push_constant) uniform parameter
#endif #endif
} p; } p;
layout (constant_id = 0) const uint BLOCK_SIZE = 64;
layout (constant_id = 1) const uint BM = 64; layout (constant_id = 1) const uint BM = 64;
layout (constant_id = 2) const uint BN = 64; layout (constant_id = 2) const uint BN = 64;
layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant
@ -65,13 +72,26 @@ layout (constant_id = 5) const uint WN = 32;
layout (constant_id = 6) const uint WMITER = 2; layout (constant_id = 6) const uint WMITER = 2;
layout (constant_id = 7) const uint TM = 4; layout (constant_id = 7) const uint TM = 4;
layout (constant_id = 8) const uint TN = 2; layout (constant_id = 8) const uint TN = 2;
layout (constant_id = 9) const uint WARP = 32; layout (constant_id = 9) const uint TK = 1; // Only needed for coopmat
layout (constant_id = 10) const uint WARP = 32;
shared FLOAT_TYPE buf_a[BM * (BK+1)]; #ifdef COOPMAT
shared FLOAT_TYPE buf_b[BN * (BK+1)]; #define SHMEM_STRIDE (BK + 8)
#else
#define SHMEM_STRIDE (BK + 1)
#endif
shared FLOAT_TYPE buf_a[BM * SHMEM_STRIDE];
shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE];
#ifdef MUL_MAT_ID #ifdef MUL_MAT_ID
shared u16vec2 row_ids[3072]; shared u16vec2 row_ids[3072];
#endif // MUL_MAT_ID
#define NUM_WARPS (BLOCK_SIZE / WARP)
#ifdef COOPMAT
shared float coopmat_stage[TM * TN * NUM_WARPS];
#endif #endif
void main() { void main() {
@ -98,17 +118,32 @@ void main() {
const uint ik = gl_WorkGroupID.x / blocks_m; const uint ik = gl_WorkGroupID.x / blocks_m;
const uint ic = gl_WorkGroupID.y; const uint ic = gl_WorkGroupID.y;
const uint warp_i = gl_LocalInvocationID.x / WARP;
const uint warp_r = warp_i % (BM / WM);
const uint warp_c = warp_i / (BM / WM);
const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER); const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
const uint WSUBM = WM / WMITER; const uint WSUBM = WM / WMITER;
const uint WSUBN = WN / WNITER; const uint WSUBN = WN / WNITER;
#ifdef COOPMAT
const uint warp_i = gl_SubgroupID;
const uint tiw = gl_SubgroupInvocationID;
const uint cms_per_row = WM / TM;
const uint cms_per_col = WN / TN;
const uint storestride = WARP / TM;
const uint store_r = tiw % TM;
const uint store_c = tiw / TM;
#else
const uint warp_i = gl_LocalInvocationID.x / WARP;
const uint tiw = gl_LocalInvocationID.x % WARP; const uint tiw = gl_LocalInvocationID.x % WARP;
const uint tiwr = tiw % (WSUBM / TM); const uint tiwr = tiw % (WSUBM / TM);
const uint tiwc = tiw / (WSUBM / TM); const uint tiwc = tiw / (WSUBM / TM);
#endif
const uint warp_r = warp_i % (BM / WM);
const uint warp_c = warp_i / (BM / WM);
const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A); const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A);
const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A); const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A);
@ -156,6 +191,15 @@ void main() {
uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B; uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B;
#endif #endif
#ifdef COOPMAT
coopmat<float16_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a;
coopmat<float16_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
coopmat<float, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
sums[i] = coopmat<float, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f);
}
#else
float sums[WMITER * TM * WNITER * TN]; float sums[WMITER * TM * WNITER * TN];
FLOAT_TYPE cache_a[WMITER * TM]; FLOAT_TYPE cache_a[WMITER * TM];
FLOAT_TYPE cache_b[WNITER * TN]; FLOAT_TYPE cache_b[WNITER * TN];
@ -163,14 +207,15 @@ void main() {
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) { [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
sums[i] = 0.0f; sums[i] = 0.0f;
} }
#endif
[[unroll]] for (uint block = start_k; block < end_k; block += BK) { [[dont_unroll]] for (uint block = start_k; block < end_k; block += BK) {
[[unroll]] for (uint l = 0; l < BM; l += loadstride_a) { [[unroll]] for (uint l = 0; l < BM; l += loadstride_a) {
#if defined(DATA_A_F32) || defined(DATA_A_F16) #if defined(DATA_A_F32) || defined(DATA_A_F16)
#if LOAD_VEC_A == 8 #if LOAD_VEC_A == 8
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx][0].x); buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx][0].x);
buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx][0].y); buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx][0].y);
buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx][0].z); buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx][0].z);
@ -181,21 +226,21 @@ void main() {
buf_a[buf_idx + 7] = FLOAT_TYPE(data_a[idx][1].w); buf_a[buf_idx + 7] = FLOAT_TYPE(data_a[idx][1].w);
#elif LOAD_VEC_A == 4 #elif LOAD_VEC_A == 4
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx].x); buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx].x);
buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx].y); buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx].y);
buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx].z); buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx].z);
buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx].w); buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx].w);
#else #else
if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) { if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) {
buf_a[(loadc_a + l) * (BK+1) + loadr_a] = FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]); buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]);
} else { } else {
buf_a[(loadc_a + l) * (BK+1) + loadr_a] = FLOAT_TYPE(0.0f); buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(0.0f);
} }
#endif #endif
#elif defined(DATA_A_Q4_0) #elif defined(DATA_A_Q4_0)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
const uint ib = idx / 16; const uint ib = idx / 16;
const uint iqs = idx & 0xF; const uint iqs = idx & 0xF;
@ -208,7 +253,7 @@ void main() {
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
#elif defined(DATA_A_Q4_1) #elif defined(DATA_A_Q4_1)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
const uint ib = idx / 16; const uint ib = idx / 16;
const uint iqs = idx & 0xF; const uint iqs = idx & 0xF;
@ -222,7 +267,7 @@ void main() {
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
#elif defined(DATA_A_Q5_0) #elif defined(DATA_A_Q5_0)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
const uint ib = idx / 16; const uint ib = idx / 16;
const uint iqs = idx & 0xF; const uint iqs = idx & 0xF;
@ -237,7 +282,7 @@ void main() {
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
#elif defined(DATA_A_Q5_1) #elif defined(DATA_A_Q5_1)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
const uint ib = idx / 16; const uint ib = idx / 16;
const uint iqs = idx & 0xF; const uint iqs = idx & 0xF;
@ -253,7 +298,7 @@ void main() {
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
#elif defined(DATA_A_Q8_0) #elif defined(DATA_A_Q8_0)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 16; const uint ib = idx / 16;
const uint iqs = (idx & 0xF) * 2; const uint iqs = (idx & 0xF) * 2;
@ -265,7 +310,7 @@ void main() {
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
#elif defined(DATA_A_Q2_K) #elif defined(DATA_A_Q2_K)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 128; // 2 values per idx const uint ib = idx / 128; // 2 values per idx
const uint iqs = idx % 128; // 0..127 const uint iqs = idx % 128; // 0..127
@ -284,7 +329,7 @@ void main() {
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
#elif defined(DATA_A_Q3_K) #elif defined(DATA_A_Q3_K)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 128; // 2 values per idx const uint ib = idx / 128; // 2 values per idx
const uint iqs = idx % 128; // 0..127 const uint iqs = idx % 128; // 0..127
@ -308,7 +353,7 @@ void main() {
buf_a[buf_idx + 1] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4))); buf_a[buf_idx + 1] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4)));
#elif defined(DATA_A_Q4_K) #elif defined(DATA_A_Q4_K)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 128; // 2 values per idx const uint ib = idx / 128; // 2 values per idx
const uint iqs = idx % 128; // 0..127 const uint iqs = idx % 128; // 0..127
@ -336,7 +381,7 @@ void main() {
buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF), m)); buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF), m));
#elif defined(DATA_A_Q5_K) #elif defined(DATA_A_Q5_K)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 128; // 2 values per idx const uint ib = idx / 128; // 2 values per idx
const uint iqs = idx % 128; // 0..127 const uint iqs = idx % 128; // 0..127
@ -367,7 +412,7 @@ void main() {
buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m)); buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m));
#elif defined(DATA_A_Q6_K) #elif defined(DATA_A_Q6_K)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 128; // 2 values per idx const uint ib = idx / 128; // 2 values per idx
const uint iqs = idx % 128; // 0..127 const uint iqs = idx % 128; // 0..127
@ -386,7 +431,7 @@ void main() {
buf_a[buf_idx + 1] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32)); buf_a[buf_idx + 1] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32));
#elif defined(DATA_A_IQ4_NL) #elif defined(DATA_A_IQ4_NL)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
const uint ib = idx / 16; const uint ib = idx / 16;
const uint iqs = idx & 0xF; const uint iqs = idx & 0xF;
@ -407,7 +452,7 @@ void main() {
#else #else
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b; const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
#endif #endif
const uint buf_idx = (loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B; const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B;
buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx][0].x); buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx][0].x);
buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx][0].y); buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx][0].y);
buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx][0].z); buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx][0].z);
@ -423,24 +468,24 @@ void main() {
#else #else
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b; const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
#endif #endif
const uint buf_idx = (loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B; const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B;
buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx].x); buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx].x);
buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx].y); buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx].y);
buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx].z); buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx].z);
buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx].w); buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx].w);
#elif !MUL_MAT_ID #elif !MUL_MAT_ID
if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) { if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) {
buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]); buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]);
} else { } else {
buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(0.0f); buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
} }
#else #else
const uint row_i = ic * BN + loadc_b + l; const uint row_i = ic * BN + loadc_b + l;
if (row_i < _ne1) { if (row_i < _ne1) {
const u16vec2 row_idx = row_ids[row_i]; const u16vec2 row_idx = row_ids[row_i];
buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]); buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]);
} else { } else {
buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(0.0f); buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
} }
#endif #endif
} }
@ -450,16 +495,30 @@ void main() {
pos_a += BK / LOAD_VEC_A; pos_a += BK / LOAD_VEC_A;
pos_b += BK / LOAD_VEC_B; pos_b += BK / LOAD_VEC_B;
#ifdef COOPMAT
for (uint i = 0; i < BK; i += TK) {
for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
// Load from shared into cache
coopMatLoad(cache_a, buf_a, (warp_r * WM + cm_row * TM) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor);
for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
coopMatLoad(cache_b, buf_b, (warp_c * WN + cm_col * TN) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor);
sums[cm_col * cms_per_row + cm_row] = coopMatMulAdd(cache_a, cache_b, sums[cm_col * cms_per_row + cm_row]);
}
}
}
#else
for (uint i = 0; i < BK; i++) { for (uint i = 0; i < BK; i++) {
// Load from shared into cache // Load from shared into cache
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
[[unroll]] for (uint j = 0; j < TM; j++) { [[unroll]] for (uint j = 0; j < TM; j++) {
cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * (BK+1) + i]; cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + i];
} }
} }
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
[[unroll]] for (uint j = 0; j < TN; j++) { [[unroll]] for (uint j = 0; j < TN; j++) {
cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * (BK+1) + i]; cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i];
} }
} }
@ -474,6 +533,7 @@ void main() {
} }
} }
} }
#endif
barrier(); barrier();
} }
@ -485,6 +545,53 @@ void main() {
const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
#endif #endif
#ifdef COOPMAT
#ifdef MUL_MAT_ID
for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
[[unroll]] for (uint col = 0; col < BN; col += storestride) {
const uint row_i = dc + cm_col * TN + col + store_c;
if (row_i >= _ne1) break;
const u16vec2 row_idx = row_ids[row_i];
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
}
}
}
#else
const bool is_aligned = p.stride_d % 4 == 0; // Assumption: D_TYPE == float
for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N;
if (is_aligned && is_in_bounds) {
// Full coopMat is within bounds and stride_d is aligned with 16B
coopMatStore(sums[cm_col * cms_per_row + cm_row], data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor);
} else if (is_in_bounds) {
// Full coopMat is within bounds, but stride_d is not aligned
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
}
} else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) {
// Partial coopMat is within bounds
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) {
data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
}
}
}
}
}
#endif // MUL_MAT_ID
#else
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
@ -496,7 +603,7 @@ void main() {
if (row_i >= _ne1) break; if (row_i >= _ne1) break;
const u16vec2 row_idx = row_ids[row_i]; const u16vec2 row_idx = row_ids[row_i];
#endif #endif // MUL_MAT_ID
[[unroll]] for (uint cr = 0; cr < TM; cr++) { [[unroll]] for (uint cr = 0; cr < TM; cr++) {
#ifdef MUL_MAT_ID #ifdef MUL_MAT_ID
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
@ -504,9 +611,10 @@ void main() {
if (dr_warp + cr < p.M && dc_warp + cc < p.N) { if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
} }
#endif #endif // MUL_MAT_ID
} }
} }
} }
} }
#endif // COOPMAT
} }

View File

@ -196,8 +196,8 @@ static uint32_t compile_count = 0;
static std::mutex compile_count_mutex; static std::mutex compile_count_mutex;
static std::condition_variable compile_count_cond; static std::condition_variable compile_count_cond;
void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true) { void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, const std::string& suffix) {
std::string name = _name + (fp16 ? "" : "_fp32"); std::string name = _name + suffix;
std::string out_fname = join_paths(output_dir, name + ".spv"); std::string out_fname = join_paths(output_dir, name + ".spv");
std::string in_path = join_paths(input_dir, in_fname); std::string in_path = join_paths(input_dir, in_fname);
@ -254,7 +254,7 @@ std::map<std::string, std::string> merge_maps(const std::map<std::string, std::s
} }
static std::vector<std::future<void>> compiles; static std::vector<std::future<void>> compiles;
void string_to_spv(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true) { void string_to_spv(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, const std::string& suffix = "") {
{ {
// wait until fewer than N compiles are in progress. // wait until fewer than N compiles are in progress.
// 16 is an arbitrary limit, the goal is to avoid "failed to create pipe" errors. // 16 is an arbitrary limit, the goal is to avoid "failed to create pipe" errors.
@ -265,16 +265,17 @@ void string_to_spv(const std::string& _name, const std::string& in_fname, const
} }
compile_count++; compile_count++;
} }
compiles.push_back(std::async(string_to_spv_func, _name, in_fname, defines, fp16)); compiles.push_back(std::async(string_to_spv_func, _name, in_fname, defines, suffix));
} }
void matmul_shaders(bool fp16, bool matmul_id) { void matmul_shaders(bool fp16, bool coopmat, bool matmul_id) {
std::string load_vec = fp16 ? "8" : "4"; std::string load_vec = fp16 ? "8" : "4";
std::string aligned_b_type_f32 = fp16 ? "mat2x4" : "vec4"; std::string aligned_b_type_f32 = fp16 ? "mat2x4" : "vec4";
std::string aligned_b_type_f16 = fp16 ? "f16mat2x4" : "f16vec4"; std::string aligned_b_type_f16 = fp16 ? "f16mat2x4" : "f16vec4";
std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", fp16 ? "float16_t" : "float"}}; std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", fp16 ? "float16_t" : "float"}};
std::string shader_name = "matmul"; std::string shader_name = "matmul";
std::string suffix = "";
if (matmul_id) { if (matmul_id) {
base_dict["MUL_MAT_ID"] = "1"; base_dict["MUL_MAT_ID"] = "1";
@ -283,14 +284,20 @@ void matmul_shaders(bool fp16, bool matmul_id) {
if (fp16) { if (fp16) {
base_dict["FLOAT16"] = "1"; base_dict["FLOAT16"] = "1";
} else {
suffix = "_fp32";
}
if (coopmat) {
base_dict["COOPMAT"] = "1";
suffix = "_coopmat";
} }
// Shaders with f16 B_TYPE // Shaders with f16 B_TYPE
string_to_spv(shader_name + "_f32_f16", "mul_mm.comp", merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16); string_to_spv(shader_name + "_f32_f16", "mul_mm.comp", merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), suffix);
string_to_spv(shader_name + "_f32_f16_aligned", "mul_mm.comp", merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}}), fp16); string_to_spv(shader_name + "_f32_f16_aligned", "mul_mm.comp", merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}}), suffix);
string_to_spv(shader_name + "_f16", "mul_mm.comp", merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16); string_to_spv(shader_name + "_f16", "mul_mm.comp", merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), suffix);
string_to_spv(shader_name + "_f16_aligned", "mul_mm.comp", merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}}), fp16); string_to_spv(shader_name + "_f16_aligned", "mul_mm.comp", merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}}), suffix);
for (const auto& tname : type_names) { for (const auto& tname : type_names) {
std::string data_a_key = "DATA_A_" + to_uppercase(tname); std::string data_a_key = "DATA_A_" + to_uppercase(tname);
@ -298,8 +305,8 @@ void matmul_shaders(bool fp16, bool matmul_id) {
std::string load_vec_a_unaligned = (tname == "f32" || tname == "f16") ? "1" : "2"; std::string load_vec_a_unaligned = (tname == "f32" || tname == "f16") ? "1" : "2";
// For aligned matmul loads // For aligned matmul loads
std::string load_vec_a = (tname == "f32" || tname == "f16") ? load_vec : "2"; std::string load_vec_a = (tname == "f32" || tname == "f16") ? load_vec : "2";
string_to_spv(shader_name + "_" + tname + "_f32", "mul_mm.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16); string_to_spv(shader_name + "_" + tname + "_f32", "mul_mm.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), suffix);
string_to_spv(shader_name + "_" + tname + "_f32_aligned", "mul_mm.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}}), fp16); string_to_spv(shader_name + "_" + tname + "_f32_aligned", "mul_mm.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}}), suffix);
} }
} }
@ -307,9 +314,13 @@ void process_shaders() {
std::cout << "ggml_vulkan: Generating and compiling shaders to SPIR-V" << std::endl; std::cout << "ggml_vulkan: Generating and compiling shaders to SPIR-V" << std::endl;
std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}}; std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}};
for (const auto& fp16 : {false, true}) { for (const auto& matmul_id : {false, true}) {
matmul_shaders(fp16, false); // Float32
matmul_shaders(fp16, true); matmul_shaders(false, false, matmul_id);
// Float16
matmul_shaders(true, false, matmul_id);
// Float16 CoopMat
matmul_shaders(true, true, matmul_id);
} }
for (const auto& tname : type_names) { for (const auto& tname : type_names) {