ggml : adapt AMX to tensor->grad removal (#0)

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-11-16 21:38:01 +02:00
parent d6c7f2a669
commit ce65dfe251
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -317,8 +317,6 @@ static bool ggml_backend_amx_device_supports_op(ggml_backend_dev_t dev, const st
const enum ggml_type type = src0->type; const enum ggml_type type = src0->type;
const int64_t ne0 = op->ne[0]; const int64_t ne0 = op->ne[0];
bool is_training = src0->grad || src1->grad;
// amx kernels enables for Q4_0, Q4_1, Q8_0, F16 // amx kernels enables for Q4_0, Q4_1, Q8_0, F16
// Q4_K, Q5_K, Q6_K, IQ4_XS enabled for QK_K = 256 // Q4_K, Q5_K, Q6_K, IQ4_XS enabled for QK_K = 256
bool has_amx_kernels = qtype_has_amx_kernels(type) || (type == GGML_TYPE_F16); bool has_amx_kernels = qtype_has_amx_kernels(type) || (type == GGML_TYPE_F16);
@ -326,7 +324,6 @@ static bool ggml_backend_amx_device_supports_op(ggml_backend_dev_t dev, const st
bool can_use_amx = bool can_use_amx =
is_contiguous_2d(src0) && // src0 must be contiguous is_contiguous_2d(src0) && // src0 must be contiguous
is_contiguous_2d(src1) && // src1 must be contiguous is_contiguous_2d(src1) && // src1 must be contiguous
!is_training && // inference only
src1->type == GGML_TYPE_F32 && // src1 must be float32 src1->type == GGML_TYPE_F32 && // src1 must be float32
has_amx_kernels && // with amx kernel impls has_amx_kernels && // with amx kernel impls
ne0 % (TILE_N * 2) == 0; // out_features is 32x ne0 % (TILE_N * 2) == 0; // out_features is 32x