cuda : avoid __hisinf branches

This commit is contained in:
Georgi Gerganov 2024-02-03 14:27:36 +02:00
parent 92472ea22c
commit c51f27c0db
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -6513,9 +6513,9 @@ static __global__ void flash_attn_ext_f16(
half S[Q];
half M[Q];
for(int i = 0; i < Q; i++) {
for (int i = 0; i < Q; ++i) {
S[i] = __float2half(0.0f);
M[i] = __float2half(-INFINITY);
M[i] = CUDART_MIN_DENORM_FP16;
}
// assume K and V are same shape
@ -6609,30 +6609,6 @@ static __global__ void flash_attn_ext_f16(
half smax = __float2half(-INFINITY);
// online softmax
if (C == 32) {
for (int j = 0; j < Q; ++j) {
const int p = lane_id;
const half m = M[j];
const half s = ss[j*T + p];
smax = warp_reduce_max(__hmax(smax, s));
M[j] = warp_reduce_max(__hmax(M[j], s));
const half ms = __hisinf(m) == -1 ? __float2half(0.0f) : hexp(m - M[j]);
const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]);
S[j] = S[j]*ms + warp_reduce_sum(vs);
// create a QxQ diagonal matrix for rescaling the output
if (p == j) {
ss[j*T + C + j] = ms;
}
// the P matrix from the paper (Q rows, C columns)
ss[j*T + p] = vs;
}
} else {
for (int j = 0; j < Q; ++j) {
const half m = M[j];
@ -6647,7 +6623,7 @@ static __global__ void flash_attn_ext_f16(
M[j] = warp_reduce_max(M[j]);
const half ms = __hisinf(m) == -1 ? __float2half(0.0f) : hexp(m - M[j]);
const half ms = hexp(m - M[j]);
// create a QxQ diagonal matrix for rescaling the output
if (lane_id == j) {
@ -6662,7 +6638,7 @@ static __global__ void flash_attn_ext_f16(
const half s = ss[j*T + p];
const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]);
const half vs = hexp(s - M[j]);
ls += vs;
@ -6672,7 +6648,6 @@ static __global__ void flash_attn_ext_f16(
S[j] = S[j]*ms + warp_reduce_sum(ls);
}
}
smax = warp_reduce_max(smax);
@ -6736,7 +6711,7 @@ static __global__ void flash_attn_ext_f16(
// reduce the warps sequentially
for (int sg = 1; sg < num_warps; ++sg) {
half S = __float2half(0.0f);
half M = __float2half(-INFINITY);
half M = CUDART_MIN_DENORM_FP16;
__syncthreads();
@ -6762,8 +6737,8 @@ static __global__ void flash_attn_ext_f16(
M = __hmax(M0, M1);
const half ms0 = __hisinf(M0) == -1 ? __float2half(0.0f) : hexp(M0 - M);
const half ms1 = __hisinf(M1) == -1 ? __float2half(0.0f) : hexp(M1 - M);
const half ms0 = hexp(M0 - M);
const half ms1 = hexp(M1 - M);
S = S0*ms0 + S1*ms1;