mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-28 12:24:35 +00:00
cuda : avoid __hisinf branches
This commit is contained in:
parent
92472ea22c
commit
c51f27c0db
89
ggml-cuda.cu
89
ggml-cuda.cu
@ -6513,9 +6513,9 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
half S[Q];
|
half S[Q];
|
||||||
half M[Q];
|
half M[Q];
|
||||||
|
|
||||||
for(int i = 0; i < Q; i++) {
|
for (int i = 0; i < Q; ++i) {
|
||||||
S[i] = __float2half(0.0f);
|
S[i] = __float2half(0.0f);
|
||||||
M[i] = __float2half(-INFINITY);
|
M[i] = CUDART_MIN_DENORM_FP16;
|
||||||
}
|
}
|
||||||
|
|
||||||
// assume K and V are same shape
|
// assume K and V are same shape
|
||||||
@ -6609,69 +6609,44 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
half smax = __float2half(-INFINITY);
|
half smax = __float2half(-INFINITY);
|
||||||
|
|
||||||
// online softmax
|
// online softmax
|
||||||
if (C == 32) {
|
for (int j = 0; j < Q; ++j) {
|
||||||
for (int j = 0; j < Q; ++j) {
|
const half m = M[j];
|
||||||
const int p = lane_id;
|
|
||||||
|
for (int p0 = 0; p0 < C; p0 += NW) {
|
||||||
|
const int p = p0 + lane_id;
|
||||||
|
|
||||||
const half m = M[j];
|
|
||||||
const half s = ss[j*T + p];
|
const half s = ss[j*T + p];
|
||||||
|
|
||||||
smax = warp_reduce_max(__hmax(smax, s));
|
smax = __hmax(smax, s);
|
||||||
M[j] = warp_reduce_max(__hmax(M[j], s));
|
M[j] = __hmax(M[j], s);
|
||||||
|
}
|
||||||
|
|
||||||
const half ms = __hisinf(m) == -1 ? __float2half(0.0f) : hexp(m - M[j]);
|
M[j] = warp_reduce_max(M[j]);
|
||||||
const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]);
|
|
||||||
|
|
||||||
S[j] = S[j]*ms + warp_reduce_sum(vs);
|
const half ms = hexp(m - M[j]);
|
||||||
|
|
||||||
// create a QxQ diagonal matrix for rescaling the output
|
// create a QxQ diagonal matrix for rescaling the output
|
||||||
if (p == j) {
|
if (lane_id == j) {
|
||||||
ss[j*T + C + j] = ms;
|
ss[j*T + C + j] = ms;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// local sum
|
||||||
|
half ls = 0.0f;
|
||||||
|
|
||||||
|
for (int p0 = 0; p0 < C; p0 += NW) {
|
||||||
|
const int p = p0 + lane_id;
|
||||||
|
|
||||||
|
const half s = ss[j*T + p];
|
||||||
|
|
||||||
|
const half vs = hexp(s - M[j]);
|
||||||
|
|
||||||
|
ls += vs;
|
||||||
|
|
||||||
// the P matrix from the paper (Q rows, C columns)
|
// the P matrix from the paper (Q rows, C columns)
|
||||||
ss[j*T + p] = vs;
|
ss[j*T + p] = vs;
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
for (int j = 0; j < Q; ++j) {
|
|
||||||
const half m = M[j];
|
|
||||||
|
|
||||||
for (int p0 = 0; p0 < C; p0 += NW) {
|
S[j] = S[j]*ms + warp_reduce_sum(ls);
|
||||||
const int p = p0 + lane_id;
|
|
||||||
|
|
||||||
const half s = ss[j*T + p];
|
|
||||||
|
|
||||||
smax = __hmax(smax, s);
|
|
||||||
M[j] = __hmax(M[j], s);
|
|
||||||
}
|
|
||||||
|
|
||||||
M[j] = warp_reduce_max(M[j]);
|
|
||||||
|
|
||||||
const half ms = __hisinf(m) == -1 ? __float2half(0.0f) : hexp(m - M[j]);
|
|
||||||
|
|
||||||
// create a QxQ diagonal matrix for rescaling the output
|
|
||||||
if (lane_id == j) {
|
|
||||||
ss[j*T + C + j] = ms;
|
|
||||||
}
|
|
||||||
|
|
||||||
// local sum
|
|
||||||
half ls = 0.0f;
|
|
||||||
|
|
||||||
for (int p0 = 0; p0 < C; p0 += NW) {
|
|
||||||
const int p = p0 + lane_id;
|
|
||||||
|
|
||||||
const half s = ss[j*T + p];
|
|
||||||
|
|
||||||
const half vs = __hisinf(s) == -1 ? __float2half(0.0f) : hexp(s - M[j]);
|
|
||||||
|
|
||||||
ls += vs;
|
|
||||||
|
|
||||||
// the P matrix from the paper (Q rows, C columns)
|
|
||||||
ss[j*T + p] = vs;
|
|
||||||
}
|
|
||||||
|
|
||||||
S[j] = S[j]*ms + warp_reduce_sum(ls);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
smax = warp_reduce_max(smax);
|
smax = warp_reduce_max(smax);
|
||||||
@ -6736,7 +6711,7 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
// reduce the warps sequentially
|
// reduce the warps sequentially
|
||||||
for (int sg = 1; sg < num_warps; ++sg) {
|
for (int sg = 1; sg < num_warps; ++sg) {
|
||||||
half S = __float2half(0.0f);
|
half S = __float2half(0.0f);
|
||||||
half M = __float2half(-INFINITY);
|
half M = CUDART_MIN_DENORM_FP16;
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
@ -6762,8 +6737,8 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
|
|
||||||
M = __hmax(M0, M1);
|
M = __hmax(M0, M1);
|
||||||
|
|
||||||
const half ms0 = __hisinf(M0) == -1 ? __float2half(0.0f) : hexp(M0 - M);
|
const half ms0 = hexp(M0 - M);
|
||||||
const half ms1 = __hisinf(M1) == -1 ? __float2half(0.0f) : hexp(M1 - M);
|
const half ms1 = hexp(M1 - M);
|
||||||
|
|
||||||
S = S0*ms0 + S1*ms1;
|
S = S0*ms0 + S1*ms1;
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user