cuda : use half2 in softmax

This commit is contained in:
Georgi Gerganov 2024-02-03 15:00:25 +02:00
parent c51f27c0db
commit b958151e3f
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -6451,12 +6451,14 @@ static __global__ void flash_attn_ext_f16(
const int T = D + num_warps*SH; // shared memory size per query in (half)
const int T2 = T/2; // shared memory size per query in (half2)
const int C2 = C/2;
extern __shared__ half __flash_attn_f16_shmem[];
// pq
half * sq = (half *) (__flash_attn_f16_shmem + 0*D); // holds the query data
half2 * sq2 = (half2 *) (__flash_attn_f16_shmem + 0*D); // same as above but in half2
half * ss = (half *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // scratch buffer for attention and diagonal matrix
half2 * ss2 = (half2 *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // same as above but in half2
half16x16_acc zr;
half16x16_acc lo[Q16][D16];
@ -6606,19 +6608,19 @@ static __global__ void flash_attn_ext_f16(
}
// used to detect blocks full of -INF
half smax = __float2half(-INFINITY);
half2 smax = make_half2(-INFINITY, -INFINITY);
// online softmax
for (int j = 0; j < Q; ++j) {
const half m = M[j];
for (int p0 = 0; p0 < C; p0 += NW) {
for (int p0 = 0; p0 < C2; p0 += NW) {
const int p = p0 + lane_id;
const half s = ss[j*T + p];
const half2 s = ss2[j*T2 + p];
smax = __hmax(smax, s);
M[j] = __hmax(M[j], s);
smax = __hmax2(smax, s);
M[j] = __hmax(M[j], __hmax(s.x, s.y));
}
M[j] = warp_reduce_max(M[j]);
@ -6631,28 +6633,31 @@ static __global__ void flash_attn_ext_f16(
}
// local sum
half ls = 0.0f;
half2 ls = make_half2(0.0f, 0.0f);
half2 M2 = make_half2(M[j], M[j]);
for (int p0 = 0; p0 < C; p0 += NW) {
for (int p0 = 0; p0 < C2; p0 += NW) {
const int p = p0 + lane_id;
const half s = ss[j*T + p];
const half2 s = ss2[j*T2 + p];
const half vs = hexp(s - M[j]);
const half2 vs = h2exp(s - M2);
ls += vs;
// the P matrix from the paper (Q rows, C columns)
ss[j*T + p] = vs;
ss2[j*T2 + p] = vs;
}
S[j] = S[j]*ms + warp_reduce_sum(ls);
ls = warp_reduce_sum(ls);
S[j] = S[j]*ms + ls.x + ls.y;
}
smax = warp_reduce_max(smax);
// skip -INF blocks
if (__hisinf(smax) == -1) {
if (__hisinf(smax.x) == -1 || __hisinf(smax.y) == -1) {
continue;
}