Compare commits

...

3 Commits

Author SHA1 Message Date
Akarshan Biswas
eea1e6e277
Merge 33b692934f into a5b57b08ce 2024-09-22 13:12:15 +05:30
Johannes Gäßler
a5b57b08ce
CUDA: enable Gemma FA for HIP/Pascal (#9581)
Some checks are pending
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/full-cuda.Dockerfile platforms:linux/amd64 tag:full-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/full.Dockerfile platforms:linux/amd64,linux/arm64 tag:full]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli-cuda.Dockerfile platforms:linux/amd64 tag:light-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli-intel.Dockerfile platforms:linux/amd64 tag:light-intel]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-cli.Dockerfile platforms:linux/amd64,linux/arm64 tag:light]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server-cuda.Dockerfile platforms:linux/amd64 tag:server-cuda]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server-intel.Dockerfile platforms:linux/amd64 tag:server-intel]) (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/llama-server.Dockerfile platforms:linux/amd64,linux/arm64 tag:server]) (push) Waiting to run
Nix CI / nix-eval (macos-latest) (push) Waiting to run
Nix CI / nix-eval (ubuntu-latest) (push) Waiting to run
Nix CI / nix-build (macos-latest) (push) Waiting to run
Nix CI / nix-build (ubuntu-latest) (push) Waiting to run
flake8 Lint / Lint (push) Waiting to run
2024-09-22 09:34:52 +02:00
Akarshan Biswas
33b692934f
Revert "[SYCL] fallback mmvq (#9088)"
This reverts commit 50addec9a5.
2024-09-21 20:02:59 +05:30
5 changed files with 13 additions and 15 deletions

View File

@ -2976,19 +2976,19 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_OP_LEAKY_RELU:
case GGML_OP_RWKV_WKV:
return true;
case GGML_OP_FLASH_ATTN_EXT:
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
return (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) || op->src[0]->ne[0] == 128;
#else
if (op->src[0]->ne[0] == 128) {
return true;
}
case GGML_OP_FLASH_ATTN_EXT: {
if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) {
return true;
}
return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA &&
op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
if (op->src[0]->ne[0] == 128) {
return true;
}
if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
return true;
}
const int cc = ggml_cuda_info().devices[cuda_ctx->device].cc;
return cc >= CC_VOLTA && cc < CC_OFFSET_AMD && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
}
case GGML_OP_CROSS_ENTROPY_LOSS:
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
case GGML_OP_OPT_STEP_ADAMW:

View File

@ -314,7 +314,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
}
if (!fast_fp16_available(cc)) {
if (Q->ne[1] <= 8) {
if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
} else {
ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);

View File

@ -3496,8 +3496,7 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
&& src1->ne[1] <= MMVQ_MAX_BATCH_SIZE
&& (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda || src1->ne[1] > MMVQ_MIN_BATCH_SIZE);
&& src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
bool use_mul_mat_q = ggml_sycl_supports_mmq(src0->type)
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;

View File

@ -134,7 +134,6 @@ typedef sycl::float2 dfloat2;
#endif // GGML_SYCL_F16
#define MMVQ_MAX_BATCH_SIZE 8
#define MMVQ_MIN_BATCH_SIZE 4
static const int8_t kvalues_iq4nl[16]={-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};

View File

@ -3599,7 +3599,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
if (hs != 128 && logit_softcap != 0.0f) continue;
for (int nh : { 32, }) {
for (int kv : { 512, 1024, }) {
for (int nb : { 1, 2, 4, 8, }) {
for (int nb : { 1, 3, 32, 35, }) {
for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV));
}