#define FATTN_KQ_STRIDE 256 #define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction. #define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs. template // D == head size #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(D, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) static __global__ void flash_attn_combine_results( const float * __restrict__ VKQ_parts, const float2 * __restrict__ VKQ_meta, float * __restrict__ dst) { VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x; VKQ_meta += parallel_blocks * gridDim.y*blockIdx.x; dst += D * gridDim.y*blockIdx.x; const int tid = threadIdx.x; __builtin_assume(tid < D); __shared__ float2 meta[parallel_blocks]; if (tid < 2*parallel_blocks) { ((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.y*(2*parallel_blocks) + tid]; } __syncthreads(); float kqmax = meta[0].x; #pragma unroll for (int l = 1; l < parallel_blocks; ++l) { kqmax = max(kqmax, meta[l].x); } float VKQ_numerator = 0.0f; float VKQ_denominator = 0.0f; #pragma unroll for (int l = 0; l < parallel_blocks; ++l) { const float diff = meta[l].x - kqmax; const float KQ_max_scale = expf(diff); const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD); *((uint32_t *) &KQ_max_scale) &= ftz_mask; VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid]; VKQ_denominator += KQ_max_scale * meta[l].y; } dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator; }