mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-30 21:34:36 +00:00
cuda : use half2 in softmax
This commit is contained in:
parent
c51f27c0db
commit
b958151e3f
29
ggml-cuda.cu
29
ggml-cuda.cu
@ -6451,12 +6451,14 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
|
|
||||||
const int T = D + num_warps*SH; // shared memory size per query in (half)
|
const int T = D + num_warps*SH; // shared memory size per query in (half)
|
||||||
const int T2 = T/2; // shared memory size per query in (half2)
|
const int T2 = T/2; // shared memory size per query in (half2)
|
||||||
|
const int C2 = C/2;
|
||||||
|
|
||||||
extern __shared__ half __flash_attn_f16_shmem[];
|
extern __shared__ half __flash_attn_f16_shmem[];
|
||||||
// pq
|
// pq
|
||||||
half * sq = (half *) (__flash_attn_f16_shmem + 0*D); // holds the query data
|
half * sq = (half *) (__flash_attn_f16_shmem + 0*D); // holds the query data
|
||||||
half2 * sq2 = (half2 *) (__flash_attn_f16_shmem + 0*D); // same as above but in half2
|
half2 * sq2 = (half2 *) (__flash_attn_f16_shmem + 0*D); // same as above but in half2
|
||||||
half * ss = (half *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // scratch buffer for attention and diagonal matrix
|
half * ss = (half *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // scratch buffer for attention and diagonal matrix
|
||||||
|
half2 * ss2 = (half2 *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // same as above but in half2
|
||||||
|
|
||||||
half16x16_acc zr;
|
half16x16_acc zr;
|
||||||
half16x16_acc lo[Q16][D16];
|
half16x16_acc lo[Q16][D16];
|
||||||
@ -6606,19 +6608,19 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// used to detect blocks full of -INF
|
// used to detect blocks full of -INF
|
||||||
half smax = __float2half(-INFINITY);
|
half2 smax = make_half2(-INFINITY, -INFINITY);
|
||||||
|
|
||||||
// online softmax
|
// online softmax
|
||||||
for (int j = 0; j < Q; ++j) {
|
for (int j = 0; j < Q; ++j) {
|
||||||
const half m = M[j];
|
const half m = M[j];
|
||||||
|
|
||||||
for (int p0 = 0; p0 < C; p0 += NW) {
|
for (int p0 = 0; p0 < C2; p0 += NW) {
|
||||||
const int p = p0 + lane_id;
|
const int p = p0 + lane_id;
|
||||||
|
|
||||||
const half s = ss[j*T + p];
|
const half2 s = ss2[j*T2 + p];
|
||||||
|
|
||||||
smax = __hmax(smax, s);
|
smax = __hmax2(smax, s);
|
||||||
M[j] = __hmax(M[j], s);
|
M[j] = __hmax(M[j], __hmax(s.x, s.y));
|
||||||
}
|
}
|
||||||
|
|
||||||
M[j] = warp_reduce_max(M[j]);
|
M[j] = warp_reduce_max(M[j]);
|
||||||
@ -6631,28 +6633,31 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// local sum
|
// local sum
|
||||||
half ls = 0.0f;
|
half2 ls = make_half2(0.0f, 0.0f);
|
||||||
|
half2 M2 = make_half2(M[j], M[j]);
|
||||||
|
|
||||||
for (int p0 = 0; p0 < C; p0 += NW) {
|
for (int p0 = 0; p0 < C2; p0 += NW) {
|
||||||
const int p = p0 + lane_id;
|
const int p = p0 + lane_id;
|
||||||
|
|
||||||
const half s = ss[j*T + p];
|
const half2 s = ss2[j*T2 + p];
|
||||||
|
|
||||||
const half vs = hexp(s - M[j]);
|
const half2 vs = h2exp(s - M2);
|
||||||
|
|
||||||
ls += vs;
|
ls += vs;
|
||||||
|
|
||||||
// the P matrix from the paper (Q rows, C columns)
|
// the P matrix from the paper (Q rows, C columns)
|
||||||
ss[j*T + p] = vs;
|
ss2[j*T2 + p] = vs;
|
||||||
}
|
}
|
||||||
|
|
||||||
S[j] = S[j]*ms + warp_reduce_sum(ls);
|
ls = warp_reduce_sum(ls);
|
||||||
|
|
||||||
|
S[j] = S[j]*ms + ls.x + ls.y;
|
||||||
}
|
}
|
||||||
|
|
||||||
smax = warp_reduce_max(smax);
|
smax = warp_reduce_max(smax);
|
||||||
|
|
||||||
// skip -INF blocks
|
// skip -INF blocks
|
||||||
if (__hisinf(smax) == -1) {
|
if (__hisinf(smax.x) == -1 || __hisinf(smax.y) == -1) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user