mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-11-11 21:39:52 +00:00
metal : unify mul_mv_id kernels (#6556)
This commit is contained in:
parent
4cc120c744
commit
fbbc030ba9
@ -1926,7 +1926,12 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||||||
{
|
{
|
||||||
nth0 = 4;
|
nth0 = 4;
|
||||||
nth1 = 16;
|
nth1 = 16;
|
||||||
|
#if QK_K == 64
|
||||||
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
|
||||||
|
#else
|
||||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
|
||||||
|
#endif
|
||||||
|
|
||||||
} break;
|
} break;
|
||||||
default:
|
default:
|
||||||
{
|
{
|
||||||
|
1323
ggml-metal.metal
1323
ggml-metal.metal
File diff suppressed because it is too large
Load Diff
1
ggml.c
1
ggml.c
@ -11012,7 +11012,6 @@ static void ggml_compute_forward_mul_mat_id(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// initialize matrix_row_counts
|
// initialize matrix_row_counts
|
||||||
GGML_ASSERT(wdata == wdata_src1_end);
|
|
||||||
memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
|
memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
|
||||||
|
|
||||||
// group rows by src0 matrix
|
// group rows by src0 matrix
|
||||||
|
@ -2014,6 +2014,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
|||||||
for (int n_mats : {2, 4, 8}) {
|
for (int n_mats : {2, 4, 8}) {
|
||||||
for (int id = 0; id < n_mats; id++) {
|
for (int id = 0; id < n_mats; id++) {
|
||||||
for (bool v : {false, true}) {
|
for (bool v : {false, true}) {
|
||||||
|
test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, id, 16, 1, 256, v));
|
||||||
test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, id, 16, 16, 256, v));
|
test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, id, 16, 16, 256, v));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user