metal : relax conditions on fast matrix multiplication kernel

This commit is contained in:
Georgi Gerganov 2023-09-14 16:14:29 +03:00
parent 71ca2fad7d
commit 336afbcb76
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
6 changed files with 154 additions and 58 deletions

View File

@ -52,7 +52,7 @@ int main(int argc, char ** argv) {
ggml_metal_set_tensor(ctx_metal, input); ggml_metal_set_tensor(ctx_metal, input);
// warmup // warmup
ggml_metal_graph_compute(ctx_metal, &gf); ggml_metal_graph_compute(ctx_metal, &gf, false);
const int n_iter = 16; const int n_iter = 16;
@ -60,7 +60,7 @@ int main(int argc, char ** argv) {
// the actual inference happens here // the actual inference happens here
for (int i = 0; i < n_iter; ++i) { for (int i = 0; i < n_iter; ++i) {
ggml_metal_graph_compute(ctx_metal, &gf); ggml_metal_graph_compute(ctx_metal, &gf, false);
} }
const int64_t t1 = ggml_time_us(); const int64_t t1 = ggml_time_us();

View File

@ -77,7 +77,7 @@ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx);
// same as ggml_graph_compute but uses Metal // same as ggml_graph_compute but uses Metal
// creates gf->n_threads command buffers in parallel // creates gf->n_threads command buffers in parallel
void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf); void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf, bool concurrent);
#ifdef __cplusplus #ifdef __cplusplus
} }

View File

