mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-09 18:21:45 +00:00
float -> half regs
This commit is contained in:
parent
e121d82f6a
commit
3ab47eb746
@ -2898,8 +2898,8 @@ kernel void kernel_flash_attn_ext(
|
|||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
{
|
{
|
||||||
float S[Q] = { [0 ... Q-1] = 0.0f };
|
half S[Q] = { [0 ... Q-1] = 0.0f };
|
||||||
float M[Q] = { [0 ... Q-1] = -FLT_MAX/2 };
|
half M[Q] = { [0 ... Q-1] = -__FLT16_MAX__/2 };
|
||||||
|
|
||||||
// thread indices inside the simdgroup
|
// thread indices inside the simdgroup
|
||||||
// TODO: see if we can utilize quad-group functions for better performance
|
// TODO: see if we can utilize quad-group functions for better performance
|
||||||
@ -2934,14 +2934,14 @@ kernel void kernel_flash_attn_ext(
|
|||||||
|
|
||||||
const bool has_mask = mask != q;
|
const bool has_mask = mask != q;
|
||||||
|
|
||||||
float slope = 1.0f;
|
half slope = 1.0f;
|
||||||
|
|
||||||
// ALiBi
|
// ALiBi
|
||||||
if (max_bias > 0.0f) {
|
if (max_bias > 0.0f) {
|
||||||
const uint32_t h = iq2;
|
const short h = iq2;
|
||||||
|
|
||||||
const float base = h < n_head_log2 ? m0 : m1;
|
const half base = h < n_head_log2 ? m0 : m1;
|
||||||
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
const short exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
||||||
|
|
||||||
slope = pow(base, exph);
|
slope = pow(base, exph);
|
||||||
}
|
}
|
||||||
@ -3047,10 +3047,10 @@ kernel void kernel_flash_attn_ext(
|
|||||||
// online softmax
|
// online softmax
|
||||||
{
|
{
|
||||||
for (short j = 0; j < Q; ++j) {
|
for (short j = 0; j < Q; ++j) {
|
||||||
const float m = M[j];
|
const half m = M[j];
|
||||||
|
|
||||||
// scale and apply the logitcap / mask
|
// scale and apply the logitcap / mask
|
||||||
float s = ss[j*TS + tiisg]*scale;
|
half s = ss[j*TS + tiisg]*scale;
|
||||||
|
|
||||||
if (logit_softcap != 0.0f) {
|
if (logit_softcap != 0.0f) {
|
||||||
s = logit_softcap*precise::tanh(s);
|
s = logit_softcap*precise::tanh(s);
|
||||||
@ -3061,8 +3061,8 @@ kernel void kernel_flash_attn_ext(
|
|||||||
|
|
||||||
M[j] = simd_max(max(M[j], s));
|
M[j] = simd_max(max(M[j], s));
|
||||||
|
|
||||||
const float ms = exp(m - M[j]);
|
const half ms = exp(m - M[j]);
|
||||||
const float vs = exp(s - M[j]);
|
const half vs = exp(s - M[j]);
|
||||||
|
|
||||||
S[j] = S[j]*ms + simd_sum(vs);
|
S[j] = S[j]*ms + simd_sum(vs);
|
||||||
|
|
||||||
@ -3163,8 +3163,8 @@ kernel void kernel_flash_attn_ext(
|
|||||||
|
|
||||||
// reduce the warps sequentially
|
// reduce the warps sequentially
|
||||||
for (short sg = 1; sg < nsg; ++sg) {
|
for (short sg = 1; sg < nsg; ++sg) {
|
||||||
float S = { 0.0f };
|
half S = { 0.0f };
|
||||||
float M = { -FLT_MAX/2 };
|
half M = { -__FLT16_MAX__/2 };
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
@ -3180,16 +3180,16 @@ kernel void kernel_flash_attn_ext(
|
|||||||
// the first simdgroup accumulates the results from the other simdgroups
|
// the first simdgroup accumulates the results from the other simdgroups
|
||||||
if (sgitg == 0) {
|
if (sgitg == 0) {
|
||||||
for (short j = 0; j < Q; ++j) {
|
for (short j = 0; j < Q; ++j) {
|
||||||
const float S0 = ss[j*TS + 0];
|
const half S0 = ss[j*TS + 0];
|
||||||
const float S1 = ss[j*TS + sg*SH + 0];
|
const half S1 = ss[j*TS + sg*SH + 0];
|
||||||
|
|
||||||
const float M0 = ss[j*TS + 1];
|
const half M0 = ss[j*TS + 1];
|
||||||
const float M1 = ss[j*TS + sg*SH + 1];
|
const half M1 = ss[j*TS + sg*SH + 1];
|
||||||
|
|
||||||
M = max(M0, M1);
|
M = max(M0, M1);
|
||||||
|
|
||||||
const float ms0 = exp(M0 - M);
|
const half ms0 = exp(M0 - M);
|
||||||
const float ms1 = exp(M1 - M);
|
const half ms1 = exp(M1 - M);
|
||||||
|
|
||||||
S = S0*ms0 + S1*ms1;
|
S = S0*ms0 + S1*ms1;
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user