From 610394fff83368d5465b62f8c8add3737a39e42a Mon Sep 17 00:00:00 2001 From: Jared Van Bortel Date: Thu, 18 Jan 2024 15:32:55 -0500 Subject: [PATCH] fix supported ops for kompute backend --- ggml-kompute.cpp | 41 ++++++++++++++++++++++++-------------- tests/test-backend-ops.cpp | 5 ++++- 2 files changed, 30 insertions(+), 16 deletions(-) diff --git a/ggml-kompute.cpp b/ggml-kompute.cpp index 007367611..720a66986 100644 --- a/ggml-kompute.cpp +++ b/ggml-kompute.cpp @@ -1316,27 +1316,13 @@ static bool ggml_kompute_supports_op(const struct ggml_tensor * op) { case GGML_OP_VIEW: case GGML_OP_TRANSPOSE: case GGML_OP_PERMUTE: - case GGML_OP_CONCAT: case GGML_OP_ADD: - case GGML_OP_ACC: case GGML_OP_MUL: - case GGML_OP_DIV: case GGML_OP_SCALE: - case GGML_OP_SQR: - case GGML_OP_SUM_ROWS: case GGML_OP_SOFT_MAX: case GGML_OP_RMS_NORM: - case GGML_OP_GROUP_NORM: case GGML_OP_NORM: - case GGML_OP_ALIBI: case GGML_OP_ROPE: - case GGML_OP_IM2COL: - case GGML_OP_UPSCALE: - case GGML_OP_PAD: - case GGML_OP_ARGSORT: - case GGML_OP_LEAKY_RELU: - case GGML_OP_MUL_MAT: - case GGML_OP_MUL_MAT_ID: return true; case GGML_OP_DUP: case GGML_OP_CPY: @@ -1357,8 +1343,33 @@ static bool ggml_kompute_supports_op(const struct ggml_tensor * op) { } return true; case GGML_OP_DIAG_MASK_INF: - case GGML_OP_GET_ROWS: return op->ne[3] == 1; + case GGML_OP_GET_ROWS: + switch (op->src[0]->type) { + case GGML_TYPE_F16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q6_K: + return op->ne[3] == 1; + default: + ; + } + return false; + case GGML_OP_MUL_MAT: + if (op->src[1]->type != GGML_TYPE_F32 || ggml_is_transposed(op->src[0]) || ggml_is_transposed(op->src[1])) + return false; + + switch (op->src[0]->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q6_K: + return true; + default: + ; + } default: ; } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index d9b8b106a..a0063bbb9 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -360,7 +360,10 @@ struct test_case { // check if backends support op bool supported = true; for (ggml_backend_t backend : {backend1, backend2}) { - if (!ggml_backend_supports_op(backend, out)) { + if ( + !ggml_backend_supports_op(backend, out) + || (op_desc(out) == "MOE" && !strcmp(ggml_backend_name(backend), "Kompute")) + ) { printf("not supported [%s] ", ggml_backend_name(backend)); supported = false; }