mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 03:44:35 +00:00
cuda : add flash_attn kernel (wip)
This commit is contained in:
parent
2e46013749
commit
b957b8f5f6
735
ggml-cuda.cu
735
ggml-cuda.cu
@ -108,6 +108,7 @@
|
||||
#include <cuda.h>
|
||||
#include <cublas_v2.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <mma.h>
|
||||
|
||||
#if CUDART_VERSION < 11020
|
||||
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
|
||||
@ -655,6 +656,19 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ half warp_reduce_sum(half x) {
|
||||
#if __CUDA_ARCH__ >= CC_VOLTA
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
x = __hadd(__shfl_xor_sync(0xffffffff, x, mask, 32), x);
|
||||
}
|
||||
return x;
|
||||
#else
|
||||
(void) x;
|
||||
NO_DEVICE_CODE;
|
||||
#endif
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float warp_reduce_max(float x) {
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
@ -676,6 +690,18 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ half warp_reduce_max(half x) {
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
x = __hmax(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
|
||||
}
|
||||
return x;
|
||||
#else
|
||||
(void) x;
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float op_repeat(const float a, const float b) {
|
||||
return b;
|
||||
GGML_UNUSED(a);
|
||||
@ -989,6 +1015,7 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
|
||||
if (lane_id == 0) {
|
||||
s_sum[warp_id] = tmp;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
tmp = s_sum[lane_id];
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
@ -6249,6 +6276,528 @@ static __global__ void pool2d_nchw_kernel(
|
||||
o_ptr[cur_oh * ow + cur_ow] = res;
|
||||
}
|
||||
|
||||
#define CUDA_FLASH_ATTENTION_BLOCK_SIZE 256
|
||||
|
||||
template<int block_size, int k_seq_len, int k_head_dim>
|
||||
static __global__ void flash_attn_f32(
|
||||
const float* __restrict__ q,
|
||||
const float* __restrict__ k,
|
||||
const float* __restrict__ v,
|
||||
float* __restrict__ kqv,
|
||||
float kq_scale,
|
||||
int head_dim, int seq_len, int num_heads) {
|
||||
const int head = blockIdx.x / seq_len;
|
||||
const int head_size = head_dim * seq_len;
|
||||
const int s = blockIdx.x % seq_len;
|
||||
|
||||
extern __shared__ char flash_attn_shmem_f32[];
|
||||
float* S = (float*)flash_attn_shmem_f32;
|
||||
float* warp_data = (float*)(flash_attn_shmem_f32 + seq_len * sizeof(float));
|
||||
|
||||
// QK^T
|
||||
#pragma unroll
|
||||
for(int is0 = 0; is0 < k_seq_len; is0 += block_size) {
|
||||
const int is = threadIdx.x + is0;
|
||||
if(is >= seq_len) {
|
||||
break;
|
||||
}
|
||||
|
||||
const int key_offset = is * head_dim + head * head_size;
|
||||
const int query_offset = s * head_dim + head * head_size;
|
||||
|
||||
float tmp = 0.0f;
|
||||
for(int d = 0; d < head_dim; d++) {
|
||||
tmp += k[key_offset + d] * q[query_offset + d];
|
||||
}
|
||||
S[is] = tmp * kq_scale;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float max_val = -INFINITY;
|
||||
// get the max
|
||||
#pragma unroll
|
||||
for(int is0 = 0; is0 < k_seq_len; is0 += block_size) {
|
||||
const int is = threadIdx.x + is0;
|
||||
if(is >= seq_len) {
|
||||
break;
|
||||
}
|
||||
|
||||
max_val = fmaxf(max_val , S[is]);
|
||||
}
|
||||
|
||||
max_val = warp_reduce_max(max_val);
|
||||
|
||||
{ // get max from all threads
|
||||
int warp_id = threadIdx.x / WARP_SIZE;
|
||||
int lane_id = threadIdx.x % WARP_SIZE;
|
||||
if (lane_id == 0) {
|
||||
warp_data[warp_id] = max_val;
|
||||
}
|
||||
__syncthreads();
|
||||
max_val = warp_data[lane_id];
|
||||
max_val = warp_reduce_max(max_val);
|
||||
}
|
||||
|
||||
// softmax(QK^T)
|
||||
float sum = 0.0f;
|
||||
#pragma unroll
|
||||
for(int is0 = 0; is0 < k_seq_len; is0 += block_size) {
|
||||
const int is = threadIdx.x + is0;
|
||||
if(is >= seq_len) {
|
||||
break;
|
||||
}
|
||||
float tmp = expf(S[is] - max_val);
|
||||
sum += tmp;
|
||||
S[is] = tmp;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
sum = warp_reduce_sum(sum);
|
||||
{ // softmax sum partials
|
||||
int warp_id = threadIdx.x / WARP_SIZE;
|
||||
int lane_id = threadIdx.x % WARP_SIZE;
|
||||
if (lane_id == 0) {
|
||||
warp_data[warp_id] = sum;
|
||||
}
|
||||
__syncthreads();
|
||||
sum = warp_data[lane_id];
|
||||
sum = warp_reduce_sum(sum);
|
||||
}
|
||||
|
||||
float inv_sum = 1.0f / sum;
|
||||
#pragma unroll
|
||||
for(int is0 = 0; is0 < k_seq_len; is0 += block_size) {
|
||||
const int is = threadIdx.x + is0;
|
||||
if(is >= seq_len) {
|
||||
break;
|
||||
}
|
||||
|
||||
S[is] *= inv_sum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// softmax(QK^T)V
|
||||
#pragma unroll
|
||||
for (int d0 = threadIdx.x; d0 < k_head_dim; d0 += block_size) {
|
||||
const int d = threadIdx.x + d0;
|
||||
if(d >= head_dim) {
|
||||
break;
|
||||
}
|
||||
const int dst_index = d + s * head_dim + head * head_size;
|
||||
const int value_offset = d * seq_len + head * head_size;
|
||||
|
||||
float temp = 0.0f;
|
||||
#pragma unroll
|
||||
for(int ic = 0; ic < k_seq_len;ic++) {
|
||||
if(ic >= seq_len) {
|
||||
break;
|
||||
}
|
||||
temp += v[value_offset + ic] * S[ic];
|
||||
}
|
||||
kqv[dst_index] = temp;
|
||||
}
|
||||
}
|
||||
|
||||
#if __CUDA_ARCH__ >= CC_VOLTA
|
||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, 16, 16, 16, half, nvcuda::wmma::row_major> half16x16_a;
|
||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::row_major> half16x16_b;
|
||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::col_major> half16x16_bT;
|
||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, 16, 16, 16, half> half16x16_acc;
|
||||
#endif
|
||||
|
||||
// based on metal version
|
||||
template<int D, int Q, int C> // D head size, Q queries per block, C cache items per block
|
||||
static __global__ void flash_attn_ext_f16(
|
||||
const char* __restrict__ q,
|
||||
const char* __restrict__ k,
|
||||
const char* __restrict__ v,
|
||||
const char* __restrict__ mask,
|
||||
float* __restrict__ dst,
|
||||
float scale,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne03,
|
||||
int ne10,
|
||||
int ne11,
|
||||
int ne12,
|
||||
int ne13,
|
||||
int ne31,
|
||||
int nb31,
|
||||
int nb01,
|
||||
int nb02,
|
||||
int nb03,
|
||||
int nb11,
|
||||
int nb12,
|
||||
int nb13,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int ne2,
|
||||
int ne3) {
|
||||
#if __CUDA_ARCH__ >= CC_VOLTA
|
||||
const int warp_id = threadIdx.y;
|
||||
const int lane_id = threadIdx.x;
|
||||
|
||||
const int num_warps = blockDim.y; // number of warps
|
||||
const int iq3 = blockIdx.z;
|
||||
const int iq2 = blockIdx.y;
|
||||
const int iq1 = blockIdx.x * Q;
|
||||
|
||||
const int D2 = D/2;
|
||||
const int D16 = D/16;
|
||||
const int Q16 = Q/16;
|
||||
const int NW = WARP_SIZE;
|
||||
const int SH = (C + Q); // shared memory per simdgroup in (half)
|
||||
|
||||
const int T = D + num_warps*SH; // shared memory size per query in (half)
|
||||
const int T2 = T/2; // shared memory size per query in (half2)
|
||||
|
||||
extern __shared__ half __flash_attn_f16_shmem[];
|
||||
// pq
|
||||
half * sq = (half *) (__flash_attn_f16_shmem + 0*D); // holds the query data
|
||||
half2 * sq2 = (half2 *) (__flash_attn_f16_shmem + 0*D); // same as above but in half2
|
||||
half * ss = (half *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // scratch buffer for attention and diagonal matrix
|
||||
|
||||
half16x16_acc zr;
|
||||
half16x16_acc lo[Q16][D16];
|
||||
|
||||
// load heads from Q to shared memory
|
||||
for (int64_t j = warp_id; j < Q; j += num_warps) {
|
||||
const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
|
||||
|
||||
for (int64_t i = lane_id; i < D2; i += NW) {
|
||||
if (iq1 + j < ne01) {
|
||||
sq2[j*T2 + i] = __float22half2_rn(q2[i]);
|
||||
} else {
|
||||
sq2[j*T2 + i] = make_half2(0.0, 0.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
nvcuda::wmma::fill_fragment(zr, 0.0);
|
||||
|
||||
// zero out lo
|
||||
for (int64_t j = 0; j < Q16; ++j) {
|
||||
for (int64_t i = 0; i < D16; ++i) {
|
||||
nvcuda::wmma::fill_fragment(lo[j][i], 0.0);
|
||||
}
|
||||
}
|
||||
|
||||
// zero out shared memory SH
|
||||
for (int64_t j = 0; j < Q; ++j) {
|
||||
for (int64_t i = lane_id; i < SH; i += NW) {
|
||||
ss[j*T + i] = 0.0;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
{
|
||||
half S[Q];
|
||||
half M[Q];
|
||||
|
||||
for(int i = 0; i < Q; i++) {
|
||||
S[i] = __float2half(0.0f);
|
||||
M[i] = __float2half(-INFINITY);
|
||||
}
|
||||
|
||||
// assume K and V are same shape
|
||||
const int ne22 = ne12;
|
||||
const int ne23 = ne13;
|
||||
|
||||
const int nb21 = nb11;
|
||||
const int nb22 = nb12;
|
||||
const int nb23 = nb13;
|
||||
|
||||
// broadcast
|
||||
const int rk2 = ne02/ne12;
|
||||
const int rk3 = ne03/ne13;
|
||||
|
||||
const int rv2 = ne02/ne22;
|
||||
const int rv3 = ne03/ne23;
|
||||
|
||||
// k indices
|
||||
const int ik2 = iq2 / rk2;
|
||||
const int ik3 = iq3 / rk3;
|
||||
|
||||
// v indices
|
||||
const int iv2 = iq2 / rv2;
|
||||
const int iv3 = iq3 / rv3;
|
||||
|
||||
// load the queries from shared memory into local memory
|
||||
half16x16_a mq[Q16][D16];
|
||||
for (int64_t j = 0; j < Q16; ++j) {
|
||||
for (int64_t i = 0; i < D16; ++i) {
|
||||
nvcuda::wmma::load_matrix_sync(mq[j][i], sq + 16*j*T + i*16, T);
|
||||
}
|
||||
}
|
||||
|
||||
// pointer to the mask
|
||||
const half * mp = mask ? (const half *) (mask + iq1*nb31) : nullptr;
|
||||
|
||||
// prepare diagonal scale matrix
|
||||
half16x16_b mscale;
|
||||
for (int i = 0; i < 16; ++i) {
|
||||
ss[i*T + i] = __float2half(scale);
|
||||
}
|
||||
nvcuda::wmma::load_matrix_sync(mscale, ss, T);
|
||||
|
||||
// loop over the KV cache
|
||||
// each simdgroup handles blocks of Q rows and C columns
|
||||
for (int64_t ic = C*warp_id; ic < ne11; ic += C*num_warps) {
|
||||
// Q*K^T
|
||||
{
|
||||
for (int cc = 0; cc < C/16; ++cc) {
|
||||
half16x16_acc mqk[Q16];
|
||||
for (int64_t j = 0; j < Q16; ++j) {
|
||||
nvcuda::wmma::fill_fragment(mqk[j], 0);
|
||||
}
|
||||
|
||||
const half * pk = (const half *) ((const char *) k + ((ic + 16*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
||||
|
||||
for (int64_t i = 0; i < D16; ++i) {
|
||||
half16x16_bT mk; // transposed key
|
||||
nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half));
|
||||
|
||||
for (int64_t j = 0; j < Q16; ++j) {
|
||||
nvcuda::wmma::mma_sync(mqk[j], mq[j][i], mk, mqk[j]);
|
||||
}
|
||||
}
|
||||
|
||||
// mqk = mqk*scale + mask
|
||||
for (int64_t j = 0; j < Q16; ++j) {
|
||||
half16x16_a mqka;
|
||||
half16x16_acc mm;
|
||||
if(mp) {
|
||||
nvcuda::wmma::load_matrix_sync(mm, mp + 16*j*(nb31/sizeof(half)) + ic + 16*cc, nb31/sizeof(half), nvcuda::wmma::mem_row_major);
|
||||
}
|
||||
|
||||
// convert accumulator to matrix_a
|
||||
nvcuda::wmma::store_matrix_sync( ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major);
|
||||
nvcuda::wmma::load_matrix_sync (mqka, ss + 16*j*T + 16*cc, T);
|
||||
|
||||
nvcuda::wmma::mma_sync(mqk[j], mqka, mscale, mp ? mm : zr);
|
||||
nvcuda::wmma::store_matrix_sync(ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// used to detect blocks full of -INF
|
||||
half smax = __float2half(-INFINITY);
|
||||
|
||||
// online softmax
|
||||
if (C == 32) {
|
||||
for (int64_t j = 0; j < Q; ++j) {
|
||||
const int64_t p = lane_id;
|
||||
|
||||
const half m = M[j];
|
||||
const half s = ss[j*T + p];
|
||||
|
||||
smax = warp_reduce_max(__hmax(smax, s));
|
||||
M[j] = warp_reduce_max(__hmax(M[j], s));
|
||||
|
||||
const half ms = __hisinf(m) ? __float2half(0.0f) : hexp(m - M[j]);
|
||||
const half vs = __hisinf(s) ? __float2half(0.0f) : hexp(s - M[j]);
|
||||
|
||||
S[j] = S[j]*ms + warp_reduce_sum(vs);
|
||||
|
||||
// create a QxQ diagonal matrix for rescaling the output
|
||||
if (p == j) {
|
||||
ss[j*T + C + j] = ms;
|
||||
}
|
||||
|
||||
// the P matrix from the paper (Q rows, C columns)
|
||||
ss[j*T + p] = vs;
|
||||
}
|
||||
} else {
|
||||
for (int64_t j = 0; j < Q; ++j) {
|
||||
const half m = M[j];
|
||||
|
||||
for (int64_t p = lane_id; p < C; p += NW) {
|
||||
const half s = ss[j*T + p];
|
||||
|
||||
smax = __hmax(smax, s);
|
||||
M[j] = __hmax(M[j], s);
|
||||
}
|
||||
|
||||
smax = warp_reduce_max(smax);
|
||||
M[j] = warp_reduce_max(M[j]);
|
||||
|
||||
const half ms = __hisinf(m) ? __float2half(0.0f) : hexp(m - M[j]);
|
||||
|
||||
// create a QxQ diagonal matrix for rescaling the output
|
||||
if (lane_id == j) {
|
||||
ss[j*T + C + j] = ms;
|
||||
}
|
||||
|
||||
// local sum
|
||||
half ls = 0.0f;
|
||||
|
||||
for (int64_t p = lane_id; p < C; p += NW) {
|
||||
const half s = ss[j*T + p];
|
||||
|
||||
const half vs = __hisinf(s) ? __float2half(0.0f) : hexp(s - M[j]);
|
||||
|
||||
ls += vs;
|
||||
|
||||
// the P matrix from the paper (Q rows, C columns)
|
||||
ss[j*T + p] = vs;
|
||||
}
|
||||
|
||||
S[j] = S[j]*ms + warp_reduce_sum(ls);
|
||||
}
|
||||
}
|
||||
|
||||
// skip -INF blocks
|
||||
if (__hisinf(smax)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// O = diag(ms)*O
|
||||
for (int64_t j = 0; j < Q16; ++j) {
|
||||
half16x16_a mm;
|
||||
half16x16_b lob;
|
||||
|
||||
nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T);
|
||||
|
||||
for (int64_t i = 0; i < D16; ++i) {
|
||||
// convert accumulator to matrix_b
|
||||
nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + 16*j, lo[j][i], T, nvcuda::wmma::mem_row_major);
|
||||
nvcuda::wmma::load_matrix_sync (lob, ss + 16*j*T + C + 16*j, T);
|
||||
|
||||
nvcuda::wmma::fill_fragment(lo[j][i], 0.0);
|
||||
nvcuda::wmma::mma_sync(lo[j][i], mm, lob, lo[j][i]);
|
||||
}
|
||||
|
||||
// restore zeros
|
||||
nvcuda::wmma::store_matrix_sync(ss + 16*j*T + C + 16*j, zr, T, nvcuda::wmma::mem_row_major);
|
||||
}
|
||||
|
||||
// O = O + (Q*K^T)*V
|
||||
{
|
||||
for (int cc = 0; cc < C/16; ++cc) {
|
||||
const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
||||
|
||||
half16x16_b mk[D16];
|
||||
for (int64_t i = 0; i < D16; ++i) {
|
||||
nvcuda::wmma::load_matrix_sync(mk[i], pv + i*16, nb21/sizeof(half));
|
||||
}
|
||||
|
||||
half16x16_a mv[Q16];
|
||||
for (int64_t j = 0; j < Q16; ++j) {
|
||||
nvcuda::wmma::load_matrix_sync(mv[j], ss + 16*j*T + 16*cc, T);
|
||||
}
|
||||
|
||||
for (int64_t j = 0; j < Q16; ++j) {
|
||||
for (int64_t i = 0; i < D16; ++i) {
|
||||
nvcuda::wmma::mma_sync(lo[j][i], mv[j], mk[i], lo[j][i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
|
||||
for (int64_t j = 0; j < Q; ++j) {
|
||||
if (lane_id == 0) {
|
||||
ss[j*T + 0] = S[j];
|
||||
ss[j*T + 1] = M[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// reduce the warps sequentially
|
||||
for (int64_t sg = 1; sg < num_warps; ++sg) {
|
||||
half S = __float2half(0.0f);
|
||||
half M = __float2half(-INFINITY);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// each simdgroup stores its output to shared memory, reusing sq
|
||||
if (warp_id == sg) {
|
||||
for (int64_t j = 0; j < Q16; ++j) {
|
||||
for (int64_t i = 0; i < D16; ++i) {
|
||||
nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// the first simdgroup accumulates the results from the other simdgroups
|
||||
if (warp_id == 0) {
|
||||
for (int64_t j = 0; j < Q; ++j) {
|
||||
const half S0 = ss[j*T + 0];
|
||||
const half S1 = ss[j*T + sg*SH + 0];
|
||||
|
||||
const half M0 = ss[j*T + 1];
|
||||
const half M1 = ss[j*T + sg*SH + 1];
|
||||
|
||||
M = __hmax(M0, M1);
|
||||
|
||||
const half ms0 = __hisinf(M0) ? __float2half(0.0f) : hexp(M0 - M);
|
||||
const half ms1 = __hisinf(M1) ? __float2half(0.0f) : hexp(M1 - M);
|
||||
|
||||
S = S0*ms0 + S1*ms1;
|
||||
|
||||
if (lane_id == 0) {
|
||||
ss[j*T + 0] = S;
|
||||
ss[j*T + 1] = M;
|
||||
|
||||
ss[j*T + C + j ] = ms0;
|
||||
ss[j*T + C + j + sg*SH] = ms1;
|
||||
}
|
||||
}
|
||||
|
||||
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
|
||||
for (int64_t j = 0; j < Q16; ++j) {
|
||||
half16x16_a ms0;
|
||||
half16x16_a ms1;
|
||||
half16x16_b t;
|
||||
half16x16_acc t2;
|
||||
|
||||
nvcuda::wmma::load_matrix_sync(ms0, ss + 16*j*T + C + 16*j, T);
|
||||
nvcuda::wmma::load_matrix_sync(ms1, ss + 16*j*T + C + 16*j + sg*SH, T);
|
||||
|
||||
for (int64_t i = 0; i < D16; ++i) {
|
||||
nvcuda::wmma::fill_fragment(t2, 0.0);
|
||||
nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T);
|
||||
nvcuda::wmma::mma_sync(t2, ms1, t, t2);
|
||||
|
||||
// convert accumulator to matrix_b
|
||||
nvcuda::wmma::store_matrix_sync( sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);
|
||||
nvcuda::wmma::load_matrix_sync (t, sq + 16*j*T + i*16, T);
|
||||
|
||||
nvcuda::wmma::mma_sync(lo[j][i], ms0, t, t2);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// store result to shared memory (reuse sq)
|
||||
if (warp_id == 0) {
|
||||
for (int64_t j = 0; j < Q16; ++j) {
|
||||
for (int64_t i = 0; i < D16; ++i) {
|
||||
nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// final rescale with 1/S and store to global memory
|
||||
if (warp_id == 0) {
|
||||
for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) {
|
||||
const half S = ss[j*T + 0];
|
||||
|
||||
for (int64_t i = lane_id; i < D; i += NW) {
|
||||
dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i] / S);
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
#endif
|
||||
}
|
||||
|
||||
template<int qk, int qr, dequantize_kernel_t dq>
|
||||
static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||
const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
|
||||
@ -7682,6 +8231,13 @@ static void im2col_cuda(const float* x, T* dst,
|
||||
im2col_kernel<<<block_nums, CUDA_IM2COL_BLOCK_SIZE, 0, stream>>>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
|
||||
}
|
||||
|
||||
static void flash_attn_f32_cuda(const float* q, const float* k,const float* v, float* dst, float kq_scale, const int d_head, const int seq_len, const int num_heads, cudaStream_t stream) {
|
||||
int sram_memory_size = seq_len*sizeof(float) + WARP_SIZE * sizeof(float);
|
||||
int num_blocks = num_heads * seq_len;
|
||||
flash_attn_f32<CUDA_FLASH_ATTENTION_BLOCK_SIZE, 1024, 64><<<num_blocks, CUDA_FLASH_ATTENTION_BLOCK_SIZE, sram_memory_size, stream>>>(
|
||||
q, k, v, dst, kq_scale, d_head, seq_len, num_heads);
|
||||
}
|
||||
|
||||
// buffer pool for cuda
|
||||
#define MAX_CUDA_BUFFERS 256
|
||||
|
||||
@ -8659,7 +9215,7 @@ static void ggml_cuda_op_dequantize_mul_mat_vec(
|
||||
src1_dfloat = src1_dfloat_a.alloc(ne00);
|
||||
ggml_cpy_f32_f16_cuda((const char *) src1_ddf_i, (char *) src1_dfloat, ne00,
|
||||
ne00, 1, sizeof(float), 0, 0,
|
||||
ne00, 1, sizeof(half), 0, 0, stream);
|
||||
ne00, 1, sizeof(half), 0, 0, 0, 0, 0, 0, stream);
|
||||
}
|
||||
#else
|
||||
const dfloat * src1_dfloat = (const dfloat *) src1_ddf_i; // dfloat == float, no conversion
|
||||
@ -10284,6 +10840,170 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
|
||||
}
|
||||
}
|
||||
|
||||
inline void ggml_cuda_flash_attn(const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV) {
|
||||
GGML_ASSERT(Q->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(K->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(V->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(KQV->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_ASSERT(Q->backend == GGML_BACKEND_GPU);
|
||||
GGML_ASSERT(K->backend == GGML_BACKEND_GPU);
|
||||
GGML_ASSERT(V->backend == GGML_BACKEND_GPU);
|
||||
GGML_ASSERT(KQV->backend == GGML_BACKEND_GPU);
|
||||
|
||||
ggml_cuda_set_device(g_main_device);
|
||||
const cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
|
||||
|
||||
ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) Q->extra;
|
||||
ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) K->extra;
|
||||
ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) V->extra;
|
||||
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) KQV->extra;
|
||||
|
||||
const int64_t d_head = Q->ne[0];
|
||||
const int64_t sequence_length = Q->ne[1];
|
||||
const int64_t num_heads = Q->ne[2];
|
||||
|
||||
GGML_ASSERT(Q->ne[0] == d_head);
|
||||
GGML_ASSERT(K->ne[0] == d_head);
|
||||
GGML_ASSERT(V->ne[1] == d_head);
|
||||
|
||||
GGML_ASSERT(Q->ne[1] == sequence_length);
|
||||
GGML_ASSERT(K->ne[1] == sequence_length);
|
||||
GGML_ASSERT(V->ne[0] == sequence_length);
|
||||
|
||||
GGML_ASSERT(Q->ne[2] == num_heads);
|
||||
GGML_ASSERT(K->ne[2] == num_heads);
|
||||
GGML_ASSERT(V->ne[2] == num_heads);
|
||||
|
||||
float KQ_scale = 1.0f / sqrtf((float)d_head);
|
||||
|
||||
flash_attn_f32_cuda(
|
||||
(float *) src0_extra->data_device[g_main_device], // Query
|
||||
(float *) src1_extra->data_device[g_main_device], // Key
|
||||
(float *) src2_extra->data_device[g_main_device], // Value
|
||||
(float *) dst_extra->data_device[g_main_device], // dst
|
||||
KQ_scale, d_head, sequence_length, num_heads, main_stream);
|
||||
}
|
||||
|
||||
|
||||
inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, const ggml_tensor * mask, ggml_tensor * KQV) {
|
||||
GGML_ASSERT(Q->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(K->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(V->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(KQV->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_ASSERT(Q->backend == GGML_BACKEND_GPU);
|
||||
GGML_ASSERT(K->backend == GGML_BACKEND_GPU);
|
||||
GGML_ASSERT(V->backend == GGML_BACKEND_GPU);
|
||||
GGML_ASSERT(KQV->backend == GGML_BACKEND_GPU);
|
||||
|
||||
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(!mask || mask->backend == GGML_BACKEND_GPU);
|
||||
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
|
||||
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
|
||||
|
||||
ggml_cuda_set_device(g_main_device);
|
||||
const cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
|
||||
|
||||
ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) Q->extra;
|
||||
ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) K->extra;
|
||||
ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) V->extra;
|
||||
ggml_tensor_extra_gpu * src3_extra = mask ? (ggml_tensor_extra_gpu *) mask->extra : nullptr;
|
||||
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) KQV->extra;
|
||||
|
||||
float scale;
|
||||
memcpy(&scale, KQV->op_params, sizeof(float));
|
||||
|
||||
#define NQPB 16
|
||||
#define NCPW 128
|
||||
|
||||
const int nqpb = NQPB; // queries per block
|
||||
const int ncpw = NCPW; // cache values per warp (does not work for other values)
|
||||
|
||||
const int nwarps_max = 8; // TODO: we don't want to launch too much warps. how much is too much?
|
||||
// TODO: produces wrong results for nwarps > 8 (RTX 2060) - not sure why
|
||||
const int nwarps = Q->ne[1] <= nqpb ? std::max(2, std::min((int) K->ne[1]/ncpw, nwarps_max)) : 2;
|
||||
|
||||
dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]);
|
||||
dim3 block_dim(32, nwarps, 1);
|
||||
|
||||
const size_t shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2);
|
||||
|
||||
switch (Q->ne[0])
|
||||
{
|
||||
case 16:
|
||||
flash_attn_ext_f16<16, NQPB, NCPW>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) src0_extra->data_device[g_main_device], // Query
|
||||
(const char *) src1_extra->data_device[g_main_device], // Key
|
||||
(const char *) src2_extra->data_device[g_main_device], // Value
|
||||
mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
|
||||
(float *) dst_extra->data_device[g_main_device], // dst
|
||||
scale,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
);
|
||||
break;
|
||||
case 64:
|
||||
flash_attn_ext_f16<64, NQPB, NCPW>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) src0_extra->data_device[g_main_device], // Query
|
||||
(const char *) src1_extra->data_device[g_main_device], // Key
|
||||
(const char *) src2_extra->data_device[g_main_device], // Value
|
||||
mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
|
||||
(float *) dst_extra->data_device[g_main_device], // dst
|
||||
scale,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
);
|
||||
break;
|
||||
case 80:
|
||||
flash_attn_ext_f16<80, NQPB, NCPW>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) src0_extra->data_device[g_main_device], // Query
|
||||
(const char *) src1_extra->data_device[g_main_device], // Key
|
||||
(const char *) src2_extra->data_device[g_main_device], // Value
|
||||
mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
|
||||
(float *) dst_extra->data_device[g_main_device], // dst
|
||||
scale,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
);
|
||||
break;
|
||||
case 128:
|
||||
flash_attn_ext_f16<128, NQPB, NCPW>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) src0_extra->data_device[g_main_device], // Query
|
||||
(const char *) src1_extra->data_device[g_main_device], // Key
|
||||
(const char *) src2_extra->data_device[g_main_device], // Value
|
||||
mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
|
||||
(float *) dst_extra->data_device[g_main_device], // dst
|
||||
scale,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_scale);
|
||||
}
|
||||
@ -10573,6 +11293,10 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st
|
||||
case GGML_OP_ARGSORT:
|
||||
func = ggml_cuda_argsort;
|
||||
break;
|
||||
case GGML_OP_FLASH_ATTN:
|
||||
break;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
@ -10587,7 +11311,13 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st
|
||||
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
||||
return true;
|
||||
}
|
||||
func(tensor->src[0], tensor->src[1], tensor);
|
||||
if(tensor->op == GGML_OP_FLASH_ATTN) {
|
||||
ggml_cuda_flash_attn(tensor->src[0], tensor->src[1], tensor->src[2], tensor);
|
||||
} else if(tensor->op == GGML_OP_FLASH_ATTN_EXT) {
|
||||
ggml_cuda_flash_attn_ext(tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
|
||||
} else {
|
||||
func(tensor->src[0], tensor->src[1], tensor);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -11403,6 +12133,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
||||
case GGML_OP_UPSCALE:
|
||||
case GGML_OP_PAD:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
|
@ -6881,7 +6881,8 @@ static int llama_decode_internal(
|
||||
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
||||
// after enough generations, the benefit from this heuristic disappears
|
||||
// if we start defragmenting the cache, the benefit from this will be more important
|
||||
kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)));
|
||||
// note: we pad the n_kv because certain GPU kernels require it (e.g. ggml_flash_attn_ext)
|
||||
kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(128, GGML_PAD(llama_kv_cache_cell_max(kv_self), 128)));
|
||||
//kv_self.n = llama_kv_cache_cell_max(kv_self);
|
||||
|
||||
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
|
||||
|
Loading…
Reference in New Issue
Block a user