mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 02:44:36 +00:00
CUDA: generalize FP16 fattn vec kernel (#7061)
* CUDA: generalize FP16 fattn vec kernel * disable unsupported head sizes for AMD in test * try AMD fix * fix batch size 2-8 * partially revert changes
This commit is contained in:
parent
f31ec120bc
commit
a743d76a01
@ -234,122 +234,6 @@ typedef float dfloat; // dequantize float
|
|||||||
typedef float2 dfloat2;
|
typedef float2 dfloat2;
|
||||||
#endif //GGML_CUDA_F16
|
#endif //GGML_CUDA_F16
|
||||||
|
|
||||||
[[noreturn]]
|
|
||||||
static __device__ void no_device_code(
|
|
||||||
const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {
|
|
||||||
|
|
||||||
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
|
||||||
printf("%s:%d: ERROR: HIP kernel %s has no device code compatible with HIP arch %d.\n",
|
|
||||||
file_name, line, function_name, arch);
|
|
||||||
GGML_UNUSED(arch_list);
|
|
||||||
#else
|
|
||||||
printf("%s:%d: ERROR: CUDA kernel %s has no device code compatible with CUDA arch %d. ggml-cuda.cu was compiled for: %s\n",
|
|
||||||
file_name, line, function_name, arch, arch_list);
|
|
||||||
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
|
||||||
__trap();
|
|
||||||
|
|
||||||
GGML_UNUSED(no_device_code); // suppress unused function warning
|
|
||||||
}
|
|
||||||
|
|
||||||
#ifdef __CUDA_ARCH__
|
|
||||||
#define NO_DEVICE_CODE no_device_code(__FILE__, __LINE__, __FUNCTION__, __CUDA_ARCH__, STRINGIZE(__CUDA_ARCH_LIST__))
|
|
||||||
#else
|
|
||||||
#define NO_DEVICE_CODE //GGML_ASSERT(false && "NO_DEVICE_CODE not valid in host code.")
|
|
||||||
#endif // __CUDA_ARCH__
|
|
||||||
|
|
||||||
static __device__ __forceinline__ float warp_reduce_sum(float x) {
|
|
||||||
#pragma unroll
|
|
||||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
|
||||||
x += __shfl_xor_sync(0xffffffff, x, mask, 32);
|
|
||||||
}
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
|
|
||||||
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
|
|
||||||
#pragma unroll
|
|
||||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
|
||||||
a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
|
|
||||||
a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
|
|
||||||
}
|
|
||||||
return a;
|
|
||||||
}
|
|
||||||
|
|
||||||
static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
|
||||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
|
||||||
#pragma unroll
|
|
||||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
|
||||||
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
|
|
||||||
}
|
|
||||||
return a;
|
|
||||||
#else
|
|
||||||
GGML_UNUSED(a);
|
|
||||||
NO_DEVICE_CODE;
|
|
||||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
|
||||||
}
|
|
||||||
|
|
||||||
static __device__ __forceinline__ float warp_reduce_max(float x) {
|
|
||||||
#pragma unroll
|
|
||||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
|
||||||
x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
|
|
||||||
}
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
|
|
||||||
static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {
|
|
||||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
|
||||||
|
|
||||||
#if CUDART_VERSION >= CUDART_HMAX
|
|
||||||
return __hmax(a, b);
|
|
||||||
#else
|
|
||||||
return __half2float(a) > __half2float(b) ? a : b;
|
|
||||||
#endif // CUDART_VERSION >= CUDART_HMAX
|
|
||||||
|
|
||||||
#else
|
|
||||||
GGML_UNUSED(a);
|
|
||||||
GGML_UNUSED(b);
|
|
||||||
NO_DEVICE_CODE;
|
|
||||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
|
|
||||||
}
|
|
||||||
static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const half2 b) {
|
|
||||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
|
||||||
|
|
||||||
#if CUDART_VERSION >= CUDART_HMAX
|
|
||||||
return __hmax2(a, b);
|
|
||||||
#else
|
|
||||||
half2 ret;
|
|
||||||
reinterpret_cast<half&>(ret.x) = __low2float(a) > __low2float(b) ? __low2half(a) : __low2half(b);
|
|
||||||
reinterpret_cast<half&>(ret.y) = __high2float(a) > __high2float(b) ? __high2half(a) : __high2half(b);
|
|
||||||
return ret;
|
|
||||||
#endif // CUDART_VERSION >= CUDART_HMAX
|
|
||||||
|
|
||||||
#else
|
|
||||||
GGML_UNUSED(a);
|
|
||||||
GGML_UNUSED(b);
|
|
||||||
NO_DEVICE_CODE;
|
|
||||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
|
|
||||||
}
|
|
||||||
|
|
||||||
static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
|
|
||||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
|
||||||
#pragma unroll
|
|
||||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
|
||||||
x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
|
|
||||||
}
|
|
||||||
return x;
|
|
||||||
#else
|
|
||||||
GGML_UNUSED(x);
|
|
||||||
NO_DEVICE_CODE;
|
|
||||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
|
||||||
}
|
|
||||||
|
|
||||||
#if CUDART_VERSION < CUDART_HMASK
|
|
||||||
static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half2 b) {
|
|
||||||
const uint32_t mask_low = 0x0000FFFF * (float( __low2half(a)) > float( __low2half(b)));
|
|
||||||
const uint32_t mask_high = 0xFFFF0000 * (float(__high2half(a)) > float(__high2half(b)));
|
|
||||||
return mask_low | mask_high;
|
|
||||||
}
|
|
||||||
#endif // CUDART_VERSION < 12000
|
|
||||||
|
|
||||||
#if defined(GGML_USE_HIPBLAS)
|
#if defined(GGML_USE_HIPBLAS)
|
||||||
#define __CUDA_ARCH__ 1300
|
#define __CUDA_ARCH__ 1300
|
||||||
|
|
||||||
@ -433,11 +317,143 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
|
|||||||
}
|
}
|
||||||
#endif // defined(GGML_USE_HIPBLAS)
|
#endif // defined(GGML_USE_HIPBLAS)
|
||||||
|
|
||||||
#define FP16_AVAILABLE defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) ? \
|
#define FP16_AVAILABLE (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
|
||||||
defined(RDNA1) || defined(RDNA2) || defined(RDNA3) : __CUDA_ARCH__ >= CC_PASCAL
|
|
||||||
|
|
||||||
#define FP16_MMA_AVAILABLE !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
|
#define FP16_MMA_AVAILABLE !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
|
||||||
|
|
||||||
|
static bool fp16_mma_available(const int cc) {
|
||||||
|
return cc < CC_OFFSET_AMD && cc >= CC_VOLTA;
|
||||||
|
}
|
||||||
|
|
||||||
|
[[noreturn]]
|
||||||
|
static __device__ void no_device_code(
|
||||||
|
const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {
|
||||||
|
|
||||||
|
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
printf("%s:%d: ERROR: HIP kernel %s has no device code compatible with HIP arch %d.\n",
|
||||||
|
file_name, line, function_name, arch);
|
||||||
|
GGML_UNUSED(arch_list);
|
||||||
|
#else
|
||||||
|
printf("%s:%d: ERROR: CUDA kernel %s has no device code compatible with CUDA arch %d. ggml-cuda.cu was compiled for: %s\n",
|
||||||
|
file_name, line, function_name, arch, arch_list);
|
||||||
|
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
__trap();
|
||||||
|
|
||||||
|
GGML_UNUSED(no_device_code); // suppress unused function warning
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef __CUDA_ARCH__
|
||||||
|
#define NO_DEVICE_CODE no_device_code(__FILE__, __LINE__, __FUNCTION__, __CUDA_ARCH__, STRINGIZE(__CUDA_ARCH_LIST__))
|
||||||
|
#else
|
||||||
|
#define NO_DEVICE_CODE //GGML_ASSERT(false && "NO_DEVICE_CODE not valid in host code.")
|
||||||
|
#endif // __CUDA_ARCH__
|
||||||
|
|
||||||
|
static __device__ __forceinline__ float warp_reduce_sum(float x) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||||
|
x += __shfl_xor_sync(0xffffffff, x, mask, 32);
|
||||||
|
}
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||||
|
a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
|
||||||
|
a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
|
||||||
|
}
|
||||||
|
return a;
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
||||||
|
#if FP16_AVAILABLE
|
||||||
|
|
||||||
|
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||||
|
const half2 a_other = __shfl_xor_sync(0xffffffff, a, mask, 32);
|
||||||
|
reinterpret_cast<half&>(a.x) += __low2half(a_other);
|
||||||
|
reinterpret_cast<half&>(a.y) += __high2half(a_other);
|
||||||
|
}
|
||||||
|
return a;
|
||||||
|
#else
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||||
|
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
|
||||||
|
}
|
||||||
|
return a;
|
||||||
|
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
|
||||||
|
#else
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
return a;
|
||||||
|
#endif // FP16_AVAILABLE
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ float warp_reduce_max(float x) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||||
|
x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
|
||||||
|
}
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {
|
||||||
|
#if FP16_AVAILABLE
|
||||||
|
|
||||||
|
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
|
||||||
|
return __float2half(fmaxf(__half2float(a), __half2float(b)));
|
||||||
|
#else
|
||||||
|
return __hmax(a, b);
|
||||||
|
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
|
||||||
|
|
||||||
|
#else
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
GGML_UNUSED(b);
|
||||||
|
return a;
|
||||||
|
#endif // FP16_AVAILABLE
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const half2 b) {
|
||||||
|
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||||
|
|
||||||
|
#if CUDART_VERSION >= CUDART_HMAX
|
||||||
|
return __hmax2(a, b);
|
||||||
|
#else
|
||||||
|
half2 ret;
|
||||||
|
reinterpret_cast<half&>(ret.x) = __float2half(fmaxf( __low2float(a), __low2float(b)));
|
||||||
|
reinterpret_cast<half&>(ret.y) = __float2half(fmaxf(__high2float(a), __high2float(b)));
|
||||||
|
return ret;
|
||||||
|
#endif // CUDART_VERSION >= CUDART_HMAX
|
||||||
|
|
||||||
|
#else
|
||||||
|
GGML_UNUSED(a);
|
||||||
|
GGML_UNUSED(b);
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
|
||||||
|
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||||
|
x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
|
||||||
|
}
|
||||||
|
return x;
|
||||||
|
#else
|
||||||
|
GGML_UNUSED(x);
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
||||||
|
}
|
||||||
|
|
||||||
|
#if CUDART_VERSION < CUDART_HMASK
|
||||||
|
static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half2 b) {
|
||||||
|
const uint32_t mask_low = 0x0000FFFF * (float( __low2half(a)) > float( __low2half(b)));
|
||||||
|
const uint32_t mask_high = 0xFFFF0000 * (float(__high2half(a)) > float(__high2half(b)));
|
||||||
|
return mask_low | mask_high;
|
||||||
|
}
|
||||||
|
#endif // CUDART_VERSION < 12000
|
||||||
|
|
||||||
// TODO: move to ggml-common.h
|
// TODO: move to ggml-common.h
|
||||||
static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
|
static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
|
||||||
|
|
||||||
|
@ -11,8 +11,10 @@
|
|||||||
#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
|
#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
|
||||||
#define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
|
#define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
|
||||||
|
|
||||||
template<int D, int parallel_blocks> // D == head size
|
template<int D, int ncols, int parallel_blocks> // D == head size
|
||||||
__launch_bounds__(((D + WARP_SIZE - 1) / WARP_SIZE)*WARP_SIZE, 1)
|
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||||
|
__launch_bounds__(D, 1)
|
||||||
|
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||||
static __global__ void flash_attn_vec_ext_f16(
|
static __global__ void flash_attn_vec_ext_f16(
|
||||||
const char * __restrict__ Q,
|
const char * __restrict__ Q,
|
||||||
const char * __restrict__ K,
|
const char * __restrict__ K,
|
||||||
@ -44,55 +46,77 @@ static __global__ void flash_attn_vec_ext_f16(
|
|||||||
#if FP16_AVAILABLE
|
#if FP16_AVAILABLE
|
||||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||||
|
|
||||||
const int ic = blockIdx.x / parallel_blocks; // Index of the Q/QKV column to work on.
|
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
|
||||||
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
|
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
|
||||||
|
|
||||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||||
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic);
|
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0);
|
||||||
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio));
|
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio));
|
||||||
const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
|
const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
|
||||||
const half * maskh = (const half *) mask + ne11*ic;
|
const half * maskh = (const half *) mask + ne11*ic0;
|
||||||
|
|
||||||
const int stride_KV = nb11 / sizeof(half);
|
const int stride_KV = nb11 / sizeof(half);
|
||||||
const int stride_KV2 = nb11 / sizeof(half2);
|
const int stride_KV2 = nb11 / sizeof(half2);
|
||||||
|
|
||||||
constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE;
|
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
||||||
|
constexpr int nwarps = D / WARP_SIZE;
|
||||||
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
|
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
|
||||||
__builtin_assume(tid < nwarps*WARP_SIZE);
|
__builtin_assume(tid < D);
|
||||||
|
|
||||||
__shared__ half KQ[nwarps*WARP_SIZE];
|
__shared__ half KQ[ncols*D];
|
||||||
KQ[tid] = -INFINITY;
|
#pragma unroll
|
||||||
|
for (int j = 0; j < ncols; ++j) {
|
||||||
|
KQ[j*D + tid] = -HALF_MAX_HALF;
|
||||||
|
}
|
||||||
half2 * KQ2 = (half2 *) KQ;
|
half2 * KQ2 = (half2 *) KQ;
|
||||||
|
|
||||||
half kqmax = -HALF_MAX_HALF;
|
half kqmax[ncols];
|
||||||
half kqsum = 0.0f;
|
#pragma unroll
|
||||||
|
for (int j = 0; j < ncols; ++j) {
|
||||||
|
kqmax[j] = -HALF_MAX_HALF;
|
||||||
|
}
|
||||||
|
half kqsum[ncols] = {0.0f};
|
||||||
|
|
||||||
__shared__ half kqmax_shared[WARP_SIZE];
|
__shared__ half kqmax_shared[ncols][WARP_SIZE];
|
||||||
__shared__ half kqsum_shared[WARP_SIZE];
|
__shared__ half kqsum_shared[ncols][WARP_SIZE];
|
||||||
if (threadIdx.y == 0) {
|
#pragma unroll
|
||||||
kqmax_shared[threadIdx.x] = -HALF_MAX_HALF;
|
for (int j = 0; j < ncols; ++j) {
|
||||||
kqsum_shared[threadIdx.x] = 0.0f;
|
if (threadIdx.y == 0) {
|
||||||
|
kqmax_shared[j][threadIdx.x] = -HALF_MAX_HALF;
|
||||||
|
kqsum_shared[j][threadIdx.x] = 0.0f;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
// Convert Q to half2 and store in registers:
|
// Convert Q to half2 and store in registers:
|
||||||
half2 Q_h2[(D/2 + WARP_SIZE - 1) / WARP_SIZE];
|
half2 Q_h2[ncols][D/(2*WARP_SIZE)];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
for (int j = 0; j < ncols; ++j) {
|
||||||
const int i = i0 + threadIdx.x;
|
#pragma unroll
|
||||||
if (i0 + WARP_SIZE > D/2 && i >= D/2) {
|
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||||
break;
|
const int i = i0 + threadIdx.x;
|
||||||
}
|
|
||||||
|
|
||||||
Q_h2[i0/WARP_SIZE] = make_half2(scale, scale) * make_half2(Q_f2[i].x, Q_f2[i].y);
|
const float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i];
|
||||||
|
Q_h2[j][i0/WARP_SIZE] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
half2 VKQ = make_half2(0.0f, 0.0f); // Each thread calculates a single VKQ value.
|
half2 VKQ[ncols] = {{0.0f, 0.0f}};
|
||||||
|
|
||||||
const int k_start = parallel_blocks == 1 ? 0 : ip*D;
|
const int k_start = parallel_blocks == 1 ? 0 : ip*D;
|
||||||
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
|
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
|
||||||
// Calculate KQ tile and keep track of new maximum KQ values:
|
// Calculate KQ tile and keep track of new maximum KQ values:
|
||||||
half kqmax_new = kqmax;
|
|
||||||
|
// For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression,
|
||||||
|
// see https://github.com/ggerganov/llama.cpp/pull/7061 .
|
||||||
|
// Therefore this variable is defined twice but only used once (so that the compiler can optimize out the unused variable).
|
||||||
|
half kqmax_new = kqmax[0];
|
||||||
|
half kqmax_new_arr[ncols];
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < ncols; ++j) {
|
||||||
|
kqmax_new_arr[j] = kqmax[j];
|
||||||
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) {
|
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) {
|
||||||
const int i_KQ = i_KQ_0 + threadIdx.y;
|
const int i_KQ = i_KQ_0 + threadIdx.y;
|
||||||
@ -101,89 +125,112 @@ static __global__ void flash_attn_vec_ext_f16(
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
half2 sum2 = make_half2(0.0f, 0.0f);
|
half2 sum2[ncols] = {{0.0f, 0.0f}};
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
|
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
|
||||||
const int k_KQ = k_KQ_0 + threadIdx.x;
|
const int k_KQ = k_KQ_0 + threadIdx.x;
|
||||||
if (k_KQ_0 + WARP_SIZE > D/2 && k_KQ >= D/2) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
const half2 K_ik = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
|
const half2 K_ik = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
|
||||||
sum2 += K_ik * Q_h2[k_KQ_0/WARP_SIZE];
|
|
||||||
}
|
|
||||||
|
|
||||||
sum2 = warp_reduce_sum(sum2);
|
|
||||||
half sum = __low2half(sum2) + __high2half(sum2);
|
|
||||||
sum += mask ? maskh[k_VKQ_0 + i_KQ] : __float2half(0.0f);
|
|
||||||
kqmax_new = ggml_cuda_hmax(kqmax_new, sum);
|
|
||||||
if (threadIdx.x == 0) {
|
|
||||||
KQ[i_KQ] = sum;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
kqmax_new = warp_reduce_max(kqmax_new);
|
|
||||||
if (threadIdx.x == 0) {
|
|
||||||
kqmax_shared[threadIdx.y] = kqmax_new;
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
kqmax_new = kqmax_shared[threadIdx.x];
|
|
||||||
kqmax_new = warp_reduce_max(kqmax_new);
|
|
||||||
|
|
||||||
const half KQ_max_scale = hexp(kqmax - kqmax_new);
|
|
||||||
kqmax = kqmax_new;
|
|
||||||
|
|
||||||
const half val = hexp(KQ[tid] - kqmax);
|
|
||||||
kqsum = kqsum*KQ_max_scale + val;
|
|
||||||
KQ[tid] = val;
|
|
||||||
|
|
||||||
VKQ *= __half2half2(KQ_max_scale);
|
|
||||||
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
if (tid < D) {
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k0 = 0; k0 < D; k0 += 2) {
|
for (int j = 0; j < ncols; ++j) {
|
||||||
if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k0 >= ne11) {
|
sum2[j] += K_ik * Q_h2[j][k_KQ_0/WARP_SIZE];
|
||||||
break;
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < ncols; ++j) {
|
||||||
|
sum2[j] = warp_reduce_sum(sum2[j]);
|
||||||
|
half sum = __low2half(sum2[j]) + __high2half(sum2[j]);
|
||||||
|
sum += mask ? maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
|
||||||
|
|
||||||
|
if (ncols == 1) {
|
||||||
|
kqmax_new = ggml_cuda_hmax(kqmax_new, sum);
|
||||||
|
} else {
|
||||||
|
kqmax_new_arr[j] = ggml_cuda_hmax(kqmax_new_arr[j], sum);
|
||||||
}
|
}
|
||||||
|
|
||||||
half2 V_k;
|
if (threadIdx.x == 0) {
|
||||||
reinterpret_cast<half&>(V_k.x) = V_h[(k_VKQ_0 + k0 + 0)*stride_KV + tid];
|
KQ[j*D + i_KQ] = sum;
|
||||||
reinterpret_cast<half&>(V_k.y) = V_h[(k_VKQ_0 + k0 + 1)*stride_KV + tid];
|
}
|
||||||
VKQ += V_k*KQ2[k0/2];
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < ncols; ++j) {
|
||||||
|
half kqmax_new_j = ncols == 1 ? kqmax_new : kqmax_new_arr[j];
|
||||||
|
|
||||||
|
kqmax_new_j = warp_reduce_max(kqmax_new_j);
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
kqmax_shared[j][threadIdx.y] = kqmax_new_j;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < ncols; ++j) {
|
||||||
|
half kqmax_new_j = kqmax_shared[j][threadIdx.x];
|
||||||
|
kqmax_new_j = warp_reduce_max(kqmax_new_j);
|
||||||
|
|
||||||
|
const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j);
|
||||||
|
kqmax[j] = kqmax_new_j;
|
||||||
|
|
||||||
|
const half val = hexp(KQ[j*D + tid] - kqmax[j]);
|
||||||
|
kqsum[j] = kqsum[j]*KQ_max_scale + val;
|
||||||
|
KQ[j*D + tid] = val;
|
||||||
|
|
||||||
|
VKQ[j] *= __half2half2(KQ_max_scale);
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int k0 = 0; k0 < D; k0 += 2) {
|
||||||
|
if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k0 >= ne11) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
half2 V_k;
|
||||||
|
reinterpret_cast<half&>(V_k.x) = V_h[(k_VKQ_0 + k0 + 0)*stride_KV + tid];
|
||||||
|
reinterpret_cast<half&>(V_k.y) = V_h[(k_VKQ_0 + k0 + 1)*stride_KV + tid];
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < ncols; ++j) {
|
||||||
|
VKQ[j] += V_k*KQ2[j*(D/2) + k0/2];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (tid >= D) {
|
#pragma unroll
|
||||||
kqsum = 0.0f;
|
for (int j = 0; j < ncols; ++j) {
|
||||||
|
kqsum[j] = warp_reduce_sum(kqsum[j]);
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
kqsum_shared[j][threadIdx.y] = kqsum[j];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
kqsum = warp_reduce_sum(kqsum);
|
|
||||||
if (threadIdx.x == 0) {
|
|
||||||
kqsum_shared[threadIdx.y] = kqsum;
|
|
||||||
}
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
kqsum = kqsum_shared[threadIdx.x];
|
|
||||||
kqsum = warp_reduce_sum(kqsum);
|
|
||||||
|
|
||||||
if (tid >= D) {
|
#pragma unroll
|
||||||
return;
|
for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
|
||||||
|
kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x];
|
||||||
|
kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
|
||||||
|
|
||||||
|
half dst_val = (__low2half(VKQ[j_VKQ]) + __high2half(VKQ[j_VKQ]));
|
||||||
|
if (parallel_blocks == 1) {
|
||||||
|
dst_val /= kqsum[j_VKQ];
|
||||||
|
}
|
||||||
|
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
|
||||||
|
dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
|
||||||
}
|
}
|
||||||
|
|
||||||
half dst_val = (__low2half(VKQ) + __high2half(VKQ));
|
if (parallel_blocks != 1 && tid != 0) {
|
||||||
if (parallel_blocks == 1) {
|
#pragma unroll
|
||||||
dst_val /= kqsum;
|
for (int j = 0; j < ncols; ++j) {
|
||||||
|
dst_meta[(ic0 + j)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j], kqsum[j]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = dst_val;
|
|
||||||
|
|
||||||
if (parallel_blocks == 1 || tid != 0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
dst_meta[ic*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax, kqsum);
|
|
||||||
#else
|
#else
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
#endif // FP16_AVAILABLE
|
#endif // FP16_AVAILABLE
|
||||||
@ -191,7 +238,9 @@ static __global__ void flash_attn_vec_ext_f16(
|
|||||||
|
|
||||||
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
|
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
|
||||||
template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t>
|
template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t>
|
||||||
|
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||||
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
||||||
|
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||||
static __global__ void flash_attn_ext_f16(
|
static __global__ void flash_attn_ext_f16(
|
||||||
const char * __restrict__ Q,
|
const char * __restrict__ Q,
|
||||||
const char * __restrict__ K,
|
const char * __restrict__ K,
|
||||||
@ -573,7 +622,9 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
}
|
}
|
||||||
|
|
||||||
template<int D, int parallel_blocks> // D == head size
|
template<int D, int parallel_blocks> // D == head size
|
||||||
|
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||||
__launch_bounds__(D, 1)
|
__launch_bounds__(D, 1)
|
||||||
|
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||||
static __global__ void flash_attn_combine_results(
|
static __global__ void flash_attn_combine_results(
|
||||||
const float * __restrict__ VKQ_parts,
|
const float * __restrict__ VKQ_parts,
|
||||||
const float2 * __restrict__ VKQ_meta,
|
const float2 * __restrict__ VKQ_meta,
|
||||||
@ -642,7 +693,7 @@ static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed.");
|
|||||||
static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed.");
|
static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed.");
|
||||||
static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed.");
|
static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed.");
|
||||||
|
|
||||||
template <int D, int parallel_blocks> void launch_fattn_vec_f16(
|
template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_vec_f16(
|
||||||
const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
|
const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
|
||||||
ggml_cuda_pool & pool, cudaStream_t main_stream
|
ggml_cuda_pool & pool, cudaStream_t main_stream
|
||||||
) {
|
) {
|
||||||
@ -656,13 +707,13 @@ template <int D, int parallel_blocks> void launch_fattn_vec_f16(
|
|||||||
|
|
||||||
constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE;
|
constexpr int nwarps = (D + WARP_SIZE - 1) / WARP_SIZE;
|
||||||
const dim3 block_dim(WARP_SIZE, nwarps, 1);
|
const dim3 block_dim(WARP_SIZE, nwarps, 1);
|
||||||
const dim3 blocks_num(parallel_blocks*Q->ne[1], Q->ne[2], Q->ne[3]);
|
const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
|
||||||
const int shmem = 0;
|
const int shmem = 0;
|
||||||
|
|
||||||
float scale;
|
float scale;
|
||||||
memcpy(&scale, KQV->op_params, sizeof(float));
|
memcpy(&scale, KQV->op_params, sizeof(float));
|
||||||
|
|
||||||
flash_attn_vec_ext_f16<D, parallel_blocks>
|
flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks>
|
||||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||||
(const char *) Q->data,
|
(const char *) Q->data,
|
||||||
(const char *) K->data,
|
(const char *) K->data,
|
||||||
@ -783,10 +834,99 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|||||||
|
|
||||||
ggml_cuda_set_device(ctx.device);
|
ggml_cuda_set_device(ctx.device);
|
||||||
|
|
||||||
|
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||||
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
|
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
|
||||||
|
|
||||||
const int32_t precision = KQV->op_params[1];
|
const int32_t precision = KQV->op_params[1];
|
||||||
|
|
||||||
|
if (!fp16_mma_available(cc)) {
|
||||||
|
GGML_ASSERT(precision == GGML_PREC_DEFAULT);
|
||||||
|
GGML_ASSERT(Q->ne[0] == 64 || Q->ne[0] == 128 && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
|
||||||
|
|
||||||
|
if (Q->ne[1] == 1) {
|
||||||
|
constexpr int cols_per_block = 1;
|
||||||
|
constexpr int parallel_blocks = 4;
|
||||||
|
switch (Q->ne[0]) {
|
||||||
|
case 64:
|
||||||
|
launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
||||||
|
break;
|
||||||
|
case 128:
|
||||||
|
launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
GGML_ASSERT(false);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (Q->ne[1] == 2) {
|
||||||
|
constexpr int cols_per_block = 2;
|
||||||
|
constexpr int parallel_blocks = 4;
|
||||||
|
switch (Q->ne[0]) {
|
||||||
|
case 64:
|
||||||
|
launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
||||||
|
break;
|
||||||
|
case 128:
|
||||||
|
launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
GGML_ASSERT(false);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (Q->ne[1] <= 4) {
|
||||||
|
constexpr int cols_per_block = 4;
|
||||||
|
constexpr int parallel_blocks = 4;
|
||||||
|
switch (Q->ne[0]) {
|
||||||
|
case 64:
|
||||||
|
launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
||||||
|
break;
|
||||||
|
case 128:
|
||||||
|
launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
GGML_ASSERT(false);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (Q->ne[1] <= 8) {
|
||||||
|
constexpr int cols_per_block = 8;
|
||||||
|
constexpr int parallel_blocks = 4;
|
||||||
|
switch (Q->ne[0]) {
|
||||||
|
case 64:
|
||||||
|
launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
||||||
|
break;
|
||||||
|
case 128:
|
||||||
|
launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
GGML_ASSERT(false);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr int cols_per_block = 8;
|
||||||
|
constexpr int parallel_blocks = 1;
|
||||||
|
switch (Q->ne[0]) {
|
||||||
|
case 64:
|
||||||
|
launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
||||||
|
break;
|
||||||
|
case 128:
|
||||||
|
launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
GGML_ASSERT(false);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
if (precision != GGML_PREC_DEFAULT) {
|
if (precision != GGML_PREC_DEFAULT) {
|
||||||
if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
|
if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
|
||||||
constexpr int cols_per_block = 16;
|
constexpr int cols_per_block = 16;
|
||||||
@ -845,16 +985,17 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
|
if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
|
||||||
|
constexpr int cols_per_block = 1;
|
||||||
constexpr int parallel_blocks = 4;
|
constexpr int parallel_blocks = 4;
|
||||||
switch (Q->ne[0]) {
|
switch (Q->ne[0]) {
|
||||||
case 64:
|
case 64:
|
||||||
launch_fattn_vec_f16< 64, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
||||||
break;
|
break;
|
||||||
case 128:
|
case 128:
|
||||||
launch_fattn_vec_f16<128, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
||||||
break;
|
break;
|
||||||
case 256:
|
case 256:
|
||||||
launch_fattn_vec_f16<256, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
launch_fattn_vec_f16<256, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
|
@ -15519,13 +15519,6 @@ struct llama_context * llama_new_context_with_model(
|
|||||||
cparams.flash_attn = false;
|
cparams.flash_attn = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef GGML_USE_HIPBLAS
|
|
||||||
if (cparams.flash_attn) {
|
|
||||||
LLAMA_LOG_WARN("%s: flash_attn is not yet compatible with HIPBLAS builds - forcing off\n", __func__);
|
|
||||||
cparams.flash_attn = false;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
if (params.seed == LLAMA_DEFAULT_SEED) {
|
if (params.seed == LLAMA_DEFAULT_SEED) {
|
||||||
params.seed = time(NULL);
|
params.seed = time(NULL);
|
||||||
}
|
}
|
||||||
|
@ -2175,7 +2175,11 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
|||||||
test_cases.emplace_back(new test_timestep_embedding());
|
test_cases.emplace_back(new test_timestep_embedding());
|
||||||
test_cases.emplace_back(new test_leaky_relu());
|
test_cases.emplace_back(new test_leaky_relu());
|
||||||
|
|
||||||
|
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
for (int hs : { 64, 128, }) { // other head sizes not implemented
|
||||||
|
#else
|
||||||
for (int hs : { 64, 80, 128, 256, }) {
|
for (int hs : { 64, 80, 128, 256, }) {
|
||||||
|
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
for (int nh : { 32, }) {
|
for (int nh : { 32, }) {
|
||||||
for (int kv : { 512, 1024, }) {
|
for (int kv : { 512, 1024, }) {
|
||||||
for (int nb : { 1, 2, 4, 8, }) {
|
for (int nb : { 1, 2, 4, 8, }) {
|
||||||
|
Loading…
Reference in New Issue
Block a user