tests : more

This commit is contained in:
Georgi Gerganov 2024-01-29 18:22:28 +02:00
parent abeaf0d90e
commit c6c1132e5e
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
4 changed files with 22 additions and 24 deletions

View File

@ -137,7 +137,10 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
GGML_METAL_KERNEL_TYPE_CPY_F32_F16, GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
GGML_METAL_KERNEL_TYPE_CPY_F32_F32, GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
@ -505,7 +508,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
@ -2166,7 +2172,10 @@ static bool ggml_metal_graph_compute(
switch (ne00) { switch (ne00) {
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
default: default:
{ {
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);

View File

@ -2326,7 +2326,10 @@ kernel void kernel_flash_attn_ext_f16(
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 8, 32>; template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 8, 32>;
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 8, 32>; template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 8, 32>;
template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96, 8, 32>;
template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112, 8, 32>;
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 8, 32>; template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 8, 32>;
template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256, 8, 32>;
kernel void kernel_cpy_f16_f16( kernel void kernel_cpy_f16_f16(
device const half * src0, device const half * src0,

5
ggml.c
View File

@ -13554,11 +13554,9 @@ static void ggml_compute_forward_flash_attn_ext_f16(
const int64_t D = neq0; const int64_t D = neq0;
const int64_t N = neq1; const int64_t N = neq1;
const int64_t P = nek1 - N;
GGML_ASSERT(ne0 == D); GGML_ASSERT(ne0 == D);
GGML_ASSERT(ne2 == N); GGML_ASSERT(ne2 == N);
GGML_ASSERT(P >= 0);
GGML_ASSERT(nbq0 == sizeof(float)); GGML_ASSERT(nbq0 == sizeof(float));
GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t)); GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t));
@ -13569,7 +13567,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(
GGML_ASSERT(nev0 == D); GGML_ASSERT(nev0 == D);
GGML_ASSERT(neq1 == N); GGML_ASSERT(neq1 == N);
GGML_ASSERT(nek1 == N + P);
GGML_ASSERT(nev0 == D); GGML_ASSERT(nev0 == D);
// dst cannot be transposed or permuted // dst cannot be transposed or permuted
@ -13608,8 +13605,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(
float scale = 1.0f; float scale = 1.0f;
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
//printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
// loop over n_batch and n_head // loop over n_batch and n_head
for (int ir = ir0; ir < ir1; ++ir) { for (int ir = ir0; ir < ir1; ++ir) {
// q indices // q indices

View File

@ -1726,25 +1726,16 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
test_cases.emplace_back(new test_pad()); test_cases.emplace_back(new test_pad());
test_cases.emplace_back(new test_leaky_relu()); test_cases.emplace_back(new test_leaky_relu());
test_cases.emplace_back(new test_attn(64, 32, 512, 8)); for (int hs : { 64, 80, 96, 112, 128, 256, }) {
test_cases.emplace_back(new test_attn(64, 32, 512, 7)); for (int nh : { 32, }) {
test_cases.emplace_back(new test_attn(64, 32, 512, 1)); for (int kv : { 512, 1024, 2048, 4096, }) {
test_cases.emplace_back(new test_attn(80, 32, 512, 8)); for (int nb : { 1, 2, 4, 8, 512, 1024, 2048, }) {
test_cases.emplace_back(new test_attn(80, 32, 512, 7)); test_cases.emplace_back(new test_attn (hs, nh, kv, nb));
test_cases.emplace_back(new test_attn(80, 32, 512, 1)); test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb));
test_cases.emplace_back(new test_attn(128, 32, 512, 8)); }
test_cases.emplace_back(new test_attn(128, 32, 512, 7)); }
test_cases.emplace_back(new test_attn(128, 32, 512, 1)); }
}
test_cases.emplace_back(new test_flash_attn_ext(64, 32, 512, 8));
test_cases.emplace_back(new test_flash_attn_ext(64, 32, 512, 7));
test_cases.emplace_back(new test_flash_attn_ext(64, 32, 512, 1));
test_cases.emplace_back(new test_flash_attn_ext(80, 32, 512, 8));
test_cases.emplace_back(new test_flash_attn_ext(80, 32, 512, 7));
test_cases.emplace_back(new test_flash_attn_ext(80, 32, 512, 1));
test_cases.emplace_back(new test_flash_attn_ext(128, 32, 512, 8));
test_cases.emplace_back(new test_flash_attn_ext(128, 32, 512, 7));
test_cases.emplace_back(new test_flash_attn_ext(128, 32, 512, 1));
#if !defined(__SANITIZE_THREAD__) #if !defined(__SANITIZE_THREAD__)
// FIXME: these tests use too much memory with thread sanitizer // FIXME: these tests use too much memory with thread sanitizer