mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-11 19:21:46 +00:00
metal : gemma2 flash attention support (#9159)
This commit is contained in:
parent
f12ceaca0c
commit
0c41e03ceb
@ -802,15 +802,6 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
|
|||||||
if (op->src[0]->ne[0] == 256) {
|
if (op->src[0]->ne[0] == 256) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
{
|
|
||||||
float logit_softcap;
|
|
||||||
|
|
||||||
memcpy(&logit_softcap, ((const float *) op->op_params) + 2, sizeof(logit_softcap));
|
|
||||||
|
|
||||||
if (logit_softcap != 0.0f) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
||||||
case GGML_OP_MUL_MAT:
|
case GGML_OP_MUL_MAT:
|
||||||
case GGML_OP_MUL_MAT_ID:
|
case GGML_OP_MUL_MAT_ID:
|
||||||
@ -2633,9 +2624,14 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||||||
|
|
||||||
float scale;
|
float scale;
|
||||||
float max_bias;
|
float max_bias;
|
||||||
|
float logit_softcap;
|
||||||
|
memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
|
||||||
|
memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
|
||||||
|
memcpy(&logit_softcap, ((int32_t *) dst->op_params) + 2, sizeof(logit_softcap));
|
||||||
|
|
||||||
memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
|
if (logit_softcap != 0.0f) {
|
||||||
memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
|
scale /= logit_softcap;
|
||||||
|
}
|
||||||
|
|
||||||
const uint32_t n_head = src0->ne[2];
|
const uint32_t n_head = src0->ne[2];
|
||||||
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
||||||
@ -2686,30 +2682,31 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||||||
} else {
|
} else {
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:3];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:3];
|
||||||
}
|
}
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:4];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:4];
|
||||||
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
|
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
|
||||||
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
|
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
|
||||||
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
|
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
|
||||||
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
|
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
|
||||||
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
|
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
|
||||||
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
|
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
|
||||||
[encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11];
|
[encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11];
|
||||||
[encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12];
|
[encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12];
|
||||||
[encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13];
|
[encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13];
|
||||||
[encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14];
|
[encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14];
|
||||||
[encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15];
|
[encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15];
|
||||||
[encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16];
|
[encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16];
|
||||||
[encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17];
|
[encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17];
|
||||||
[encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18];
|
[encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18];
|
||||||
[encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19];
|
[encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19];
|
||||||
[encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20];
|
[encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20];
|
||||||
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21];
|
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21];
|
||||||
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22];
|
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22];
|
||||||
[encoder setBytes:&scale length:sizeof( float) atIndex:23];
|
[encoder setBytes:&scale length:sizeof( float) atIndex:23];
|
||||||
[encoder setBytes:&max_bias length:sizeof( float) atIndex:24];
|
[encoder setBytes:&max_bias length:sizeof( float) atIndex:24];
|
||||||
[encoder setBytes:&m0 length:sizeof(m0) atIndex:25];
|
[encoder setBytes:&m0 length:sizeof(m0) atIndex:25];
|
||||||
[encoder setBytes:&m1 length:sizeof(m1) atIndex:26];
|
[encoder setBytes:&m1 length:sizeof(m1) atIndex:26];
|
||||||
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27];
|
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27];
|
||||||
|
[encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:28];
|
||||||
|
|
||||||
if (!use_vec_kernel) {
|
if (!use_vec_kernel) {
|
||||||
// half8x8 kernel
|
// half8x8 kernel
|
||||||
|
@ -1976,6 +1976,7 @@ typedef void (flash_attn_ext_f16_t)(
|
|||||||
constant float & m0,
|
constant float & m0,
|
||||||
constant float & m1,
|
constant float & m1,
|
||||||
constant uint32_t & n_head_log2,
|
constant uint32_t & n_head_log2,
|
||||||
|
constant float & logit_softcap,
|
||||||
threadgroup half * shared,
|
threadgroup half * shared,
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
@ -2014,6 +2015,7 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
constant float & m0,
|
constant float & m0,
|
||||||
constant float & m1,
|
constant float & m1,
|
||||||
constant uint32_t & n_head_log2,
|
constant uint32_t & n_head_log2,
|
||||||
|
constant float & logit_softcap,
|
||||||
threadgroup half * shared [[threadgroup(0)]],
|
threadgroup half * shared [[threadgroup(0)]],
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
@ -2142,14 +2144,19 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
const short tx = tiisg%4;
|
const short tx = tiisg%4;
|
||||||
const short ty = tiisg/4;
|
const short ty = tiisg/4;
|
||||||
|
|
||||||
|
// mqk = mqk*scale
|
||||||
|
ss[8*cc + ty*TF + 2*tx + 0] *= scale;
|
||||||
|
ss[8*cc + ty*TF + 2*tx + 1] *= scale;
|
||||||
|
|
||||||
|
if (logit_softcap != 0.0f) {
|
||||||
|
ss[8*cc + ty*TF + 2*tx + 0] = logit_softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 0]);
|
||||||
|
ss[8*cc + ty*TF + 2*tx + 1] = logit_softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 1]);
|
||||||
|
}
|
||||||
|
|
||||||
if (mask != q) {
|
if (mask != q) {
|
||||||
// mqk = mqk*scale + mask*slope
|
// mqk = mqk + mask*slope
|
||||||
ss[8*cc + ty*TF + 2*tx + 0] = scale*ss[8*cc + ty*TF + 2*tx + 0] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0];
|
ss[8*cc + ty*TF + 2*tx + 0] += slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0];
|
||||||
ss[8*cc + ty*TF + 2*tx + 1] = scale*ss[8*cc + ty*TF + 2*tx + 1] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1];
|
ss[8*cc + ty*TF + 2*tx + 1] += slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1];
|
||||||
} else {
|
|
||||||
// mqk = mqk*scale
|
|
||||||
ss[8*cc + ty*TF + 2*tx + 0] *= scale;
|
|
||||||
ss[8*cc + ty*TF + 2*tx + 1] *= scale;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -2345,6 +2352,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|||||||
constant float & m0,
|
constant float & m0,
|
||||||
constant float & m1,
|
constant float & m1,
|
||||||
constant uint32_t & n_head_log2,
|
constant uint32_t & n_head_log2,
|
||||||
|
constant float & logit_softcap,
|
||||||
threadgroup half * shared [[threadgroup(0)]],
|
threadgroup half * shared [[threadgroup(0)]],
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
@ -2479,7 +2487,13 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|||||||
|
|
||||||
// mqk = mqk*scale + mask*slope
|
// mqk = mqk*scale + mask*slope
|
||||||
if (tiisg == 0) {
|
if (tiisg == 0) {
|
||||||
mqk = mqk*scale + ((mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f);
|
mqk *= scale;
|
||||||
|
|
||||||
|
if (logit_softcap != 0.0f) {
|
||||||
|
mqk = logit_softcap*precise::tanh(mqk);
|
||||||
|
}
|
||||||
|
|
||||||
|
mqk += (mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f;
|
||||||
|
|
||||||
ss4[cc] = mqk;
|
ss4[cc] = mqk;
|
||||||
}
|
}
|
||||||
|
@ -2487,7 +2487,6 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
|||||||
}
|
}
|
||||||
|
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static void usage(char ** argv) {
|
static void usage(char ** argv) {
|
||||||
|
Loading…
Reference in New Issue
Block a user