mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 10:54:36 +00:00
CPU/CUDA: Gemma 2 FlashAttention support (#8542)
* CPU/CUDA: Gemma 2 FlashAttention support * apply logit_softcap to scale in kernel * disable logit softcapping tests on Metal * remove metal check
This commit is contained in:
parent
8f824ffe8e
commit
e11bd856d5
@ -1760,7 +1760,8 @@ extern "C" {
|
|||||||
struct ggml_tensor * v,
|
struct ggml_tensor * v,
|
||||||
struct ggml_tensor * mask,
|
struct ggml_tensor * mask,
|
||||||
float scale,
|
float scale,
|
||||||
float max_bias);
|
float max_bias,
|
||||||
|
float logit_softcap);
|
||||||
|
|
||||||
GGML_API void ggml_flash_attn_ext_set_prec(
|
GGML_API void ggml_flash_attn_ext_set_prec(
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
|
@ -22,6 +22,7 @@ typedef void (* fattn_kernel_t)(
|
|||||||
const float m0,
|
const float m0,
|
||||||
const float m1,
|
const float m1,
|
||||||
const uint32_t n_head_log2,
|
const uint32_t n_head_log2,
|
||||||
|
const float logit_softcap,
|
||||||
const int ne00,
|
const int ne00,
|
||||||
const int ne01,
|
const int ne01,
|
||||||
const int ne02,
|
const int ne02,
|
||||||
@ -659,9 +660,15 @@ void launch_fattn(
|
|||||||
|
|
||||||
float scale = 1.0f;
|
float scale = 1.0f;
|
||||||
float max_bias = 0.0f;
|
float max_bias = 0.0f;
|
||||||
|
float logit_softcap = 0.0f;
|
||||||
|
|
||||||
memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
|
memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
|
||||||
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
|
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
|
||||||
|
memcpy(&logit_softcap, (float *) KQV->op_params + 2, sizeof(float));
|
||||||
|
|
||||||
|
if (logit_softcap != 0.0f) {
|
||||||
|
scale /= logit_softcap;
|
||||||
|
}
|
||||||
|
|
||||||
const uint32_t n_head = Q->ne[2];
|
const uint32_t n_head = Q->ne[2];
|
||||||
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
||||||
@ -675,7 +682,7 @@ void launch_fattn(
|
|||||||
V_data,
|
V_data,
|
||||||
mask ? ((const char *) mask->data) : nullptr,
|
mask ? ((const char *) mask->data) : nullptr,
|
||||||
(parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
|
(parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
|
||||||
scale, max_bias, m0, m1, n_head_log2,
|
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
|
||||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||||
|
@ -4,7 +4,7 @@
|
|||||||
|
|
||||||
#define FATTN_KQ_STRIDE_TILE_F16 64
|
#define FATTN_KQ_STRIDE_TILE_F16 64
|
||||||
|
|
||||||
template<int D, int ncols, int nwarps, int parallel_blocks> // D == head size
|
template<int D, int ncols, int nwarps, int parallel_blocks, bool use_logit_softcap> // D == head size
|
||||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||||
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
||||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||||
@ -20,6 +20,7 @@ static __global__ void flash_attn_tile_ext_f16(
|
|||||||
const float m0,
|
const float m0,
|
||||||
const float m1,
|
const float m1,
|
||||||
const uint32_t n_head_log2,
|
const uint32_t n_head_log2,
|
||||||
|
const float logit_softcap,
|
||||||
const int ne00,
|
const int ne00,
|
||||||
const int ne01,
|
const int ne01,
|
||||||
const int ne02,
|
const int ne02,
|
||||||
@ -44,6 +45,12 @@ static __global__ void flash_attn_tile_ext_f16(
|
|||||||
const int ne2,
|
const int ne2,
|
||||||
const int ne3) {
|
const int ne3) {
|
||||||
#ifdef FP16_AVAILABLE
|
#ifdef FP16_AVAILABLE
|
||||||
|
// Skip unused kernel variants for faster compilation:
|
||||||
|
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||||
|
|
||||||
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
|
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
|
||||||
@ -154,7 +161,13 @@ static __global__ void flash_attn_tile_ext_f16(
|
|||||||
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
|
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
|
||||||
const int j_KQ = j_KQ_0 + threadIdx.y;
|
const int j_KQ = j_KQ_0 + threadIdx.y;
|
||||||
|
|
||||||
half sum = __low2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]) + __high2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
|
half sum;
|
||||||
|
if (use_logit_softcap) {
|
||||||
|
const float2 tmp = __half22float2(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
|
||||||
|
sum = logit_softcap * tanhf(tmp.x + tmp.y);
|
||||||
|
} else {
|
||||||
|
sum = __low2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]) + __high2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
|
||||||
|
}
|
||||||
sum += mask ? slopeh*maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
|
sum += mask ? slopeh*maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
|
||||||
|
|
||||||
kqmax_new[j_KQ_0/nwarps] = ggml_cuda_hmax(kqmax_new[j_KQ_0/nwarps], sum);
|
kqmax_new[j_KQ_0/nwarps] = ggml_cuda_hmax(kqmax_new[j_KQ_0/nwarps], sum);
|
||||||
@ -270,20 +283,20 @@ static __global__ void flash_attn_tile_ext_f16(
|
|||||||
#endif // FP16_AVAILABLE
|
#endif // FP16_AVAILABLE
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int cols_per_block, int parallel_blocks>
|
template <int cols_per_block, int parallel_blocks, bool use_logit_softcap>
|
||||||
void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
const ggml_tensor * Q = dst->src[0];
|
const ggml_tensor * Q = dst->src[0];
|
||||||
switch (Q->ne[0]) {
|
switch (Q->ne[0]) {
|
||||||
case 64: {
|
case 64: {
|
||||||
constexpr int D = 64;
|
constexpr int D = 64;
|
||||||
constexpr int nwarps = 8;
|
constexpr int nwarps = 8;
|
||||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
|
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
|
||||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||||
} break;
|
} break;
|
||||||
case 128: {
|
case 128: {
|
||||||
constexpr int D = 128;
|
constexpr int D = 128;
|
||||||
constexpr int nwarps = 8;
|
constexpr int nwarps = 8;
|
||||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
|
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
|
||||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||||
} break;
|
} break;
|
||||||
default: {
|
default: {
|
||||||
@ -296,24 +309,45 @@ void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_ten
|
|||||||
const ggml_tensor * KQV = dst;
|
const ggml_tensor * KQV = dst;
|
||||||
const ggml_tensor * Q = dst->src[0];
|
const ggml_tensor * Q = dst->src[0];
|
||||||
|
|
||||||
const int32_t precision = KQV->op_params[2];
|
const int32_t precision = KQV->op_params[3];
|
||||||
GGML_ASSERT(precision == GGML_PREC_DEFAULT);
|
GGML_ASSERT(precision == GGML_PREC_DEFAULT);
|
||||||
|
|
||||||
|
float logit_softcap;
|
||||||
|
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
||||||
|
|
||||||
if (Q->ne[1] <= 16) {
|
if (Q->ne[1] <= 16) {
|
||||||
constexpr int cols_per_block = 16;
|
constexpr int cols_per_block = 16;
|
||||||
constexpr int parallel_blocks = 4;
|
constexpr int parallel_blocks = 4;
|
||||||
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
|
if (logit_softcap == 0.0f) {
|
||||||
|
constexpr bool use_logit_softcap = false;
|
||||||
|
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||||
|
} else {
|
||||||
|
constexpr bool use_logit_softcap = true;
|
||||||
|
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (Q->ne[1] <= 32) {
|
if (Q->ne[1] <= 32) {
|
||||||
constexpr int cols_per_block = 32;
|
constexpr int cols_per_block = 32;
|
||||||
constexpr int parallel_blocks = 4;
|
constexpr int parallel_blocks = 4;
|
||||||
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
|
if (logit_softcap == 0.0f) {
|
||||||
|
constexpr bool use_logit_softcap = false;
|
||||||
|
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||||
|
} else {
|
||||||
|
constexpr bool use_logit_softcap = true;
|
||||||
|
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
constexpr int cols_per_block = 32;
|
constexpr int cols_per_block = 32;
|
||||||
constexpr int parallel_blocks = 1;
|
constexpr int parallel_blocks = 1;
|
||||||
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
|
if (logit_softcap == 0.0f) {
|
||||||
|
constexpr bool use_logit_softcap = false;
|
||||||
|
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||||
|
} else {
|
||||||
|
constexpr bool use_logit_softcap = true;
|
||||||
|
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -4,7 +4,7 @@
|
|||||||
|
|
||||||
#define FATTN_KQ_STRIDE_TILE_F32 32
|
#define FATTN_KQ_STRIDE_TILE_F32 32
|
||||||
|
|
||||||
template<int D, int ncols, int nwarps, int parallel_blocks> // D == head size
|
template<int D, int ncols, int nwarps, int parallel_blocks, bool use_logit_softcap> // D == head size
|
||||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||||
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
||||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||||
@ -20,6 +20,7 @@ static __global__ void flash_attn_tile_ext_f32(
|
|||||||
const float m0,
|
const float m0,
|
||||||
const float m1,
|
const float m1,
|
||||||
const uint32_t n_head_log2,
|
const uint32_t n_head_log2,
|
||||||
|
const float logit_softcap,
|
||||||
const int ne00,
|
const int ne00,
|
||||||
const int ne01,
|
const int ne01,
|
||||||
const int ne02,
|
const int ne02,
|
||||||
@ -43,6 +44,12 @@ static __global__ void flash_attn_tile_ext_f32(
|
|||||||
const int ne1,
|
const int ne1,
|
||||||
const int ne2,
|
const int ne2,
|
||||||
const int ne3) {
|
const int ne3) {
|
||||||
|
// Skip unused kernel variants for faster compilation:
|
||||||
|
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||||
|
|
||||||
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
|
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
|
||||||
@ -151,6 +158,10 @@ static __global__ void flash_attn_tile_ext_f32(
|
|||||||
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
|
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
|
||||||
const int j_KQ = j_KQ_0 + threadIdx.y;
|
const int j_KQ = j_KQ_0 + threadIdx.y;
|
||||||
|
|
||||||
|
if (use_logit_softcap) {
|
||||||
|
sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] = logit_softcap * tanhf(sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
|
||||||
|
}
|
||||||
|
|
||||||
sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
|
sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
|
||||||
|
|
||||||
kqmax_new[j_KQ_0/nwarps] = fmaxf(kqmax_new[j_KQ_0/nwarps], sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
|
kqmax_new[j_KQ_0/nwarps] = fmaxf(kqmax_new[j_KQ_0/nwarps], sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
|
||||||
@ -267,20 +278,20 @@ static __global__ void flash_attn_tile_ext_f32(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int cols_per_block, int parallel_blocks>
|
template <int cols_per_block, int parallel_blocks, bool use_logit_softcap>
|
||||||
void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
const ggml_tensor * Q = dst->src[0];
|
const ggml_tensor * Q = dst->src[0];
|
||||||
switch (Q->ne[0]) {
|
switch (Q->ne[0]) {
|
||||||
case 64: {
|
case 64: {
|
||||||
constexpr int D = 64;
|
constexpr int D = 64;
|
||||||
constexpr int nwarps = 8;
|
constexpr int nwarps = 8;
|
||||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>;
|
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
|
||||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||||
} break;
|
} break;
|
||||||
case 128: {
|
case 128: {
|
||||||
constexpr int D = 128;
|
constexpr int D = 128;
|
||||||
constexpr int nwarps = 8;
|
constexpr int nwarps = 8;
|
||||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>;
|
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
|
||||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||||
} break;
|
} break;
|
||||||
default: {
|
default: {
|
||||||
@ -290,23 +301,45 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|||||||
}
|
}
|
||||||
|
|
||||||
void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
const ggml_tensor * KQV = dst;
|
||||||
const ggml_tensor * Q = dst->src[0];
|
const ggml_tensor * Q = dst->src[0];
|
||||||
|
|
||||||
|
float logit_softcap;
|
||||||
|
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
||||||
|
|
||||||
if (Q->ne[1] <= 16) {
|
if (Q->ne[1] <= 16) {
|
||||||
constexpr int cols_per_block = 16;
|
constexpr int cols_per_block = 16;
|
||||||
constexpr int parallel_blocks = 4;
|
constexpr int parallel_blocks = 4;
|
||||||
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
|
if (logit_softcap == 0.0f) {
|
||||||
|
constexpr bool use_logit_softcap = false;
|
||||||
|
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||||
|
} else {
|
||||||
|
constexpr bool use_logit_softcap = true;
|
||||||
|
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (Q->ne[1] <= 32) {
|
if (Q->ne[1] <= 32) {
|
||||||
constexpr int cols_per_block = 32;
|
constexpr int cols_per_block = 32;
|
||||||
constexpr int parallel_blocks = 4;
|
constexpr int parallel_blocks = 4;
|
||||||
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
|
if (logit_softcap == 0.0f) {
|
||||||
|
constexpr bool use_logit_softcap = false;
|
||||||
|
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||||
|
} else {
|
||||||
|
constexpr bool use_logit_softcap = true;
|
||||||
|
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
constexpr int cols_per_block = 32;
|
constexpr int cols_per_block = 32;
|
||||||
constexpr int parallel_blocks = 1;
|
constexpr int parallel_blocks = 1;
|
||||||
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
|
if (logit_softcap == 0.0f) {
|
||||||
|
constexpr bool use_logit_softcap = false;
|
||||||
|
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||||
|
} else {
|
||||||
|
constexpr bool use_logit_softcap = true;
|
||||||
|
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
#include "common.cuh"
|
#include "common.cuh"
|
||||||
#include "fattn-common.cuh"
|
#include "fattn-common.cuh"
|
||||||
|
|
||||||
template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V> // D == head size
|
template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
|
||||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||||
__launch_bounds__(D, 1)
|
__launch_bounds__(D, 1)
|
||||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||||
@ -17,6 +17,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
|||||||
const float m0,
|
const float m0,
|
||||||
const float m1,
|
const float m1,
|
||||||
const uint32_t n_head_log2,
|
const uint32_t n_head_log2,
|
||||||
|
const float logit_softcap,
|
||||||
const int ne00,
|
const int ne00,
|
||||||
const int ne01,
|
const int ne01,
|
||||||
const int ne02,
|
const int ne02,
|
||||||
@ -41,6 +42,12 @@ static __global__ void flash_attn_vec_ext_f16(
|
|||||||
const int ne2,
|
const int ne2,
|
||||||
const int ne3) {
|
const int ne3) {
|
||||||
#ifdef FP16_AVAILABLE
|
#ifdef FP16_AVAILABLE
|
||||||
|
// Skip unused kernel variants for faster compilation:
|
||||||
|
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||||
|
|
||||||
constexpr vec_dot_KQ_f16_t vec_dot_KQ = get_vec_dot_KQ_f16<D>(type_K);
|
constexpr vec_dot_KQ_f16_t vec_dot_KQ = get_vec_dot_KQ_f16<D>(type_K);
|
||||||
@ -190,6 +197,11 @@ static __global__ void flash_attn_vec_ext_f16(
|
|||||||
for (int j = 0; j < ncols; ++j) {
|
for (int j = 0; j < ncols; ++j) {
|
||||||
half sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_h2[j], Q_i32[j], Q_ds[j]);
|
half sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_h2[j], Q_i32[j], Q_ds[j]);
|
||||||
sum = warp_reduce_sum(sum);
|
sum = warp_reduce_sum(sum);
|
||||||
|
|
||||||
|
if (use_logit_softcap) {
|
||||||
|
sum = logit_softcap*tanhf(sum);
|
||||||
|
}
|
||||||
|
|
||||||
sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
|
sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
|
||||||
|
|
||||||
if (ncols == 1) {
|
if (ncols == 1) {
|
||||||
@ -286,10 +298,10 @@ static __global__ void flash_attn_vec_ext_f16(
|
|||||||
#endif // FP16_AVAILABLE
|
#endif // FP16_AVAILABLE
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V>
|
template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
|
||||||
void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
constexpr int nwarps = D/WARP_SIZE;
|
constexpr int nwarps = D/WARP_SIZE;
|
||||||
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks, type_K, type_V>;
|
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>;
|
||||||
constexpr bool need_f16_K = D != 128;
|
constexpr bool need_f16_K = D != 128;
|
||||||
constexpr bool need_f16_V = D != 128 && D != 64;
|
constexpr bool need_f16_V = D != 128 && D != 64;
|
||||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
|
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
|
||||||
@ -297,48 +309,81 @@ void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx,
|
|||||||
|
|
||||||
template <int D, ggml_type type_K, ggml_type type_V>
|
template <int D, ggml_type type_K, ggml_type type_V>
|
||||||
void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
ggml_tensor * KQV = dst;
|
const ggml_tensor * KQV = dst;
|
||||||
ggml_tensor * Q = dst->src[0];
|
const ggml_tensor * Q = dst->src[0];
|
||||||
ggml_tensor * K = dst->src[1];
|
const ggml_tensor * K = dst->src[1];
|
||||||
ggml_tensor * V = dst->src[2];
|
const ggml_tensor * V = dst->src[2];
|
||||||
|
|
||||||
const int32_t precision = KQV->op_params[2];
|
const int32_t precision = KQV->op_params[3];
|
||||||
GGML_ASSERT(precision == GGML_PREC_DEFAULT);
|
GGML_ASSERT(precision == GGML_PREC_DEFAULT);
|
||||||
|
|
||||||
GGML_ASSERT(K->type == type_K);
|
GGML_ASSERT(K->type == type_K);
|
||||||
GGML_ASSERT(V->type == type_V);
|
GGML_ASSERT(V->type == type_V);
|
||||||
|
|
||||||
|
float logit_softcap;
|
||||||
|
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
||||||
|
|
||||||
if (Q->ne[1] == 1) {
|
if (Q->ne[1] == 1) {
|
||||||
constexpr int cols_per_block = 1;
|
constexpr int cols_per_block = 1;
|
||||||
constexpr int parallel_blocks = 4;
|
constexpr int parallel_blocks = 4;
|
||||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
|
if (logit_softcap == 0.0f) {
|
||||||
|
constexpr bool use_logit_softcap = false;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
} else {
|
||||||
|
constexpr bool use_logit_softcap = true;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (Q->ne[1] == 2) {
|
if (Q->ne[1] == 2) {
|
||||||
constexpr int cols_per_block = 2;
|
constexpr int cols_per_block = 2;
|
||||||
constexpr int parallel_blocks = 4;
|
constexpr int parallel_blocks = 4;
|
||||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
|
if (logit_softcap == 0.0f) {
|
||||||
|
constexpr bool use_logit_softcap = false;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
} else {
|
||||||
|
constexpr bool use_logit_softcap = true;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (Q->ne[1] <= 4) {
|
if (Q->ne[1] <= 4) {
|
||||||
constexpr int cols_per_block = 4;
|
constexpr int cols_per_block = 4;
|
||||||
constexpr int parallel_blocks = 4;
|
constexpr int parallel_blocks = 4;
|
||||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
|
if (logit_softcap == 0.0f) {
|
||||||
|
constexpr bool use_logit_softcap = false;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
} else {
|
||||||
|
constexpr bool use_logit_softcap = true;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (Q->ne[1] <= 8) {
|
if (Q->ne[1] <= 8) {
|
||||||
constexpr int cols_per_block = 8;
|
constexpr int cols_per_block = 8;
|
||||||
constexpr int parallel_blocks = 4;
|
constexpr int parallel_blocks = 4;
|
||||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
|
if (logit_softcap == 0.0f) {
|
||||||
|
constexpr bool use_logit_softcap = false;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
} else {
|
||||||
|
constexpr bool use_logit_softcap = true;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
constexpr int cols_per_block = 8;
|
constexpr int cols_per_block = 8;
|
||||||
constexpr int parallel_blocks = 1;
|
constexpr int parallel_blocks = 1;
|
||||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
|
if (logit_softcap == 0.0f) {
|
||||||
|
constexpr bool use_logit_softcap = false;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
} else {
|
||||||
|
constexpr bool use_logit_softcap = true;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#define DECL_FATTN_VEC_F16_CASE(D, type_K, type_V) \
|
#define DECL_FATTN_VEC_F16_CASE(D, type_K, type_V) \
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
#include "common.cuh"
|
#include "common.cuh"
|
||||||
#include "fattn-common.cuh"
|
#include "fattn-common.cuh"
|
||||||
|
|
||||||
template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V> // D == head size
|
template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
|
||||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||||
__launch_bounds__(D, 1)
|
__launch_bounds__(D, 1)
|
||||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||||
@ -17,6 +17,7 @@ static __global__ void flash_attn_vec_ext_f32(
|
|||||||
const float m0,
|
const float m0,
|
||||||
const float m1,
|
const float m1,
|
||||||
const uint32_t n_head_log2,
|
const uint32_t n_head_log2,
|
||||||
|
const float logit_softcap,
|
||||||
const int ne00,
|
const int ne00,
|
||||||
const int ne01,
|
const int ne01,
|
||||||
const int ne02,
|
const int ne02,
|
||||||
@ -40,6 +41,12 @@ static __global__ void flash_attn_vec_ext_f32(
|
|||||||
const int ne1,
|
const int ne1,
|
||||||
const int ne2,
|
const int ne2,
|
||||||
const int ne3) {
|
const int ne3) {
|
||||||
|
// Skip unused kernel variants for faster compilation:
|
||||||
|
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||||
|
|
||||||
constexpr vec_dot_KQ_f32_t vec_dot_KQ = get_vec_dot_KQ_f32<D>(type_K);
|
constexpr vec_dot_KQ_f32_t vec_dot_KQ = get_vec_dot_KQ_f32<D>(type_K);
|
||||||
@ -180,6 +187,11 @@ static __global__ void flash_attn_vec_ext_f32(
|
|||||||
for (int j = 0; j < ncols; ++j) {
|
for (int j = 0; j < ncols; ++j) {
|
||||||
float sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_f2[j], Q_i32[j], Q_ds[j]);
|
float sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_f2[j], Q_i32[j], Q_ds[j]);
|
||||||
sum = warp_reduce_sum(sum);
|
sum = warp_reduce_sum(sum);
|
||||||
|
|
||||||
|
if (use_logit_softcap) {
|
||||||
|
sum = logit_softcap*tanhf(sum);
|
||||||
|
}
|
||||||
|
|
||||||
sum += mask ? slope*__half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
|
sum += mask ? slope*__half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
|
||||||
|
|
||||||
kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum);
|
kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum);
|
||||||
@ -267,10 +279,10 @@ static __global__ void flash_attn_vec_ext_f32(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V>
|
template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
|
||||||
void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
constexpr int nwarps = D/WARP_SIZE;
|
constexpr int nwarps = D/WARP_SIZE;
|
||||||
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks, type_K, type_V>;
|
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>;
|
||||||
constexpr bool need_f16_K = D != 128;
|
constexpr bool need_f16_K = D != 128;
|
||||||
constexpr bool need_f16_V = D != 128 && D != 64;
|
constexpr bool need_f16_V = D != 128 && D != 64;
|
||||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
|
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
|
||||||
@ -278,44 +290,78 @@ void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx,
|
|||||||
|
|
||||||
template <int D, ggml_type type_K, ggml_type type_V>
|
template <int D, ggml_type type_K, ggml_type type_V>
|
||||||
void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
ggml_tensor * Q = dst->src[0];
|
const ggml_tensor * KQV = dst;
|
||||||
ggml_tensor * K = dst->src[1];
|
const ggml_tensor * Q = dst->src[0];
|
||||||
ggml_tensor * V = dst->src[2];
|
const ggml_tensor * K = dst->src[1];
|
||||||
|
const ggml_tensor * V = dst->src[2];
|
||||||
|
|
||||||
GGML_ASSERT(K->type == type_K);
|
GGML_ASSERT(K->type == type_K);
|
||||||
GGML_ASSERT(V->type == type_V);
|
GGML_ASSERT(V->type == type_V);
|
||||||
|
|
||||||
|
float logit_softcap;
|
||||||
|
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
||||||
|
|
||||||
if (Q->ne[1] == 1) {
|
if (Q->ne[1] == 1) {
|
||||||
constexpr int cols_per_block = 1;
|
constexpr int cols_per_block = 1;
|
||||||
constexpr int parallel_blocks = 4;
|
constexpr int parallel_blocks = 4;
|
||||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
|
if (logit_softcap == 0.0f) {
|
||||||
|
constexpr bool use_logit_softcap = false;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
} else {
|
||||||
|
constexpr bool use_logit_softcap = true;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (Q->ne[1] == 2) {
|
if (Q->ne[1] == 2) {
|
||||||
constexpr int cols_per_block = 2;
|
constexpr int cols_per_block = 2;
|
||||||
constexpr int parallel_blocks = 4;
|
constexpr int parallel_blocks = 4;
|
||||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
|
if (logit_softcap == 0.0f) {
|
||||||
|
constexpr bool use_logit_softcap = false;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
} else {
|
||||||
|
constexpr bool use_logit_softcap = true;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (Q->ne[1] <= 4) {
|
if (Q->ne[1] <= 4) {
|
||||||
constexpr int cols_per_block = 4;
|
constexpr int cols_per_block = 4;
|
||||||
constexpr int parallel_blocks = 4;
|
constexpr int parallel_blocks = 4;
|
||||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
|
if (logit_softcap == 0.0f) {
|
||||||
|
constexpr bool use_logit_softcap = false;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
} else {
|
||||||
|
constexpr bool use_logit_softcap = true;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (Q->ne[1] <= 8) {
|
if (Q->ne[1] <= 8) {
|
||||||
constexpr int cols_per_block = 8;
|
constexpr int cols_per_block = 8;
|
||||||
constexpr int parallel_blocks = 4;
|
constexpr int parallel_blocks = 4;
|
||||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
|
if (logit_softcap == 0.0f) {
|
||||||
|
constexpr bool use_logit_softcap = false;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
} else {
|
||||||
|
constexpr bool use_logit_softcap = true;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
constexpr int cols_per_block = 8;
|
constexpr int cols_per_block = 8;
|
||||||
constexpr int parallel_blocks = 1;
|
constexpr int parallel_blocks = 1;
|
||||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
|
if (logit_softcap == 0.0f) {
|
||||||
|
constexpr bool use_logit_softcap = false;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
} else {
|
||||||
|
constexpr bool use_logit_softcap = true;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#define DECL_FATTN_VEC_F32_CASE(D, type_K, type_V) \
|
#define DECL_FATTN_VEC_F32_CASE(D, type_K, type_V) \
|
||||||
|
@ -6,7 +6,7 @@
|
|||||||
#endif // FP16_MMA_AVAILABLE
|
#endif // FP16_MMA_AVAILABLE
|
||||||
|
|
||||||
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
|
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
|
||||||
template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t>
|
template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t, bool use_logit_softcap>
|
||||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||||
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
||||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||||
@ -22,6 +22,7 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
const float m0,
|
const float m0,
|
||||||
const float m1,
|
const float m1,
|
||||||
const uint32_t n_head_log2,
|
const uint32_t n_head_log2,
|
||||||
|
const float logit_softcap,
|
||||||
const int ne00,
|
const int ne00,
|
||||||
const int ne01,
|
const int ne01,
|
||||||
const int ne02,
|
const int ne02,
|
||||||
@ -46,6 +47,12 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
const int ne2,
|
const int ne2,
|
||||||
const int ne3) {
|
const int ne3) {
|
||||||
#ifdef FP16_MMA_AVAILABLE
|
#ifdef FP16_MMA_AVAILABLE
|
||||||
|
// Skip unused kernel variants for faster compilation:
|
||||||
|
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||||
|
|
||||||
const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
|
const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
|
||||||
@ -85,6 +92,8 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
const half slopeh = __float2half(slopef);
|
const half slopeh = __float2half(slopef);
|
||||||
const half2 slope2 = make_half2(slopef, slopef);
|
const half2 slope2 = make_half2(slopef, slopef);
|
||||||
|
|
||||||
|
const half2 logit_softcap_2 = make_half2(logit_softcap, logit_softcap);
|
||||||
|
|
||||||
frag_b Q_b[D/16][ncols/frag_n];
|
frag_b Q_b[D/16][ncols/frag_n];
|
||||||
|
|
||||||
// A single buffer for temporarily holding tiles of KQ and VKQ parts:
|
// A single buffer for temporarily holding tiles of KQ and VKQ parts:
|
||||||
@ -194,6 +203,10 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
const int k = k0 + threadIdx.x;
|
const int k = k0 + threadIdx.x;
|
||||||
|
|
||||||
KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k];
|
KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k];
|
||||||
|
|
||||||
|
if (use_logit_softcap) {
|
||||||
|
KQ_f_tmp[k0/WARP_SIZE] = logit_softcap*tanhf(KQ_f_tmp[k0/WARP_SIZE]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
float KQ_max_new = KQ_max_f[j0/nwarps];
|
float KQ_max_new = KQ_max_f[j0/nwarps];
|
||||||
@ -237,6 +250,15 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
const int k = k0 + threadIdx.x;
|
const int k = k0 + threadIdx.x;
|
||||||
|
|
||||||
KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k];
|
KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k];
|
||||||
|
|
||||||
|
if (use_logit_softcap) {
|
||||||
|
// There is no dedicated tangens hyperbolicus function for half2.
|
||||||
|
KQ2_tmp[k0/WARP_SIZE] = h2exp(KQ2_tmp[k0/WARP_SIZE]*make_half2(2.0f, 2.0f));
|
||||||
|
KQ2_tmp[k0/WARP_SIZE] = (KQ2_tmp[k0/WARP_SIZE] - make_half2(1.0f, 1.0f))
|
||||||
|
/(KQ2_tmp[k0/WARP_SIZE] + make_half2(1.0f, 1.0f));
|
||||||
|
|
||||||
|
KQ2_tmp[k0/WARP_SIZE] *= logit_softcap_2;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
half2 KQ_max_new = KQ_max_h2[j0/nwarps];
|
half2 KQ_max_new = KQ_max_h2[j0/nwarps];
|
||||||
@ -427,6 +449,7 @@ static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed.");
|
|||||||
|
|
||||||
template <int D, int cols_per_block, typename KQ_acc_t>
|
template <int D, int cols_per_block, typename KQ_acc_t>
|
||||||
void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
const ggml_tensor * KQV = dst;
|
||||||
const ggml_tensor * Q = dst->src[0];
|
const ggml_tensor * Q = dst->src[0];
|
||||||
|
|
||||||
constexpr int nwarps = 4;
|
constexpr int nwarps = 4;
|
||||||
@ -435,20 +458,50 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
|
|||||||
const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
|
const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
|
||||||
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
|
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
|
||||||
|
|
||||||
|
float logit_softcap;
|
||||||
|
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
||||||
|
|
||||||
if (4*blocks_num_pb1 < 2*nsm) {
|
if (4*blocks_num_pb1 < 2*nsm) {
|
||||||
constexpr int parallel_blocks = 4;
|
constexpr int parallel_blocks = 4;
|
||||||
fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
|
fattn_kernel_t fattn_kernel;
|
||||||
|
if (logit_softcap == 0.0f) {
|
||||||
|
constexpr bool use_logit_softcap = false;
|
||||||
|
fattn_kernel = flash_attn_ext_f16<
|
||||||
|
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
||||||
|
} else {
|
||||||
|
constexpr bool use_logit_softcap = true;
|
||||||
|
fattn_kernel = flash_attn_ext_f16<
|
||||||
|
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
||||||
|
}
|
||||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (2*blocks_num_pb1 < 2*nsm) {
|
if (2*blocks_num_pb1 < 2*nsm) {
|
||||||
constexpr int parallel_blocks = 2;
|
constexpr int parallel_blocks = 2;
|
||||||
fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
|
fattn_kernel_t fattn_kernel;
|
||||||
|
if (logit_softcap == 0.0f) {
|
||||||
|
constexpr bool use_logit_softcap = false;
|
||||||
|
fattn_kernel = flash_attn_ext_f16<
|
||||||
|
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
||||||
|
} else {
|
||||||
|
constexpr bool use_logit_softcap = true;
|
||||||
|
fattn_kernel = flash_attn_ext_f16<
|
||||||
|
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
||||||
|
}
|
||||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
constexpr int parallel_blocks = 1;
|
constexpr int parallel_blocks = 1;
|
||||||
fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
|
fattn_kernel_t fattn_kernel;
|
||||||
|
if (logit_softcap == 0.0f) {
|
||||||
|
constexpr bool use_logit_softcap = false;
|
||||||
|
fattn_kernel = flash_attn_ext_f16<
|
||||||
|
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
||||||
|
} else {
|
||||||
|
constexpr bool use_logit_softcap = true;
|
||||||
|
fattn_kernel = flash_attn_ext_f16<
|
||||||
|
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
||||||
|
}
|
||||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -13,7 +13,7 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
|
|||||||
const ggml_tensor * KQV = dst;
|
const ggml_tensor * KQV = dst;
|
||||||
const ggml_tensor * Q = dst->src[0];
|
const ggml_tensor * Q = dst->src[0];
|
||||||
|
|
||||||
const int32_t precision = KQV->op_params[2];
|
const int32_t precision = KQV->op_params[3];
|
||||||
|
|
||||||
if (precision != GGML_PREC_DEFAULT) {
|
if (precision != GGML_PREC_DEFAULT) {
|
||||||
if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
|
if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
|
||||||
@ -301,7 +301,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|||||||
|
|
||||||
ggml_cuda_set_device(ctx.device);
|
ggml_cuda_set_device(ctx.device);
|
||||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||||
const int32_t precision = KQV->op_params[2];
|
const int32_t precision = KQV->op_params[3];
|
||||||
|
|
||||||
// On AMD the tile kernels perform poorly, use the vec kernel instead:
|
// On AMD the tile kernels perform poorly, use the vec kernel instead:
|
||||||
if (cc >= CC_OFFSET_AMD) {
|
if (cc >= CC_OFFSET_AMD) {
|
||||||
|
@ -802,6 +802,15 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
|
|||||||
if (op->src[0]->ne[0] == 256) {
|
if (op->src[0]->ne[0] == 256) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
{
|
||||||
|
float logit_softcap;
|
||||||
|
|
||||||
|
memcpy(&logit_softcap, ((const float *) op->op_params) + 2, sizeof(logit_softcap));
|
||||||
|
|
||||||
|
if (logit_softcap != 0.0f) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
||||||
case GGML_OP_MUL_MAT:
|
case GGML_OP_MUL_MAT:
|
||||||
case GGML_OP_MUL_MAT_ID:
|
case GGML_OP_MUL_MAT_ID:
|
||||||
|
@ -7095,7 +7095,8 @@ struct ggml_tensor * ggml_flash_attn_ext(
|
|||||||
struct ggml_tensor * v,
|
struct ggml_tensor * v,
|
||||||
struct ggml_tensor * mask,
|
struct ggml_tensor * mask,
|
||||||
float scale,
|
float scale,
|
||||||
float max_bias) {
|
float max_bias,
|
||||||
|
float logit_softcap) {
|
||||||
GGML_ASSERT(ggml_can_mul_mat(k, q));
|
GGML_ASSERT(ggml_can_mul_mat(k, q));
|
||||||
// TODO: check if vT can be multiplied by (k*qT)
|
// TODO: check if vT can be multiplied by (k*qT)
|
||||||
|
|
||||||
@ -7122,7 +7123,7 @@ struct ggml_tensor * ggml_flash_attn_ext(
|
|||||||
int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
|
int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
|
||||||
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
||||||
|
|
||||||
float params[] = { scale, max_bias };
|
float params[] = { scale, max_bias, logit_softcap };
|
||||||
ggml_set_op_params(result, params, sizeof(params));
|
ggml_set_op_params(result, params, sizeof(params));
|
||||||
|
|
||||||
result->op = GGML_OP_FLASH_ATTN_EXT;
|
result->op = GGML_OP_FLASH_ATTN_EXT;
|
||||||
@ -7142,7 +7143,7 @@ void ggml_flash_attn_ext_set_prec(
|
|||||||
|
|
||||||
const int32_t prec_i32 = (int32_t) prec;
|
const int32_t prec_i32 = (int32_t) prec;
|
||||||
|
|
||||||
ggml_set_op_params_i32(a, 2, prec_i32); // scale is on first pos, max_bias on second
|
ggml_set_op_params_i32(a, 3, prec_i32); // scale is on first pos, max_bias on second
|
||||||
}
|
}
|
||||||
|
|
||||||
// ggml_flash_attn_back
|
// ggml_flash_attn_back
|
||||||
@ -15273,9 +15274,15 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|||||||
|
|
||||||
float scale = 1.0f;
|
float scale = 1.0f;
|
||||||
float max_bias = 0.0f;
|
float max_bias = 0.0f;
|
||||||
|
float logit_softcap = 0.0f;
|
||||||
|
|
||||||
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
||||||
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
|
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
|
||||||
|
memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));
|
||||||
|
|
||||||
|
if (logit_softcap != 0) {
|
||||||
|
scale /= logit_softcap;
|
||||||
|
}
|
||||||
|
|
||||||
const uint32_t n_head = neq2;
|
const uint32_t n_head = neq2;
|
||||||
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
|
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
|
||||||
@ -15339,7 +15346,13 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|||||||
const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
|
const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
|
||||||
kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1);
|
kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1);
|
||||||
|
|
||||||
s = s*scale + mv; // scale KQ value and apply mask
|
s = s*scale; // scale KQ value
|
||||||
|
|
||||||
|
if (logit_softcap != 0.0f) {
|
||||||
|
s = logit_softcap*tanhf(s);
|
||||||
|
}
|
||||||
|
|
||||||
|
s += mv; // apply mask
|
||||||
|
|
||||||
const float Mold = M;
|
const float Mold = M;
|
||||||
|
|
||||||
@ -15348,7 +15361,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|||||||
|
|
||||||
const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
|
const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
|
||||||
|
|
||||||
if (v->type== GGML_TYPE_F16) {
|
if (v->type == GGML_TYPE_F16) {
|
||||||
if (s > M) {
|
if (s > M) {
|
||||||
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
|
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
|
||||||
M = s;
|
M = s;
|
||||||
@ -15415,7 +15428,7 @@ static void ggml_compute_forward_flash_attn_ext(
|
|||||||
const struct ggml_tensor * v,
|
const struct ggml_tensor * v,
|
||||||
const struct ggml_tensor * mask,
|
const struct ggml_tensor * mask,
|
||||||
struct ggml_tensor * dst) {
|
struct ggml_tensor * dst) {
|
||||||
switch (dst->op_params[2]) {
|
switch (dst->op_params[3]) {
|
||||||
case GGML_PREC_DEFAULT:
|
case GGML_PREC_DEFAULT:
|
||||||
case GGML_PREC_F32:
|
case GGML_PREC_F32:
|
||||||
{
|
{
|
||||||
|
@ -8874,7 +8874,8 @@ static struct ggml_tensor * llm_build_kqv(
|
|||||||
0);
|
0);
|
||||||
cb(v, "v", il);
|
cb(v, "v", il);
|
||||||
|
|
||||||
cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias);
|
cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
|
||||||
|
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
|
||||||
|
|
||||||
if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) {
|
if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) {
|
||||||
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
|
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
|
||||||
@ -17533,12 +17534,6 @@ struct llama_context * llama_new_context_with_model(
|
|||||||
params.flash_attn = false;
|
params.flash_attn = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (params.flash_attn && model->hparams.attn_soft_cap) {
|
|
||||||
LLAMA_LOG_WARN("%s: flash_attn is not compatible with attn_soft_cap - forcing off\n", __func__);
|
|
||||||
params.flash_attn = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) {
|
if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) {
|
||||||
LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__);
|
LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__);
|
||||||
params.flash_attn = false;
|
params.flash_attn = false;
|
||||||
|
@ -1652,19 +1652,20 @@ struct test_flash_attn_ext : public test_case {
|
|||||||
const bool mask; // use mask
|
const bool mask; // use mask
|
||||||
|
|
||||||
const float max_bias; // ALiBi
|
const float max_bias; // ALiBi
|
||||||
|
const float logit_softcap; // Gemma 2
|
||||||
|
|
||||||
const ggml_type type_KV;
|
const ggml_type type_KV;
|
||||||
|
|
||||||
std::string vars() override {
|
std::string vars() override {
|
||||||
return VARS_TO_STR7(hs, nh, kv, nb, mask, max_bias, type_KV);
|
return VARS_TO_STR8(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV);
|
||||||
}
|
}
|
||||||
|
|
||||||
double max_nmse_err() override {
|
double max_nmse_err() override {
|
||||||
return 5e-4;
|
return 5e-4;
|
||||||
}
|
}
|
||||||
|
|
||||||
test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8, bool mask = true, float max_bias = 0.0f, ggml_type type_KV = GGML_TYPE_F16)
|
test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8, bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_type type_KV = GGML_TYPE_F16)
|
||||||
: hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), type_KV(type_KV) {}
|
: hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), type_KV(type_KV) {}
|
||||||
|
|
||||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||||
const int64_t hs_padded = GGML_PAD(hs, ggml_blck_size(type_KV));
|
const int64_t hs_padded = GGML_PAD(hs, ggml_blck_size(type_KV));
|
||||||
@ -1673,7 +1674,7 @@ struct test_flash_attn_ext : public test_case {
|
|||||||
ggml_tensor * k = ggml_new_tensor_4d(ctx, type_KV, hs_padded, kv, nh, 1);
|
ggml_tensor * k = ggml_new_tensor_4d(ctx, type_KV, hs_padded, kv, nh, 1);
|
||||||
ggml_tensor * v = ggml_new_tensor_4d(ctx, type_KV, hs_padded, kv, nh, 1);
|
ggml_tensor * v = ggml_new_tensor_4d(ctx, type_KV, hs_padded, kv, nh, 1);
|
||||||
ggml_tensor * m = mask ? ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1) : nullptr;
|
ggml_tensor * m = mask ? ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1) : nullptr;
|
||||||
ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hs), max_bias);
|
ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hs), max_bias, logit_softcap);
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -2437,11 +2438,14 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
|||||||
for (bool mask : { true, false } ) {
|
for (bool mask : { true, false } ) {
|
||||||
for (float max_bias : { 0.0f, 8.0f }) {
|
for (float max_bias : { 0.0f, 8.0f }) {
|
||||||
if (!mask && max_bias > 0.0f) continue;
|
if (!mask && max_bias > 0.0f) continue;
|
||||||
|
for (float logit_softcap : {0.0f, 10.0f}) {
|
||||||
|
if (hs != 128 && logit_softcap != 0.0f) continue;
|
||||||
for (int nh : { 32, }) {
|
for (int nh : { 32, }) {
|
||||||
for (int kv : { 512, 1024, }) {
|
for (int kv : { 512, 1024, }) {
|
||||||
for (int nb : { 1, 2, 4, 8, }) {
|
for (int nb : { 1, 2, 4, 8, }) {
|
||||||
for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
|
for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
|
||||||
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, type_KV));
|
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user