diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index b8a21a2cc..a7e9dc9b2 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1760,7 +1760,8 @@ extern "C" { struct ggml_tensor * v, struct ggml_tensor * mask, float scale, - float max_bias); + float max_bias, + float logit_softcap); GGML_API void ggml_flash_attn_ext_set_prec( struct ggml_tensor * a, diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 950fd93df..1fb5c09c3 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -22,6 +22,7 @@ typedef void (* fattn_kernel_t)( const float m0, const float m1, const uint32_t n_head_log2, + const float logit_softcap, const int ne00, const int ne01, const int ne02, @@ -657,11 +658,17 @@ void launch_fattn( const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]); const int shmem = 0; - float scale = 1.0f; - float max_bias = 0.0f; + float scale = 1.0f; + float max_bias = 0.0f; + float logit_softcap = 0.0f; - memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float)); - memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float)); + memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float)); + memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float)); + memcpy(&logit_softcap, (float *) KQV->op_params + 2, sizeof(float)); + + if (logit_softcap != 0.0f) { + scale /= logit_softcap; + } const uint32_t n_head = Q->ne[2]; const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); @@ -675,7 +682,7 @@ void launch_fattn( V_data, mask ? ((const char *) mask->data) : nullptr, (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, - scale, max_bias, m0, m1, n_head_log2, + scale, max_bias, m0, m1, n_head_log2, logit_softcap, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cu b/ggml/src/ggml-cuda/fattn-tile-f16.cu index 1b2fd500b..342f2eb66 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f16.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f16.cu @@ -4,7 +4,7 @@ #define FATTN_KQ_STRIDE_TILE_F16 64 -template // D == head size +template // D == head size #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(nwarps*WARP_SIZE, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) @@ -20,6 +20,7 @@ static __global__ void flash_attn_tile_ext_f16( const float m0, const float m1, const uint32_t n_head_log2, + const float logit_softcap, const int ne00, const int ne01, const int ne02, @@ -44,6 +45,12 @@ static __global__ void flash_attn_tile_ext_f16( const int ne2, const int ne3) { #ifdef FP16_AVAILABLE + // Skip unused kernel variants for faster compilation: + if (use_logit_softcap && !(D == 128 || D == 256)) { + NO_DEVICE_CODE; + return; + } + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on. @@ -154,7 +161,13 @@ static __global__ void flash_attn_tile_ext_f16( for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) { const int j_KQ = j_KQ_0 + threadIdx.y; - half sum = __low2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]) + __high2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]); + half sum; + if (use_logit_softcap) { + const float2 tmp = __half22float2(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]); + sum = logit_softcap * tanhf(tmp.x + tmp.y); + } else { + sum = __low2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]) + __high2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]); + } sum += mask ? slopeh*maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f); kqmax_new[j_KQ_0/nwarps] = ggml_cuda_hmax(kqmax_new[j_KQ_0/nwarps], sum); @@ -270,20 +283,20 @@ static __global__ void flash_attn_tile_ext_f16( #endif // FP16_AVAILABLE } -template +template void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; switch (Q->ne[0]) { case 64: { constexpr int D = 64; constexpr int nwarps = 8; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; + fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } break; case 128: { constexpr int D = 128; constexpr int nwarps = 8; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; + fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } break; default: { @@ -296,24 +309,45 @@ void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_ten const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; - const int32_t precision = KQV->op_params[2]; + const int32_t precision = KQV->op_params[3]; GGML_ASSERT(precision == GGML_PREC_DEFAULT); + float logit_softcap; + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); + if (Q->ne[1] <= 16) { constexpr int cols_per_block = 16; constexpr int parallel_blocks = 4; - launch_fattn_tile_f16_64_128(ctx, dst); + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + launch_fattn_tile_f16_64_128(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + launch_fattn_tile_f16_64_128(ctx, dst); + } return; } if (Q->ne[1] <= 32) { constexpr int cols_per_block = 32; constexpr int parallel_blocks = 4; - launch_fattn_tile_f16_64_128(ctx, dst); + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + launch_fattn_tile_f16_64_128(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + launch_fattn_tile_f16_64_128(ctx, dst); + } return; } constexpr int cols_per_block = 32; constexpr int parallel_blocks = 1; - launch_fattn_tile_f16_64_128(ctx, dst); + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + launch_fattn_tile_f16_64_128(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + launch_fattn_tile_f16_64_128(ctx, dst); + } } diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu index f3e68dbfa..827437ca0 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f32.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f32.cu @@ -4,7 +4,7 @@ #define FATTN_KQ_STRIDE_TILE_F32 32 -template // D == head size +template // D == head size #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(nwarps*WARP_SIZE, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) @@ -20,6 +20,7 @@ static __global__ void flash_attn_tile_ext_f32( const float m0, const float m1, const uint32_t n_head_log2, + const float logit_softcap, const int ne00, const int ne01, const int ne02, @@ -43,6 +44,12 @@ static __global__ void flash_attn_tile_ext_f32( const int ne1, const int ne2, const int ne3) { + // Skip unused kernel variants for faster compilation: + if (use_logit_softcap && !(D == 128 || D == 256)) { + NO_DEVICE_CODE; + return; + } + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on. @@ -151,6 +158,10 @@ static __global__ void flash_attn_tile_ext_f32( for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) { const int j_KQ = j_KQ_0 + threadIdx.y; + if (use_logit_softcap) { + sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] = logit_softcap * tanhf(sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]); + } + sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f; kqmax_new[j_KQ_0/nwarps] = fmaxf(kqmax_new[j_KQ_0/nwarps], sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]); @@ -267,20 +278,20 @@ static __global__ void flash_attn_tile_ext_f32( } } -template +template void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; switch (Q->ne[0]) { case 64: { constexpr int D = 64; constexpr int nwarps = 8; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; + fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } break; case 128: { constexpr int D = 128; constexpr int nwarps = 8; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; + fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } break; default: { @@ -290,23 +301,45 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * } void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; + float logit_softcap; + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); + if (Q->ne[1] <= 16) { constexpr int cols_per_block = 16; constexpr int parallel_blocks = 4; - launch_fattn_tile_f32_64_128(ctx, dst); + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + launch_fattn_tile_f32_64_128(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + launch_fattn_tile_f32_64_128(ctx, dst); + } return; } if (Q->ne[1] <= 32) { constexpr int cols_per_block = 32; constexpr int parallel_blocks = 4; - launch_fattn_tile_f32_64_128(ctx, dst); + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + launch_fattn_tile_f32_64_128(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + launch_fattn_tile_f32_64_128(ctx, dst); + } return; } constexpr int cols_per_block = 32; constexpr int parallel_blocks = 1; - launch_fattn_tile_f32_64_128(ctx, dst); + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + launch_fattn_tile_f32_64_128(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + launch_fattn_tile_f32_64_128(ctx, dst); + } } diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index 02a4ad072..448a9a905 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -1,7 +1,7 @@ #include "common.cuh" #include "fattn-common.cuh" -template // D == head size +template // D == head size #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(D, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) @@ -17,6 +17,7 @@ static __global__ void flash_attn_vec_ext_f16( const float m0, const float m1, const uint32_t n_head_log2, + const float logit_softcap, const int ne00, const int ne01, const int ne02, @@ -41,6 +42,12 @@ static __global__ void flash_attn_vec_ext_f16( const int ne2, const int ne3) { #ifdef FP16_AVAILABLE + // Skip unused kernel variants for faster compilation: + if (use_logit_softcap && !(D == 128 || D == 256)) { + NO_DEVICE_CODE; + return; + } + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. constexpr vec_dot_KQ_f16_t vec_dot_KQ = get_vec_dot_KQ_f16(type_K); @@ -190,6 +197,11 @@ static __global__ void flash_attn_vec_ext_f16( for (int j = 0; j < ncols; ++j) { half sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_h2[j], Q_i32[j], Q_ds[j]); sum = warp_reduce_sum(sum); + + if (use_logit_softcap) { + sum = logit_softcap*tanhf(sum); + } + sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f); if (ncols == 1) { @@ -286,10 +298,10 @@ static __global__ void flash_attn_vec_ext_f16( #endif // FP16_AVAILABLE } -template +template void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { constexpr int nwarps = D/WARP_SIZE; - fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16; + fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16; constexpr bool need_f16_K = D != 128; constexpr bool need_f16_V = D != 128 && D != 64; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V); @@ -297,48 +309,81 @@ void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, template void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_tensor * KQV = dst; - ggml_tensor * Q = dst->src[0]; - ggml_tensor * K = dst->src[1]; - ggml_tensor * V = dst->src[2]; + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; - const int32_t precision = KQV->op_params[2]; + const int32_t precision = KQV->op_params[3]; GGML_ASSERT(precision == GGML_PREC_DEFAULT); GGML_ASSERT(K->type == type_K); GGML_ASSERT(V->type == type_V); + float logit_softcap; + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); + if (Q->ne[1] == 1) { constexpr int cols_per_block = 1; constexpr int parallel_blocks = 4; - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + } return; } if (Q->ne[1] == 2) { constexpr int cols_per_block = 2; constexpr int parallel_blocks = 4; - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + } return; } if (Q->ne[1] <= 4) { constexpr int cols_per_block = 4; constexpr int parallel_blocks = 4; - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + } return; } if (Q->ne[1] <= 8) { constexpr int cols_per_block = 8; constexpr int parallel_blocks = 4; - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + } return; } constexpr int cols_per_block = 8; constexpr int parallel_blocks = 1; - ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + } } #define DECL_FATTN_VEC_F16_CASE(D, type_K, type_V) \ diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index 11a5e355f..bf5125902 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -1,7 +1,7 @@ #include "common.cuh" #include "fattn-common.cuh" -template // D == head size +template // D == head size #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(D, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) @@ -17,6 +17,7 @@ static __global__ void flash_attn_vec_ext_f32( const float m0, const float m1, const uint32_t n_head_log2, + const float logit_softcap, const int ne00, const int ne01, const int ne02, @@ -40,6 +41,12 @@ static __global__ void flash_attn_vec_ext_f32( const int ne1, const int ne2, const int ne3) { + // Skip unused kernel variants for faster compilation: + if (use_logit_softcap && !(D == 128 || D == 256)) { + NO_DEVICE_CODE; + return; + } + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. constexpr vec_dot_KQ_f32_t vec_dot_KQ = get_vec_dot_KQ_f32(type_K); @@ -180,6 +187,11 @@ static __global__ void flash_attn_vec_ext_f32( for (int j = 0; j < ncols; ++j) { float sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_f2[j], Q_i32[j], Q_ds[j]); sum = warp_reduce_sum(sum); + + if (use_logit_softcap) { + sum = logit_softcap*tanhf(sum); + } + sum += mask ? slope*__half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f; kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum); @@ -267,10 +279,10 @@ static __global__ void flash_attn_vec_ext_f32( } } -template +template void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { constexpr int nwarps = D/WARP_SIZE; - fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32; + fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32; constexpr bool need_f16_K = D != 128; constexpr bool need_f16_V = D != 128 && D != 64; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V); @@ -278,44 +290,78 @@ void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, template void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_tensor * Q = dst->src[0]; - ggml_tensor * K = dst->src[1]; - ggml_tensor * V = dst->src[2]; + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; GGML_ASSERT(K->type == type_K); GGML_ASSERT(V->type == type_V); + float logit_softcap; + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); + if (Q->ne[1] == 1) { constexpr int cols_per_block = 1; constexpr int parallel_blocks = 4; - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + } return; } if (Q->ne[1] == 2) { constexpr int cols_per_block = 2; constexpr int parallel_blocks = 4; - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + } return; } if (Q->ne[1] <= 4) { constexpr int cols_per_block = 4; constexpr int parallel_blocks = 4; - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + } return; } if (Q->ne[1] <= 8) { constexpr int cols_per_block = 8; constexpr int parallel_blocks = 4; - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + } return; } constexpr int cols_per_block = 8; constexpr int parallel_blocks = 1; - ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + } } #define DECL_FATTN_VEC_F32_CASE(D, type_K, type_V) \ diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh index ae2322242..b10d19d93 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh @@ -6,7 +6,7 @@ #endif // FP16_MMA_AVAILABLE // D == head size, VKQ_stride == num VKQ rows calculated in parallel: -template +template #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(nwarps*WARP_SIZE, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) @@ -22,6 +22,7 @@ static __global__ void flash_attn_ext_f16( const float m0, const float m1, const uint32_t n_head_log2, + const float logit_softcap, const int ne00, const int ne01, const int ne02, @@ -46,6 +47,12 @@ static __global__ void flash_attn_ext_f16( const int ne2, const int ne3) { #ifdef FP16_MMA_AVAILABLE + // Skip unused kernel variants for faster compilation: + if (use_logit_softcap && !(D == 128 || D == 256)) { + NO_DEVICE_CODE; + return; + } + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on. @@ -85,6 +92,8 @@ static __global__ void flash_attn_ext_f16( const half slopeh = __float2half(slopef); const half2 slope2 = make_half2(slopef, slopef); + const half2 logit_softcap_2 = make_half2(logit_softcap, logit_softcap); + frag_b Q_b[D/16][ncols/frag_n]; // A single buffer for temporarily holding tiles of KQ and VKQ parts: @@ -194,6 +203,10 @@ static __global__ void flash_attn_ext_f16( const int k = k0 + threadIdx.x; KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k]; + + if (use_logit_softcap) { + KQ_f_tmp[k0/WARP_SIZE] = logit_softcap*tanhf(KQ_f_tmp[k0/WARP_SIZE]); + } } float KQ_max_new = KQ_max_f[j0/nwarps]; @@ -237,6 +250,15 @@ static __global__ void flash_attn_ext_f16( const int k = k0 + threadIdx.x; KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k]; + + if (use_logit_softcap) { + // There is no dedicated tangens hyperbolicus function for half2. + KQ2_tmp[k0/WARP_SIZE] = h2exp(KQ2_tmp[k0/WARP_SIZE]*make_half2(2.0f, 2.0f)); + KQ2_tmp[k0/WARP_SIZE] = (KQ2_tmp[k0/WARP_SIZE] - make_half2(1.0f, 1.0f)) + /(KQ2_tmp[k0/WARP_SIZE] + make_half2(1.0f, 1.0f)); + + KQ2_tmp[k0/WARP_SIZE] *= logit_softcap_2; + } } half2 KQ_max_new = KQ_max_h2[j0/nwarps]; @@ -427,7 +449,8 @@ static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed."); template void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; constexpr int nwarps = 4; @@ -435,20 +458,50 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3]; const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; + float logit_softcap; + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); + if (4*blocks_num_pb1 < 2*nsm) { constexpr int parallel_blocks = 4; - fattn_kernel_t fattn_kernel = flash_attn_ext_f16; + fattn_kernel_t fattn_kernel; + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + fattn_kernel = flash_attn_ext_f16< + D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; + } else { + constexpr bool use_logit_softcap = true; + fattn_kernel = flash_attn_ext_f16< + D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; + } launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); return; } if (2*blocks_num_pb1 < 2*nsm) { constexpr int parallel_blocks = 2; - fattn_kernel_t fattn_kernel = flash_attn_ext_f16; + fattn_kernel_t fattn_kernel; + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + fattn_kernel = flash_attn_ext_f16< + D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; + } else { + constexpr bool use_logit_softcap = true; + fattn_kernel = flash_attn_ext_f16< + D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; + } launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); return; } constexpr int parallel_blocks = 1; - fattn_kernel_t fattn_kernel = flash_attn_ext_f16; + fattn_kernel_t fattn_kernel; + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + fattn_kernel = flash_attn_ext_f16< + D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; + } else { + constexpr bool use_logit_softcap = true; + fattn_kernel = flash_attn_ext_f16< + D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; + } launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 29f608b0f..f87f33b3e 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -13,7 +13,7 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; - const int32_t precision = KQV->op_params[2]; + const int32_t precision = KQV->op_params[3]; if (precision != GGML_PREC_DEFAULT) { if (Q->ne[1] <= 32 || Q->ne[0] > 128) { @@ -301,7 +301,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst ggml_cuda_set_device(ctx.device); const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; - const int32_t precision = KQV->op_params[2]; + const int32_t precision = KQV->op_params[3]; // On AMD the tile kernels perform poorly, use the vec kernel instead: if (cc >= CC_OFFSET_AMD) { diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 995f1934b..32b2e9c31 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -802,6 +802,15 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx if (op->src[0]->ne[0] == 256) { return false; } + { + float logit_softcap; + + memcpy(&logit_softcap, ((const float *) op->op_params) + 2, sizeof(logit_softcap)); + + if (logit_softcap != 0.0f) { + return false; + } + } return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index d63c917a5..d0a991503 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -7095,7 +7095,8 @@ struct ggml_tensor * ggml_flash_attn_ext( struct ggml_tensor * v, struct ggml_tensor * mask, float scale, - float max_bias) { + float max_bias, + float logit_softcap) { GGML_ASSERT(ggml_can_mul_mat(k, q)); // TODO: check if vT can be multiplied by (k*qT) @@ -7122,7 +7123,7 @@ struct ggml_tensor * ggml_flash_attn_ext( int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); - float params[] = { scale, max_bias }; + float params[] = { scale, max_bias, logit_softcap }; ggml_set_op_params(result, params, sizeof(params)); result->op = GGML_OP_FLASH_ATTN_EXT; @@ -7142,7 +7143,7 @@ void ggml_flash_attn_ext_set_prec( const int32_t prec_i32 = (int32_t) prec; - ggml_set_op_params_i32(a, 2, prec_i32); // scale is on first pos, max_bias on second + ggml_set_op_params_i32(a, 3, prec_i32); // scale is on first pos, max_bias on second } // ggml_flash_attn_back @@ -15271,11 +15272,17 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - float scale = 1.0f; - float max_bias = 0.0f; + float scale = 1.0f; + float max_bias = 0.0f; + float logit_softcap = 0.0f; - memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); - memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); + memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); + memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float)); + + if (logit_softcap != 0) { + scale /= logit_softcap; + } const uint32_t n_head = neq2; const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); @@ -15339,7 +15346,13 @@ static void ggml_compute_forward_flash_attn_ext_f16( const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3); kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1); - s = s*scale + mv; // scale KQ value and apply mask + s = s*scale; // scale KQ value + + if (logit_softcap != 0.0f) { + s = logit_softcap*tanhf(s); + } + + s += mv; // apply mask const float Mold = M; @@ -15348,7 +15361,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3)); - if (v->type== GGML_TYPE_F16) { + if (v->type == GGML_TYPE_F16) { if (s > M) { // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f M = s; @@ -15415,7 +15428,7 @@ static void ggml_compute_forward_flash_attn_ext( const struct ggml_tensor * v, const struct ggml_tensor * mask, struct ggml_tensor * dst) { - switch (dst->op_params[2]) { + switch (dst->op_params[3]) { case GGML_PREC_DEFAULT: case GGML_PREC_F32: { diff --git a/src/llama.cpp b/src/llama.cpp index 869b584aa..aeea54cff 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -8874,7 +8874,8 @@ static struct ggml_tensor * llm_build_kqv( 0); cb(v, "v", il); - cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias); + cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias, + hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f); if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) { ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); @@ -17533,12 +17534,6 @@ struct llama_context * llama_new_context_with_model( params.flash_attn = false; } - if (params.flash_attn && model->hparams.attn_soft_cap) { - LLAMA_LOG_WARN("%s: flash_attn is not compatible with attn_soft_cap - forcing off\n", __func__); - params.flash_attn = false; - } - - if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) { LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__); params.flash_attn = false; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index e2b720844..351b1d567 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1652,19 +1652,20 @@ struct test_flash_attn_ext : public test_case { const bool mask; // use mask const float max_bias; // ALiBi + const float logit_softcap; // Gemma 2 const ggml_type type_KV; std::string vars() override { - return VARS_TO_STR7(hs, nh, kv, nb, mask, max_bias, type_KV); + return VARS_TO_STR8(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV); } double max_nmse_err() override { return 5e-4; } - test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8, bool mask = true, float max_bias = 0.0f, ggml_type type_KV = GGML_TYPE_F16) - : hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), type_KV(type_KV) {} + test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8, bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_type type_KV = GGML_TYPE_F16) + : hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), type_KV(type_KV) {} ggml_tensor * build_graph(ggml_context * ctx) override { const int64_t hs_padded = GGML_PAD(hs, ggml_blck_size(type_KV)); @@ -1673,7 +1674,7 @@ struct test_flash_attn_ext : public test_case { ggml_tensor * k = ggml_new_tensor_4d(ctx, type_KV, hs_padded, kv, nh, 1); ggml_tensor * v = ggml_new_tensor_4d(ctx, type_KV, hs_padded, kv, nh, 1); ggml_tensor * m = mask ? ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1) : nullptr; - ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hs), max_bias); + ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hs), max_bias, logit_softcap); return out; } }; @@ -2437,11 +2438,14 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op for (bool mask : { true, false } ) { for (float max_bias : { 0.0f, 8.0f }) { if (!mask && max_bias > 0.0f) continue; - for (int nh : { 32, }) { - for (int kv : { 512, 1024, }) { - for (int nb : { 1, 2, 4, 8, }) { - for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) { - test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, type_KV)); + for (float logit_softcap : {0.0f, 10.0f}) { + if (hs != 128 && logit_softcap != 0.0f) continue; + for (int nh : { 32, }) { + for (int kv : { 512, 1024, }) { + for (int nb : { 1, 2, 4, 8, }) { + for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) { + test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV)); + } } } }