metal : unify mul_mv_id kernels (#6556)

This commit is contained in:
slaren 2024-04-12 18:13:20 +02:00 committed by GitHub
parent 4cc120c744
commit fbbc030ba9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 208 additions and 1122 deletions

View File

@ -1926,7 +1926,12 @@ static enum ggml_status ggml_metal_graph_compute(
{ {
nth0 = 4; nth0 = 4;
nth1 = 16; nth1 = 16;
#if QK_K == 64
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
#else
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
#endif
} break; } break;
default: default:
{ {

File diff suppressed because it is too large Load Diff

1
ggml.c
View File

@ -11012,7 +11012,6 @@ static void ggml_compute_forward_mul_mat_id(
} }
// initialize matrix_row_counts // initialize matrix_row_counts
GGML_ASSERT(wdata == wdata_src1_end);
memset(matrix_row_counts, 0, n_as*sizeof(int64_t)); memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
// group rows by src0 matrix // group rows by src0 matrix

View File

@ -2014,6 +2014,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
for (int n_mats : {2, 4, 8}) { for (int n_mats : {2, 4, 8}) {
for (int id = 0; id < n_mats; id++) { for (int id = 0; id < n_mats; id++) {
for (bool v : {false, true}) { for (bool v : {false, true}) {
test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, id, 16, 1, 256, v));
test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, id, 16, 16, 256, v)); test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, id, 16, 16, 256, v));
} }
} }