mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-12 19:50:17 +00:00
sync : ggml (part 3, Metal)
This commit is contained in:
parent
6b1cf54197
commit
95e9d8a780
@ -98,6 +98,8 @@ GGML_API ggml_backend_t ggml_backend_metal_init(void);
|
|||||||
|
|
||||||
GGML_API bool ggml_backend_is_metal(ggml_backend_t backend);
|
GGML_API bool ggml_backend_is_metal(ggml_backend_t backend);
|
||||||
|
|
||||||
|
GGML_API ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void);
|
||||||
|
|
||||||
GGML_API void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb);
|
GGML_API void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
|
666
ggml-metal.m
666
ggml-metal.m
@ -62,6 +62,8 @@ struct ggml_metal_context {
|
|||||||
GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast
|
GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast
|
||||||
GGML_METAL_DECL_KERNEL(mul);
|
GGML_METAL_DECL_KERNEL(mul);
|
||||||
GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
|
GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
|
||||||
|
GGML_METAL_DECL_KERNEL(div);
|
||||||
|
GGML_METAL_DECL_KERNEL(div_row);
|
||||||
GGML_METAL_DECL_KERNEL(scale);
|
GGML_METAL_DECL_KERNEL(scale);
|
||||||
GGML_METAL_DECL_KERNEL(scale_4);
|
GGML_METAL_DECL_KERNEL(scale_4);
|
||||||
GGML_METAL_DECL_KERNEL(silu);
|
GGML_METAL_DECL_KERNEL(silu);
|
||||||
@ -112,10 +114,24 @@ struct ggml_metal_context {
|
|||||||
GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
|
GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
|
GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
|
GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
|
||||||
|
GGML_METAL_DECL_KERNEL(mul_mm_id_f32_f32);
|
||||||
|
GGML_METAL_DECL_KERNEL(mul_mm_id_f16_f32);
|
||||||
|
GGML_METAL_DECL_KERNEL(mul_mm_id_q4_0_f32);
|
||||||
|
GGML_METAL_DECL_KERNEL(mul_mm_id_q4_1_f32);
|
||||||
|
GGML_METAL_DECL_KERNEL(mul_mm_id_q5_0_f32);
|
||||||
|
GGML_METAL_DECL_KERNEL(mul_mm_id_q5_1_f32);
|
||||||
|
GGML_METAL_DECL_KERNEL(mul_mm_id_q8_0_f32);
|
||||||
|
GGML_METAL_DECL_KERNEL(mul_mm_id_q2_K_f32);
|
||||||
|
GGML_METAL_DECL_KERNEL(mul_mm_id_q3_K_f32);
|
||||||
|
GGML_METAL_DECL_KERNEL(mul_mm_id_q4_K_f32);
|
||||||
|
GGML_METAL_DECL_KERNEL(mul_mm_id_q5_K_f32);
|
||||||
|
GGML_METAL_DECL_KERNEL(mul_mm_id_q6_K_f32);
|
||||||
GGML_METAL_DECL_KERNEL(rope_f32);
|
GGML_METAL_DECL_KERNEL(rope_f32);
|
||||||
GGML_METAL_DECL_KERNEL(rope_f16);
|
GGML_METAL_DECL_KERNEL(rope_f16);
|
||||||
GGML_METAL_DECL_KERNEL(alibi_f32);
|
GGML_METAL_DECL_KERNEL(alibi_f32);
|
||||||
GGML_METAL_DECL_KERNEL(im2col_f16);
|
GGML_METAL_DECL_KERNEL(im2col_f16);
|
||||||
|
GGML_METAL_DECL_KERNEL(argsort_f32_i32_asc);
|
||||||
|
GGML_METAL_DECL_KERNEL(argsort_f32_i32_desc);
|
||||||
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
|
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
|
||||||
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
|
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
|
||||||
GGML_METAL_DECL_KERNEL(cpy_f32_q8_0);
|
GGML_METAL_DECL_KERNEL(cpy_f32_q8_0);
|
||||||
@ -126,6 +142,7 @@ struct ggml_metal_context {
|
|||||||
GGML_METAL_DECL_KERNEL(cpy_f16_f16);
|
GGML_METAL_DECL_KERNEL(cpy_f16_f16);
|
||||||
GGML_METAL_DECL_KERNEL(concat);
|
GGML_METAL_DECL_KERNEL(concat);
|
||||||
GGML_METAL_DECL_KERNEL(sqr);
|
GGML_METAL_DECL_KERNEL(sqr);
|
||||||
|
GGML_METAL_DECL_KERNEL(sum_rows);
|
||||||
|
|
||||||
#undef GGML_METAL_DECL_KERNEL
|
#undef GGML_METAL_DECL_KERNEL
|
||||||
};
|
};
|
||||||
@ -169,12 +186,10 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||||
GGML_METAL_LOG_INFO("%s: allocating\n", __func__);
|
GGML_METAL_LOG_INFO("%s: allocating\n", __func__);
|
||||||
|
|
||||||
id <MTLDevice> device;
|
id<MTLDevice> device;
|
||||||
NSString * s;
|
NSString * s;
|
||||||
|
|
||||||
#if TARGET_OS_OSX
|
#if TARGET_OS_OSX
|
||||||
@ -250,6 +265,29 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if TARGET_OS_OSX
|
||||||
|
// print MTL GPU family:
|
||||||
|
GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]);
|
||||||
|
|
||||||
|
// determine max supported GPU family
|
||||||
|
// https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
|
||||||
|
// https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
|
||||||
|
for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
|
||||||
|
if ([ctx->device supportsFamily:i]) {
|
||||||
|
GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
|
||||||
|
GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6);
|
||||||
|
if (ctx->device.maxTransferRate != 0) {
|
||||||
|
GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1e6);
|
||||||
|
} else {
|
||||||
|
GGML_METAL_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
// load kernels
|
// load kernels
|
||||||
{
|
{
|
||||||
NSError * error = nil;
|
NSError * error = nil;
|
||||||
@ -271,6 +309,8 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||||||
GGML_METAL_ADD_KERNEL(add_row);
|
GGML_METAL_ADD_KERNEL(add_row);
|
||||||
GGML_METAL_ADD_KERNEL(mul);
|
GGML_METAL_ADD_KERNEL(mul);
|
||||||
GGML_METAL_ADD_KERNEL(mul_row);
|
GGML_METAL_ADD_KERNEL(mul_row);
|
||||||
|
GGML_METAL_ADD_KERNEL(div);
|
||||||
|
GGML_METAL_ADD_KERNEL(div_row);
|
||||||
GGML_METAL_ADD_KERNEL(scale);
|
GGML_METAL_ADD_KERNEL(scale);
|
||||||
GGML_METAL_ADD_KERNEL(scale_4);
|
GGML_METAL_ADD_KERNEL(scale_4);
|
||||||
GGML_METAL_ADD_KERNEL(silu);
|
GGML_METAL_ADD_KERNEL(silu);
|
||||||
@ -322,11 +362,25 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||||||
GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
|
GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
|
GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
|
GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
|
||||||
|
GGML_METAL_ADD_KERNEL(mul_mm_id_f32_f32);
|
||||||
|
GGML_METAL_ADD_KERNEL(mul_mm_id_f16_f32);
|
||||||
|
GGML_METAL_ADD_KERNEL(mul_mm_id_q4_0_f32);
|
||||||
|
GGML_METAL_ADD_KERNEL(mul_mm_id_q4_1_f32);
|
||||||
|
GGML_METAL_ADD_KERNEL(mul_mm_id_q5_0_f32);
|
||||||
|
GGML_METAL_ADD_KERNEL(mul_mm_id_q5_1_f32);
|
||||||
|
GGML_METAL_ADD_KERNEL(mul_mm_id_q8_0_f32);
|
||||||
|
GGML_METAL_ADD_KERNEL(mul_mm_id_q2_K_f32);
|
||||||
|
GGML_METAL_ADD_KERNEL(mul_mm_id_q3_K_f32);
|
||||||
|
GGML_METAL_ADD_KERNEL(mul_mm_id_q4_K_f32);
|
||||||
|
GGML_METAL_ADD_KERNEL(mul_mm_id_q5_K_f32);
|
||||||
|
GGML_METAL_ADD_KERNEL(mul_mm_id_q6_K_f32);
|
||||||
}
|
}
|
||||||
GGML_METAL_ADD_KERNEL(rope_f32);
|
GGML_METAL_ADD_KERNEL(rope_f32);
|
||||||
GGML_METAL_ADD_KERNEL(rope_f16);
|
GGML_METAL_ADD_KERNEL(rope_f16);
|
||||||
GGML_METAL_ADD_KERNEL(alibi_f32);
|
GGML_METAL_ADD_KERNEL(alibi_f32);
|
||||||
GGML_METAL_ADD_KERNEL(im2col_f16);
|
GGML_METAL_ADD_KERNEL(im2col_f16);
|
||||||
|
GGML_METAL_ADD_KERNEL(argsort_f32_i32_asc);
|
||||||
|
GGML_METAL_ADD_KERNEL(argsort_f32_i32_desc);
|
||||||
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
|
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
|
||||||
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
|
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
|
||||||
GGML_METAL_ADD_KERNEL(cpy_f32_q8_0);
|
GGML_METAL_ADD_KERNEL(cpy_f32_q8_0);
|
||||||
@ -337,33 +391,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||||||
GGML_METAL_ADD_KERNEL(cpy_f16_f16);
|
GGML_METAL_ADD_KERNEL(cpy_f16_f16);
|
||||||
GGML_METAL_ADD_KERNEL(concat);
|
GGML_METAL_ADD_KERNEL(concat);
|
||||||
GGML_METAL_ADD_KERNEL(sqr);
|
GGML_METAL_ADD_KERNEL(sqr);
|
||||||
|
GGML_METAL_ADD_KERNEL(sum_rows);
|
||||||
|
|
||||||
#undef GGML_METAL_ADD_KERNEL
|
#undef GGML_METAL_ADD_KERNEL
|
||||||
}
|
}
|
||||||
|
|
||||||
#if TARGET_OS_OSX
|
|
||||||
// print MTL GPU family:
|
|
||||||
GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]);
|
|
||||||
|
|
||||||
// determine max supported GPU family
|
|
||||||
// https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
|
|
||||||
// https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
|
|
||||||
for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
|
|
||||||
if ([ctx->device supportsFamily:i]) {
|
|
||||||
GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
|
|
||||||
GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MiB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
|
||||||
if (ctx->device.maxTransferRate != 0) {
|
|
||||||
GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MiB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
|
|
||||||
} else {
|
|
||||||
GGML_METAL_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__);
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
return ctx;
|
return ctx;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -377,6 +409,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|||||||
GGML_METAL_DEL_KERNEL(add_row);
|
GGML_METAL_DEL_KERNEL(add_row);
|
||||||
GGML_METAL_DEL_KERNEL(mul);
|
GGML_METAL_DEL_KERNEL(mul);
|
||||||
GGML_METAL_DEL_KERNEL(mul_row);
|
GGML_METAL_DEL_KERNEL(mul_row);
|
||||||
|
GGML_METAL_DEL_KERNEL(div);
|
||||||
|
GGML_METAL_DEL_KERNEL(div_row);
|
||||||
GGML_METAL_DEL_KERNEL(scale);
|
GGML_METAL_DEL_KERNEL(scale);
|
||||||
GGML_METAL_DEL_KERNEL(scale_4);
|
GGML_METAL_DEL_KERNEL(scale_4);
|
||||||
GGML_METAL_DEL_KERNEL(silu);
|
GGML_METAL_DEL_KERNEL(silu);
|
||||||
@ -428,11 +462,25 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|||||||
GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
|
GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
|
||||||
GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
|
GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
|
||||||
GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
|
GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
|
||||||
|
GGML_METAL_DEL_KERNEL(mul_mm_id_f32_f32);
|
||||||
|
GGML_METAL_DEL_KERNEL(mul_mm_id_f16_f32);
|
||||||
|
GGML_METAL_DEL_KERNEL(mul_mm_id_q4_0_f32);
|
||||||
|
GGML_METAL_DEL_KERNEL(mul_mm_id_q4_1_f32);
|
||||||
|
GGML_METAL_DEL_KERNEL(mul_mm_id_q5_0_f32);
|
||||||
|
GGML_METAL_DEL_KERNEL(mul_mm_id_q5_1_f32);
|
||||||
|
GGML_METAL_DEL_KERNEL(mul_mm_id_q8_0_f32);
|
||||||
|
GGML_METAL_DEL_KERNEL(mul_mm_id_q2_K_f32);
|
||||||
|
GGML_METAL_DEL_KERNEL(mul_mm_id_q3_K_f32);
|
||||||
|
GGML_METAL_DEL_KERNEL(mul_mm_id_q4_K_f32);
|
||||||
|
GGML_METAL_DEL_KERNEL(mul_mm_id_q5_K_f32);
|
||||||
|
GGML_METAL_DEL_KERNEL(mul_mm_id_q6_K_f32);
|
||||||
}
|
}
|
||||||
GGML_METAL_DEL_KERNEL(rope_f32);
|
GGML_METAL_DEL_KERNEL(rope_f32);
|
||||||
GGML_METAL_DEL_KERNEL(rope_f16);
|
GGML_METAL_DEL_KERNEL(rope_f16);
|
||||||
GGML_METAL_DEL_KERNEL(alibi_f32);
|
GGML_METAL_DEL_KERNEL(alibi_f32);
|
||||||
GGML_METAL_DEL_KERNEL(im2col_f16);
|
GGML_METAL_DEL_KERNEL(im2col_f16);
|
||||||
|
GGML_METAL_DEL_KERNEL(argsort_f32_i32_asc);
|
||||||
|
GGML_METAL_DEL_KERNEL(argsort_f32_i32_desc);
|
||||||
GGML_METAL_DEL_KERNEL(cpy_f32_f16);
|
GGML_METAL_DEL_KERNEL(cpy_f32_f16);
|
||||||
GGML_METAL_DEL_KERNEL(cpy_f32_f32);
|
GGML_METAL_DEL_KERNEL(cpy_f32_f32);
|
||||||
GGML_METAL_DEL_KERNEL(cpy_f32_q8_0);
|
GGML_METAL_DEL_KERNEL(cpy_f32_q8_0);
|
||||||
@ -443,6 +491,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|||||||
GGML_METAL_DEL_KERNEL(cpy_f16_f16);
|
GGML_METAL_DEL_KERNEL(cpy_f16_f16);
|
||||||
GGML_METAL_DEL_KERNEL(concat);
|
GGML_METAL_DEL_KERNEL(concat);
|
||||||
GGML_METAL_DEL_KERNEL(sqr);
|
GGML_METAL_DEL_KERNEL(sqr);
|
||||||
|
GGML_METAL_DEL_KERNEL(sum_rows);
|
||||||
|
|
||||||
#undef GGML_METAL_DEL_KERNEL
|
#undef GGML_METAL_DEL_KERNEL
|
||||||
|
|
||||||
@ -486,6 +535,13 @@ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) {
|
|||||||
return ctx->concur_list;
|
return ctx->concur_list;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// temporarily defined here for compatibility between ggml-backend and the old API
|
||||||
|
struct ggml_backend_metal_buffer_context {
|
||||||
|
void * data;
|
||||||
|
|
||||||
|
id<MTLBuffer> metal;
|
||||||
|
};
|
||||||
|
|
||||||
// finds the Metal buffer that contains the tensor data on the GPU device
|
// finds the Metal buffer that contains the tensor data on the GPU device
|
||||||
// the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
|
// the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
|
||||||
// Metal buffer based on the host memory pointer
|
// Metal buffer based on the host memory pointer
|
||||||
@ -495,8 +551,17 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
|
|||||||
|
|
||||||
const int64_t tsize = ggml_nbytes(t);
|
const int64_t tsize = ggml_nbytes(t);
|
||||||
|
|
||||||
if (t->buffer && t->buffer->backend && t->buffer->backend->context) {
|
// compatibility with ggml-backend
|
||||||
ctx = t->buffer->backend->context;
|
if (t->buffer && t->buffer->buft == ggml_backend_metal_buffer_type()) {
|
||||||
|
struct ggml_backend_metal_buffer_context * buf_ctx = (struct ggml_backend_metal_buffer_context *) t->buffer->context;
|
||||||
|
|
||||||
|
const int64_t ioffs = (int64_t) t->data - (int64_t) buf_ctx->data;
|
||||||
|
|
||||||
|
GGML_ASSERT(ioffs >= 0 && ioffs + tsize <= (int64_t) t->buffer->size);
|
||||||
|
|
||||||
|
*offs = (size_t) ioffs;
|
||||||
|
|
||||||
|
return buf_ctx->metal;
|
||||||
}
|
}
|
||||||
|
|
||||||
// find the view that contains the tensor fully
|
// find the view that contains the tensor fully
|
||||||
@ -721,6 +786,52 @@ void ggml_metal_graph_find_concurrency(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
|
||||||
|
switch (op->op) {
|
||||||
|
case GGML_OP_UNARY:
|
||||||
|
switch (ggml_get_unary_op(op)) {
|
||||||
|
case GGML_UNARY_OP_SILU:
|
||||||
|
case GGML_UNARY_OP_RELU:
|
||||||
|
case GGML_UNARY_OP_GELU:
|
||||||
|
return true;
|
||||||
|
default:
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case GGML_OP_NONE:
|
||||||
|
case GGML_OP_RESHAPE:
|
||||||
|
case GGML_OP_VIEW:
|
||||||
|
case GGML_OP_TRANSPOSE:
|
||||||
|
case GGML_OP_PERMUTE:
|
||||||
|
case GGML_OP_CONCAT:
|
||||||
|
case GGML_OP_ADD:
|
||||||
|
case GGML_OP_MUL:
|
||||||
|
case GGML_OP_DIV:
|
||||||
|
case GGML_OP_SCALE:
|
||||||
|
case GGML_OP_SQR:
|
||||||
|
case GGML_OP_SUM_ROWS:
|
||||||
|
case GGML_OP_SOFT_MAX:
|
||||||
|
case GGML_OP_RMS_NORM:
|
||||||
|
case GGML_OP_NORM:
|
||||||
|
case GGML_OP_ALIBI:
|
||||||
|
case GGML_OP_ROPE:
|
||||||
|
case GGML_OP_IM2COL:
|
||||||
|
case GGML_OP_ARGSORT:
|
||||||
|
case GGML_OP_DUP:
|
||||||
|
case GGML_OP_CPY:
|
||||||
|
case GGML_OP_CONT:
|
||||||
|
case GGML_OP_MUL_MAT:
|
||||||
|
case GGML_OP_MUL_MAT_ID:
|
||||||
|
return true;
|
||||||
|
case GGML_OP_DIAG_MASK_INF:
|
||||||
|
case GGML_OP_GET_ROWS:
|
||||||
|
{
|
||||||
|
return op->ne[0] % 4 == 0;
|
||||||
|
} break;
|
||||||
|
default:
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
void ggml_metal_graph_compute(
|
void ggml_metal_graph_compute(
|
||||||
struct ggml_metal_context * ctx,
|
struct ggml_metal_context * ctx,
|
||||||
struct ggml_cgraph * gf) {
|
struct ggml_cgraph * gf) {
|
||||||
@ -791,6 +902,8 @@ void ggml_metal_graph_compute(
|
|||||||
} break;
|
} break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_metal_supports_op(dst));
|
||||||
|
|
||||||
const int64_t ne00 = src0 ? src0->ne[0] : 0;
|
const int64_t ne00 = src0 ? src0->ne[0] : 0;
|
||||||
const int64_t ne01 = src0 ? src0->ne[1] : 0;
|
const int64_t ne01 = src0 ? src0->ne[1] : 0;
|
||||||
const int64_t ne02 = src0 ? src0->ne[2] : 0;
|
const int64_t ne02 = src0 ? src0->ne[2] : 0;
|
||||||
@ -883,6 +996,8 @@ void ggml_metal_graph_compute(
|
|||||||
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_ADD:
|
case GGML_OP_ADD:
|
||||||
|
case GGML_OP_MUL:
|
||||||
|
case GGML_OP_DIV:
|
||||||
{
|
{
|
||||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||||
GGML_ASSERT(ggml_is_contiguous(src1));
|
GGML_ASSERT(ggml_is_contiguous(src1));
|
||||||
@ -896,11 +1011,21 @@ void ggml_metal_graph_compute(
|
|||||||
GGML_ASSERT(ne11 == 1);
|
GGML_ASSERT(ne11 == 1);
|
||||||
|
|
||||||
nb = ne00 / 4;
|
nb = ne00 / 4;
|
||||||
[encoder setComputePipelineState:ctx->pipeline_add_row];
|
switch (dst->op) {
|
||||||
|
case GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add_row]; break;
|
||||||
|
case GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul_row]; break;
|
||||||
|
case GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div_row]; break;
|
||||||
|
default: GGML_ASSERT(false);
|
||||||
|
}
|
||||||
|
|
||||||
bcast_row = true;
|
bcast_row = true;
|
||||||
} else {
|
} else {
|
||||||
[encoder setComputePipelineState:ctx->pipeline_add];
|
switch (dst->op) {
|
||||||
|
case GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add]; break;
|
||||||
|
case GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul]; break;
|
||||||
|
case GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div]; break;
|
||||||
|
default: GGML_ASSERT(false);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||||
@ -941,31 +1066,6 @@ void ggml_metal_graph_compute(
|
|||||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_MUL:
|
|
||||||
{
|
|
||||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
||||||
GGML_ASSERT(ggml_is_contiguous(src1));
|
|
||||||
|
|
||||||
// utilize float4
|
|
||||||
GGML_ASSERT(ne00 % 4 == 0);
|
|
||||||
const int64_t nb = ne00/4;
|
|
||||||
|
|
||||||
if (ggml_nelements(src1) == ne10) {
|
|
||||||
// src1 is a row
|
|
||||||
GGML_ASSERT(ne11 == 1);
|
|
||||||
[encoder setComputePipelineState:ctx->pipeline_mul_row];
|
|
||||||
} else {
|
|
||||||
[encoder setComputePipelineState:ctx->pipeline_mul];
|
|
||||||
}
|
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
||||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
||||||
[encoder setBytes:&nb length:sizeof(nb) atIndex:3];
|
|
||||||
|
|
||||||
const int64_t n = ggml_nelements(dst)/4;
|
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
||||||
} break;
|
|
||||||
case GGML_OP_SCALE:
|
case GGML_OP_SCALE:
|
||||||
{
|
{
|
||||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||||
@ -1038,6 +1138,40 @@ void ggml_metal_graph_compute(
|
|||||||
const int64_t n = ggml_nelements(dst);
|
const int64_t n = ggml_nelements(dst);
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_SUM_ROWS:
|
||||||
|
{
|
||||||
|
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:ctx->pipeline_sum_rows];
|
||||||
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
|
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
||||||
|
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
||||||
|
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
||||||
|
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
||||||
|
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
||||||
|
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
||||||
|
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
||||||
|
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
||||||
|
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
|
||||||
|
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
|
||||||
|
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
|
||||||
|
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
|
||||||
|
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
|
||||||
|
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
|
||||||
|
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
|
||||||
|
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
|
||||||
|
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
|
||||||
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19];
|
||||||
|
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20];
|
||||||
|
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21];
|
||||||
|
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22];
|
||||||
|
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23];
|
||||||
|
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24];
|
||||||
|
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25];
|
||||||
|
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
|
} break;
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
{
|
{
|
||||||
int nth = 32; // SIMD width
|
int nth = 32; // SIMD width
|
||||||
@ -1092,13 +1226,17 @@ void ggml_metal_graph_compute(
|
|||||||
case GGML_OP_MUL_MAT:
|
case GGML_OP_MUL_MAT:
|
||||||
{
|
{
|
||||||
GGML_ASSERT(ne00 == ne10);
|
GGML_ASSERT(ne00 == ne10);
|
||||||
GGML_ASSERT(ne03 == ne13);
|
|
||||||
|
|
||||||
const uint gqa = ne12/ne02;
|
// TODO: assert that dim2 and dim3 are contiguous
|
||||||
|
GGML_ASSERT(ne12 % ne02 == 0);
|
||||||
|
GGML_ASSERT(ne13 % ne03 == 0);
|
||||||
|
|
||||||
|
const uint r2 = ne12/ne02;
|
||||||
|
const uint r3 = ne13/ne03;
|
||||||
|
|
||||||
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
||||||
// to the matrix-vector kernel
|
// to the matrix-vector kernel
|
||||||
int ne11_mm_min = src0t == GGML_TYPE_F16 ? 1 : 16;
|
int ne11_mm_min = 1;
|
||||||
|
|
||||||
#if 0
|
#if 0
|
||||||
// the numbers below are measured on M2 Ultra for 7B and 13B models
|
// the numbers below are measured on M2 Ultra for 7B and 13B models
|
||||||
@ -1159,9 +1297,10 @@ void ggml_metal_graph_compute(
|
|||||||
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
|
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
|
||||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
|
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
|
||||||
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
|
||||||
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:13];
|
[encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
|
||||||
|
[encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
|
||||||
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
||||||
} else {
|
} else {
|
||||||
int nth0 = 32;
|
int nth0 = 32;
|
||||||
int nth1 = 1;
|
int nth1 = 1;
|
||||||
@ -1197,90 +1336,60 @@ void ggml_metal_graph_compute(
|
|||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
{
|
{
|
||||||
GGML_ASSERT(ne02 == 1);
|
|
||||||
GGML_ASSERT(ne12 == 1);
|
|
||||||
|
|
||||||
nth0 = 8;
|
nth0 = 8;
|
||||||
nth1 = 8;
|
nth1 = 8;
|
||||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_0_f32];
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_0_f32];
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_Q4_1:
|
case GGML_TYPE_Q4_1:
|
||||||
{
|
{
|
||||||
GGML_ASSERT(ne02 == 1);
|
|
||||||
GGML_ASSERT(ne12 == 1);
|
|
||||||
|
|
||||||
nth0 = 8;
|
nth0 = 8;
|
||||||
nth1 = 8;
|
nth1 = 8;
|
||||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32];
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32];
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_Q5_0:
|
case GGML_TYPE_Q5_0:
|
||||||
{
|
{
|
||||||
GGML_ASSERT(ne02 == 1);
|
|
||||||
GGML_ASSERT(ne12 == 1);
|
|
||||||
|
|
||||||
nth0 = 8;
|
nth0 = 8;
|
||||||
nth1 = 8;
|
nth1 = 8;
|
||||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_0_f32];
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_0_f32];
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_Q5_1:
|
case GGML_TYPE_Q5_1:
|
||||||
{
|
{
|
||||||
GGML_ASSERT(ne02 == 1);
|
|
||||||
GGML_ASSERT(ne12 == 1);
|
|
||||||
|
|
||||||
nth0 = 8;
|
nth0 = 8;
|
||||||
nth1 = 8;
|
nth1 = 8;
|
||||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_1_f32];
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_1_f32];
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_Q8_0:
|
case GGML_TYPE_Q8_0:
|
||||||
{
|
{
|
||||||
GGML_ASSERT(ne02 == 1);
|
|
||||||
GGML_ASSERT(ne12 == 1);
|
|
||||||
|
|
||||||
nth0 = 8;
|
nth0 = 8;
|
||||||
nth1 = 8;
|
nth1 = 8;
|
||||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q8_0_f32];
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q8_0_f32];
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_Q2_K:
|
case GGML_TYPE_Q2_K:
|
||||||
{
|
{
|
||||||
GGML_ASSERT(ne02 == 1);
|
|
||||||
GGML_ASSERT(ne12 == 1);
|
|
||||||
|
|
||||||
nth0 = 2;
|
nth0 = 2;
|
||||||
nth1 = 32;
|
nth1 = 32;
|
||||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q2_K_f32];
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q2_K_f32];
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_Q3_K:
|
case GGML_TYPE_Q3_K:
|
||||||
{
|
{
|
||||||
GGML_ASSERT(ne02 == 1);
|
|
||||||
GGML_ASSERT(ne12 == 1);
|
|
||||||
|
|
||||||
nth0 = 2;
|
nth0 = 2;
|
||||||
nth1 = 32;
|
nth1 = 32;
|
||||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q3_K_f32];
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q3_K_f32];
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_Q4_K:
|
case GGML_TYPE_Q4_K:
|
||||||
{
|
{
|
||||||
GGML_ASSERT(ne02 == 1);
|
|
||||||
GGML_ASSERT(ne12 == 1);
|
|
||||||
|
|
||||||
nth0 = 4; //1;
|
nth0 = 4; //1;
|
||||||
nth1 = 8; //32;
|
nth1 = 8; //32;
|
||||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_K_f32];
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_K_f32];
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_Q5_K:
|
case GGML_TYPE_Q5_K:
|
||||||
{
|
{
|
||||||
GGML_ASSERT(ne02 == 1);
|
|
||||||
GGML_ASSERT(ne12 == 1);
|
|
||||||
|
|
||||||
nth0 = 2;
|
nth0 = 2;
|
||||||
nth1 = 32;
|
nth1 = 32;
|
||||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_K_f32];
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_K_f32];
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_Q6_K:
|
case GGML_TYPE_Q6_K:
|
||||||
{
|
{
|
||||||
GGML_ASSERT(ne02 == 1);
|
|
||||||
GGML_ASSERT(ne12 == 1);
|
|
||||||
|
|
||||||
nth0 = 2;
|
nth0 = 2;
|
||||||
nth1 = 32;
|
nth1 = 32;
|
||||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32];
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32];
|
||||||
@ -1309,34 +1418,127 @@ void ggml_metal_graph_compute(
|
|||||||
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
|
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
|
||||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
|
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
|
||||||
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
|
||||||
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
|
[encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
|
||||||
|
[encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
|
||||||
|
|
||||||
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
|
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
|
||||||
src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
|
src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
|
||||||
src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
|
src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
}
|
}
|
||||||
else if (src0t == GGML_TYPE_Q4_K) {
|
else if (src0t == GGML_TYPE_Q4_K) {
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
}
|
}
|
||||||
else if (src0t == GGML_TYPE_Q3_K) {
|
else if (src0t == GGML_TYPE_Q3_K) {
|
||||||
#ifdef GGML_QKK_64
|
#ifdef GGML_QKK_64
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
#else
|
#else
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
else if (src0t == GGML_TYPE_Q5_K) {
|
else if (src0t == GGML_TYPE_Q5_K) {
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
}
|
}
|
||||||
else if (src0t == GGML_TYPE_Q6_K) {
|
else if (src0t == GGML_TYPE_Q6_K) {
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
} else {
|
} else {
|
||||||
int64_t ny = (ne11 + nrows - 1)/nrows;
|
int64_t ny = (ne11 + nrows - 1)/nrows;
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_MUL_MAT_ID:
|
||||||
|
{
|
||||||
|
//GGML_ASSERT(ne00 == ne10);
|
||||||
|
//GGML_ASSERT(ne03 == ne13);
|
||||||
|
|
||||||
|
GGML_ASSERT(src0t == GGML_TYPE_I32);
|
||||||
|
|
||||||
|
const int n_as = ne00;
|
||||||
|
|
||||||
|
// TODO: make this more general
|
||||||
|
GGML_ASSERT(n_as <= 8);
|
||||||
|
|
||||||
|
struct ggml_tensor * src2 = gf->nodes[i]->src[2];
|
||||||
|
|
||||||
|
const int64_t ne20 = src2 ? src2->ne[0] : 0;
|
||||||
|
const int64_t ne21 = src2 ? src2->ne[1] : 0;
|
||||||
|
const int64_t ne22 = src2 ? src2->ne[2] : 0;
|
||||||
|
const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
|
||||||
|
|
||||||
|
const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
|
||||||
|
const uint64_t nb21 = src2 ? src2->nb[1] : 0;
|
||||||
|
const uint64_t nb22 = src2 ? src2->nb[2] : 0;
|
||||||
|
const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23);
|
||||||
|
|
||||||
|
const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
|
||||||
|
|
||||||
|
GGML_ASSERT(!ggml_is_transposed(src2));
|
||||||
|
GGML_ASSERT(!ggml_is_transposed(src1));
|
||||||
|
|
||||||
|
GGML_ASSERT(ne20 % 32 == 0);
|
||||||
|
// !!!!!!!!! TODO: this assert is probably required but not sure!
|
||||||
|
//GGML_ASSERT(ne20 >= 64);
|
||||||
|
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
const uint r2 = ne12/ne22;
|
||||||
|
const uint r3 = ne13/ne23;
|
||||||
|
|
||||||
|
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
||||||
|
// to the matrix-vector kernel
|
||||||
|
int ne11_mm_min = 0;
|
||||||
|
|
||||||
|
const int idx = ((int32_t *) dst->op_params)[0];
|
||||||
|
|
||||||
|
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
||||||
|
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
||||||
|
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
||||||
|
ne11 > ne11_mm_min) {
|
||||||
|
switch (src2->type) {
|
||||||
|
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break;
|
||||||
|
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break;
|
||||||
|
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_0_f32]; break;
|
||||||
|
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_1_f32]; break;
|
||||||
|
case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_0_f32]; break;
|
||||||
|
case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_1_f32]; break;
|
||||||
|
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q8_0_f32]; break;
|
||||||
|
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q2_K_f32]; break;
|
||||||
|
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q3_K_f32]; break;
|
||||||
|
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_K_f32]; break;
|
||||||
|
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_K_f32]; break;
|
||||||
|
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break;
|
||||||
|
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
|
||||||
|
}
|
||||||
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||||
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||||
|
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:3];
|
||||||
|
[encoder setBytes:&ne22 length:sizeof(ne22) atIndex:4];
|
||||||
|
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:5];
|
||||||
|
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:6];
|
||||||
|
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
|
||||||
|
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
|
||||||
|
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
|
||||||
|
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
|
||||||
|
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
|
||||||
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
|
||||||
|
[encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
|
||||||
|
[encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
|
||||||
|
[encoder setBytes:&idx length:sizeof(idx) atIndex:15];
|
||||||
|
// TODO: how to make this an array? read Metal docs
|
||||||
|
for (int j = 0; j < n_as; ++j) {
|
||||||
|
struct ggml_tensor * src_cur = dst->src[2 + j];
|
||||||
|
|
||||||
|
size_t offs_src_cur = 0;
|
||||||
|
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
|
||||||
|
|
||||||
|
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:16 + j];
|
||||||
|
}
|
||||||
|
|
||||||
|
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne21 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
||||||
|
}
|
||||||
|
} break;
|
||||||
case GGML_OP_GET_ROWS:
|
case GGML_OP_GET_ROWS:
|
||||||
{
|
{
|
||||||
switch (src0->type) {
|
switch (src0->type) {
|
||||||
@ -1560,6 +1762,27 @@ void ggml_metal_graph_compute(
|
|||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
|
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_ARGSORT:
|
||||||
|
{
|
||||||
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT( dst->type == GGML_TYPE_I32);
|
||||||
|
|
||||||
|
const int nrows = ggml_nrows(src0);
|
||||||
|
|
||||||
|
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
|
||||||
|
|
||||||
|
switch (order) {
|
||||||
|
case GGML_SORT_ASC: [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_asc]; break;
|
||||||
|
case GGML_SORT_DESC: [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_desc]; break;
|
||||||
|
default: GGML_ASSERT(false);
|
||||||
|
};
|
||||||
|
|
||||||
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
|
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
||||||
|
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
|
||||||
|
} break;
|
||||||
case GGML_OP_DUP:
|
case GGML_OP_DUP:
|
||||||
case GGML_OP_CPY:
|
case GGML_OP_CPY:
|
||||||
case GGML_OP_CONT:
|
case GGML_OP_CONT:
|
||||||
@ -1655,6 +1878,132 @@ void ggml_metal_graph_compute(
|
|||||||
|
|
||||||
// backend interface
|
// backend interface
|
||||||
|
|
||||||
|
static id<MTLDevice> g_backend_device = nil;
|
||||||
|
static int g_backend_device_ref_count = 0;
|
||||||
|
|
||||||
|
static id<MTLDevice> ggml_backend_metal_get_device(void) {
|
||||||
|
if (g_backend_device == nil) {
|
||||||
|
g_backend_device = MTLCreateSystemDefaultDevice();
|
||||||
|
}
|
||||||
|
|
||||||
|
g_backend_device_ref_count++;
|
||||||
|
|
||||||
|
return g_backend_device;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_backend_metal_free_device(void) {
|
||||||
|
assert(g_backend_device_ref_count > 0);
|
||||||
|
|
||||||
|
g_backend_device_ref_count--;
|
||||||
|
|
||||||
|
if (g_backend_device_ref_count == 0) {
|
||||||
|
[g_backend_device release];
|
||||||
|
g_backend_device = nil;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) {
|
||||||
|
struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
|
||||||
|
|
||||||
|
return ctx->data;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
||||||
|
struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
|
||||||
|
|
||||||
|
[ctx->metal release];
|
||||||
|
ggml_backend_metal_free_device();
|
||||||
|
|
||||||
|
free(ctx->data);
|
||||||
|
free(ctx);
|
||||||
|
|
||||||
|
UNUSED(buffer);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_backend_metal_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
||||||
|
GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
|
||||||
|
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
|
||||||
|
|
||||||
|
memcpy((char *)tensor->data + offset, data, size);
|
||||||
|
|
||||||
|
UNUSED(buffer);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_backend_metal_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
||||||
|
GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
|
||||||
|
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
|
||||||
|
|
||||||
|
memcpy(data, (const char *)tensor->data + offset, size);
|
||||||
|
|
||||||
|
UNUSED(buffer);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_backend_metal_buffer_cpy_tensor_from(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst) {
|
||||||
|
ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src));
|
||||||
|
|
||||||
|
UNUSED(buffer);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_backend_metal_buffer_cpy_tensor_to(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst) {
|
||||||
|
ggml_backend_tensor_set(dst, src->data, 0, ggml_nbytes(src));
|
||||||
|
|
||||||
|
UNUSED(buffer);
|
||||||
|
}
|
||||||
|
|
||||||
|
static struct ggml_backend_buffer_i metal_backend_buffer_i = {
|
||||||
|
/* .free_buffer = */ ggml_backend_metal_buffer_free_buffer,
|
||||||
|
/* .get_base = */ ggml_backend_metal_buffer_get_base,
|
||||||
|
/* .init_tensor = */ NULL,
|
||||||
|
/* .set_tensor = */ ggml_backend_metal_buffer_set_tensor,
|
||||||
|
/* .get_tensor = */ ggml_backend_metal_buffer_get_tensor,
|
||||||
|
/* .cpy_tensor_from = */ ggml_backend_metal_buffer_cpy_tensor_from,
|
||||||
|
/* .cpy_tensor_to = */ ggml_backend_metal_buffer_cpy_tensor_to,
|
||||||
|
};
|
||||||
|
|
||||||
|
static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
||||||
|
struct ggml_backend_metal_buffer_context * ctx = malloc(sizeof(struct ggml_backend_metal_buffer_context));
|
||||||
|
|
||||||
|
const size_t size_page = sysconf(_SC_PAGESIZE);
|
||||||
|
|
||||||
|
size_t size_aligned = size;
|
||||||
|
if ((size_aligned % size_page) != 0) {
|
||||||
|
size_aligned += (size_page - (size_aligned % size_page));
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx->data = ggml_metal_host_malloc(size);
|
||||||
|
ctx->metal = [ggml_backend_metal_get_device() newBufferWithBytesNoCopy:ctx->data
|
||||||
|
length:size_aligned
|
||||||
|
options:MTLResourceStorageModeShared
|
||||||
|
deallocator:nil];
|
||||||
|
|
||||||
|
return ggml_backend_buffer_init(buft, metal_backend_buffer_i, ctx, size);
|
||||||
|
}
|
||||||
|
|
||||||
|
static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
|
||||||
|
return 32;
|
||||||
|
UNUSED(buft);
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool ggml_backend_metal_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
|
||||||
|
return ggml_backend_is_metal(backend) || ggml_backend_is_cpu(backend);
|
||||||
|
|
||||||
|
GGML_UNUSED(buft);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
|
||||||
|
static struct ggml_backend_buffer_type ggml_backend_buffer_type_metal = {
|
||||||
|
/* .iface = */ {
|
||||||
|
/* .alloc_buffer = */ ggml_backend_metal_buffer_type_alloc_buffer,
|
||||||
|
/* .get_alignment = */ ggml_backend_metal_buffer_type_get_alignment,
|
||||||
|
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
|
||||||
|
/* .supports_backend = */ ggml_backend_metal_buffer_type_supports_backend,
|
||||||
|
},
|
||||||
|
/* .context = */ NULL,
|
||||||
|
};
|
||||||
|
|
||||||
|
return &ggml_backend_buffer_type_metal;
|
||||||
|
}
|
||||||
|
|
||||||
static const char * ggml_backend_metal_name(ggml_backend_t backend) {
|
static const char * ggml_backend_metal_name(ggml_backend_t backend) {
|
||||||
return "Metal";
|
return "Metal";
|
||||||
|
|
||||||
@ -1667,69 +2016,12 @@ static void ggml_backend_metal_free(ggml_backend_t backend) {
|
|||||||
free(backend);
|
free(backend);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) {
|
|
||||||
return (void *)buffer->context;
|
|
||||||
}
|
|
||||||
|
|
||||||
static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
|
||||||
free(buffer->context);
|
|
||||||
UNUSED(buffer);
|
|
||||||
}
|
|
||||||
|
|
||||||
static struct ggml_backend_buffer_i metal_backend_buffer_i = {
|
|
||||||
/* .free_buffer = */ ggml_backend_metal_buffer_free_buffer,
|
|
||||||
/* .get_base = */ ggml_backend_metal_buffer_get_base,
|
|
||||||
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
|
|
||||||
/* .init_tensor = */ NULL, // no initialization required
|
|
||||||
/* .free_tensor = */ NULL, // no cleanup required
|
|
||||||
};
|
|
||||||
|
|
||||||
static ggml_backend_buffer_t ggml_backend_metal_alloc_buffer(ggml_backend_t backend, size_t size) {
|
|
||||||
struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
|
|
||||||
|
|
||||||
void * data = ggml_metal_host_malloc(size);
|
|
||||||
|
|
||||||
// TODO: set proper name of the buffers
|
|
||||||
ggml_metal_add_buffer(ctx, "backend", data, size, 0);
|
|
||||||
|
|
||||||
return ggml_backend_buffer_init(backend, metal_backend_buffer_i, data, size);
|
|
||||||
}
|
|
||||||
|
|
||||||
static size_t ggml_backend_metal_get_alignment(ggml_backend_t backend) {
|
|
||||||
return 32;
|
|
||||||
UNUSED(backend);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void ggml_backend_metal_set_tensor_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
|
||||||
GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
|
|
||||||
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
|
|
||||||
|
|
||||||
memcpy((char *)tensor->data + offset, data, size);
|
|
||||||
|
|
||||||
UNUSED(backend);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void ggml_backend_metal_get_tensor_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
|
||||||
GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
|
|
||||||
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
|
|
||||||
|
|
||||||
memcpy(data, (const char *)tensor->data + offset, size);
|
|
||||||
|
|
||||||
UNUSED(backend);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void ggml_backend_metal_synchronize(ggml_backend_t backend) {
|
static void ggml_backend_metal_synchronize(ggml_backend_t backend) {
|
||||||
UNUSED(backend);
|
UNUSED(backend);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_backend_metal_cpy_tensor_from(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) {
|
static ggml_backend_buffer_type_t ggml_backend_metal_get_default_buffer_type(ggml_backend_t backend) {
|
||||||
ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src));
|
return ggml_backend_metal_buffer_type();
|
||||||
|
|
||||||
UNUSED(backend);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void ggml_backend_metal_cpy_tensor_to(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) {
|
|
||||||
ggml_backend_tensor_set_async(dst, src->data, 0, ggml_nbytes(src));
|
|
||||||
|
|
||||||
UNUSED(backend);
|
UNUSED(backend);
|
||||||
}
|
}
|
||||||
@ -1741,32 +2033,43 @@ static void ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml
|
|||||||
}
|
}
|
||||||
|
|
||||||
static bool ggml_backend_metal_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
|
static bool ggml_backend_metal_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
|
||||||
return true;
|
return ggml_metal_supports_op(op);
|
||||||
|
|
||||||
UNUSED(backend);
|
UNUSED(backend);
|
||||||
UNUSED(op);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static struct ggml_backend_i metal_backend_i = {
|
static struct ggml_backend_i metal_backend_i = {
|
||||||
/* .get_name = */ ggml_backend_metal_name,
|
/* .get_name = */ ggml_backend_metal_name,
|
||||||
/* .free = */ ggml_backend_metal_free,
|
/* .free = */ ggml_backend_metal_free,
|
||||||
/* .alloc_buffer = */ ggml_backend_metal_alloc_buffer,
|
/* .get_default_buffer_type = */ ggml_backend_metal_get_default_buffer_type,
|
||||||
/* .get_alignment = */ ggml_backend_metal_get_alignment,
|
/* .set_tensor_async = */ NULL,
|
||||||
/* .set_tensor_async = */ ggml_backend_metal_set_tensor_async,
|
/* .get_tensor_async = */ NULL,
|
||||||
/* .get_tensor_async = */ ggml_backend_metal_get_tensor_async,
|
/* .cpy_tensor_from_async = */ NULL,
|
||||||
/* .synchronize = */ ggml_backend_metal_synchronize,
|
/* .cpy_tensor_to_async = */ NULL,
|
||||||
/* .cpy_tensor_from = */ ggml_backend_metal_cpy_tensor_from,
|
/* .synchronize = */ ggml_backend_metal_synchronize,
|
||||||
/* .cpy_tensor_to = */ ggml_backend_metal_cpy_tensor_to,
|
/* .graph_plan_create = */ NULL, // the metal implementation does not require creating graph plans atm
|
||||||
/* .graph_plan_create = */ NULL, // the metal implementation does not require creating graph plans atm
|
/* .graph_plan_free = */ NULL,
|
||||||
/* .graph_plan_free = */ NULL,
|
/* .graph_plan_compute = */ NULL,
|
||||||
/* .graph_plan_compute = */ NULL,
|
/* .graph_compute = */ ggml_backend_metal_graph_compute,
|
||||||
/* .graph_compute = */ ggml_backend_metal_graph_compute,
|
/* .supports_op = */ ggml_backend_metal_supports_op,
|
||||||
/* .supports_op = */ ggml_backend_metal_supports_op,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
ggml_backend_t ggml_backend_metal_init(void) {
|
// TODO: make a common log callback for all backends in ggml-backend
|
||||||
struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
|
static void ggml_backend_log_callback(enum ggml_log_level level, const char * msg, void * user_data) {
|
||||||
|
fprintf(stderr, "%s", msg);
|
||||||
|
|
||||||
ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS);
|
UNUSED(level);
|
||||||
|
UNUSED(user_data);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_backend_t ggml_backend_metal_init(void) {
|
||||||
|
ggml_metal_log_set_callback(ggml_backend_log_callback, NULL);
|
||||||
|
|
||||||
|
struct ggml_metal_context * ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS);
|
||||||
|
|
||||||
|
if (ctx == NULL) {
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
ggml_backend_t metal_backend = malloc(sizeof(struct ggml_backend));
|
ggml_backend_t metal_backend = malloc(sizeof(struct ggml_backend));
|
||||||
|
|
||||||
@ -1783,7 +2086,18 @@ bool ggml_backend_is_metal(ggml_backend_t backend) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
|
void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
|
||||||
|
GGML_ASSERT(ggml_backend_is_metal(backend));
|
||||||
|
|
||||||
struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
|
struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
|
||||||
|
|
||||||
ggml_metal_set_n_cb(ctx, n_cb);
|
ggml_metal_set_n_cb(ctx, n_cb);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data) {
|
||||||
|
return ggml_backend_metal_init();
|
||||||
|
|
||||||
|
GGML_UNUSED(params);
|
||||||
|
GGML_UNUSED(user_data);
|
||||||
|
}
|
||||||
|
|
||||||
|
GGML_BACKEND_REGISTER("Metal", ggml_backend_reg_metal_init, ggml_backend_metal_buffer_type(), NULL)
|
||||||
|
755
ggml-metal.metal
755
ggml-metal.metal
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user