metal : fix support check

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-11-04 13:40:52 +02:00
parent e9565ccf9a
commit 13b87f212e
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -949,6 +949,9 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_OP_LEAKY_RELU:
return true;
case GGML_OP_FLASH_ATTN_EXT:
if (op->src[1]->type != op->src[2]->type) {
return false;
}
return support_simdgroup_mm; // TODO: over-restricted for vec-kernels
case GGML_OP_SSM_CONV:
case GGML_OP_SSM_SCAN:
@ -2893,6 +2896,7 @@ static void ggml_metal_encode_node(
GGML_ASSERT(ne11 % 32 == 0);
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == src2->type);
GGML_ASSERT(ggml_are_same_shape (src1, src2));
@ -3165,7 +3169,7 @@ static void ggml_metal_encode_node(
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
} else {
// half1x4 kernel
// half4x4 kernel
const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!