mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-07 17:21:46 +00:00
f16 vec
This commit is contained in:
parent
8f0ef15265
commit
3b9625032c
@ -3294,10 +3294,10 @@ static void ggml_metal_encode_node(
|
||||
// for each query, we load it as f16 in shared memory (ne00)
|
||||
// and store the attention scores (nqptg x ncpsg) as f32
|
||||
//
|
||||
// 2*ne00*(nsg)
|
||||
// each simdgroup has a full f32 head vector in shared mem to accumulate results
|
||||
// ne00*(nsg)
|
||||
// each simdgroup has a full f16 head vector in shared mem to accumulate results
|
||||
//
|
||||
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*ncpsg*(nsg)) + 2*ne00*(nsg))*(sizeof(float)/2), 16))
|
||||
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*ncpsg*(nsg)) + ne00*(nsg))*(sizeof(float)/2), 16))
|
||||
|
||||
int64_t nsgmax = 2;
|
||||
|
||||
|
@ -3219,7 +3219,7 @@ kernel void kernel_flash_attn_ext(
|
||||
// final rescale with 1/S and store to global memory
|
||||
if (sgitg == 0) {
|
||||
for (short j = 0; j < Q && iq1 + j < ne01; ++j) {
|
||||
const half S = ss[j*TS + 0];
|
||||
const float S = ss[j*TS + 0];
|
||||
|
||||
for (short i = tiisg; i < D4; i += NW) {
|
||||
dst4[((int64_t)iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) so4[j*D4 + i]/S;
|
||||
@ -3292,19 +3292,21 @@ template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_
|
||||
#undef FA_TYPES
|
||||
|
||||
template<
|
||||
typename q4_t,
|
||||
typename q4_t, // query types in shared memory
|
||||
typename q4x4_t,
|
||||
typename k4x4_t,
|
||||
typename v4x4_t,
|
||||
typename s_t, // attention accumulation types
|
||||
typename k4x4_t, // key types in shared memory
|
||||
typename v4x4_t, // value types in shared memory
|
||||
typename qk_t, // Q*K types
|
||||
typename s_t, // soft-max types
|
||||
typename s4_t,
|
||||
typename s4x4_t,
|
||||
typename o4x4_t,
|
||||
typename block_q,
|
||||
typename o4x4_t, // attention accumulation types
|
||||
typename kd4x4_t, // key type in device memory
|
||||
short nl_k,
|
||||
void (*deq_k)(device const block_q *, short, thread k4x4_t &),
|
||||
void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &),
|
||||
typename vd4x4_t, // key type in device memory
|
||||
short nl_v,
|
||||
void (*deq_v)(device const block_q *, short, thread v4x4_t &),
|
||||
void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
|
||||
short D, // head size
|
||||
short Q = 1, // queries per threadgroup
|
||||
short C = 32> // cache items per threadgroup
|
||||
@ -3333,14 +3335,14 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
constant float & max_bias,
|
||||
constant float & m0,
|
||||
constant float & m1,
|
||||
constant uint32_t & n_head_log2,
|
||||
constant uint16_t & n_head_log2,
|
||||
constant float & logit_softcap,
|
||||
threadgroup half * shared [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
ushort3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
const short nsg = ntg.y; // number of simdgroups
|
||||
|
||||
const int iq3 = tgpig[2];
|
||||
@ -3353,16 +3355,14 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
const short NW4 = NW/4;
|
||||
const short SH = C; // shared memory per simdgroup in (half)
|
||||
|
||||
const short SF = sizeof(s_t)/sizeof(half);
|
||||
const short T = D + 2*nsg*SH; // shared memory size per query in (half)
|
||||
|
||||
const short T = D + SF*nsg*SH; // shared memory size per query in (half)
|
||||
|
||||
//threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data
|
||||
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0*D); // same as above but in half4
|
||||
threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shared + 0*D); // same as above but in half4x4
|
||||
threadgroup s_t * ss = (threadgroup s_t *) (shared + SF*sgitg*SH + 1*D); // scratch buffer for attention
|
||||
threadgroup s4_t * ss4 = (threadgroup s4_t *) (shared + SF*sgitg*SH + 1*D); // same as above but in half4
|
||||
threadgroup s4x4_t * sr4x4 = (threadgroup s4x4_t *) (shared + SF*sgitg*D + Q*T); // scratch buffer for the results
|
||||
//threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data
|
||||
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0*D); // same as above but in half4
|
||||
threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shared + 0*D); // same as above but in half4x4
|
||||
threadgroup s_t * ss = (threadgroup s_t *) (shared + 2*sgitg*SH + Q*D); // scratch buffer for attention
|
||||
threadgroup s4_t * ss4 = (threadgroup s4_t *) (shared + 2*sgitg*SH + Q*D); // same as above but in half4
|
||||
threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shared + sgitg*D + Q*T); // scratch buffer for the results
|
||||
|
||||
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
|
||||
o4x4_t lo[D16/NW4];
|
||||
@ -3374,7 +3374,7 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
if (iq1 < ne01) {
|
||||
sq4[i] = (q4_t) q4[i];
|
||||
} else {
|
||||
sq4[i] = (q4_t) (float4) 0.0f;
|
||||
sq4[i] = (q4_t) 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
@ -3385,14 +3385,14 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
|
||||
// zero out shared memory SH
|
||||
for (short i = tiisg; i < SH/4; i += NW) {
|
||||
ss4[i] = (s4_t) (float4) 0.0f;
|
||||
ss4[i] = (s4_t) 0.0f;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
{
|
||||
float S = 0.0f;
|
||||
float M = -FLT_MAX/2;
|
||||
half S = 0.0f;
|
||||
half M = -__FLT16_MAX__/2;
|
||||
|
||||
// thread indices inside the simdgroup
|
||||
const short tx = tiisg%8;
|
||||
@ -3406,25 +3406,25 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
const short ikv3 = iq3/(ne03/ne_12_3);
|
||||
|
||||
// load the queries from shared memory into local memory
|
||||
k4x4_t mq[D16/NW4];
|
||||
q4x4_t mq[D16/NW4];
|
||||
|
||||
for (short ii = 0; ii < D16; ii += NW4) {
|
||||
mq[ii/NW4] = (k4x4_t) sq4x4[ii + tx];
|
||||
mq[ii/NW4] = sq4x4[ii + tx];
|
||||
}
|
||||
|
||||
// pointer to the mask
|
||||
device const half * mp = (device const half *) (mask + iq1*nb31);
|
||||
|
||||
float slope = 1.0f;
|
||||
half slope = 1.0f;
|
||||
|
||||
// ALiBi
|
||||
if (max_bias > 0.0f) {
|
||||
const uint32_t h = iq2;
|
||||
const short h = iq2;
|
||||
|
||||
const float base = h < n_head_log2 ? m0 : m1;
|
||||
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
||||
const half base = h < n_head_log2 ? m0 : m1;
|
||||
const short exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
||||
|
||||
slope = pow(base, exp);
|
||||
slope = pow(base, exph);
|
||||
}
|
||||
|
||||
// loop over the KV cache
|
||||
@ -3439,9 +3439,9 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
{
|
||||
// each simdgroup processes 1 query and 4 keys
|
||||
for (short cc = 0; cc < C/4; ++cc) {
|
||||
s_t mqk = 0.0;
|
||||
qk_t mqk = 0.0;
|
||||
|
||||
device const block_q * pk = (device const block_q *) ((device const char *) k + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
|
||||
device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
|
||||
|
||||
#pragma unroll
|
||||
for (short ii = 0; ii < D16; ii += NW4) {
|
||||
@ -3487,20 +3487,18 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
|
||||
// online softmax
|
||||
{
|
||||
const short p = tiisg;
|
||||
|
||||
const float m = M;
|
||||
const float s = ss[p];
|
||||
const half m = M;
|
||||
const half s = ss[tiisg];
|
||||
|
||||
M = simd_max(max(M, s));
|
||||
|
||||
const float ms = exp(m - M);
|
||||
const float vs = exp(s - M);
|
||||
const half ms = exp(m - M);
|
||||
const half vs = exp(s - M);
|
||||
|
||||
S = S*ms + simd_sum(vs);
|
||||
|
||||
// the P matrix from the paper (Q rows, C columns)
|
||||
ss[p] = vs;
|
||||
ss[tiisg] = vs;
|
||||
|
||||
// O = diag(ms)*O
|
||||
#pragma unroll
|
||||
@ -3515,9 +3513,9 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
{
|
||||
#pragma unroll
|
||||
for (short cc = 0; cc < C/4; ++cc) {
|
||||
device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
|
||||
device const vd4x4_t * pv4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
|
||||
|
||||
const s4x4_t ms(ss[4*cc + ty]);
|
||||
const v4x4_t ms(ss[4*cc + ty]);
|
||||
|
||||
#pragma unroll
|
||||
for (short ii = 0; ii < D16; ii += NW4) {
|
||||
@ -3526,7 +3524,7 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
v4x4_t mv;
|
||||
deq_v(pv4 + i/nl_v, i%nl_v, mv);
|
||||
|
||||
lo[ii/NW4] += mv*ms;
|
||||
lo[ii/NW4] += (o4x4_t)(mv*ms);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -3572,22 +3570,22 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
// parallel reduce
|
||||
for (short r = nsg/2; r > 0; r >>= 1) {
|
||||
if (sgitg < r) {
|
||||
const float S0 = ss[ 0];
|
||||
const float S1 = ss[r*SH + 0];
|
||||
const half S0 = ss[ 0];
|
||||
const half S1 = ss[r*SH + 0];
|
||||
|
||||
const float M0 = ss[ 1];
|
||||
const float M1 = ss[r*SH + 1];
|
||||
const half M0 = ss[ 1];
|
||||
const half M1 = ss[r*SH + 1];
|
||||
|
||||
const float M = max(M0, M1);
|
||||
const half M = max(M0, M1);
|
||||
|
||||
const float ms0 = exp(M0 - M);
|
||||
const float ms1 = exp(M1 - M);
|
||||
const half ms0 = exp(M0 - M);
|
||||
const half ms1 = exp(M1 - M);
|
||||
|
||||
const float S = S0*ms0 + S1*ms1;
|
||||
const half S = S0*ms0 + S1*ms1;
|
||||
|
||||
if (tiisg == 0) {
|
||||
ss[0] = (s_t) S;
|
||||
ss[1] = (s_t) M;
|
||||
ss[0] = S;
|
||||
ss[1] = M;
|
||||
}
|
||||
|
||||
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
|
||||
@ -3611,33 +3609,31 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE: can use half instead of float precision for some extra perf
|
||||
// however, by default use F32 since the op should be mostly memory bandwidth bound
|
||||
|
||||
#define FA_TYPES \
|
||||
half4, half4x4, \
|
||||
float4x4, \
|
||||
float4x4, \
|
||||
half4x4, \
|
||||
half4x4, \
|
||||
float, \
|
||||
float, float4, float4x4, \
|
||||
float4x4
|
||||
half4x4
|
||||
|
||||
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, 1, dequantize_f16, 64>) flash_attn_ext_vec_t;
|
||||
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64>) flash_attn_ext_vec_t;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, 1, dequantize_f16, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4x4, 1, dequantize_bf16, 1, dequantize_bf16, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 2, dequantize_q4_0, 2, dequantize_q4_0, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 2, dequantize_q4_1, 2, dequantize_q4_1, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 2, dequantize_q5_0, 2, dequantize_q5_0, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 2, dequantize_q5_1, 2, dequantize_q5_1, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 2, dequantize_q8_0, 2, dequantize_q8_0, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 128>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, 1, dequantize_f16, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4x4, 1, dequantize_bf16, 1, dequantize_bf16, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 2, dequantize_q4_0, 2, dequantize_q4_0, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 2, dequantize_q4_1, 2, dequantize_q4_1, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 2, dequantize_q5_0, 2, dequantize_q5_0, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 2, dequantize_q5_1, 2, dequantize_q5_1, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 2, dequantize_q8_0, 2, dequantize_q8_0, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 256>;
|
||||
|
||||
#undef FA_TYPES
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user