mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-08 09:41:45 +00:00
metal : fix support check
ggml-ci
This commit is contained in:
parent
e9565ccf9a
commit
13b87f212e
@ -949,6 +949,9 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
return true;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
if (op->src[1]->type != op->src[2]->type) {
|
||||
return false;
|
||||
}
|
||||
return support_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
||||
case GGML_OP_SSM_CONV:
|
||||
case GGML_OP_SSM_SCAN:
|
||||
@ -2893,6 +2896,7 @@ static void ggml_metal_encode_node(
|
||||
GGML_ASSERT(ne11 % 32 == 0);
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->type == src2->type);
|
||||
|
||||
GGML_ASSERT(ggml_are_same_shape (src1, src2));
|
||||
|
||||
@ -3165,7 +3169,7 @@ static void ggml_metal_encode_node(
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
||||
} else {
|
||||
// half1x4 kernel
|
||||
// half4x4 kernel
|
||||
const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
|
||||
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user