metal : fattn quantization (wip)

This commit is contained in:
Georgi Gerganov 2024-11-02 18:29:08 +02:00
parent 1926d6e39d
commit fdc2bb17b6
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
2 changed files with 342 additions and 16 deletions

View File

@ -258,6 +258,7 @@ enum ggml_metal_kernel_type {
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128,
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
@ -706,6 +707,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, support_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, support_simdgroup_reduction);
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
@ -862,12 +864,6 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_OP_LEAKY_RELU:
return true;
case GGML_OP_FLASH_ATTN_EXT:
if (op->src[1]->type != GGML_TYPE_F16) {
return false;
}
if (op->src[2]->type != GGML_TYPE_F16) {
return false;
}
if (op->src[0]->ne[0] == 256) {
return false;
}
@ -2861,7 +2857,11 @@ static void ggml_metal_encode_node(
bool use_vec_kernel = false;
if (ne01 >= 4 || (ne00%128 != 0)) {
if (src1->type == GGML_TYPE_Q8_0 && src2->type == GGML_TYPE_Q8_0) {
use_vec_kernel = true;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128].pipeline;
} else if (ne01 >= 4 || (ne00%128 != 0)) {
switch (ne00) {
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;

View File

@ -2293,7 +2293,7 @@ kernel void kernel_leaky_relu_f32(
dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
}
typedef void (flash_attn_ext_f16_t)(
typedef void (flash_attn_ext_t)(
device const char * q,
device const char * k,
device const char * v,
@ -2652,12 +2652,12 @@ kernel void kernel_flash_attn_ext_f16(
}
}
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64>;
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80>;
template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>;
template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>;
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>;
//template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>;
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext_f16<64>;
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext_f16<80>;
template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext_f16<96>;
template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext_f16<112>;
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext_f16<128>;
//template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext_f16<256>;
template<int64_t D, int64_t Q = 1, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
kernel void kernel_flash_attn_ext_vec_f16(
@ -2941,8 +2941,334 @@ kernel void kernel_flash_attn_ext_vec_f16(
}
}
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
//template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext_vec_f16<128>;
//template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext_vec_f16<256>;
template <typename type4x4>
void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg);
template<int64_t D, int64_t Q = 1, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
kernel void kernel_flash_attn_ext_q8_0(
device const char * q,
device const char * k,
device const char * v,
device const char * mask,
device float * dst,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne11,
constant int64_t & ne12,
constant int64_t & ne13,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant uint64_t & nb13,
constant uint64_t & nb21,
constant uint64_t & nb22,
constant uint64_t & nb23,
constant uint64_t & nb31,
constant int64_t & ne1,
constant int64_t & ne2,
constant float & scale,
constant float & max_bias,
constant float & m0,
constant float & m1,
constant uint32_t & n_head_log2,
constant float & logit_softcap,
threadgroup half * shared [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
const short nsg = ntg.y; // number of simdgroups
const short iq3 = tgpig[2];
const short iq2 = tgpig[1];
const short iq1 = tgpig[0];
const short D4 = D/4;
const short NW = N_SIMDWIDTH;
const short SH = (C + Q); // shared memory per simdgroup in (half)
const short T = D + 2*nsg*SH; // shared memory size per query in (half)
float slope = 1.0f;
// ALiBi
if (max_bias > 0.0f) {
const uint32_t h = iq2;
const float base = h < n_head_log2 ? m0 : m1;
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
slope = pow(base, exp);
}
//threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4
threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + 1*T); // scratch buffer for the results
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
half4 lo[D4/NW];
// load heads from Q to shared memory
device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03));
for (short i = tiisg; i < D4; i += NW) {
if (iq1 < ne01) {
sq4[i] = (half4) q4[i];
} else {
sq4[i] = 0.0h;
}
}
// zero out lo
for (short i = tiisg; i < D4; i += NW) {
lo[i/NW] = 0.0h;
}
// zero out shared memory SH
for (short i = tiisg; i < SH/4; i += NW) {
ss4[i] = 0.0h;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
{
float S = { 0.0h };
float M = { -FLT_MAX/2 };
// assume K and V are same shape
const short ne22 = ne12;
const short ne23 = ne13;
// broadcast
const short rk2 = ne02/ne12;
const short rk3 = ne03/ne13;
const short rv2 = ne02/ne22;
const short rv3 = ne03/ne23;
// k indices
const short ik2 = iq2 / rk2;
const short ik3 = iq3 / rk3;
// v indices
const short iv2 = iq2 / rv2;
const short iv3 = iq3 / rv3;
// load the queries from shared memory into local memory
float4 mq[D4];
for (short ii = 0; ii < D4; ii += NW) {
short i = ii + tiisg;
mq[i] = (float4) sq4[i];
}
// pointer to the mask
device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31);
// loop over the KV cache
// each simdgroup handles blocks of Q rows and C columns
for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
const int ic = ic0 + C*sgitg;
if (ic >= ne11) {
break;
}
// Q*K^T
{
#pragma unroll
for (short cc = 0; cc < C/4; ++cc) {
float4 mqk = { 0.0h };
//device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13));
//device const block_q8_0 * pk = (device const block_q8_0 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13));
#pragma unroll
//for (short ii = 0; ii < D4; ii += NW) {
// const short i = ii + tiisg;
// float4x4 mk;
// mk[0] = (float4) pk4[i + 0*(nb11/8)];
// mk[1] = (float4) pk4[i + 1*(nb11/8)];
// mk[2] = (float4) pk4[i + 2*(nb11/8)];
// mk[3] = (float4) pk4[i + 3*(nb11/8)];
// mqk += (float4) (mq[i] * mk);
//}
//for (short ii = 0; ii < D/16; ii += NW) {
// const short i = ii + tiisg%8;
// const short j = tiisg/8;
// device const block_q8_0 * pk = (device const block_q8_0 *) ((device const char *) k + ((ic + 4*cc + j)*nb11 + ik2*nb12 + ik3*nb13));
// float4x4 mk;
// dequantize_q8_0<float4x4>(pk + i/2, i%2, mk);
// //mqk += mq[4*i + 0]*mk[0] + mq[4*i + 1]*mk[1] + mq[4*i + 2]*mk[2] + mq[4*i + 3]*mk[3];
// mqk[j] += dot(mq[4*i + 0], mk[0]) + dot(mq[4*i + 1], mk[1]) + dot(mq[4*i + 2], mk[2]) + dot(mq[4*i + 3], mk[3]);
//}
for (short ii = 0; ii < D4; ii += NW) {
const short i = ii + tiisg;
for (short j = 0; j < 4; ++j) {
device const block_q8_0 * pk = (device const block_q8_0 *) ((device const char *) k + ((ic + 4*cc + j)*nb11 + ik2*nb12 + ik3*nb13));
float4x4 mk;
dequantize_q8_0<float4x4>(pk + i/8, (i/4)%2, mk);
mqk[j] += dot(mq[i], mk[i%4]);
}
}
// reduce the results from the threads in the simdgroup
mqk += simd_shuffle_down(mqk, 16);
mqk += simd_shuffle_down(mqk, 8);
mqk += simd_shuffle_down(mqk, 4);
mqk += simd_shuffle_down(mqk, 2);
mqk += simd_shuffle_down(mqk, 1);
// mqk = mqk*scale + mask*slope
if (tiisg == 0) {
mqk *= scale;
if (logit_softcap != 0.0f) {
mqk = logit_softcap*precise::tanh(mqk);
}
mqk += (mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f;
ss4[cc] = mqk;
}
}
}
// online softmax
{
const short p = tiisg;
const float m = M;
const float s = ss[p];
M = simd_max(max(M, s));
const float ms = exp(m - M);
const float vs = exp(s - M);
S = S*ms + simd_sum(vs);
// the P matrix from the paper (Q rows, C columns)
ss[p] = vs;
// O = diag(ms)*O
#pragma unroll
for (short ii = 0; ii < D4; ii += NW) {
const short i = ii + tiisg;
lo[i/NW] *= ms;
}
}
// O = O + (Q*K^T)*V
{
#pragma unroll
for (short cc = 0; cc < C/4; ++cc) {
//device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + 4*cc)*nb21 + iv2*nb22 + iv3*nb23));
#pragma unroll
//for (short ii = 0; ii < D4; ii += NW) {
// const short i = ii + tiisg;
// lo[i/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0];
// lo[i/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1];
// lo[i/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2];
// lo[i/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3];
//}
for (short ii = 0; ii < D4; ii += NW) {
const short i = ii + tiisg; // 0..31
for (short j = 0; j < 4; ++j) {
device const block_q8_0 * pv = (device const block_q8_0 *) ((device const char *) v + ((ic + 4*cc + j)*nb21 + iv2*nb22 + iv3*nb23));
half4x4 mv;
dequantize_q8_0<half4x4>(pv + i/8, (i/4)%2, mv);
lo[i/NW] += mv[i%4] * ss[4*cc + j];
}
}
}
}
}
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
if (tiisg == 0) {
ss[0] = S;
ss[1] = M;
}
}
// store results to shared memory
for (short ii = 0; ii < D4; ii += NW) {
short i = ii + tiisg;
sr4[i] = lo[ii/NW];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// parallel reduce
for (short r = nsg/2; r > 0; r >>= 1) {
if (sgitg < r) {
const float S0 = ss[ 0];
const float S1 = ss[r*SH + 0];
const float M0 = ss[ 1];
const float M1 = ss[r*SH + 1];
const float M = max(M0, M1);
const float ms0 = exp(M0 - M);
const float ms1 = exp(M1 - M);
const float S = S0*ms0 + S1*ms1;
if (tiisg == 0) {
ss[0] = S;
ss[1] = M;
}
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
for (short ii = 0; ii < D4; ii += NW) {
short i = ii + tiisg;
sr4[i] = sr4[i]*ms0 + sr4[i + r*D4]*ms1;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
device float4 * dst4 = (device float4 *) dst;
// final rescale with 1/S and store to global memory
if (sgitg == 0) {
const float S = ss[0];
for (short ii = 0; ii < D4; ii += NW) {
short i = ii + tiisg;
dst4[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D4 + i] = (float4) sr4[i]/S;
}
}
}
template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext_q8_0<128>;
template<typename T0, typename T1>
kernel void kernel_cpy(