2024-03-25 12:50:23 +00:00
|
|
|
#include "argsort.cuh"
|
|
|
|
|
|
|
|
template<typename T>
|
|
|
|
static inline __device__ void ggml_cuda_swap(T & a, T & b) {
|
|
|
|
T tmp = a;
|
|
|
|
a = b;
|
|
|
|
b = tmp;
|
|
|
|
}
|
|
|
|
|
|
|
|
template<ggml_sort_order order>
|
2024-04-03 13:07:05 +00:00
|
|
|
static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad) {
|
2024-03-25 12:50:23 +00:00
|
|
|
// bitonic sort
|
|
|
|
int col = threadIdx.x;
|
|
|
|
int row = blockIdx.y;
|
|
|
|
|
2024-04-03 13:07:05 +00:00
|
|
|
if (col >= ncols_pad) {
|
|
|
|
return;
|
|
|
|
}
|
2024-03-25 12:50:23 +00:00
|
|
|
|
|
|
|
const float * x_row = x + row * ncols;
|
2024-04-03 13:07:05 +00:00
|
|
|
extern __shared__ int dst_row[];
|
2024-03-25 12:50:23 +00:00
|
|
|
|
|
|
|
// initialize indices
|
2024-04-03 13:07:05 +00:00
|
|
|
dst_row[col] = col;
|
|
|
|
|
2024-03-25 12:50:23 +00:00
|
|
|
__syncthreads();
|
|
|
|
|
2024-04-03 13:07:05 +00:00
|
|
|
for (int k = 2; k <= ncols_pad; k *= 2) {
|
2024-03-25 12:50:23 +00:00
|
|
|
for (int j = k / 2; j > 0; j /= 2) {
|
|
|
|
int ixj = col ^ j;
|
|
|
|
if (ixj > col) {
|
|
|
|
if ((col & k) == 0) {
|
2024-04-03 13:07:05 +00:00
|
|
|
if (dst_row[col] >= ncols ||
|
|
|
|
(dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
|
|
|
|
x_row[dst_row[col]] > x_row[dst_row[ixj]] :
|
|
|
|
x_row[dst_row[col]] < x_row[dst_row[ixj]]))
|
|
|
|
) {
|
2024-03-25 12:50:23 +00:00
|
|
|
ggml_cuda_swap(dst_row[col], dst_row[ixj]);
|
|
|
|
}
|
|
|
|
} else {
|
2024-04-03 13:07:05 +00:00
|
|
|
if (dst_row[ixj] >= ncols ||
|
|
|
|
(dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
|
|
|
|
x_row[dst_row[col]] < x_row[dst_row[ixj]] :
|
|
|
|
x_row[dst_row[col]] > x_row[dst_row[ixj]]))
|
|
|
|
) {
|
2024-03-25 12:50:23 +00:00
|
|
|
ggml_cuda_swap(dst_row[col], dst_row[ixj]);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
__syncthreads();
|
|
|
|
}
|
|
|
|
}
|
2024-04-03 13:07:05 +00:00
|
|
|
|
|
|
|
// copy the result to dst without the padding
|
|
|
|
if (col < ncols) {
|
|
|
|
dst[row * ncols + col] = dst_row[col];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
static int next_power_of_2(int x) {
|
|
|
|
int n = 1;
|
|
|
|
while (n < x) {
|
|
|
|
n *= 2;
|
|
|
|
}
|
|
|
|
return n;
|
2024-03-25 12:50:23 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) {
|
|
|
|
// bitonic sort requires ncols to be power of 2
|
2024-04-03 13:07:05 +00:00
|
|
|
const int ncols_pad = next_power_of_2(ncols);
|
2024-03-25 12:50:23 +00:00
|
|
|
|
2024-04-03 13:07:05 +00:00
|
|
|
const dim3 block_dims(ncols_pad, 1, 1);
|
2024-03-25 12:50:23 +00:00
|
|
|
const dim3 block_nums(1, nrows, 1);
|
2024-04-03 13:07:05 +00:00
|
|
|
const size_t shared_mem = ncols_pad * sizeof(int);
|
|
|
|
|
2024-06-14 16:41:49 +00:00
|
|
|
// FIXME: this limit could be raised by ~2-4x on Ampere or newer
|
2024-04-03 13:07:05 +00:00
|
|
|
GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);
|
|
|
|
|
2024-03-25 12:50:23 +00:00
|
|
|
if (order == GGML_SORT_ORDER_ASC) {
|
2024-04-03 13:07:05 +00:00
|
|
|
k_argsort_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
|
2024-03-25 12:50:23 +00:00
|
|
|
} else if (order == GGML_SORT_ORDER_DESC) {
|
2024-04-03 13:07:05 +00:00
|
|
|
k_argsort_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
|
2024-03-25 12:50:23 +00:00
|
|
|
} else {
|
|
|
|
GGML_ASSERT(false);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
|
const ggml_tensor * src0 = dst->src[0];
|
|
|
|
const float * src0_d = (const float *)src0->data;
|
|
|
|
float * dst_d = (float *)dst->data;
|
|
|
|
cudaStream_t stream = ctx.stream();
|
|
|
|
|
|
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
|
|
GGML_ASSERT( dst->type == GGML_TYPE_I32);
|
|
|
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
|
|
|
|
|
|
const int64_t ncols = src0->ne[0];
|
|
|
|
const int64_t nrows = ggml_nrows(src0);
|
|
|
|
|
|
|
|
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
|
|
|
|
|
|
|
|
argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, stream);
|
|
|
|
}
|