cuda : avoid __hisinf branches

This commit is contained in:
Georgi Gerganov 2024-02-03 14:27:36 +02:00
parent 92472ea22c
commit c51f27c0db
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -6513,9 +6513,9 @@ static __global__ void flash_attn_ext_f16(
half S[Q]; half S[Q];
half M[Q]; half M[Q];
for(int i = 0; i < Q; i++) { for (int i = 0; i < Q; ++i) {
S[i] = __float2half(0.0f); S[i] = __float2half(0.0f);
M[i] = __float2half(-INFINITY); M[i] = CUDART_MIN_DENORM_FP16;
} }
// assume K and V are same shape // assume K and V are same shape
@ -6609,69 +6609,44 @@ static __global__ void flash_attn_ext_f16(
half smax = __float2half(-INFINITY); half smax = __float2half(-INFINITY);
// online softmax // online softmax
if (C == 32) { for (int j = 0; j < Q; ++j) {
for (int j = 0; j < Q; ++j) { const half m = M[j];
const int p = lane_id;
for (int p0 = 0; p0 < C; p0 += NW) {
const int p = p0 + lane_id;
const half m = M[j];
const half s = ss[j*T + p]; const half s = ss[j*T + p];
smax = warp_reduce_max(__hmax(smax, s)); smax = __hmax(smax, s);
M[j] = warp_reduce_max(__hmax(M[j], s)); M[j] = __hmax(M[j], s);
}
const half ms = __hisinf(m) == -1 ? __float2half(0.0f) : hexp(m - M[j]); M[j] = warp_reduce_max(M[j]);
const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]);
S[j] = S[j]*ms + warp_reduce_sum(vs); const half ms = hexp(m - M[j]);
// create a QxQ diagonal matrix for rescaling the output // create a QxQ diagonal matrix for rescaling the output
if (p == j) { if (lane_id == j) {
ss[j*T + C + j] = ms; ss[j*T + C + j] = ms;
} }
// local sum
half ls = 0.0f;
for (int p0 = 0; p0 < C; p0 += NW) {
const int p = p0 + lane_id;
const half s = ss[j*T + p];
const half vs = hexp(s - M[j]);
ls += vs;
// the P matrix from the paper (Q rows, C columns) // the P matrix from the paper (Q rows, C columns)
ss[j*T + p] = vs; ss[j*T + p] = vs;
} }
} else {
for (int j = 0; j < Q; ++j) {
const half m = M[j];
for (int p0 = 0; p0 < C; p0 += NW) { S[j] = S[j]*ms + warp_reduce_sum(ls);
const int p = p0 + lane_id;
const half s = ss[j*T + p];
smax = __hmax(smax, s);
M[j] = __hmax(M[j], s);
}
M[j] = warp_reduce_max(M[j]);
const half ms = __hisinf(m) == -1 ? __float2half(0.0f) : hexp(m - M[j]);
// create a QxQ diagonal matrix for rescaling the output
if (lane_id == j) {
ss[j*T + C + j] = ms;
}
// local sum
half ls = 0.0f;
for (int p0 = 0; p0 < C; p0 += NW) {
const int p = p0 + lane_id;
const half s = ss[j*T + p];
const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]);
ls += vs;
// the P matrix from the paper (Q rows, C columns)
ss[j*T + p] = vs;
}
S[j] = S[j]*ms + warp_reduce_sum(ls);
}
} }
smax = warp_reduce_max(smax); smax = warp_reduce_max(smax);
@ -6736,7 +6711,7 @@ static __global__ void flash_attn_ext_f16(
// reduce the warps sequentially // reduce the warps sequentially
for (int sg = 1; sg < num_warps; ++sg) { for (int sg = 1; sg < num_warps; ++sg) {
half S = __float2half(0.0f); half S = __float2half(0.0f);
half M = __float2half(-INFINITY); half M = CUDART_MIN_DENORM_FP16;
__syncthreads(); __syncthreads();
@ -6762,8 +6737,8 @@ static __global__ void flash_attn_ext_f16(
M = __hmax(M0, M1); M = __hmax(M0, M1);
const half ms0 = __hisinf(M0) == -1 ? __float2half(0.0f) : hexp(M0 - M); const half ms0 = hexp(M0 - M);
const half ms1 = __hisinf(M1) == -1 ? __float2half(0.0f) : hexp(M1 - M); const half ms1 = hexp(M1 - M);
S = S0*ms0 + S1*ms1; S = S0*ms0 + S1*ms1;