mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-28 12:24:35 +00:00
cuda : avoid warp_reduce for smax
This commit is contained in:
parent
b68a112204
commit
b150abe83e
@ -6621,7 +6621,6 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
M[j] = __hmax(M[j], s);
|
M[j] = __hmax(M[j], s);
|
||||||
}
|
}
|
||||||
|
|
||||||
smax = warp_reduce_max(smax);
|
|
||||||
M[j] = warp_reduce_max(M[j]);
|
M[j] = warp_reduce_max(M[j]);
|
||||||
|
|
||||||
const half ms = __hisinf(m) == -1 ? __float2half(0.0f) : hexp(m - M[j]);
|
const half ms = __hisinf(m) == -1 ? __float2half(0.0f) : hexp(m - M[j]);
|
||||||
@ -6649,6 +6648,8 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
smax = warp_reduce_max(smax);
|
||||||
|
|
||||||
// skip -INF blocks
|
// skip -INF blocks
|
||||||
if (__hisinf(smax) == -1) {
|
if (__hisinf(smax) == -1) {
|
||||||
continue;
|
continue;
|
||||||
|
Loading…
Reference in New Issue
Block a user