cuda : add flash_attn kernel (wip)

This commit is contained in:
Georgi Gerganov 2024-02-01 19:47:11 +02:00
parent 2e46013749
commit b957b8f5f6
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
2 changed files with 735 additions and 3 deletions

View File

@ -108,6 +108,7 @@
#include <cuda.h>
#include <cublas_v2.h>
#include <cuda_fp16.h>
#include <mma.h>
#if CUDART_VERSION < 11020
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
@ -655,6 +656,19 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
}
static __device__ __forceinline__ half warp_reduce_sum(half x) {
#if __CUDA_ARCH__ >= CC_VOLTA
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
x = __hadd(__shfl_xor_sync(0xffffffff, x, mask, 32), x);
}
return x;
#else
(void) x;
NO_DEVICE_CODE;
#endif
}
static __device__ __forceinline__ float warp_reduce_max(float x) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
@ -676,6 +690,18 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
}
static __device__ __forceinline__ half warp_reduce_max(half x) {
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
x = __hmax(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
}
return x;
#else
(void) x;
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
}
static __device__ __forceinline__ float op_repeat(const float a, const float b) {
return b;
GGML_UNUSED(a);
@ -989,6 +1015,7 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
if (lane_id == 0) {
s_sum[warp_id] = tmp;
}
__syncthreads();
tmp = s_sum[lane_id];
tmp = warp_reduce_sum(tmp);
@ -6249,6 +6276,528 @@ static __global__ void pool2d_nchw_kernel(
o_ptr[cur_oh * ow + cur_ow] = res;
}
#define CUDA_FLASH_ATTENTION_BLOCK_SIZE 256
template<int block_size, int k_seq_len, int k_head_dim>
static __global__ void flash_attn_f32(
const float* __restrict__ q,
const float* __restrict__ k,
const float* __restrict__ v,
float* __restrict__ kqv,
float kq_scale,
int head_dim, int seq_len, int num_heads) {
const int head = blockIdx.x / seq_len;
const int head_size = head_dim * seq_len;
const int s = blockIdx.x % seq_len;
extern __shared__ char flash_attn_shmem_f32[];
float* S = (float*)flash_attn_shmem_f32;
float* warp_data = (float*)(flash_attn_shmem_f32 + seq_len * sizeof(float));
// QK^T
#pragma unroll
for(int is0 = 0; is0 < k_seq_len; is0 += block_size) {
const int is = threadIdx.x + is0;
if(is >= seq_len) {
break;
}
const int key_offset = is * head_dim + head * head_size;
const int query_offset = s * head_dim + head * head_size;
float tmp = 0.0f;
for(int d = 0; d < head_dim; d++) {
tmp += k[key_offset + d] * q[query_offset + d];
}
S[is] = tmp * kq_scale;
}
__syncthreads();
float max_val = -INFINITY;
// get the max
#pragma unroll
for(int is0 = 0; is0 < k_seq_len; is0 += block_size) {
const int is = threadIdx.x + is0;
if(is >= seq_len) {
break;
}
max_val = fmaxf(max_val , S[is]);
}
max_val = warp_reduce_max(max_val);
{ // get max from all threads
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
warp_data[warp_id] = max_val;
}
__syncthreads();
max_val = warp_data[lane_id];
max_val = warp_reduce_max(max_val);
}
// softmax(QK^T)
float sum = 0.0f;
#pragma unroll
for(int is0 = 0; is0 < k_seq_len; is0 += block_size) {
const int is = threadIdx.x + is0;
if(is >= seq_len) {
break;
}
float tmp = expf(S[is] - max_val);
sum += tmp;
S[is] = tmp;
}
__syncthreads();
sum = warp_reduce_sum(sum);
{ // softmax sum partials
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
warp_data[warp_id] = sum;
}
__syncthreads();
sum = warp_data[lane_id];
sum = warp_reduce_sum(sum);
}
float inv_sum = 1.0f / sum;
#pragma unroll
for(int is0 = 0; is0 < k_seq_len; is0 += block_size) {
const int is = threadIdx.x + is0;
if(is >= seq_len) {
break;
}
S[is] *= inv_sum;
}
__syncthreads();
// softmax(QK^T)V
#pragma unroll
for (int d0 = threadIdx.x; d0 < k_head_dim; d0 += block_size) {
const int d = threadIdx.x + d0;
if(d >= head_dim) {
break;
}
const int dst_index = d + s * head_dim + head * head_size;
const int value_offset = d * seq_len + head * head_size;
float temp = 0.0f;
#pragma unroll
for(int ic = 0; ic < k_seq_len;ic++) {
if(ic >= seq_len) {
break;
}
temp += v[value_offset + ic] * S[ic];
}
kqv[dst_index] = temp;
}
}
#if __CUDA_ARCH__ >= CC_VOLTA
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, 16, 16, 16, half, nvcuda::wmma::row_major> half16x16_a;
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::row_major> half16x16_b;
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::col_major> half16x16_bT;
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, 16, 16, 16, half> half16x16_acc;
#endif
// based on metal version
template<int D, int Q, int C> // D head size, Q queries per block, C cache items per block
static __global__ void flash_attn_ext_f16(
const char* __restrict__ q,
const char* __restrict__ k,
const char* __restrict__ v,
const char* __restrict__ mask,
float* __restrict__ dst,
float scale,
int ne00,
int ne01,
int ne02,
int ne03,
int ne10,
int ne11,
int ne12,
int ne13,
int ne31,
int nb31,
int nb01,
int nb02,
int nb03,
int nb11,
int nb12,
int nb13,
int ne0,
int ne1,
int ne2,
int ne3) {
#if __CUDA_ARCH__ >= CC_VOLTA
const int warp_id = threadIdx.y;
const int lane_id = threadIdx.x;
const int num_warps = blockDim.y; // number of warps
const int iq3 = blockIdx.z;
const int iq2 = blockIdx.y;
const int iq1 = blockIdx.x * Q;
const int D2 = D/2;
const int D16 = D/16;
const int Q16 = Q/16;
const int NW = WARP_SIZE;
const int SH = (C + Q); // shared memory per simdgroup in (half)
const int T = D + num_warps*SH; // shared memory size per query in (half)
const int T2 = T/2; // shared memory size per query in (half2)
extern __shared__ half __flash_attn_f16_shmem[];
// pq
half * sq = (half *) (__flash_attn_f16_shmem + 0*D); // holds the query data
half2 * sq2 = (half2 *) (__flash_attn_f16_shmem + 0*D); // same as above but in half2
half * ss = (half *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // scratch buffer for attention and diagonal matrix
half16x16_acc zr;
half16x16_acc lo[Q16][D16];
// load heads from Q to shared memory
for (int64_t j = warp_id; j < Q; j += num_warps) {
const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
for (int64_t i = lane_id; i < D2; i += NW) {
if (iq1 + j < ne01) {
sq2[j*T2 + i] = __float22half2_rn(q2[i]);
} else {
sq2[j*T2 + i] = make_half2(0.0, 0.0);
}
}
}
nvcuda::wmma::fill_fragment(zr, 0.0);
// zero out lo
for (int64_t j = 0; j < Q16; ++j) {
for (int64_t i = 0; i < D16; ++i) {
nvcuda::wmma::fill_fragment(lo[j][i], 0.0);
}
}
// zero out shared memory SH
for (int64_t j = 0; j < Q; ++j) {
for (int64_t i = lane_id; i < SH; i += NW) {
ss[j*T + i] = 0.0;
}
}
__syncthreads();
{
half S[Q];
half M[Q];
for(int i = 0; i < Q; i++) {
S[i] = __float2half(0.0f);
M[i] = __float2half(-INFINITY);
}
// assume K and V are same shape
const int ne22 = ne12;
const int ne23 = ne13;
const int nb21 = nb11;
const int nb22 = nb12;
const int nb23 = nb13;
// broadcast
const int rk2 = ne02/ne12;
const int rk3 = ne03/ne13;
const int rv2 = ne02/ne22;
const int rv3 = ne03/ne23;
// k indices
const int ik2 = iq2 / rk2;
const int ik3 = iq3 / rk3;
// v indices
const int iv2 = iq2 / rv2;
const int iv3 = iq3 / rv3;
// load the queries from shared memory into local memory
half16x16_a mq[Q16][D16];
for (int64_t j = 0; j < Q16; ++j) {
for (int64_t i = 0; i < D16; ++i) {
nvcuda::wmma::load_matrix_sync(mq[j][i], sq + 16*j*T + i*16, T);
}
}
// pointer to the mask
const half * mp = mask ? (const half *) (mask + iq1*nb31) : nullptr;
// prepare diagonal scale matrix
half16x16_b mscale;
for (int i = 0; i < 16; ++i) {
ss[i*T + i] = __float2half(scale);
}
nvcuda::wmma::load_matrix_sync(mscale, ss, T);
// loop over the KV cache
// each simdgroup handles blocks of Q rows and C columns
for (int64_t ic = C*warp_id; ic < ne11; ic += C*num_warps) {
// Q*K^T
{
for (int cc = 0; cc < C/16; ++cc) {
half16x16_acc mqk[Q16];
for (int64_t j = 0; j < Q16; ++j) {
nvcuda::wmma::fill_fragment(mqk[j], 0);
}
const half * pk = (const half *) ((const char *) k + ((ic + 16*cc)*nb11 + ik2*nb12 + ik3*nb13));
for (int64_t i = 0; i < D16; ++i) {
half16x16_bT mk; // transposed key
nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half));
for (int64_t j = 0; j < Q16; ++j) {
nvcuda::wmma::mma_sync(mqk[j], mq[j][i], mk, mqk[j]);
}
}
// mqk = mqk*scale + mask
for (int64_t j = 0; j < Q16; ++j) {
half16x16_a mqka;
half16x16_acc mm;
if(mp) {
nvcuda::wmma::load_matrix_sync(mm, mp + 16*j*(nb31/sizeof(half)) + ic + 16*cc, nb31/sizeof(half), nvcuda::wmma::mem_row_major);
}
// convert accumulator to matrix_a
nvcuda::wmma::store_matrix_sync( ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major);
nvcuda::wmma::load_matrix_sync (mqka, ss + 16*j*T + 16*cc, T);
nvcuda::wmma::mma_sync(mqk[j], mqka, mscale, mp ? mm : zr);
nvcuda::wmma::store_matrix_sync(ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major);
}
}
}
// used to detect blocks full of -INF
half smax = __float2half(-INFINITY);
// online softmax
if (C == 32) {
for (int64_t j = 0; j < Q; ++j) {
const int64_t p = lane_id;
const half m = M[j];
const half s = ss[j*T + p];
smax = warp_reduce_max(__hmax(smax, s));
M[j] = warp_reduce_max(__hmax(M[j], s));
const half ms = __hisinf(m) ? __float2half(0.0f) : hexp(m - M[j]);
const half vs = __hisinf(s) ? __float2half(0.0f) : hexp(s - M[j]);
S[j] = S[j]*ms + warp_reduce_sum(vs);
// create a QxQ diagonal matrix for rescaling the output
if (p == j) {
ss[j*T + C + j] = ms;
}
// the P matrix from the paper (Q rows, C columns)
ss[j*T + p] = vs;
}
} else {
for (int64_t j = 0; j < Q; ++j) {
const half m = M[j];
for (int64_t p = lane_id; p < C; p += NW) {
const half s = ss[j*T + p];
smax = __hmax(smax, s);
M[j] = __hmax(M[j], s);
}
smax = warp_reduce_max(smax);
M[j] = warp_reduce_max(M[j]);
const half ms = __hisinf(m) ? __float2half(0.0f) : hexp(m - M[j]);
// create a QxQ diagonal matrix for rescaling the output
if (lane_id == j) {
ss[j*T + C + j] = ms;
}
// local sum
half ls = 0.0f;
for (int64_t p = lane_id; p < C; p += NW) {
const half s = ss[j*T + p];
const half vs = __hisinf(s) ? __float2half(0.0f) : hexp(s - M[j]);
ls += vs;
// the P matrix from the paper (Q rows, C columns)
ss[j*T + p] = vs;
}
S[j] = S[j]*ms + warp_reduce_sum(ls);
}
}
// skip -INF blocks
if (__hisinf(smax)) {
continue;
}
// O = diag(ms)*O
for (int64_t j = 0; j < Q16; ++j) {
half16x16_a mm;
half16x16_b lob;
nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T);
for (int64_t i = 0; i < D16; ++i) {
// convert accumulator to matrix_b
nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + 16*j, lo[j][i], T, nvcuda::wmma::mem_row_major);
nvcuda::wmma::load_matrix_sync (lob, ss + 16*j*T + C + 16*j, T);
nvcuda::wmma::fill_fragment(lo[j][i], 0.0);
nvcuda::wmma::mma_sync(lo[j][i], mm, lob, lo[j][i]);
}
// restore zeros
nvcuda::wmma::store_matrix_sync(ss + 16*j*T + C + 16*j, zr, T, nvcuda::wmma::mem_row_major);
}
// O = O + (Q*K^T)*V
{
for (int cc = 0; cc < C/16; ++cc) {
const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23));
half16x16_b mk[D16];
for (int64_t i = 0; i < D16; ++i) {
nvcuda::wmma::load_matrix_sync(mk[i], pv + i*16, nb21/sizeof(half));
}
half16x16_a mv[Q16];
for (int64_t j = 0; j < Q16; ++j) {
nvcuda::wmma::load_matrix_sync(mv[j], ss + 16*j*T + 16*cc, T);
}
for (int64_t j = 0; j < Q16; ++j) {
for (int64_t i = 0; i < D16; ++i) {
nvcuda::wmma::mma_sync(lo[j][i], mv[j], mk[i], lo[j][i]);
}
}
}
}
}
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
for (int64_t j = 0; j < Q; ++j) {
if (lane_id == 0) {
ss[j*T + 0] = S[j];
ss[j*T + 1] = M[j];
}
}
}
// reduce the warps sequentially
for (int64_t sg = 1; sg < num_warps; ++sg) {
half S = __float2half(0.0f);
half M = __float2half(-INFINITY);
__syncthreads();
// each simdgroup stores its output to shared memory, reusing sq
if (warp_id == sg) {
for (int64_t j = 0; j < Q16; ++j) {
for (int64_t i = 0; i < D16; ++i) {
nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);
}
}
}
__syncthreads();
// the first simdgroup accumulates the results from the other simdgroups
if (warp_id == 0) {
for (int64_t j = 0; j < Q; ++j) {
const half S0 = ss[j*T + 0];
const half S1 = ss[j*T + sg*SH + 0];
const half M0 = ss[j*T + 1];
const half M1 = ss[j*T + sg*SH + 1];
M = __hmax(M0, M1);
const half ms0 = __hisinf(M0) ? __float2half(0.0f) : hexp(M0 - M);
const half ms1 = __hisinf(M1) ? __float2half(0.0f) : hexp(M1 - M);
S = S0*ms0 + S1*ms1;
if (lane_id == 0) {
ss[j*T + 0] = S;
ss[j*T + 1] = M;
ss[j*T + C + j ] = ms0;
ss[j*T + C + j + sg*SH] = ms1;
}
}
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
for (int64_t j = 0; j < Q16; ++j) {
half16x16_a ms0;
half16x16_a ms1;
half16x16_b t;
half16x16_acc t2;
nvcuda::wmma::load_matrix_sync(ms0, ss + 16*j*T + C + 16*j, T);
nvcuda::wmma::load_matrix_sync(ms1, ss + 16*j*T + C + 16*j + sg*SH, T);
for (int64_t i = 0; i < D16; ++i) {
nvcuda::wmma::fill_fragment(t2, 0.0);
nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T);
nvcuda::wmma::mma_sync(t2, ms1, t, t2);
// convert accumulator to matrix_b
nvcuda::wmma::store_matrix_sync( sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);
nvcuda::wmma::load_matrix_sync (t, sq + 16*j*T + i*16, T);
nvcuda::wmma::mma_sync(lo[j][i], ms0, t, t2);
}
}
}
}
// store result to shared memory (reuse sq)
if (warp_id == 0) {
for (int64_t j = 0; j < Q16; ++j) {
for (int64_t i = 0; i < D16; ++i) {
nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);
}
}
}
// final rescale with 1/S and store to global memory
if (warp_id == 0) {
for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) {
const half S = ss[j*T + 0];
for (int64_t i = lane_id; i < D; i += NW) {
dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i] / S);
}
}
}
#else
NO_DEVICE_CODE;
#endif
}
template<int qk, int qr, dequantize_kernel_t dq>
static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
@ -7682,6 +8231,13 @@ static void im2col_cuda(const float* x, T* dst,
im2col_kernel<<<block_nums, CUDA_IM2COL_BLOCK_SIZE, 0, stream>>>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
}
static void flash_attn_f32_cuda(const float* q, const float* k,const float* v, float* dst, float kq_scale, const int d_head, const int seq_len, const int num_heads, cudaStream_t stream) {
int sram_memory_size = seq_len*sizeof(float) + WARP_SIZE * sizeof(float);
int num_blocks = num_heads * seq_len;
flash_attn_f32<CUDA_FLASH_ATTENTION_BLOCK_SIZE, 1024, 64><<<num_blocks, CUDA_FLASH_ATTENTION_BLOCK_SIZE, sram_memory_size, stream>>>(
q, k, v, dst, kq_scale, d_head, seq_len, num_heads);
}
// buffer pool for cuda
#define MAX_CUDA_BUFFERS 256
@ -8659,7 +9215,7 @@ static void ggml_cuda_op_dequantize_mul_mat_vec(
src1_dfloat = src1_dfloat_a.alloc(ne00);
ggml_cpy_f32_f16_cuda((const char *) src1_ddf_i, (char *) src1_dfloat, ne00,
ne00, 1, sizeof(float), 0, 0,
ne00, 1, sizeof(half), 0, 0, stream);
ne00, 1, sizeof(half), 0, 0, 0, 0, 0, 0, stream);
}
#else
const dfloat * src1_dfloat = (const dfloat *) src1_ddf_i; // dfloat == float, no conversion
@ -10284,6 +10840,170 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
}
}
inline void ggml_cuda_flash_attn(const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV) {
GGML_ASSERT(Q->type == GGML_TYPE_F32);
GGML_ASSERT(K->type == GGML_TYPE_F32);
GGML_ASSERT(V->type == GGML_TYPE_F32);
GGML_ASSERT(KQV->type == GGML_TYPE_F32);
GGML_ASSERT(Q->backend == GGML_BACKEND_GPU);
GGML_ASSERT(K->backend == GGML_BACKEND_GPU);
GGML_ASSERT(V->backend == GGML_BACKEND_GPU);
GGML_ASSERT(KQV->backend == GGML_BACKEND_GPU);
ggml_cuda_set_device(g_main_device);
const cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) Q->extra;
ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) K->extra;
ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) V->extra;
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) KQV->extra;
const int64_t d_head = Q->ne[0];
const int64_t sequence_length = Q->ne[1];
const int64_t num_heads = Q->ne[2];
GGML_ASSERT(Q->ne[0] == d_head);
GGML_ASSERT(K->ne[0] == d_head);
GGML_ASSERT(V->ne[1] == d_head);
GGML_ASSERT(Q->ne[1] == sequence_length);
GGML_ASSERT(K->ne[1] == sequence_length);
GGML_ASSERT(V->ne[0] == sequence_length);
GGML_ASSERT(Q->ne[2] == num_heads);
GGML_ASSERT(K->ne[2] == num_heads);
GGML_ASSERT(V->ne[2] == num_heads);
float KQ_scale = 1.0f / sqrtf((float)d_head);
flash_attn_f32_cuda(
(float *) src0_extra->data_device[g_main_device], // Query
(float *) src1_extra->data_device[g_main_device], // Key
(float *) src2_extra->data_device[g_main_device], // Value
(float *) dst_extra->data_device[g_main_device], // dst
KQ_scale, d_head, sequence_length, num_heads, main_stream);
}
inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, const ggml_tensor * mask, ggml_tensor * KQV) {
GGML_ASSERT(Q->type == GGML_TYPE_F32);
GGML_ASSERT(K->type == GGML_TYPE_F16);
GGML_ASSERT(V->type == GGML_TYPE_F16);
GGML_ASSERT(KQV->type == GGML_TYPE_F32);
GGML_ASSERT(Q->backend == GGML_BACKEND_GPU);
GGML_ASSERT(K->backend == GGML_BACKEND_GPU);
GGML_ASSERT(V->backend == GGML_BACKEND_GPU);
GGML_ASSERT(KQV->backend == GGML_BACKEND_GPU);
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
GGML_ASSERT(!mask || mask->backend == GGML_BACKEND_GPU);
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
ggml_cuda_set_device(g_main_device);
const cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) Q->extra;
ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) K->extra;
ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) V->extra;
ggml_tensor_extra_gpu * src3_extra = mask ? (ggml_tensor_extra_gpu *) mask->extra : nullptr;
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) KQV->extra;
float scale;
memcpy(&scale, KQV->op_params, sizeof(float));
#define NQPB 16
#define NCPW 128
const int nqpb = NQPB; // queries per block
const int ncpw = NCPW; // cache values per warp (does not work for other values)
const int nwarps_max = 8; // TODO: we don't want to launch too much warps. how much is too much?
// TODO: produces wrong results for nwarps > 8 (RTX 2060) - not sure why
const int nwarps = Q->ne[1] <= nqpb ? std::max(2, std::min((int) K->ne[1]/ncpw, nwarps_max)) : 2;
dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]);
dim3 block_dim(32, nwarps, 1);
const size_t shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2);
switch (Q->ne[0])
{
case 16:
flash_attn_ext_f16<16, NQPB, NCPW>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) src0_extra->data_device[g_main_device], // Query
(const char *) src1_extra->data_device[g_main_device], // Key
(const char *) src2_extra->data_device[g_main_device], // Value
mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
(float *) dst_extra->data_device[g_main_device], // dst
scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
break;
case 64:
flash_attn_ext_f16<64, NQPB, NCPW>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) src0_extra->data_device[g_main_device], // Query
(const char *) src1_extra->data_device[g_main_device], // Key
(const char *) src2_extra->data_device[g_main_device], // Value
mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
(float *) dst_extra->data_device[g_main_device], // dst
scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
break;
case 80:
flash_attn_ext_f16<80, NQPB, NCPW>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) src0_extra->data_device[g_main_device], // Query
(const char *) src1_extra->data_device[g_main_device], // Key
(const char *) src2_extra->data_device[g_main_device], // Value
mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
(float *) dst_extra->data_device[g_main_device], // dst
scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
break;
case 128:
flash_attn_ext_f16<128, NQPB, NCPW>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) src0_extra->data_device[g_main_device], // Query
(const char *) src1_extra->data_device[g_main_device], // Key
(const char *) src2_extra->data_device[g_main_device], // Value
mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
(float *) dst_extra->data_device[g_main_device], // dst
scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
break;
default:
break;
}
}
static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_scale);
}
@ -10573,6 +11293,10 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st
case GGML_OP_ARGSORT:
func = ggml_cuda_argsort;
break;
case GGML_OP_FLASH_ATTN:
break;
case GGML_OP_FLASH_ATTN_EXT:
break;
default:
return false;
}
@ -10587,7 +11311,13 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return true;
}
func(tensor->src[0], tensor->src[1], tensor);
if(tensor->op == GGML_OP_FLASH_ATTN) {
ggml_cuda_flash_attn(tensor->src[0], tensor->src[1], tensor->src[2], tensor);
} else if(tensor->op == GGML_OP_FLASH_ATTN_EXT) {
ggml_cuda_flash_attn_ext(tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
} else {
func(tensor->src[0], tensor->src[1], tensor);
}
return true;
}
@ -11403,6 +12133,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_OP_UPSCALE:
case GGML_OP_PAD:
case GGML_OP_LEAKY_RELU:
case GGML_OP_FLASH_ATTN_EXT:
return true;
default:
return false;

View File

@ -6881,7 +6881,8 @@ static int llama_decode_internal(
// a heuristic, to avoid attending the full cache if it is not yet utilized
// after enough generations, the benefit from this heuristic disappears
// if we start defragmenting the cache, the benefit from this will be more important
kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)));
// note: we pad the n_kv because certain GPU kernels require it (e.g. ggml_flash_attn_ext)
kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(128, GGML_PAD(llama_kv_cache_cell_max(kv_self), 128)));
//kv_self.n = llama_kv_cache_cell_max(kv_self);
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);