mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-30 21:34:36 +00:00
558 lines
20 KiB
Plaintext
558 lines
20 KiB
Plaintext
#include "fattn.cuh"
|
|
|
|
#if __CUDA_ARCH__ >= CC_VOLTA
|
|
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, 16, 16, 16, half, nvcuda::wmma::row_major> half16x16_a;
|
|
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::row_major> half16x16_b;
|
|
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::col_major> half16x16_bT;
|
|
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, 16, 16, 16, half> half16x16_acc;
|
|
#endif
|
|
|
|
// based on metal version
|
|
template<int D, int Q, int C> // D head size, Q queries per block, C cache items per block
|
|
static __global__ void flash_attn_ext_f16(
|
|
const char* __restrict__ q,
|
|
const char* __restrict__ k,
|
|
const char* __restrict__ v,
|
|
const char* __restrict__ mask,
|
|
float* __restrict__ dst,
|
|
float scale,
|
|
int ne00,
|
|
int ne01,
|
|
int ne02,
|
|
int ne03,
|
|
int ne10,
|
|
int ne11,
|
|
int ne12,
|
|
int ne13,
|
|
int ne31,
|
|
int nb31,
|
|
int nb01,
|
|
int nb02,
|
|
int nb03,
|
|
int nb11,
|
|
int nb12,
|
|
int nb13,
|
|
int ne0,
|
|
int ne1,
|
|
int ne2,
|
|
int ne3) {
|
|
#if __CUDA_ARCH__ >= CC_VOLTA
|
|
const int warp_id = threadIdx.y;
|
|
const int lane_id = threadIdx.x;
|
|
|
|
const int num_warps = blockDim.y; // number of warps
|
|
const int iq3 = blockIdx.z;
|
|
const int iq2 = blockIdx.y;
|
|
const int iq1 = blockIdx.x * Q;
|
|
|
|
const int D16 = D/16;
|
|
const int Q16 = Q/16;
|
|
const int C16 = C/16;
|
|
|
|
const int NW = WARP_SIZE;
|
|
const int SH = (C + Q); // shared memory per simdgroup in (half)
|
|
|
|
const int T = D + num_warps*SH; // shared memory size per query in (half)
|
|
const int T2 = T/2; // shared memory size per query in (half2)
|
|
const int C2 = C/2;
|
|
const int D2 = D/2;
|
|
|
|
extern __shared__ half __flash_attn_f16_shmem[];
|
|
// pq
|
|
half * sq = (half *) (__flash_attn_f16_shmem + 0*D); // holds the query data
|
|
half2 * sq2 = (half2 *) (__flash_attn_f16_shmem + 0*D); // same as above but in half2
|
|
half * ss = (half *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // scratch buffer for attention and diagonal matrix
|
|
half2 * ss2 = (half2 *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // same as above but in half2
|
|
|
|
half16x16_acc zr;
|
|
half16x16_acc lo[Q16][D16];
|
|
|
|
// load heads from Q to shared memory
|
|
#pragma unroll
|
|
for (int j0 = 0; j0 < Q; j0 += num_warps) {
|
|
const int j = j0 + warp_id;
|
|
if (j >= Q) {
|
|
break;
|
|
}
|
|
|
|
const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
|
|
|
|
#pragma unroll
|
|
for (int i0 = 0; i0 < D2; i0 += NW) {
|
|
const int i = i0 + lane_id;
|
|
if (i >= D2) {
|
|
break;
|
|
}
|
|
|
|
if (iq1 + j < ne01) {
|
|
sq2[j*T2 + i] = __float22half2_rn(q2[i]);
|
|
} else {
|
|
sq2[j*T2 + i] = make_half2(0.0, 0.0);
|
|
}
|
|
}
|
|
}
|
|
|
|
nvcuda::wmma::fill_fragment(zr, 0.0);
|
|
|
|
// zero out lo
|
|
for (int j = 0; j < Q16; ++j) {
|
|
for (int i = 0; i < D16; ++i) {
|
|
nvcuda::wmma::fill_fragment(lo[j][i], 0.0);
|
|
}
|
|
}
|
|
|
|
// zero out shared memory SH
|
|
for (int j = 0; j < Q; ++j) {
|
|
for (int i0 = 0; i0 < SH; i0 += NW) {
|
|
const int i = i0 + lane_id;
|
|
if (i >= SH) {
|
|
break;
|
|
}
|
|
|
|
ss[j*T + i] = 0.0;
|
|
}
|
|
}
|
|
|
|
__syncthreads();
|
|
|
|
{
|
|
half S = __float2half(0.0f);
|
|
half M[Q];
|
|
|
|
for (int i = 0; i < Q; ++i) {
|
|
M[i] = CUDART_MIN_DENORM_FP16;
|
|
}
|
|
|
|
// assume K and V are same shape
|
|
const int ne22 = ne12;
|
|
const int ne23 = ne13;
|
|
|
|
const int nb21 = nb11;
|
|
const int nb22 = nb12;
|
|
const int nb23 = nb13;
|
|
|
|
// broadcast
|
|
const int rk2 = ne02/ne12;
|
|
const int rk3 = ne03/ne13;
|
|
|
|
const int rv2 = ne02/ne22;
|
|
const int rv3 = ne03/ne23;
|
|
|
|
// k indices
|
|
const int ik2 = iq2 / rk2;
|
|
const int ik3 = iq3 / rk3;
|
|
|
|
// v indices
|
|
const int iv2 = iq2 / rv2;
|
|
const int iv3 = iq3 / rv3;
|
|
|
|
// load the queries from shared memory into local memory
|
|
half16x16_a mq[Q16][D16];
|
|
for (int j = 0; j < Q16; ++j) {
|
|
for (int i = 0; i < D16; ++i) {
|
|
nvcuda::wmma::load_matrix_sync(mq[j][i], sq + 16*j*T + i*16, T);
|
|
}
|
|
}
|
|
|
|
// pointer to the mask
|
|
const half * mp = mask ? (const half *) (mask + iq1*nb31) : nullptr;
|
|
|
|
// prepare diagonal scale matrix
|
|
half16x16_b mscale;
|
|
for (int i = 0; i < 16; ++i) {
|
|
ss[i*T + i] = __float2half(scale);
|
|
}
|
|
nvcuda::wmma::load_matrix_sync(mscale, ss, T);
|
|
|
|
// loop over the KV cache
|
|
// each simdgroup handles blocks of Q rows and C columns
|
|
for (int ic0 = 0; ic0 < ne11; ic0 += C*num_warps) {
|
|
const int ic = ic0 + warp_id*C;
|
|
if (ic >= ne11) {
|
|
break;
|
|
}
|
|
|
|
// Q*K^T
|
|
{
|
|
#pragma unroll
|
|
for (int cc = 0; cc < C16; ++cc) {
|
|
half16x16_acc mqk[Q16];
|
|
for (int j = 0; j < Q16; ++j) {
|
|
nvcuda::wmma::fill_fragment(mqk[j], 0);
|
|
}
|
|
|
|
const half * pk = (const half *) ((const char *) k + ((ic + 16*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
|
|
|
for (int i = 0; i < D16; ++i) {
|
|
half16x16_bT mk; // transposed key
|
|
nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half));
|
|
|
|
for (int j = 0; j < Q16; ++j) {
|
|
nvcuda::wmma::mma_sync(mqk[j], mq[j][i], mk, mqk[j]);
|
|
}
|
|
}
|
|
|
|
// mqk = mqk*scale + mask
|
|
for (int j = 0; j < Q16; ++j) {
|
|
half16x16_a mqka;
|
|
half16x16_acc mm;
|
|
|
|
if (mp) {
|
|
nvcuda::wmma::load_matrix_sync(mm, mp + 16*j*(nb31/sizeof(half)) + ic + 16*cc, nb31/sizeof(half), nvcuda::wmma::mem_row_major);
|
|
}
|
|
|
|
// convert accumulator to matrix_a
|
|
nvcuda::wmma::store_matrix_sync( ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major);
|
|
nvcuda::wmma::load_matrix_sync (mqka, ss + 16*j*T + 16*cc, T);
|
|
|
|
nvcuda::wmma::mma_sync(mqk[j], mqka, mscale, mp ? mm : zr);
|
|
nvcuda::wmma::store_matrix_sync(ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major);
|
|
}
|
|
}
|
|
}
|
|
|
|
// used to detect blocks full of -INF
|
|
half2 smax = make_half2(-INFINITY, -INFINITY);
|
|
|
|
// online softmax
|
|
for (int j = 0; j < Q; ++j) {
|
|
const half m = M[j];
|
|
|
|
for (int p0 = 0; p0 < C2; p0 += NW) {
|
|
const int p = p0 + lane_id;
|
|
|
|
const half2 s = ss2[j*T2 + p];
|
|
|
|
smax = __hmax2(smax, s);
|
|
M[j] = __hmax(M[j], __hmax(s.x, s.y));
|
|
}
|
|
|
|
M[j] = warp_reduce_max(M[j]);
|
|
|
|
// local sum
|
|
half2 ls = make_half2(0.0f, 0.0f);
|
|
half2 M2 = make_half2(M[j], M[j]);
|
|
|
|
for (int p0 = 0; p0 < C2; p0 += NW) {
|
|
const int p = p0 + lane_id;
|
|
|
|
const half2 s = ss2[j*T2 + p];
|
|
|
|
const half2 vs = h2exp(s - M2);
|
|
|
|
ls += vs;
|
|
|
|
// the P matrix from the paper (Q rows, C columns)
|
|
ss2[j*T2 + p] = vs;
|
|
}
|
|
|
|
ls = warp_reduce_sum(ls);
|
|
|
|
const half ms = hexp(m - M[j]);
|
|
|
|
// create a QxQ diagonal matrix for rescaling the output
|
|
if (lane_id == j) {
|
|
ss[j*T + C + j] = ms;
|
|
|
|
S = S*ms + ls.x + ls.y;
|
|
}
|
|
}
|
|
|
|
smax = warp_reduce_max(smax);
|
|
|
|
// skip -INF blocks
|
|
if (__hisinf(smax.x) == -1 && __hisinf(smax.y) == -1) {
|
|
continue;
|
|
}
|
|
|
|
// O = diag(ms)*O
|
|
for (int j = 0; j < Q16; ++j) {
|
|
half16x16_a mm;
|
|
half16x16_b lob;
|
|
|
|
nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T);
|
|
|
|
for (int i = 0; i < D16; ++i) {
|
|
// convert accumulator to matrix_b
|
|
nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + 16*j, lo[j][i], T, nvcuda::wmma::mem_row_major);
|
|
nvcuda::wmma::load_matrix_sync (lob, ss + 16*j*T + C + 16*j, T);
|
|
|
|
nvcuda::wmma::mma_sync(lo[j][i], mm, lob, zr);
|
|
}
|
|
}
|
|
|
|
// restore zeros
|
|
for (int j = 0; j < Q16; ++j) {
|
|
nvcuda::wmma::store_matrix_sync(ss + 16*j*T + C + 16*j, zr, T, nvcuda::wmma::mem_row_major);
|
|
}
|
|
|
|
// O = O + (Q*K^T)*V
|
|
{
|
|
for (int cc = 0; cc < C16; ++cc) {
|
|
const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
|
|
|
half16x16_b mv[D16];
|
|
for (int i = 0; i < D16; ++i) {
|
|
nvcuda::wmma::load_matrix_sync(mv[i], pv + i*16, nb21/sizeof(half));
|
|
}
|
|
|
|
half16x16_a ms[Q16];
|
|
for (int j = 0; j < Q16; ++j) {
|
|
nvcuda::wmma::load_matrix_sync(ms[j], ss + 16*j*T + 16*cc, T);
|
|
}
|
|
|
|
for (int j = 0; j < Q16; ++j) {
|
|
for (int i = 0; i < D16; ++i) {
|
|
nvcuda::wmma::mma_sync(lo[j][i], ms[j], mv[i], lo[j][i]);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
|
|
if (lane_id < Q) {
|
|
ss[lane_id*T + 0] = S;
|
|
ss[lane_id*T + 1] = M[lane_id];
|
|
}
|
|
}
|
|
|
|
// reduce the warps sequentially
|
|
for (int sg = 1; sg < num_warps; ++sg) {
|
|
__syncthreads();
|
|
|
|
// each simdgroup stores its output to shared memory, reusing sq
|
|
if (warp_id == sg) {
|
|
for (int j = 0; j < Q16; ++j) {
|
|
for (int i = 0; i < D16; ++i) {
|
|
nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);
|
|
}
|
|
}
|
|
}
|
|
|
|
__syncthreads();
|
|
|
|
// the first simdgroup accumulates the results from the other simdgroups
|
|
if (warp_id == 0) {
|
|
for (int j = lane_id; j < Q; j += NW) {
|
|
const half S0 = ss[j*T + 0];
|
|
const half S1 = ss[j*T + sg*SH + 0];
|
|
|
|
const half M0 = ss[j*T + 1];
|
|
const half M1 = ss[j*T + sg*SH + 1];
|
|
|
|
const half M = __hmax(M0, M1);
|
|
|
|
const half ms0 = hexp(M0 - M);
|
|
const half ms1 = hexp(M1 - M);
|
|
|
|
const half S = S0*ms0 + S1*ms1;
|
|
|
|
ss[j*T + 0] = S;
|
|
ss[j*T + 1] = M;
|
|
|
|
ss[j*T + C + j ] = ms0;
|
|
ss[j*T + C + j + sg*SH] = ms1;
|
|
}
|
|
|
|
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
|
|
for (int j = 0; j < Q16; ++j) {
|
|
half16x16_a ms0;
|
|
half16x16_a ms1;
|
|
half16x16_b t;
|
|
half16x16_acc t2;
|
|
|
|
nvcuda::wmma::load_matrix_sync(ms0, ss + 16*j*T + C + 16*j, T);
|
|
nvcuda::wmma::load_matrix_sync(ms1, ss + 16*j*T + C + 16*j + sg*SH, T);
|
|
|
|
for (int i = 0; i < D16; ++i) {
|
|
nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T);
|
|
nvcuda::wmma::mma_sync(t2, ms1, t, zr);
|
|
|
|
// convert accumulator to matrix_b
|
|
nvcuda::wmma::store_matrix_sync( sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);
|
|
nvcuda::wmma::load_matrix_sync (t, sq + 16*j*T + i*16, T);
|
|
|
|
nvcuda::wmma::mma_sync(lo[j][i], ms0, t, t2);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// store result to shared memory (reuse sq)
|
|
if (warp_id == 0) {
|
|
for (int j = 0; j < Q16; ++j) {
|
|
for (int i = 0; i < D16; ++i) {
|
|
nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);
|
|
}
|
|
}
|
|
}
|
|
|
|
// final rescale with 1/S and store to global memory
|
|
if (warp_id == 0) {
|
|
for (int j = 0; j < Q && iq1 + j < ne01; ++j) {
|
|
const half S = ss[j*T + 0];
|
|
|
|
for (int i0 = 0; i0 < D; i0 += NW) {
|
|
const int i = i0 + lane_id;
|
|
if (i >= D) {
|
|
break;
|
|
}
|
|
|
|
dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i] / S);
|
|
}
|
|
}
|
|
}
|
|
#else
|
|
NO_DEVICE_CODE;
|
|
#endif
|
|
}
|
|
|
|
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, const ggml_tensor * mask, ggml_tensor * KQV) {
|
|
GGML_ASSERT(Q->type == GGML_TYPE_F32);
|
|
GGML_ASSERT(K->type == GGML_TYPE_F16);
|
|
GGML_ASSERT(V->type == GGML_TYPE_F16);
|
|
GGML_ASSERT(KQV->type == GGML_TYPE_F32);
|
|
|
|
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
|
|
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
|
|
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
|
|
|
|
ggml_cuda_set_device(ctx.device);
|
|
|
|
const cudaStream_t main_stream = ctx.stream();
|
|
|
|
float scale;
|
|
memcpy(&scale, KQV->op_params, sizeof(float));
|
|
|
|
#define NQPB 16
|
|
#define NCPW 128
|
|
|
|
const int nqpb = NQPB; // queries per block
|
|
const int ncpw = NCPW; // cache values per warp (does not work for other values)
|
|
|
|
GGML_ASSERT(NQPB <= 32);
|
|
|
|
const int nwarps_max = 8; // TODO: we don't want to launch too much warps. how much is too much?
|
|
// TODO: produces wrong results for nwarps > 8 (RTX 2060) - not sure why
|
|
const int nwarps = Q->ne[1] <= nqpb ? std::max(2, std::min((int) K->ne[1]/ncpw, nwarps_max)) : 1;
|
|
|
|
dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]);
|
|
dim3 block_dim(32, nwarps, 1);
|
|
|
|
const size_t shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2);
|
|
|
|
// increase shared memory limit to 96KB
|
|
//const size_t shmem_max = 96*1024;
|
|
//cudaFuncSetAttribute(flash_attn_ext_f16<128, NQPB, NCPW>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_max);
|
|
|
|
switch (Q->ne[0]) {
|
|
case 64:
|
|
flash_attn_ext_f16<64, NQPB, NCPW>
|
|
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
|
(const char *) Q->data, // Query
|
|
(const char *) K->data, // Key
|
|
(const char *) V->data, // Value
|
|
mask ? (const char *) mask->data : nullptr, // Mask
|
|
(float *) KQV->data, // dst
|
|
scale,
|
|
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
|
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
|
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
|
Q->nb[1], Q->nb[2], Q->nb[3],
|
|
K->nb[1], K->nb[2], K->nb[3],
|
|
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
|
);
|
|
break;
|
|
case 80:
|
|
flash_attn_ext_f16<80, NQPB, NCPW>
|
|
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
|
(const char *) Q->data, // Query
|
|
(const char *) K->data, // Key
|
|
(const char *) V->data, // Value
|
|
mask ? (const char *) mask->data : nullptr, // Mask
|
|
(float *) KQV->data, // dst
|
|
scale,
|
|
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
|
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
|
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
|
Q->nb[1], Q->nb[2], Q->nb[3],
|
|
K->nb[1], K->nb[2], K->nb[3],
|
|
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
|
);
|
|
break;
|
|
case 96:
|
|
flash_attn_ext_f16<96, NQPB, NCPW>
|
|
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
|
(const char *) Q->data, // Query
|
|
(const char *) K->data, // Key
|
|
(const char *) V->data, // Value
|
|
mask ? (const char *) mask->data : nullptr, // Mask
|
|
(float *) KQV->data, // dst
|
|
scale,
|
|
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
|
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
|
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
|
Q->nb[1], Q->nb[2], Q->nb[3],
|
|
K->nb[1], K->nb[2], K->nb[3],
|
|
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
|
);
|
|
break;
|
|
case 112:
|
|
flash_attn_ext_f16<112, NQPB, NCPW>
|
|
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
|
(const char *) Q->data, // Query
|
|
(const char *) K->data, // Key
|
|
(const char *) V->data, // Value
|
|
mask ? (const char *) mask->data : nullptr, // Mask
|
|
(float *) KQV->data, // dst
|
|
scale,
|
|
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
|
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
|
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
|
Q->nb[1], Q->nb[2], Q->nb[3],
|
|
K->nb[1], K->nb[2], K->nb[3],
|
|
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
|
);
|
|
break;
|
|
case 128:
|
|
flash_attn_ext_f16<128, NQPB, NCPW>
|
|
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
|
(const char *) Q->data, // Query
|
|
(const char *) K->data, // Key
|
|
(const char *) V->data, // Value
|
|
mask ? (const char *) mask->data : nullptr, // Mask
|
|
(float *) KQV->data, // dst
|
|
scale,
|
|
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
|
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
|
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
|
Q->nb[1], Q->nb[2], Q->nb[3],
|
|
K->nb[1], K->nb[2], K->nb[3],
|
|
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
|
);
|
|
break;
|
|
case 256:
|
|
flash_attn_ext_f16<256, NQPB, NCPW>
|
|
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
|
(const char *) Q->data, // Query
|
|
(const char *) K->data, // Key
|
|
(const char *) V->data, // Value
|
|
mask ? (const char *) mask->data : nullptr, // Mask
|
|
(float *) KQV->data, // dst
|
|
scale,
|
|
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
|
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
|
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
|
Q->nb[1], Q->nb[2], Q->nb[3],
|
|
K->nb[1], K->nb[2], K->nb[3],
|
|
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
|
);
|
|
break;
|
|
default:
|
|
break;
|
|
}
|
|
|
|
CUDA_CHECK(cudaGetLastError());
|
|
}
|