cuda : use half2 in softmax

This commit is contained in:
Georgi Gerganov 2024-02-03 15:00:25 +02:00
parent c51f27c0db
commit b958151e3f
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -6451,12 +6451,14 @@ static __global__ void flash_attn_ext_f16(
const int T = D + num_warps*SH; // shared memory size per query in (half) const int T = D + num_warps*SH; // shared memory size per query in (half)
const int T2 = T/2; // shared memory size per query in (half2) const int T2 = T/2; // shared memory size per query in (half2)
const int C2 = C/2;
extern __shared__ half __flash_attn_f16_shmem[]; extern __shared__ half __flash_attn_f16_shmem[];
// pq // pq
half * sq = (half *) (__flash_attn_f16_shmem + 0*D); // holds the query data half * sq = (half *) (__flash_attn_f16_shmem + 0*D); // holds the query data
half2 * sq2 = (half2 *) (__flash_attn_f16_shmem + 0*D); // same as above but in half2 half2 * sq2 = (half2 *) (__flash_attn_f16_shmem + 0*D); // same as above but in half2
half * ss = (half *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // scratch buffer for attention and diagonal matrix half * ss = (half *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // scratch buffer for attention and diagonal matrix
half2 * ss2 = (half2 *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // same as above but in half2
half16x16_acc zr; half16x16_acc zr;
half16x16_acc lo[Q16][D16]; half16x16_acc lo[Q16][D16];
@ -6606,19 +6608,19 @@ static __global__ void flash_attn_ext_f16(
} }
// used to detect blocks full of -INF // used to detect blocks full of -INF
half smax = __float2half(-INFINITY); half2 smax = make_half2(-INFINITY, -INFINITY);
// online softmax // online softmax
for (int j = 0; j < Q; ++j) { for (int j = 0; j < Q; ++j) {
const half m = M[j]; const half m = M[j];
for (int p0 = 0; p0 < C; p0 += NW) { for (int p0 = 0; p0 < C2; p0 += NW) {
const int p = p0 + lane_id; const int p = p0 + lane_id;
const half s = ss[j*T + p]; const half2 s = ss2[j*T2 + p];
smax = __hmax(smax, s); smax = __hmax2(smax, s);
M[j] = __hmax(M[j], s); M[j] = __hmax(M[j], __hmax(s.x, s.y));
} }
M[j] = warp_reduce_max(M[j]); M[j] = warp_reduce_max(M[j]);
@ -6631,28 +6633,31 @@ static __global__ void flash_attn_ext_f16(
} }
// local sum // local sum
half ls = 0.0f; half2 ls = make_half2(0.0f, 0.0f);
half2 M2 = make_half2(M[j], M[j]);
for (int p0 = 0; p0 < C; p0 += NW) { for (int p0 = 0; p0 < C2; p0 += NW) {
const int p = p0 + lane_id; const int p = p0 + lane_id;
const half s = ss[j*T + p]; const half2 s = ss2[j*T2 + p];
const half vs = hexp(s - M[j]); const half2 vs = h2exp(s - M2);
ls += vs; ls += vs;
// the P matrix from the paper (Q rows, C columns) // the P matrix from the paper (Q rows, C columns)
ss[j*T + p] = vs; ss2[j*T2 + p] = vs;
} }
S[j] = S[j]*ms + warp_reduce_sum(ls); ls = warp_reduce_sum(ls);
S[j] = S[j]*ms + ls.x + ls.y;
} }
smax = warp_reduce_max(smax); smax = warp_reduce_max(smax);
// skip -INF blocks // skip -INF blocks
if (__hisinf(smax) == -1) { if (__hisinf(smax.x) == -1 || __hisinf(smax.y) == -1) {
continue; continue;
} }