#version 450 #include "types.comp" #define BLOCK_SIZE 1024 #define ASC 0 layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; layout (binding = 1) buffer D {int data_d[];}; layout (push_constant) uniform parameter { uint ncols; uint ncols_pad; uint order; } p; shared int dst_row[BLOCK_SIZE]; void swap(uint idx0, uint idx1) { int tmp = dst_row[idx0]; dst_row[idx0] = dst_row[idx1]; dst_row[idx1] = tmp; } void main() { // bitonic sort const int col = int(gl_LocalInvocationID.x); const uint row = gl_WorkGroupID.y; if (col >= p.ncols_pad) { return; } const uint row_offset = row * p.ncols; // initialize indices dst_row[col] = col; barrier(); for (uint k = 2; k <= p.ncols_pad; k *= 2) { for (uint j = k / 2; j > 0; j /= 2) { const uint ixj = col ^ j; if (ixj > col) { if ((col & k) == 0) { if (dst_row[col] >= p.ncols || (dst_row[ixj] < p.ncols && (p.order == ASC ? data_a[row_offset + dst_row[col]] > data_a[row_offset + dst_row[ixj]] : data_a[row_offset + dst_row[col]] < data_a[row_offset + dst_row[ixj]])) ) { swap(col, ixj); } } else { if (dst_row[ixj] >= p.ncols || (dst_row[col] < p.ncols && (p.order == ASC ? data_a[row_offset + dst_row[col]] < data_a[row_offset + dst_row[ixj]] : data_a[row_offset + dst_row[col]] > data_a[row_offset + dst_row[ixj]])) ) { swap(col, ixj); } } } barrier(); } } if (col < p.ncols) { data_d[row_offset + col] = dst_row[col]; } }