From c6c1132e5e6658b3c209433ed5ef75067ef31a2f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 29 Jan 2024 18:22:28 +0200 Subject: [PATCH] tests : more --- ggml-metal.m | 9 +++++++++ ggml-metal.metal | 3 +++ ggml.c | 5 ----- tests/test-backend-ops.cpp | 29 ++++++++++------------------- 4 files changed, 22 insertions(+), 24 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index a637f0487..4b5fd0bb8 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -137,7 +137,10 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, GGML_METAL_KERNEL_TYPE_CPY_F32_F16, GGML_METAL_KERNEL_TYPE_CPY_F32_F32, GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, @@ -505,7 +508,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); @@ -2166,7 +2172,10 @@ static bool ggml_metal_graph_compute( switch (ne00) { case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break; case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; default: { GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); diff --git a/ggml-metal.metal b/ggml-metal.metal index 08c000cc4..be059d78f 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2326,7 +2326,10 @@ kernel void kernel_flash_attn_ext_f16( template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 8, 32>; template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 8, 32>; +template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96, 8, 32>; +template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112, 8, 32>; template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 8, 32>; +template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256, 8, 32>; kernel void kernel_cpy_f16_f16( device const half * src0, diff --git a/ggml.c b/ggml.c index e8a5fcfa4..57271a1ad 100644 --- a/ggml.c +++ b/ggml.c @@ -13554,11 +13554,9 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int64_t D = neq0; const int64_t N = neq1; - const int64_t P = nek1 - N; GGML_ASSERT(ne0 == D); GGML_ASSERT(ne2 == N); - GGML_ASSERT(P >= 0); GGML_ASSERT(nbq0 == sizeof(float)); GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t)); @@ -13569,7 +13567,6 @@ static void ggml_compute_forward_flash_attn_ext_f16( GGML_ASSERT(nev0 == D); GGML_ASSERT(neq1 == N); - GGML_ASSERT(nek1 == N + P); GGML_ASSERT(nev0 == D); // dst cannot be transposed or permuted @@ -13608,8 +13605,6 @@ static void ggml_compute_forward_flash_attn_ext_f16( float scale = 1.0f; memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); - //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); - // loop over n_batch and n_head for (int ir = ir0; ir < ir1; ++ir) { // q indices diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 0ce498e9e..f57e8ab1a 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1726,25 +1726,16 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_pad()); test_cases.emplace_back(new test_leaky_relu()); - test_cases.emplace_back(new test_attn(64, 32, 512, 8)); - test_cases.emplace_back(new test_attn(64, 32, 512, 7)); - test_cases.emplace_back(new test_attn(64, 32, 512, 1)); - test_cases.emplace_back(new test_attn(80, 32, 512, 8)); - test_cases.emplace_back(new test_attn(80, 32, 512, 7)); - test_cases.emplace_back(new test_attn(80, 32, 512, 1)); - test_cases.emplace_back(new test_attn(128, 32, 512, 8)); - test_cases.emplace_back(new test_attn(128, 32, 512, 7)); - test_cases.emplace_back(new test_attn(128, 32, 512, 1)); - - test_cases.emplace_back(new test_flash_attn_ext(64, 32, 512, 8)); - test_cases.emplace_back(new test_flash_attn_ext(64, 32, 512, 7)); - test_cases.emplace_back(new test_flash_attn_ext(64, 32, 512, 1)); - test_cases.emplace_back(new test_flash_attn_ext(80, 32, 512, 8)); - test_cases.emplace_back(new test_flash_attn_ext(80, 32, 512, 7)); - test_cases.emplace_back(new test_flash_attn_ext(80, 32, 512, 1)); - test_cases.emplace_back(new test_flash_attn_ext(128, 32, 512, 8)); - test_cases.emplace_back(new test_flash_attn_ext(128, 32, 512, 7)); - test_cases.emplace_back(new test_flash_attn_ext(128, 32, 512, 1)); + for (int hs : { 64, 80, 96, 112, 128, 256, }) { + for (int nh : { 32, }) { + for (int kv : { 512, 1024, 2048, 4096, }) { + for (int nb : { 1, 2, 4, 8, 512, 1024, 2048, }) { + test_cases.emplace_back(new test_attn (hs, nh, kv, nb)); + test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb)); + } + } + } + } #if !defined(__SANITIZE_THREAD__) // FIXME: these tests use too much memory with thread sanitizer