cuda : simplify softmax

This commit is contained in:
Georgi Gerganov 2024-02-03 18:31:55 +02:00
parent e04ff39181
commit cfd9732b2e
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -6512,11 +6512,10 @@ static __global__ void flash_attn_ext_f16(
__syncthreads(); __syncthreads();
{ {
half S[Q]; half S = __float2half(0.0f);
half M[Q]; half M[Q];
for (int i = 0; i < Q; ++i) { for (int i = 0; i < Q; ++i) {
S[i] = __float2half(0.0f);
M[i] = CUDART_MIN_DENORM_FP16; M[i] = CUDART_MIN_DENORM_FP16;
} }
@ -6626,13 +6625,6 @@ static __global__ void flash_attn_ext_f16(
M[j] = warp_reduce_max(M[j]); M[j] = warp_reduce_max(M[j]);
const half ms = hexp(m - M[j]);
// create a QxQ diagonal matrix for rescaling the output
if (lane_id == j) {
ss[j*T + C + j] = ms;
}
// local sum // local sum
half2 ls = make_half2(0.0f, 0.0f); half2 ls = make_half2(0.0f, 0.0f);
half2 M2 = make_half2(M[j], M[j]); half2 M2 = make_half2(M[j], M[j]);
@ -6652,7 +6644,14 @@ static __global__ void flash_attn_ext_f16(
ls = warp_reduce_sum(ls); ls = warp_reduce_sum(ls);
S[j] = S[j]*ms + ls.x + ls.y; const half ms = hexp(m - M[j]);
// create a QxQ diagonal matrix for rescaling the output
if (lane_id == j) {
ss[j*T + C + j] = ms;
S = S*ms + ls.x + ls.y;
}
} }
smax = warp_reduce_max(smax); smax = warp_reduce_max(smax);
@ -6709,8 +6708,8 @@ static __global__ void flash_attn_ext_f16(
// these are needed for reducing the results from the simdgroups (reuse the ss buffer) // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
for (int j = 0; j < Q; ++j) { for (int j = 0; j < Q; ++j) {
if (lane_id == 0) { if (lane_id == j) {
ss[j*T + 0] = S[j]; ss[j*T + 0] = S;
ss[j*T + 1] = M[j]; ss[j*T + 1] = M[j];
} }
} }