@ -66,6 +66,7 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(soft_max_4); GGML_METAL_DECL_KERNEL(soft_max_4);
GGML_METAL_DECL_KERNEL(diag_mask_inf); GGML_METAL_DECL_KERNEL(diag_mask_inf);
GGML_METAL_DECL_KERNEL(diag_mask_inf_8); GGML_METAL_DECL_KERNEL(diag_mask_inf_8);
GGML_METAL_DECL_KERNEL(get_rows_f32);
GGML_METAL_DECL_KERNEL(get_rows_f16); GGML_METAL_DECL_KERNEL(get_rows_f16);
GGML_METAL_DECL_KERNEL(get_rows_q4_0); GGML_METAL_DECL_KERNEL(get_rows_q4_0);
GGML_METAL_DECL_KERNEL(get_rows_q4_1); GGML_METAL_DECL_KERNEL(get_rows_q4_1);
@ -145,7 +146,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
ctx->n_buffers = 0; ctx->n_buffers = 0;
ctx->concur_list_len = 0; ctx->concur_list_len = 0;
ctx->d_queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT); ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
#ifdef GGML_SWIFT #ifdef GGML_SWIFT
// load the default.metallib file // load the default.metallib file
@ -175,7 +176,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
//NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"]; //NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"];
NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]]; NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"]; NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
metal_printf("%s: loading '%s'\n", __func__, [path UTF8String]); metal_printf("%s: loading '%s'\n", __func__, [path UTF8String]);
NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error]; NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
@ -224,6 +225,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(soft_max_4); GGML_METAL_ADD_KERNEL(soft_max_4);
GGML_METAL_ADD_KERNEL(diag_mask_inf); GGML_METAL_ADD_KERNEL(diag_mask_inf);
GGML_METAL_ADD_KERNEL(diag_mask_inf_8); GGML_METAL_ADD_KERNEL(diag_mask_inf_8);
GGML_METAL_ADD_KERNEL(get_rows_f32);
GGML_METAL_ADD_KERNEL(get_rows_f16); GGML_METAL_ADD_KERNEL(get_rows_f16);
GGML_METAL_ADD_KERNEL(get_rows_q4_0); GGML_METAL_ADD_KERNEL(get_rows_q4_0);
GGML_METAL_ADD_KERNEL(get_rows_q4_1); GGML_METAL_ADD_KERNEL(get_rows_q4_1);
@ -293,7 +295,9 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(gelu); GGML_METAL_DEL_KERNEL(gelu);
GGML_METAL_DEL_KERNEL(soft_max); GGML_METAL_DEL_KERNEL(soft_max);
GGML_METAL_DEL_KERNEL(soft_max_4); GGML_METAL_DEL_KERNEL(soft_max_4);
GGML_METAL_DEL_KERNEL(diag_mask_inf);
GGML_METAL_DEL_KERNEL(diag_mask_inf_8); GGML_METAL_DEL_KERNEL(diag_mask_inf_8);
GGML_METAL_DEL_KERNEL(get_rows_f32);
GGML_METAL_DEL_KERNEL(get_rows_f16); GGML_METAL_DEL_KERNEL(get_rows_f16);
GGML_METAL_DEL_KERNEL(get_rows_q4_0); GGML_METAL_DEL_KERNEL(get_rows_q4_0);
GGML_METAL_DEL_KERNEL(get_rows_q4_1); GGML_METAL_DEL_KERNEL(get_rows_q4_1);
@ -386,6 +390,7 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
for (int i = 0; i < ctx->n_buffers; ++i) { for (int i = 0; i < ctx->n_buffers; ++i) {
const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data; const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
//metal_printf("ioffs = %10ld, tsize = %10ld, sum = %10ld, ctx->buffers[%d].size = %10ld, name = %s\n", ioffs, tsize, ioffs + tsize, i, ctx->buffers[i].size, ctx->buffers[i].name);
if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) { if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
*offs = (size_t) ioffs; *offs = (size_t) ioffs;
@ -605,14 +610,15 @@ void ggml_metal_graph_find_concurrency(
void ggml_metal_graph_compute( void ggml_metal_graph_compute(
struct ggml_metal_context * ctx, struct ggml_metal_context * ctx,
struct ggml_cgraph * gf) { struct ggml_cgraph * gf,
bool concurrent) {
@autoreleasepool { @autoreleasepool {
// if there is ctx->concur_list, dispatch concurrently // if there is ctx->concur_list, dispatch concurrently
// else fallback to serial dispatch // else fallback to serial dispatch
MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor; MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_CONCUR; const bool has_concur = concurrent && ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_CONCUR;
const int n_nodes = has_concur ? ctx->concur_list_len : gf->n_nodes; const int n_nodes = has_concur ? ctx->concur_list_len : gf->n_nodes;
edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial; edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial;
@ -723,6 +729,7 @@ void ggml_metal_graph_compute(
case GGML_OP_ADD: case GGML_OP_ADD:
{ {
GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));
// utilize float4 // utilize float4
GGML_ASSERT(ne00 % 4 == 0); GGML_ASSERT(ne00 % 4 == 0);
@ -730,6 +737,7 @@ void ggml_metal_graph_compute(
if (ggml_nelements(src1) == ne10) { if (ggml_nelements(src1) == ne10) {
// src1 is a row // src1 is a row
GGML_ASSERT(ne11 == 1);
[encoder setComputePipelineState:ctx->pipeline_add_row]; [encoder setComputePipelineState:ctx->pipeline_add_row];
} else { } else {
[encoder setComputePipelineState:ctx->pipeline_add]; [encoder setComputePipelineState:ctx->pipeline_add];
@ -746,6 +754,7 @@ void ggml_metal_graph_compute(
case GGML_OP_MUL: case GGML_OP_MUL:
{ {
GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));
// utilize float4 // utilize float4
GGML_ASSERT(ne00 % 4 == 0); GGML_ASSERT(ne00 % 4 == 0);
@ -753,6 +762,7 @@ void ggml_metal_graph_compute(
if (ggml_nelements(src1) == ne10) { if (ggml_nelements(src1) == ne10) {
// src1 is a row // src1 is a row
GGML_ASSERT(ne11 == 1);
[encoder setComputePipelineState:ctx->pipeline_mul_row]; [encoder setComputePipelineState:ctx->pipeline_mul_row];
} else { } else {
[encoder setComputePipelineState:ctx->pipeline_mul]; [encoder setComputePipelineState:ctx->pipeline_mul];
@ -768,6 +778,8 @@ void ggml_metal_graph_compute(
} break; } break;
case GGML_OP_SCALE: case GGML_OP_SCALE:
{ {
GGML_ASSERT(ggml_is_contiguous(src0));
const float scale = *(const float *) src1->data; const float scale = *(const float *) src1->data;
[encoder setComputePipelineState:ctx->pipeline_scale]; [encoder setComputePipelineState:ctx->pipeline_scale];
@ -867,8 +879,8 @@ void ggml_metal_graph_compute(
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
if (ggml_is_contiguous(src0) && if (!ggml_is_transposed(src0) &&
ggml_is_contiguous(src1) && !ggml_is_transposed(src1) &&
src1t == GGML_TYPE_F32 && src1t == GGML_TYPE_F32 &&
[ctx->device supportsFamily:MTLGPUFamilyApple7] && [ctx->device supportsFamily:MTLGPUFamilyApple7] &&
ne00%32 == 0 && ne00%32 == 0 &&
@ -893,9 +905,12 @@ void ggml_metal_graph_compute(
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5]; [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6]; [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7]; [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8]; [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9]; [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:10]; [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:13];
[encoder setThreadgroupMemoryLength:8192 atIndex:0]; [encoder setThreadgroupMemoryLength:8192 atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
} else { } else {
@ -911,7 +926,9 @@ void ggml_metal_graph_compute(
nth1 = 1; nth1 = 1;
if (ne11 * ne12 < 4) { if (ne11 * ne12 < 4) {
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row]; [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row];
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { //} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
} else if (false) {
// TODO: with ggml_mul_mat_pad this kernel no longer seems to be needed
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_l4]; [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_l4];
nrows = ne11; nrows = ne11;
} else { } else {
@ -1045,6 +1062,7 @@ void ggml_metal_graph_compute(
case GGML_OP_GET_ROWS: case GGML_OP_GET_ROWS:
{ {
switch (src0->type) { switch (src0->type) {
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_get_rows_f32]; break;
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break; case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break; case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break; case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
@ -1060,9 +1078,9 @@ void ggml_metal_graph_compute(
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&(src0->ne[0]) length:sizeof( int64_t) atIndex:3]; [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
[encoder setBytes:&(src0->nb[1]) length:sizeof(uint64_t) atIndex:4]; [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
[encoder setBytes:&(dst->nb[1]) length:sizeof(uint64_t) atIndex:5]; [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:5];
const int64_t n = ggml_nelements(src1); const int64_t n = ggml_nelements(src1);

View File

@ -38,7 +38,7 @@ kernel void kernel_add_row(
device const float4 * src0, device const float4 * src0,
device const float4 * src1, device const float4 * src1,
device float4 * dst, device float4 * dst,
constant int64_t & nb, constant int64_t & nb,
uint tpig[[thread_position_in_grid]]) { uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] + src1[tpig % nb]; dst[tpig] = src0[tpig] + src1[tpig % nb];
} }
@ -1321,7 +1321,6 @@ kernel void kernel_mul_mat_q3_K_f32(
dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = sumf1[row]; dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = sumf1[row];
} }
} }
} }
#else #else
kernel void kernel_mul_mat_q3_K_f32( kernel void kernel_mul_mat_q3_K_f32(
@ -1865,6 +1864,15 @@ kernel void kernel_mul_mat_q6_K_f32(
//============================= templates and their specializations ============================= //============================= templates and their specializations =============================
// NOTE: this is not dequantizing - we are simply fitting the template
template <typename type4x4>
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
float4x4 temp = *(((device float4x4 *)src));
for (int i = 0; i < 16; i++){
reg[i/4][i%4] = temp[i/4][i%4];
}
}
template <typename type4x4> template <typename type4x4>
void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) { void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
half4x4 temp = *(((device half4x4 *)src)); half4x4 temp = *(((device half4x4 *)src));
@ -1875,7 +1883,6 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
template <typename type4x4> template <typename type4x4>
void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) { void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 1); device const uint16_t * qs = ((device const uint16_t *)xb + 1);
const float d1 = il ? (xb->d / 16.h) : xb->d; const float d1 = il ? (xb->d / 16.h) : xb->d;
const float d2 = d1 / 256.f; const float d2 = d1 / 256.f;
@ -1887,12 +1894,10 @@ void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg
reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md; reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md; reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
} }
} }
template <typename type4x4> template <typename type4x4>
void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) { void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 2); device const uint16_t * qs = ((device const uint16_t *)xb + 2);
const float d1 = il ? (xb->d / 16.h) : xb->d; const float d1 = il ? (xb->d / 16.h) : xb->d;
const float d2 = d1 / 256.f; const float d2 = d1 / 256.f;
@ -1964,7 +1969,6 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
for (int i = 0; i < 16; ++i) { for (int i = 0; i < 16; ++i) {
reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml); reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
} }
#else #else
float kcoef = il&1 ? 1.f/16.f : 1.f; float kcoef = il&1 ? 1.f/16.f : 1.f;
uint16_t kmask = il&1 ? 0xF0 : 0x0F; uint16_t kmask = il&1 ? 0xF0 : 0x0F;
@ -2110,22 +2114,25 @@ kernel void kernel_get_rows(
// each block_q contains 16*nl weights // each block_q contains 16*nl weights
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)> template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
kernel void kernel_mul_mm(device const uchar * src0, kernel void kernel_mul_mm(device const uchar * src0,
device const float * src1, device const uchar * src1,
device float * dst, device float * dst,
constant int64_t & ne00, constant int64_t & ne00,
constant int64_t & ne02, constant int64_t & ne02,
constant int64_t & nb01, constant int64_t & nb01,
constant int64_t & nb02, constant int64_t & nb02,
constant int64_t & ne12, constant int64_t & ne12,
constant int64_t & ne0, constant int64_t & nb10,
constant int64_t & ne1, constant int64_t & nb11,
constant uint & gqa, constant int64_t & nb12,
threadgroup uchar * shared_memory [[threadgroup(0)]], constant int64_t & ne0,
uint3 tgpig[[threadgroup_position_in_grid]], constant int64_t & ne1,
uint tiitg[[thread_index_in_threadgroup]], constant uint & gqa,
uint sgitg[[simdgroup_index_in_threadgroup]]) { threadgroup uchar * shared_memory [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
threadgroup half * sa = ((threadgroup half *)shared_memory); threadgroup half * sa = (threadgroup half *)(shared_memory);
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096); threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
const uint r0 = tgpig.y; const uint r0 = tgpig.y;
@ -2138,7 +2145,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
simdgroup_half8x8 ma[4]; simdgroup_half8x8 ma[4];
simdgroup_float8x8 mb[2]; simdgroup_float8x8 mb[2];
simdgroup_float8x8 c_res[8]; simdgroup_float8x8 c_res[8];
for (int i = 0; i < 8; i++){ for (int i = 0; i < 8; i++){
@ -2146,10 +2153,15 @@ kernel void kernel_mul_mm(device const uchar * src0,
} }
short il = (tiitg % THREAD_PER_ROW); short il = (tiitg % THREAD_PER_ROW);
uint offset0 = im/gqa*nb02; ushort offset1 = il/nl;
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1; uint offset0 = im/gqa*nb02;
device const float * y = src1 + (r1 * BLOCK_SIZE_N + thread_col) * ne00 \ ushort offset1 = il/nl;
+ BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL) + im * ne00 * ne1;
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
device const float * y = (device const float *)(src1
+ nb12 * im
+ nb11 * (r1 * BLOCK_SIZE_N + thread_col)
+ nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
//load data and store to threadgroup memory //load data and store to threadgroup memory
@ -2229,6 +2241,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \ typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
constant uint64_t &, constant uint64_t &, uint, uint, uint); constant uint64_t &, constant uint64_t &, uint, uint, uint);
template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>; template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>; template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>; template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
@ -2239,14 +2252,27 @@ template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>; template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>; template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
typedef void (mat_mm_t)(device const uchar *, device const float *, device float *, constant int64_t &,\ typedef void (mat_mm_t)(
constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \ device const uchar * src0,
constant int64_t &, constant int64_t &, constant uint &, threadgroup uchar *, uint3, uint, uint); device const uchar * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne02,
constant int64_t & nb01,
constant int64_t & nb02,
constant int64_t & ne12,
constant int64_t & nb10,
constant int64_t & nb11,
constant int64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & gqa,
threadgroup uchar *, uint3, uint, uint);
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>; template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>; template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>; template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>; template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>; template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>; template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>; template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;

20
ggml.c
View File

@ -4303,10 +4303,21 @@ int64_t ggml_nrows(const struct ggml_tensor * tensor) {
} }
size_t ggml_nbytes(const struct ggml_tensor * tensor) { size_t ggml_nbytes(const struct ggml_tensor * tensor) {
size_t nbytes = tensor->ne[0]*tensor->nb[0]/ggml_blck_size(tensor->type); size_t nbytes;
for (int i = 1; i < GGML_MAX_DIMS; ++i) { size_t blck_size = ggml_blck_size(tensor->type);
nbytes += (tensor->ne[i] - 1)*tensor->nb[i]; if (blck_size == 1) {
nbytes = ggml_type_size(tensor->type);
for (int i = 0; i < GGML_MAX_DIMS; ++i) {
nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
}
} }
else {
nbytes = tensor->ne[0]*tensor->nb[0]/blck_size;
for (int i = 1; i < GGML_MAX_DIMS; ++i) {
nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
}
}
return nbytes; return nbytes;
} }
@ -18340,7 +18351,8 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s\n", GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s\n",
i, i,
node->ne[0], node->ne[1], node->ne[0], node->ne[1],
ggml_op_name(node->op)); ggml_op_name(node->op),
ggml_get_name(node));
} }
for (int i = 0; i < GGML_OP_COUNT; i++) { for (int i = 0; i < GGML_OP_COUNT; i++) {

View File

@ -438,6 +438,50 @@ static void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph *
ggml_graph_compute(graph, &plan); ggml_graph_compute(graph, &plan);
} }
//// EXPERIMENTAL:
////
//// faster matrix multiplications for tensors that do not have dimension 0 divisible by "pad"
//// the idea is to represent the original matrix multiplication:
////
//// Z = X @ Y
////
//// with the sum of two matrix multiplications:
////
//// Z = (X_0 @ Y_0) + (X_1 @ Y_1)
////
//// here X_0 and Y_0 are views of X and Y that have dimension 0 divisible by "pad"
//// and X_1 and Y_1 are the remaining views. X_1 and Y_1 end up being small matrices that can be processed with more
//// general-purpose kernels
////
//static struct ggml_tensor * ggml_mul_mat_pad(struct ggml_context * ctx, struct ggml_tensor * x, struct ggml_tensor * y, int pad = 32) {
////#if !defined(GGML_USE_METAL)
//// return ggml_mul_mat(ctx, x, y);
////#endif
//
// // use padding only if dimension 0 is at least 8 times larger than the padding
// // else we won't get much benefit from the optimization
// const int n_pad_req = 8;
//
// if (x->ne[0] % pad == 0 || x->ne[0] / pad < n_pad_req) {
// return ggml_mul_mat(ctx, x, y);
// }
//
// struct ggml_tensor * x_0 = ggml_view_3d(ctx, x, (x->ne[0]/pad)*pad, x->ne[1], x->ne[2], x->nb[1], x->nb[2], 0);
// struct ggml_tensor * x_1 = ggml_view_3d(ctx, x, x->ne[0]%pad, x->ne[1], x->ne[2], x->nb[1], x->nb[2], x_0->ne[0]*x_0->nb[0]);
//
// struct ggml_tensor * y_0 = ggml_view_3d(ctx, y, (y->ne[0]/pad)*pad, y->ne[1], y->ne[2], y->nb[1], y->nb[2], 0);
// struct ggml_tensor * y_1 = ggml_view_3d(ctx, y, y->ne[0]%pad, y->ne[1], y->ne[2], y->nb[1], y->nb[2], y_0->ne[0]*y_0->nb[0]);
//
// return ggml_add(ctx,
// ggml_mul_mat(ctx, x_0, y_0),
// ggml_mul_mat(ctx, x_1, y_1));
//}
//
//// TODO: check if other backends benefit from this and enable for all
//#if defined(GGML_USE_METAL)
////#define ggml_mul_mat ggml_mul_mat_pad
//#endif
// //
// llama helpers // llama helpers
// //
@ -2968,11 +3012,7 @@ static bool llama_eval_internal(
#ifdef GGML_USE_METAL #ifdef GGML_USE_METAL
if (lctx.ctx_metal) { if (lctx.ctx_metal) {
ggml_metal_set_n_cb (lctx.ctx_metal, n_threads); ggml_metal_set_n_cb (lctx.ctx_metal, n_threads);
ggml_metal_graph_compute(lctx.ctx_metal, gf); ggml_metal_graph_compute(lctx.ctx_metal, gf, n_tokens > 1);
ggml_metal_get_tensor (lctx.ctx_metal, res);
if (!lctx.embedding.empty()) {
ggml_metal_get_tensor(lctx.ctx_metal, embeddings);
}
} else { } else {
ggml_graph_compute_helper(lctx.work_buffer, gf, n_threads); ggml_graph_compute_helper(lctx.work_buffer, gf, n_threads);
} }