mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-28 12:24:35 +00:00
cuda : use half2 in softmax
This commit is contained in:
parent
c51f27c0db
commit
b958151e3f
29
ggml-cuda.cu
29
ggml-cuda.cu
@ -6451,12 +6451,14 @@ static __global__ void flash_attn_ext_f16(
|
||||
|
||||
const int T = D + num_warps*SH; // shared memory size per query in (half)
|
||||
const int T2 = T/2; // shared memory size per query in (half2)
|
||||
const int C2 = C/2;
|
||||
|
||||
extern __shared__ half __flash_attn_f16_shmem[];
|
||||
// pq
|
||||
half * sq = (half *) (__flash_attn_f16_shmem + 0*D); // holds the query data
|
||||
half2 * sq2 = (half2 *) (__flash_attn_f16_shmem + 0*D); // same as above but in half2
|
||||
half * ss = (half *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // scratch buffer for attention and diagonal matrix
|
||||
half2 * ss2 = (half2 *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // same as above but in half2
|
||||
|
||||
half16x16_acc zr;
|
||||
half16x16_acc lo[Q16][D16];
|
||||
@ -6606,19 +6608,19 @@ static __global__ void flash_attn_ext_f16(
|
||||
}
|
||||
|
||||
// used to detect blocks full of -INF
|
||||
half smax = __float2half(-INFINITY);
|
||||
half2 smax = make_half2(-INFINITY, -INFINITY);
|
||||
|
||||
// online softmax
|
||||
for (int j = 0; j < Q; ++j) {
|
||||
const half m = M[j];
|
||||
|
||||
for (int p0 = 0; p0 < C; p0 += NW) {
|
||||
for (int p0 = 0; p0 < C2; p0 += NW) {
|
||||
const int p = p0 + lane_id;
|
||||
|
||||
const half s = ss[j*T + p];
|
||||
const half2 s = ss2[j*T2 + p];
|
||||
|
||||
smax = __hmax(smax, s);
|
||||
M[j] = __hmax(M[j], s);
|
||||
smax = __hmax2(smax, s);
|
||||
M[j] = __hmax(M[j], __hmax(s.x, s.y));
|
||||
}
|
||||
|
||||
M[j] = warp_reduce_max(M[j]);
|
||||
@ -6631,28 +6633,31 @@ static __global__ void flash_attn_ext_f16(
|
||||
}
|
||||
|
||||
// local sum
|
||||
half ls = 0.0f;
|
||||
half2 ls = make_half2(0.0f, 0.0f);
|
||||
half2 M2 = make_half2(M[j], M[j]);
|
||||
|
||||
for (int p0 = 0; p0 < C; p0 += NW) {
|
||||
for (int p0 = 0; p0 < C2; p0 += NW) {
|
||||
const int p = p0 + lane_id;
|
||||
|
||||
const half s = ss[j*T + p];
|
||||
const half2 s = ss2[j*T2 + p];
|
||||
|
||||
const half vs = hexp(s - M[j]);
|
||||
const half2 vs = h2exp(s - M2);
|
||||
|
||||
ls += vs;
|
||||
|
||||
// the P matrix from the paper (Q rows, C columns)
|
||||
ss[j*T + p] = vs;
|
||||
ss2[j*T2 + p] = vs;
|
||||
}
|
||||
|
||||
S[j] = S[j]*ms + warp_reduce_sum(ls);
|
||||
ls = warp_reduce_sum(ls);
|
||||
|
||||
S[j] = S[j]*ms + ls.x + ls.y;
|
||||
}
|
||||
|
||||
smax = warp_reduce_max(smax);
|
||||
|
||||
// skip -INF blocks
|
||||
if (__hisinf(smax) == -1) {
|
||||
if (__hisinf(smax.x) == -1 || __hisinf(smax.y) == -1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user