From 336afbcb76af366b903dd42b812b3abd178f681a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 14 Sep 2023 16:14:29 +0300 Subject: [PATCH] metal : relax conditions on fast matrix multiplication kernel --- examples/metal/metal.cpp | 4 +- ggml-metal.h | 2 +- ggml-metal.m | 44 +++++++++++++------ ggml-metal.metal | 92 ++++++++++++++++++++++++++-------------- ggml.c | 20 +++++++-- llama.cpp | 50 +++++++++++++++++++--- 6 files changed, 154 insertions(+), 58 deletions(-) diff --git a/examples/metal/metal.cpp b/examples/metal/metal.cpp index c05a4fa93..50a651fa1 100644 --- a/examples/metal/metal.cpp +++ b/examples/metal/metal.cpp @@ -52,7 +52,7 @@ int main(int argc, char ** argv) { ggml_metal_set_tensor(ctx_metal, input); // warmup - ggml_metal_graph_compute(ctx_metal, &gf); + ggml_metal_graph_compute(ctx_metal, &gf, false); const int n_iter = 16; @@ -60,7 +60,7 @@ int main(int argc, char ** argv) { // the actual inference happens here for (int i = 0; i < n_iter; ++i) { - ggml_metal_graph_compute(ctx_metal, &gf); + ggml_metal_graph_compute(ctx_metal, &gf, false); } const int64_t t1 = ggml_time_us(); diff --git a/ggml-metal.h b/ggml-metal.h index fca28d37e..4e36cc129 100644 --- a/ggml-metal.h +++ b/ggml-metal.h @@ -77,7 +77,7 @@ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx); // same as ggml_graph_compute but uses Metal // creates gf->n_threads command buffers in parallel -void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf); +void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf, bool concurrent); #ifdef __cplusplus } diff --git a/ggml-metal.m b/ggml-metal.m index 4f3f14e24..7ec31c21b 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -66,6 +66,7 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(soft_max_4); GGML_METAL_DECL_KERNEL(diag_mask_inf); GGML_METAL_DECL_KERNEL(diag_mask_inf_8); + GGML_METAL_DECL_KERNEL(get_rows_f32); GGML_METAL_DECL_KERNEL(get_rows_f16); GGML_METAL_DECL_KERNEL(get_rows_q4_0); GGML_METAL_DECL_KERNEL(get_rows_q4_1); @@ -145,7 +146,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { ctx->n_buffers = 0; ctx->concur_list_len = 0; - ctx->d_queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT); + ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT); #ifdef GGML_SWIFT // load the default.metallib file @@ -175,7 +176,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { //NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"]; NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]]; - NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"]; + NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"]; metal_printf("%s: loading '%s'\n", __func__, [path UTF8String]); NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error]; @@ -224,6 +225,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(soft_max_4); GGML_METAL_ADD_KERNEL(diag_mask_inf); GGML_METAL_ADD_KERNEL(diag_mask_inf_8); + GGML_METAL_ADD_KERNEL(get_rows_f32); GGML_METAL_ADD_KERNEL(get_rows_f16); GGML_METAL_ADD_KERNEL(get_rows_q4_0); GGML_METAL_ADD_KERNEL(get_rows_q4_1); @@ -293,7 +295,9 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(gelu); GGML_METAL_DEL_KERNEL(soft_max); GGML_METAL_DEL_KERNEL(soft_max_4); + GGML_METAL_DEL_KERNEL(diag_mask_inf); GGML_METAL_DEL_KERNEL(diag_mask_inf_8); + GGML_METAL_DEL_KERNEL(get_rows_f32); GGML_METAL_DEL_KERNEL(get_rows_f16); GGML_METAL_DEL_KERNEL(get_rows_q4_0); GGML_METAL_DEL_KERNEL(get_rows_q4_1); @@ -386,6 +390,7 @@ static id ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru for (int i = 0; i < ctx->n_buffers; ++i) { const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data; + //metal_printf("ioffs = %10ld, tsize = %10ld, sum = %10ld, ctx->buffers[%d].size = %10ld, name = %s\n", ioffs, tsize, ioffs + tsize, i, ctx->buffers[i].size, ctx->buffers[i].name); if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) { *offs = (size_t) ioffs; @@ -605,14 +610,15 @@ void ggml_metal_graph_find_concurrency( void ggml_metal_graph_compute( struct ggml_metal_context * ctx, - struct ggml_cgraph * gf) { + struct ggml_cgraph * gf, + bool concurrent) { @autoreleasepool { // if there is ctx->concur_list, dispatch concurrently // else fallback to serial dispatch MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor; - const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_CONCUR; + const bool has_concur = concurrent && ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_CONCUR; const int n_nodes = has_concur ? ctx->concur_list_len : gf->n_nodes; edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial; @@ -723,6 +729,7 @@ void ggml_metal_graph_compute( case GGML_OP_ADD: { GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); // utilize float4 GGML_ASSERT(ne00 % 4 == 0); @@ -730,6 +737,7 @@ void ggml_metal_graph_compute( if (ggml_nelements(src1) == ne10) { // src1 is a row + GGML_ASSERT(ne11 == 1); [encoder setComputePipelineState:ctx->pipeline_add_row]; } else { [encoder setComputePipelineState:ctx->pipeline_add]; @@ -746,6 +754,7 @@ void ggml_metal_graph_compute( case GGML_OP_MUL: { GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); // utilize float4 GGML_ASSERT(ne00 % 4 == 0); @@ -753,6 +762,7 @@ void ggml_metal_graph_compute( if (ggml_nelements(src1) == ne10) { // src1 is a row + GGML_ASSERT(ne11 == 1); [encoder setComputePipelineState:ctx->pipeline_mul_row]; } else { [encoder setComputePipelineState:ctx->pipeline_mul]; @@ -768,6 +778,8 @@ void ggml_metal_graph_compute( } break; case GGML_OP_SCALE: { + GGML_ASSERT(ggml_is_contiguous(src0)); + const float scale = *(const float *) src1->data; [encoder setComputePipelineState:ctx->pipeline_scale]; @@ -867,8 +879,8 @@ void ggml_metal_graph_compute( // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel - if (ggml_is_contiguous(src0) && - ggml_is_contiguous(src1) && + if (!ggml_is_transposed(src0) && + !ggml_is_transposed(src1) && src1t == GGML_TYPE_F32 && [ctx->device supportsFamily:MTLGPUFamilyApple7] && ne00%32 == 0 && @@ -893,9 +905,12 @@ void ggml_metal_graph_compute( [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5]; [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6]; [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9]; - [encoder setBytes:&gqa length:sizeof(gqa) atIndex:10]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12]; + [encoder setBytes:&gqa length:sizeof(gqa) atIndex:13]; [encoder setThreadgroupMemoryLength:8192 atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; } else { @@ -911,7 +926,9 @@ void ggml_metal_graph_compute( nth1 = 1; if (ne11 * ne12 < 4) { [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row]; - } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { + //} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { + } else if (false) { + // TODO: with ggml_mul_mat_pad this kernel no longer seems to be needed [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_l4]; nrows = ne11; } else { @@ -1045,6 +1062,7 @@ void ggml_metal_graph_compute( case GGML_OP_GET_ROWS: { switch (src0->type) { + case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_get_rows_f32]; break; case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break; case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break; case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break; @@ -1060,9 +1078,9 @@ void ggml_metal_graph_compute( [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&(src0->ne[0]) length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&(src0->nb[1]) length:sizeof(uint64_t) atIndex:4]; - [encoder setBytes:&(dst->nb[1]) length:sizeof(uint64_t) atIndex:5]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:5]; const int64_t n = ggml_nelements(src1); diff --git a/ggml-metal.metal b/ggml-metal.metal index f45b1490f..ea8b42844 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -38,7 +38,7 @@ kernel void kernel_add_row( device const float4 * src0, device const float4 * src1, device float4 * dst, - constant int64_t & nb, + constant int64_t & nb, uint tpig[[thread_position_in_grid]]) { dst[tpig] = src0[tpig] + src1[tpig % nb]; } @@ -1321,7 +1321,6 @@ kernel void kernel_mul_mat_q3_K_f32( dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = sumf1[row]; } } - } #else kernel void kernel_mul_mat_q3_K_f32( @@ -1865,6 +1864,15 @@ kernel void kernel_mul_mat_q6_K_f32( //============================= templates and their specializations ============================= +// NOTE: this is not dequantizing - we are simply fitting the template +template +void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) { + float4x4 temp = *(((device float4x4 *)src)); + for (int i = 0; i < 16; i++){ + reg[i/4][i%4] = temp[i/4][i%4]; + } +} + template void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) { half4x4 temp = *(((device half4x4 *)src)); @@ -1875,7 +1883,6 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) template void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) { - device const uint16_t * qs = ((device const uint16_t *)xb + 1); const float d1 = il ? (xb->d / 16.h) : xb->d; const float d2 = d1 / 256.f; @@ -1887,12 +1894,10 @@ void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md; reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md; } - } template void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) { - device const uint16_t * qs = ((device const uint16_t *)xb + 2); const float d1 = il ? (xb->d / 16.h) : xb->d; const float d2 = d1 / 256.f; @@ -1964,7 +1969,6 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg for (int i = 0; i < 16; ++i) { reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml); } - #else float kcoef = il&1 ? 1.f/16.f : 1.f; uint16_t kmask = il&1 ? 0xF0 : 0x0F; @@ -2110,22 +2114,25 @@ kernel void kernel_get_rows( // each block_q contains 16*nl weights template kernel void kernel_mul_mm(device const uchar * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne02, - constant int64_t & nb01, - constant int64_t & nb02, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & gqa, - threadgroup uchar * shared_memory [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const uchar * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne02, + constant int64_t & nb01, + constant int64_t & nb02, + constant int64_t & ne12, + constant int64_t & nb10, + constant int64_t & nb11, + constant int64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & gqa, + threadgroup uchar * shared_memory [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { - threadgroup half * sa = ((threadgroup half *)shared_memory); + threadgroup half * sa = (threadgroup half *)(shared_memory); threadgroup float * sb = (threadgroup float *)(shared_memory + 4096); const uint r0 = tgpig.y; @@ -2138,7 +2145,7 @@ kernel void kernel_mul_mm(device const uchar * src0, short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; - simdgroup_half8x8 ma[4]; + simdgroup_half8x8 ma[4]; simdgroup_float8x8 mb[2]; simdgroup_float8x8 c_res[8]; for (int i = 0; i < 8; i++){ @@ -2146,10 +2153,15 @@ kernel void kernel_mul_mm(device const uchar * src0, } short il = (tiitg % THREAD_PER_ROW); - uint offset0 = im/gqa*nb02; ushort offset1 = il/nl; - device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1; - device const float * y = src1 + (r1 * BLOCK_SIZE_N + thread_col) * ne00 \ - + BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL) + im * ne00 * ne1; + + uint offset0 = im/gqa*nb02; + ushort offset1 = il/nl; + + device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1; + device const float * y = (device const float *)(src1 + + nb12 * im + + nb11 * (r1 * BLOCK_SIZE_N + thread_col) + + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { //load data and store to threadgroup memory @@ -2229,6 +2241,7 @@ kernel void kernel_mul_mm(device const uchar * src0, typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \ constant uint64_t &, constant uint64_t &, uint, uint, uint); +template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows; @@ -2239,14 +2252,27 @@ template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows; -typedef void (mat_mm_t)(device const uchar *, device const float *, device float *, constant int64_t &,\ - constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \ - constant int64_t &, constant int64_t &, constant uint &, threadgroup uchar *, uint3, uint, uint); +typedef void (mat_mm_t)( + device const uchar * src0, + device const uchar * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne02, + constant int64_t & nb01, + constant int64_t & nb02, + constant int64_t & ne12, + constant int64_t & nb10, + constant int64_t & nb11, + constant int64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & gqa, + threadgroup uchar *, uint3, uint, uint); -template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm; diff --git a/ggml.c b/ggml.c index a9cffb439..96edebeb8 100644 --- a/ggml.c +++ b/ggml.c @@ -4303,10 +4303,21 @@ int64_t ggml_nrows(const struct ggml_tensor * tensor) { } size_t ggml_nbytes(const struct ggml_tensor * tensor) { - size_t nbytes = tensor->ne[0]*tensor->nb[0]/ggml_blck_size(tensor->type); - for (int i = 1; i < GGML_MAX_DIMS; ++i) { - nbytes += (tensor->ne[i] - 1)*tensor->nb[i]; + size_t nbytes; + size_t blck_size = ggml_blck_size(tensor->type); + if (blck_size == 1) { + nbytes = ggml_type_size(tensor->type); + for (int i = 0; i < GGML_MAX_DIMS; ++i) { + nbytes += (tensor->ne[i] - 1)*tensor->nb[i]; + } } + else { + nbytes = tensor->ne[0]*tensor->nb[0]/blck_size; + for (int i = 1; i < GGML_MAX_DIMS; ++i) { + nbytes += (tensor->ne[i] - 1)*tensor->nb[i]; + } + } + return nbytes; } @@ -18340,7 +18351,8 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) { GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s\n", i, node->ne[0], node->ne[1], - ggml_op_name(node->op)); + ggml_op_name(node->op), + ggml_get_name(node)); } for (int i = 0; i < GGML_OP_COUNT; i++) { diff --git a/llama.cpp b/llama.cpp index cbaf8edac..80b9993f4 100644 --- a/llama.cpp +++ b/llama.cpp @@ -438,6 +438,50 @@ static void ggml_graph_compute_helper(std::vector & buf, ggml_cgraph * ggml_graph_compute(graph, &plan); } +//// EXPERIMENTAL: +//// +//// faster matrix multiplications for tensors that do not have dimension 0 divisible by "pad" +//// the idea is to represent the original matrix multiplication: +//// +//// Z = X @ Y +//// +//// with the sum of two matrix multiplications: +//// +//// Z = (X_0 @ Y_0) + (X_1 @ Y_1) +//// +//// here X_0 and Y_0 are views of X and Y that have dimension 0 divisible by "pad" +//// and X_1 and Y_1 are the remaining views. X_1 and Y_1 end up being small matrices that can be processed with more +//// general-purpose kernels +//// +//static struct ggml_tensor * ggml_mul_mat_pad(struct ggml_context * ctx, struct ggml_tensor * x, struct ggml_tensor * y, int pad = 32) { +////#if !defined(GGML_USE_METAL) +//// return ggml_mul_mat(ctx, x, y); +////#endif +// +// // use padding only if dimension 0 is at least 8 times larger than the padding +// // else we won't get much benefit from the optimization +// const int n_pad_req = 8; +// +// if (x->ne[0] % pad == 0 || x->ne[0] / pad < n_pad_req) { +// return ggml_mul_mat(ctx, x, y); +// } +// +// struct ggml_tensor * x_0 = ggml_view_3d(ctx, x, (x->ne[0]/pad)*pad, x->ne[1], x->ne[2], x->nb[1], x->nb[2], 0); +// struct ggml_tensor * x_1 = ggml_view_3d(ctx, x, x->ne[0]%pad, x->ne[1], x->ne[2], x->nb[1], x->nb[2], x_0->ne[0]*x_0->nb[0]); +// +// struct ggml_tensor * y_0 = ggml_view_3d(ctx, y, (y->ne[0]/pad)*pad, y->ne[1], y->ne[2], y->nb[1], y->nb[2], 0); +// struct ggml_tensor * y_1 = ggml_view_3d(ctx, y, y->ne[0]%pad, y->ne[1], y->ne[2], y->nb[1], y->nb[2], y_0->ne[0]*y_0->nb[0]); +// +// return ggml_add(ctx, +// ggml_mul_mat(ctx, x_0, y_0), +// ggml_mul_mat(ctx, x_1, y_1)); +//} +// +//// TODO: check if other backends benefit from this and enable for all +//#if defined(GGML_USE_METAL) +////#define ggml_mul_mat ggml_mul_mat_pad +//#endif + // // llama helpers // @@ -2968,11 +3012,7 @@ static bool llama_eval_internal( #ifdef GGML_USE_METAL if (lctx.ctx_metal) { ggml_metal_set_n_cb (lctx.ctx_metal, n_threads); - ggml_metal_graph_compute(lctx.ctx_metal, gf); - ggml_metal_get_tensor (lctx.ctx_metal, res); - if (!lctx.embedding.empty()) { - ggml_metal_get_tensor(lctx.ctx_metal, embeddings); - } + ggml_metal_graph_compute(lctx.ctx_metal, gf, n_tokens > 1); } else { ggml_graph_compute_helper(lctx.work_buffer, gf, n_threads); }