mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-03 23:34:35 +00:00
cuda : optimize argmax
This commit is contained in:
parent
9abe9eeae9
commit
35386e8904
@ -1,57 +1,68 @@
|
|||||||
#include "common.cuh"
|
#include <algorithm>
|
||||||
#include "argmax.cuh"
|
|
||||||
#include "sum.cuh"
|
|
||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
|
||||||
static __global__ void argmax_f32(
|
#include "argmax.cuh"
|
||||||
const float * x, int32_t * dst, const int64_t ncols, const int64_t nrows) {
|
#include "common.cuh"
|
||||||
|
#include "sum.cuh"
|
||||||
|
|
||||||
int argmax_thread = 0;
|
static __global__ void argmax_f32(const float * x, int32_t * dst, const int64_t ncols, const int64_t nrows) {
|
||||||
const int64_t row0 = (int64_t)blockIdx.x*WARP_SIZE;
|
const int64_t row = blockIdx.x;
|
||||||
|
|
||||||
#pragma unroll
|
float maxval = -FLT_MAX;
|
||||||
for (int64_t row1 = 0; row1 < WARP_SIZE; ++row1) {
|
int argmax = -1;
|
||||||
const int64_t row = row0 + row1;
|
const float * rowx = x + row * ncols;
|
||||||
|
|
||||||
if (row >= nrows) {
|
for (int32_t col = threadIdx.x; col < ncols; col += blockDim.x) {
|
||||||
break;
|
const float val = rowx[col];
|
||||||
|
if (val > maxval) {
|
||||||
|
maxval = val;
|
||||||
|
argmax = col;
|
||||||
}
|
}
|
||||||
|
|
||||||
float maxval = -FLT_MAX;
|
|
||||||
int argmax = -1;
|
|
||||||
|
|
||||||
for (int32_t col = threadIdx.x; col < ncols; col += WARP_SIZE) {
|
|
||||||
const float val = x[row*ncols + col];
|
|
||||||
const int bigger = val > maxval;
|
|
||||||
const int not_bigger = bigger ^ 0x00000001;
|
|
||||||
|
|
||||||
maxval = maxval*not_bigger + val*bigger;
|
|
||||||
argmax = argmax*not_bigger + col*bigger;
|
|
||||||
}
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
|
||||||
const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, mask, WARP_SIZE);
|
|
||||||
const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, mask, WARP_SIZE);
|
|
||||||
const int bigger = val > maxval;
|
|
||||||
const int not_bigger = bigger ^ 0x00000001;
|
|
||||||
|
|
||||||
maxval = maxval*not_bigger + val*bigger;
|
|
||||||
argmax = argmax*not_bigger + col*bigger;
|
|
||||||
}
|
|
||||||
|
|
||||||
const int store = row1 == threadIdx.x;
|
|
||||||
argmax_thread += store*argmax;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const int row = row0 + threadIdx.x;
|
#pragma unroll
|
||||||
|
for (int offset = 16; offset > 0; offset >>= 1) {
|
||||||
if (row >= nrows) {
|
const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
|
||||||
return;
|
const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
|
||||||
|
if (val > maxval) {
|
||||||
|
maxval = val;
|
||||||
|
argmax = col;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
dst[row] = argmax_thread;
|
const int n_warps = (blockDim.x + WARP_SIZE - 1) / WARP_SIZE;
|
||||||
|
const int lane_id = threadIdx.x % WARP_SIZE;
|
||||||
|
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||||
|
if (n_warps > 1) {
|
||||||
|
constexpr int max_warps = 1024 / WARP_SIZE;
|
||||||
|
__shared__ float shared_maxval[max_warps];
|
||||||
|
__shared__ int shared_argmax[max_warps];
|
||||||
|
if (lane_id == 0) {
|
||||||
|
shared_maxval[warp_id] = maxval;
|
||||||
|
shared_argmax[warp_id] = argmax;
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
if (warp_id == 0 && lane_id < n_warps) {
|
||||||
|
maxval = shared_maxval[lane_id];
|
||||||
|
argmax = shared_argmax[lane_id];
|
||||||
|
const unsigned int mask = (1 << n_warps) - 1;
|
||||||
|
#pragma unroll
|
||||||
|
for (int offset = 16; offset > 0; offset >>= 1) {
|
||||||
|
const float val = __shfl_xor_sync(mask, maxval, offset, WARP_SIZE);
|
||||||
|
const int col = __shfl_xor_sync(mask, argmax, offset, WARP_SIZE);
|
||||||
|
if (val > maxval) {
|
||||||
|
maxval = val;
|
||||||
|
argmax = col;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (warp_id == 0 && lane_id == 0) {
|
||||||
|
dst[row] = argmax;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
@ -70,9 +81,8 @@ void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|||||||
|
|
||||||
cudaStream_t stream = ctx.stream();
|
cudaStream_t stream = ctx.stream();
|
||||||
|
|
||||||
const int64_t num_blocks = (nrows + WARP_SIZE - 1) / WARP_SIZE;
|
const int64_t num_blocks = nrows;
|
||||||
|
const dim3 blocks_dim(std::min<int64_t>(ne00, 1024), 1, 1);
|
||||||
const dim3 blocks_dim(WARP_SIZE, 1, 1);
|
|
||||||
const dim3 blocks_num(num_blocks, 1, 1);
|
const dim3 blocks_num(num_blocks, 1, 1);
|
||||||
|
|
||||||
argmax_f32<<<blocks_num, blocks_dim, 0, stream>>>(src0_d, dst_d, ne00, nrows);
|
argmax_f32<<<blocks_num, blocks_dim, 0, stream>>>(src0_d, dst_d, ne00, nrows);
|
||||||
|
@ -180,8 +180,8 @@ static __device__ __forceinline__ int warp_reduce_sum(int x) {
|
|||||||
return __reduce_add_sync(0xffffffff, x);
|
return __reduce_add_sync(0xffffffff, x);
|
||||||
#else
|
#else
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
for (int offset = 16; offset > 0; offset >>= 1) {
|
||||||
x += __shfl_xor_sync(0xffffffff, x, mask, 32);
|
x += __shfl_xor_sync(0xffffffff, x, offset, 32);
|
||||||
}
|
}
|
||||||
return x;
|
return x;
|
||||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_AMPERE
|
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_AMPERE
|
||||||
@ -189,17 +189,17 @@ static __device__ __forceinline__ int warp_reduce_sum(int x) {
|
|||||||
|
|
||||||
static __device__ __forceinline__ float warp_reduce_sum(float x) {
|
static __device__ __forceinline__ float warp_reduce_sum(float x) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
for (int offset = 16; offset > 0; offset >>= 1) {
|
||||||
x += __shfl_xor_sync(0xffffffff, x, mask, 32);
|
x += __shfl_xor_sync(0xffffffff, x, offset, 32);
|
||||||
}
|
}
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
|
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
for (int offset = 16; offset > 0; offset >>= 1) {
|
||||||
a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
|
a.x += __shfl_xor_sync(0xffffffff, a.x, offset, 32);
|
||||||
a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
|
a.y += __shfl_xor_sync(0xffffffff, a.y, offset, 32);
|
||||||
}
|
}
|
||||||
return a;
|
return a;
|
||||||
}
|
}
|
||||||
@ -209,16 +209,16 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
|||||||
|
|
||||||
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
for (int offset = 16; offset > 0; offset >>= 1) {
|
||||||
const half2 a_other = __shfl_xor_sync(0xffffffff, a, mask, 32);
|
const half2 a_other = __shfl_xor_sync(0xffffffff, a, offset, 32);
|
||||||
reinterpret_cast<half&>(a.x) += __low2half(a_other);
|
reinterpret_cast<half&>(a.x) += __low2half(a_other);
|
||||||
reinterpret_cast<half&>(a.y) += __high2half(a_other);
|
reinterpret_cast<half&>(a.y) += __high2half(a_other);
|
||||||
}
|
}
|
||||||
return a;
|
return a;
|
||||||
#else
|
#else
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
for (int offset = 16; offset > 0; offset >>= 1) {
|
||||||
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
|
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, offset, 32));
|
||||||
}
|
}
|
||||||
return a;
|
return a;
|
||||||
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
||||||
@ -231,8 +231,8 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
|||||||
|
|
||||||
static __device__ __forceinline__ float warp_reduce_max(float x) {
|
static __device__ __forceinline__ float warp_reduce_max(float x) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
for (int offset = 16; offset > 0; offset >>= 1) {
|
||||||
x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
|
x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, offset, 32));
|
||||||
}
|
}
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
@ -275,8 +275,8 @@ static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const hal
|
|||||||
static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
|
static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
|
||||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
for (int offset = 16; offset > 0; offset >>= 1) {
|
||||||
x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
|
x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, offset, 32));
|
||||||
}
|
}
|
||||||
return x;
|
return x;
|
||||||
#else
|
#else
|
||||||
|
@ -69,8 +69,8 @@ static __global__ void quantize_mmq_q8_1(
|
|||||||
|
|
||||||
// Exchange max. abs. value between vals_per_scale/4 threads.
|
// Exchange max. abs. value between vals_per_scale/4 threads.
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = vals_per_scale/8; mask > 0; mask >>= 1) {
|
for (int offset = vals_per_scale/8; offset > 0; offset >>= 1) {
|
||||||
amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, WARP_SIZE));
|
amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, offset, WARP_SIZE));
|
||||||
}
|
}
|
||||||
|
|
||||||
float sum;
|
float sum;
|
||||||
@ -79,8 +79,8 @@ static __global__ void quantize_mmq_q8_1(
|
|||||||
|
|
||||||
// Exchange calculate sum across vals_per_sum/4 threads.
|
// Exchange calculate sum across vals_per_sum/4 threads.
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = vals_per_sum/8; mask > 0; mask >>= 1) {
|
for (int offset = vals_per_sum/8; offset > 0; offset >>= 1) {
|
||||||
sum += __shfl_xor_sync(0xFFFFFFFF, sum, mask, WARP_SIZE);
|
sum += __shfl_xor_sync(0xFFFFFFFF, sum, offset, WARP_SIZE);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1155,6 +1155,26 @@ struct test_argmax : public test_case {
|
|||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void initialize_tensors(ggml_context * ctx) override {
|
||||||
|
std::random_device rd;
|
||||||
|
std::default_random_engine rng(rd());
|
||||||
|
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||||
|
if (t->type == GGML_TYPE_F32) {
|
||||||
|
// initialize with unique values to avoid ties
|
||||||
|
for (int64_t r = 0; r < ggml_nrows(t); r++) {
|
||||||
|
std::vector<float> data(t->ne[0]);
|
||||||
|
for (int i = 0; i < t->ne[0]; i++) {
|
||||||
|
data[i] = i;
|
||||||
|
}
|
||||||
|
std::shuffle(data.begin(), data.end(), rng);
|
||||||
|
ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(float));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
init_tensor_uniform(t);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
double max_nmse_err() override {
|
double max_nmse_err() override {
|
||||||
return 0.0;
|
return 0.0;
|
||||||
}
|
}
|
||||||
@ -3441,6 +3461,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||||||
test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1));
|
test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1));
|
||||||
|
|
||||||
test_cases.emplace_back(new test_argmax());
|
test_cases.emplace_back(new test_argmax());
|
||||||
|
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 1, 1, 1}));
|
||||||
|
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {100, 10, 1, 1}));
|
||||||
|
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1}));
|
||||||
|
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {2000, 10, 1, 1}));
|
||||||
|
|
||||||
test_cases.emplace_back(new test_count_equal());
|
test_cases.emplace_back(new test_count_equal());
|
||||||
|
|
||||||
for (int ne3 : {1, 3}) { // CUDA backward pass only supports ne3 == 1
|
for (int ne3 : {1, 3}) { // CUDA backward pass only supports ne3 == 1
|
||||||
@ -3831,6 +3856,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
|||||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {64, 64, 20, 1}, false, 1.0f, 0.0f));
|
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {64, 64, 20, 1}, false, 1.0f, 0.0f));
|
||||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 64, 20, 1}, false, 1.0f, 0.0f));
|
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 64, 20, 1}, false, 1.0f, 0.0f));
|
||||||
|
|
||||||
|
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 10, 1, 1}));
|
||||||
|
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1}));
|
||||||
|
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32000, 512, 1, 1}));
|
||||||
|
|
||||||
for (int bs : {1, 512}) {
|
for (int bs : {1, 512}) {
|
||||||
for (ggml_type type_a : all_types) {
|
for (ggml_type type_a : all_types) {
|
||||||
for (ggml_type type_b : {GGML_TYPE_F32}) {
|
for (ggml_type type_b : {GGML_TYPE_F32}) {
|
||||||
|
Loading…
Reference in New Issue
Block a user