cuda : avoid warp_reduce for smax

This commit is contained in:
Georgi Gerganov 2024-02-03 13:17:47 +02:00
parent b68a112204
commit b150abe83e
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -6621,7 +6621,6 @@ static __global__ void flash_attn_ext_f16(
M[j] = __hmax(M[j], s);
}
smax = warp_reduce_max(smax);
M[j] = warp_reduce_max(M[j]);
const half ms = __hisinf(m) == -1 ? __float2half(0.0f) : hexp(m - M[j]);
@ -6649,6 +6648,8 @@ static __global__ void flash_attn_ext_f16(
}
}
smax = warp_reduce_max(smax);
// skip -INF blocks
if (__hisinf(smax) == -1) {
continue;