mtgpu: disable flash attention on qy1 (MTT S80); disable q3_k and mul_mat_batched_cublas

Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>
This commit is contained in:
Xiaodong Ye 2024-09-22 12:47:59 +08:00
parent e40b33dcad
commit 43ff5f36c2
3 changed files with 25 additions and 1 deletions

View File

@ -2829,6 +2829,12 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
if (op->op == GGML_OP_MUL_MAT && a->ne[3] != b->ne[3]) {
return false;
}
#ifdef GGML_USE_MUSA
if (b->type == GGML_TYPE_F16 && b->ne[2]*b->ne[3] > 1 &&
!ggml_is_transposed(a) && !ggml_is_transposed(b)) {
return false;
}
#endif // GGML_USE_MUSA
switch (a->type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
@ -2852,6 +2858,11 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
#ifdef GGML_USE_MUSA
if (a->type == GGML_TYPE_Q3_K) {
return false;
}
#endif // GGML_USE_MUSA
return true;
default:
return false;
@ -2977,6 +2988,9 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_OP_RWKV_WKV:
return true;
case GGML_OP_FLASH_ATTN_EXT: {
#ifndef FLASH_ATTN_AVAILABLE
return false;
#endif
if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) {
return true;
}

View File

@ -50,6 +50,8 @@
#define CC_RDNA1 (CC_OFFSET_AMD + 1010)
#define CC_RDNA2 (CC_OFFSET_AMD + 1030)
#define CC_RDNA3 (CC_OFFSET_AMD + 1100)
#define CC_QY1 210
#define CC_QY2 220
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
@ -134,6 +136,10 @@ typedef float2 dfloat2;
#define INT8_MMA_AVAILABLE
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
#if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= CC_QY1)
#define FLASH_ATTN_AVAILABLE
#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= CC_QY1)
static constexpr bool fast_fp16_available(const int cc) {
return cc >= CC_PASCAL && cc != 610;
}

View File

@ -44,13 +44,17 @@ static __global__ void flash_attn_tile_ext_f32(
const int ne1,
const int ne2,
const int ne3) {
#ifndef FLASH_ATTN_AVAILABLE
NO_DEVICE_CODE;
return;
#endif // FLASH_ATTN_AVAILABLE
// Skip unused kernel variants for faster compilation:
if (use_logit_softcap && !(D == 128 || D == 256)) {
NO_DEVICE_CODE;
return;
}
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
// In this kernel Q, K, V are matrices while i, j, k are matrix indices.
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.