mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-08 09:41:45 +00:00
a5b57b08ce
Some checks are pending
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/full-cuda.Dockerfile platforms:linux/amd64 tag:full-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/full.Dockerfile platforms:linux/amd64,linux/arm64 tag:full]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli-cuda.Dockerfile platforms:linux/amd64 tag:light-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli-intel.Dockerfile platforms:linux/amd64 tag:light-intel]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli.Dockerfile platforms:linux/amd64,linux/arm64 tag:light]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server-cuda.Dockerfile platforms:linux/amd64 tag:server-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server-intel.Dockerfile platforms:linux/amd64 tag:server-intel]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server.Dockerfile platforms:linux/amd64,linux/arm64 tag:server]) (push) Waiting to run
Nix CI / nix-eval (macos-latest) (push) Waiting to run
Nix CI / nix-eval (ubuntu-latest) (push) Waiting to run
Nix CI / nix-build (macos-latest) (push) Waiting to run
Nix CI / nix-build (ubuntu-latest) (push) Waiting to run
flake8 Lint / Lint (push) Waiting to run
346 lines
14 KiB
Plaintext
346 lines
14 KiB
Plaintext
#include "common.cuh"
|
|
#include "fattn-common.cuh"
|
|
#include "fattn-tile-f16.cuh"
|
|
#include "fattn-tile-f32.cuh"
|
|
#include "fattn-vec-f16.cuh"
|
|
#include "fattn-vec-f32.cuh"
|
|
#include "fattn-wmma-f16.cuh"
|
|
#include "fattn.cuh"
|
|
|
|
#include <cstdint>
|
|
|
|
static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
const ggml_tensor * KQV = dst;
|
|
const ggml_tensor * Q = dst->src[0];
|
|
|
|
const int32_t precision = KQV->op_params[3];
|
|
|
|
if (precision != GGML_PREC_DEFAULT) {
|
|
if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
|
|
constexpr int cols_per_block = 16;
|
|
switch (Q->ne[0]) {
|
|
case 64:
|
|
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
|
|
break;
|
|
case 80:
|
|
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
|
|
break;
|
|
case 96:
|
|
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
|
|
break;
|
|
case 112:
|
|
ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
|
|
break;
|
|
case 128:
|
|
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
|
|
break;
|
|
case 256:
|
|
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
|
|
break;
|
|
default:
|
|
GGML_ABORT("fatal error");
|
|
break;
|
|
}
|
|
} else {
|
|
constexpr int cols_per_block = 32;
|
|
switch (Q->ne[0]) {
|
|
case 64:
|
|
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
|
|
break;
|
|
case 80:
|
|
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
|
|
break;
|
|
case 96:
|
|
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
|
|
break;
|
|
case 112:
|
|
ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
|
|
break;
|
|
case 128:
|
|
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
|
|
break;
|
|
// case 256:
|
|
// ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
|
|
// break;
|
|
default:
|
|
GGML_ABORT("fatal error");
|
|
break;
|
|
}
|
|
}
|
|
return;
|
|
}
|
|
|
|
if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) {
|
|
constexpr int cols_per_block = 8;
|
|
switch (Q->ne[0]) {
|
|
case 64:
|
|
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
|
|
break;
|
|
case 96:
|
|
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
|
|
break;
|
|
case 128:
|
|
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
|
|
break;
|
|
case 256:
|
|
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
|
|
break;
|
|
default:
|
|
GGML_ABORT("fatal error");
|
|
break;
|
|
}
|
|
return;
|
|
}
|
|
|
|
if (Q->ne[1] <= 32) {
|
|
constexpr int cols_per_block = 16;
|
|
switch (Q->ne[0]) {
|
|
case 64:
|
|
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
|
|
break;
|
|
case 80:
|
|
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
|
|
break;
|
|
case 96:
|
|
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
|
|
break;
|
|
case 112:
|
|
ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
|
|
break;
|
|
case 128:
|
|
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
|
|
break;
|
|
case 256:
|
|
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
|
|
break;
|
|
default:
|
|
GGML_ABORT("fatal error");
|
|
break;
|
|
}
|
|
return;
|
|
}
|
|
|
|
constexpr int cols_per_block = 32;
|
|
switch (Q->ne[0]) {
|
|
case 64:
|
|
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
|
|
break;
|
|
case 80:
|
|
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
|
|
break;
|
|
case 96:
|
|
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
|
|
break;
|
|
case 112:
|
|
ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
|
|
break;
|
|
case 128:
|
|
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
|
|
break;
|
|
case 256:
|
|
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
|
|
break;
|
|
default:
|
|
GGML_ABORT("fatal error");
|
|
break;
|
|
}
|
|
}
|
|
#define FATTN_VEC_F16_CASE(D, type_K, type_V) \
|
|
if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \
|
|
ggml_cuda_flash_attn_ext_vec_f16_case<D, type_K, type_V>(ctx, dst); \
|
|
return; \
|
|
} \
|
|
|
|
static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
ggml_tensor * Q = dst->src[0];
|
|
ggml_tensor * K = dst->src[1];
|
|
ggml_tensor * V = dst->src[2];
|
|
|
|
#ifdef GGML_CUDA_FA_ALL_QUANTS
|
|
FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0)
|
|
FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1)
|
|
FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0)
|
|
FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1)
|
|
FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0)
|
|
FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16 )
|
|
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0)
|
|
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1)
|
|
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0)
|
|
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1)
|
|
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0)
|
|
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
|
|
|
|
FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
|
|
#else
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
|
|
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
|
|
|
|
FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
|
|
FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
|
|
#endif // GGML_CUDA_FA_ALL_QUANTS
|
|
|
|
on_no_fattn_vec_case(Q->ne[0]);
|
|
}
|
|
|
|
#define FATTN_VEC_F32_CASE(D, type_K, type_V) \
|
|
if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \
|
|
ggml_cuda_flash_attn_ext_vec_f32_case<D, type_K, type_V>(ctx, dst); \
|
|
return; \
|
|
} \
|
|
|
|
static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
ggml_tensor * Q = dst->src[0];
|
|
ggml_tensor * K = dst->src[1];
|
|
ggml_tensor * V = dst->src[2];
|
|
|
|
#ifdef GGML_CUDA_FA_ALL_QUANTS
|
|
FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0)
|
|
FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1)
|
|
FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0)
|
|
FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1)
|
|
FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0)
|
|
FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
|
|
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0)
|
|
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1)
|
|
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0)
|
|
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1)
|
|
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0)
|
|
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
|
|
|
|
FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
|
|
#else
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
|
|
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
|
|
|
|
FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
|
|
FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
|
|
#endif // GGML_CUDA_FA_ALL_QUANTS
|
|
|
|
on_no_fattn_vec_case(Q->ne[0]);
|
|
}
|
|
|
|
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
const ggml_tensor * KQV = dst;
|
|
const ggml_tensor * Q = dst->src[0];
|
|
|
|
ggml_cuda_set_device(ctx.device);
|
|
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
|
const int32_t precision = KQV->op_params[3];
|
|
|
|
// On AMD the tile kernels perform poorly, use the vec kernel instead:
|
|
if (cc >= CC_OFFSET_AMD) {
|
|
if (precision == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
|
|
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
|
} else {
|
|
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
|
}
|
|
return;
|
|
}
|
|
|
|
if (!fast_fp16_available(cc)) {
|
|
if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
|
|
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
|
} else {
|
|
ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
|
|
}
|
|
return;
|
|
}
|
|
|
|
if (!fp16_mma_available(cc)) {
|
|
if (Q->ne[1] <= 8) {
|
|
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
|
} else {
|
|
ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
|
|
}
|
|
return;
|
|
}
|
|
|
|
if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
|
|
if (precision == GGML_PREC_DEFAULT) {
|
|
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
|
return;
|
|
} else if(Q->ne[0] <= 128) {
|
|
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
|
return;
|
|
}
|
|
}
|
|
|
|
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
|
|
}
|