From 1613ef8d8eb2479ba55c4d598e08c8f3f18a0fed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Wed, 1 May 2024 14:46:37 +0200 Subject: [PATCH] CUDA: CUDART < 11.7 workaround for __hmax, __hmax2 (#7019) --- ggml-cuda/common.cuh | 45 +++++++++++++++++++++++++++++++++++++++----- ggml-cuda/fattn.cu | 6 +++--- 2 files changed, 43 insertions(+), 8 deletions(-) diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index 156eba6d1..b2627b7b4 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -137,7 +137,8 @@ #define STRINGIZE(...) STRINGIZE_IMPL(__VA_ARGS__) #define WARP_SIZE 32 -#define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed) +#define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed) +#define CUDART_HMASK 12000 // CUDA 12.0, min. ver. for half2 -> uint mask comparisons #define CC_PASCAL 600 #define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products @@ -293,20 +294,54 @@ static __device__ __forceinline__ float warp_reduce_max(float x) { return x; } +static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) { +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) + +#if CUDART_VERSION >= CUDART_HMAX + return __hmax(a, b); +#else + return __half2float(a) > __half2float(b) ? a : b; +#endif // CUDART_VERSION >= CUDART_HMAX + +#else + GGML_UNUSED(a); + GGML_UNUSED(b); + NO_DEVICE_CODE; +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX +} +static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const half2 b) { +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) + +#if CUDART_VERSION >= CUDART_HMAX + return __hmax2(a, b); +#else + half2 ret; + reinterpret_cast(ret.x) = __low2float(a) > __low2float(b) ? __low2half(a) : __low2half(b); + reinterpret_cast(ret.y) = __high2float(a) > __high2float(b) ? __high2half(a) : __high2half(b); + return ret; +#endif // CUDART_VERSION >= CUDART_HMAX + +#else + GGML_UNUSED(a); + GGML_UNUSED(b); + NO_DEVICE_CODE; +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX +} + static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) { - x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); + x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); } return x; #else GGML_UNUSED(x); NO_DEVICE_CODE; -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL } -#if CUDART_VERSION < 12000 +#if CUDART_VERSION < CUDART_HMASK static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half2 b) { const uint32_t mask_low = 0x0000FFFF * (float( __low2half(a)) > float( __low2half(b))); const uint32_t mask_high = 0xFFFF0000 * (float(__high2half(a)) > float(__high2half(b))); diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index df1e80068..c8a11d173 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -116,7 +116,7 @@ static __global__ void flash_attn_vec_ext_f16( sum2 = warp_reduce_sum(sum2); half sum = __low2half(sum2) + __high2half(sum2); sum += mask ? maskh[k_VKQ_0 + i_KQ] : __float2half(0.0f); - kqmax_new = __hmax(kqmax_new, sum); + kqmax_new = ggml_cuda_hmax(kqmax_new, sum); if (threadIdx.x == 0) { KQ[i_KQ] = sum; } @@ -416,9 +416,9 @@ static __global__ void flash_attn_ext_f16( const int k = k0 + threadIdx.x; KQ2_tmp[k0/WARP_SIZE] += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); - KQ_max_new = __hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]); + KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]); } - KQ_max_new = __half2half2(warp_reduce_max(__hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); + KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new; KQ_max_scale_h2[j0/nwarps] = h2exp(diff); const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));