Move the subgroups and printf into common.

This commit is contained in:
Adam Treat 2023-10-02 09:00:55 -04:00 committed by cebtenzzre
parent 93306f16d0
commit 601905e75e
3 changed files with 5 additions and 8 deletions

View File

@ -12,6 +12,8 @@
#extension GL_EXT_shader_explicit_arithmetic_types_int8: require #extension GL_EXT_shader_explicit_arithmetic_types_int8: require
#extension GL_EXT_shader_explicit_arithmetic_types_int16: require #extension GL_EXT_shader_explicit_arithmetic_types_int16: require
#extension GL_EXT_control_flow_attributes: enable #extension GL_EXT_control_flow_attributes: enable
#extension GL_KHR_shader_subgroup_arithmetic : require
#extension GL_EXT_debug_printf : enable
#define QK4_0 32 #define QK4_0 32
#define QR4_0 2 #define QR4_0 2

View File

@ -6,9 +6,6 @@
* this software. Except as expressly granted in the SOM license, all rights are reserved by Nomic, Inc. * this software. Except as expressly granted in the SOM license, all rights are reserved by Nomic, Inc.
*/ */
#extension GL_KHR_shader_subgroup_arithmetic : require
#extension GL_EXT_debug_printf : enable
void main() { void main() {
const uint nb = uint(pcs.ne00/BLOCKS_IN_QUANT); const uint nb = uint(pcs.ne00/BLOCKS_IN_QUANT);
const uint r0 = gl_WorkGroupID.x; const uint r0 = gl_WorkGroupID.x;
@ -27,9 +24,9 @@ void main() {
uint yb = y + ix * BLOCKS_IN_QUANT + il; uint yb = y + ix * BLOCKS_IN_QUANT + il;
debugPrintfEXT("gl_NumSubgroups=%d, gl_SubgroupID=%d, gl_SubgroupInvocationID=%d, glSubgroupSize=%d, gl_WorkGroupSize.x=%d, gl_WorkGroupSize.y=%d, gl_WorkGroupSize.z=%d\n", //debugPrintfEXT("gl_NumSubgroups=%d, gl_SubgroupID=%d, gl_SubgroupInvocationID=%d, glSubgroupSize=%d, gl_WorkGroupSize.x=%d, gl_WorkGroupSize.y=%d, gl_WorkGroupSize.z=%d\n",
gl_NumSubgroups, gl_SubgroupID, gl_SubgroupInvocationID, gl_SubgroupSize, // gl_NumSubgroups, gl_SubgroupID, gl_SubgroupInvocationID, gl_SubgroupSize,
gl_WorkGroupSize.x, gl_WorkGroupSize.y, gl_WorkGroupSize.z); // gl_WorkGroupSize.x, gl_WorkGroupSize.y, gl_WorkGroupSize.z);
for (uint ib = ix; ib < nb; ib += gl_SubgroupSize/2) { for (uint ib = ix; ib < nb; ib += gl_SubgroupSize/2) {
for (int row = 0; row < N_ROWS; row++) { for (int row = 0; row < N_ROWS; row++) {

View File

@ -10,8 +10,6 @@
#include "common.comp" #include "common.comp"
#extension GL_KHR_shader_subgroup_arithmetic : require
layout(local_size_x_id = 0) in; layout(local_size_x_id = 0) in;
layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; }; layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };