mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 02:44:36 +00:00
metal : matrix-matrix multiplication kernel (#2615)
* metal: matrix-matrix multiplication kernel This commit removes MPS and uses custom matrix-matrix multiplication kernels for all quantization types. This commit also adds grouped-query attention to support llama2 70B. * metal: fix performance degradation from gqa Integers are slow on the GPU, and 64-bit divides are extremely slow. In the context of GQA, we introduce a 64-bit divide that cannot be optimized out by the compiler, which results in a decrease of ~8% in inference performance. This commit fixes that issue by calculating a part of the offset with a 32-bit divide. Naturally, this limits the size of a single matrix to ~4GB. However, this limitation should suffice for the near future. * metal: fix bugs for GQA and perplexity test. I mixed up ne02 and nb02 in previous commit.
This commit is contained in:
parent
b5ffb2849d
commit
bf83bff674
@ -296,7 +296,6 @@ if (LLAMA_METAL)
|
|||||||
find_library(FOUNDATION_LIBRARY Foundation REQUIRED)
|
find_library(FOUNDATION_LIBRARY Foundation REQUIRED)
|
||||||
find_library(METAL_FRAMEWORK Metal REQUIRED)
|
find_library(METAL_FRAMEWORK Metal REQUIRED)
|
||||||
find_library(METALKIT_FRAMEWORK MetalKit REQUIRED)
|
find_library(METALKIT_FRAMEWORK MetalKit REQUIRED)
|
||||||
find_library(METALPERFORMANCE_FRAMEWORK MetalPerformanceShaders REQUIRED)
|
|
||||||
|
|
||||||
set(GGML_SOURCES_METAL ggml-metal.m ggml-metal.h)
|
set(GGML_SOURCES_METAL ggml-metal.m ggml-metal.h)
|
||||||
|
|
||||||
@ -313,7 +312,6 @@ if (LLAMA_METAL)
|
|||||||
${FOUNDATION_LIBRARY}
|
${FOUNDATION_LIBRARY}
|
||||||
${METAL_FRAMEWORK}
|
${METAL_FRAMEWORK}
|
||||||
${METALKIT_FRAMEWORK}
|
${METALKIT_FRAMEWORK}
|
||||||
${METALPERFORMANCE_FRAMEWORK}
|
|
||||||
)
|
)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
2
Makefile
2
Makefile
@ -283,7 +283,7 @@ endif # LLAMA_CLBLAST
|
|||||||
ifdef LLAMA_METAL
|
ifdef LLAMA_METAL
|
||||||
CFLAGS += -DGGML_USE_METAL -DGGML_METAL_NDEBUG
|
CFLAGS += -DGGML_USE_METAL -DGGML_METAL_NDEBUG
|
||||||
CXXFLAGS += -DGGML_USE_METAL
|
CXXFLAGS += -DGGML_USE_METAL
|
||||||
LDFLAGS += -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders
|
LDFLAGS += -framework Foundation -framework Metal -framework MetalKit
|
||||||
OBJS += ggml-metal.o
|
OBJS += ggml-metal.o
|
||||||
endif # LLAMA_METAL
|
endif # LLAMA_METAL
|
||||||
|
|
||||||
|
@ -14,8 +14,6 @@
|
|||||||
with pkgs.darwin.apple_sdk_11_0.frameworks; [
|
with pkgs.darwin.apple_sdk_11_0.frameworks; [
|
||||||
Accelerate
|
Accelerate
|
||||||
MetalKit
|
MetalKit
|
||||||
MetalPerformanceShaders
|
|
||||||
MetalPerformanceShadersGraph
|
|
||||||
]
|
]
|
||||||
else if isAarch32 && isDarwin then
|
else if isAarch32 && isDarwin then
|
||||||
with pkgs.darwin.apple_sdk.frameworks; [
|
with pkgs.darwin.apple_sdk.frameworks; [
|
||||||
|
169
ggml-metal.m
169
ggml-metal.m
@ -5,7 +5,6 @@
|
|||||||
#import <Foundation/Foundation.h>
|
#import <Foundation/Foundation.h>
|
||||||
|
|
||||||
#import <Metal/Metal.h>
|
#import <Metal/Metal.h>
|
||||||
#import <MetalPerformanceShaders/MetalPerformanceShaders.h>
|
|
||||||
|
|
||||||
#undef MIN
|
#undef MIN
|
||||||
#undef MAX
|
#undef MAX
|
||||||
@ -79,6 +78,14 @@ struct ggml_metal_context {
|
|||||||
GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
|
GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
|
GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
|
GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
|
||||||
|
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
|
||||||
|
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
|
||||||
|
GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
|
||||||
|
GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
|
||||||
|
GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
|
||||||
|
GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
|
||||||
|
GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
|
||||||
|
GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
|
||||||
GGML_METAL_DECL_KERNEL(rope);
|
GGML_METAL_DECL_KERNEL(rope);
|
||||||
GGML_METAL_DECL_KERNEL(alibi_f32);
|
GGML_METAL_DECL_KERNEL(alibi_f32);
|
||||||
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
|
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
|
||||||
@ -110,13 +117,6 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||||||
ctx->n_buffers = 0;
|
ctx->n_buffers = 0;
|
||||||
ctx->concur_list_len = 0;
|
ctx->concur_list_len = 0;
|
||||||
|
|
||||||
// determine if we can use MPS
|
|
||||||
if (MPSSupportsMTLDevice(ctx->device)) {
|
|
||||||
fprintf(stderr, "%s: using MPS\n", __func__);
|
|
||||||
} else {
|
|
||||||
fprintf(stderr, "%s: not using MPS\n", __func__);
|
|
||||||
GGML_ASSERT(false && "MPS not supported");
|
|
||||||
}
|
|
||||||
|
|
||||||
#if 0
|
#if 0
|
||||||
// compile from source string and show compile log
|
// compile from source string and show compile log
|
||||||
@ -196,6 +196,14 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||||||
GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
|
GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
|
GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
|
GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
|
||||||
|
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
|
||||||
|
GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
|
||||||
|
GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
|
||||||
|
GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
|
||||||
|
GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
|
||||||
|
GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
|
||||||
|
GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
|
||||||
|
GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
|
||||||
GGML_METAL_ADD_KERNEL(rope);
|
GGML_METAL_ADD_KERNEL(rope);
|
||||||
GGML_METAL_ADD_KERNEL(alibi_f32);
|
GGML_METAL_ADD_KERNEL(alibi_f32);
|
||||||
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
|
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
|
||||||
@ -506,7 +514,7 @@ void ggml_metal_graph_compute(
|
|||||||
|
|
||||||
id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
|
id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
|
||||||
|
|
||||||
id<MTLComputeCommandEncoder> encoder = nil;
|
id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||||
|
|
||||||
const int node_start = (cb_idx + 0) * n_nodes_per_cb;
|
const int node_start = (cb_idx + 0) * n_nodes_per_cb;
|
||||||
const int node_end = (cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb;
|
const int node_end = (cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb;
|
||||||
@ -515,10 +523,6 @@ void ggml_metal_graph_compute(
|
|||||||
const int i = has_concur ? ctx->concur_list[ind] : ind;
|
const int i = has_concur ? ctx->concur_list[ind] : ind;
|
||||||
|
|
||||||
if (i == -1) {
|
if (i == -1) {
|
||||||
if (encoder == nil) {
|
|
||||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
|
[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -592,10 +596,6 @@ void ggml_metal_graph_compute(
|
|||||||
} break;
|
} break;
|
||||||
case GGML_OP_ADD:
|
case GGML_OP_ADD:
|
||||||
{
|
{
|
||||||
if (encoder == nil) {
|
|
||||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
|
||||||
}
|
|
||||||
|
|
||||||
if (ggml_nelements(src1) == ne10) {
|
if (ggml_nelements(src1) == ne10) {
|
||||||
// src1 is a row
|
// src1 is a row
|
||||||
[encoder setComputePipelineState:ctx->pipeline_add_row];
|
[encoder setComputePipelineState:ctx->pipeline_add_row];
|
||||||
@ -613,10 +613,6 @@ void ggml_metal_graph_compute(
|
|||||||
} break;
|
} break;
|
||||||
case GGML_OP_MUL:
|
case GGML_OP_MUL:
|
||||||
{
|
{
|
||||||
if (encoder == nil) {
|
|
||||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
|
||||||
}
|
|
||||||
|
|
||||||
if (ggml_nelements(src1) == ne10) {
|
if (ggml_nelements(src1) == ne10) {
|
||||||
// src1 is a row
|
// src1 is a row
|
||||||
[encoder setComputePipelineState:ctx->pipeline_mul_row];
|
[encoder setComputePipelineState:ctx->pipeline_mul_row];
|
||||||
@ -634,10 +630,6 @@ void ggml_metal_graph_compute(
|
|||||||
} break;
|
} break;
|
||||||
case GGML_OP_SCALE:
|
case GGML_OP_SCALE:
|
||||||
{
|
{
|
||||||
if (encoder == nil) {
|
|
||||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
|
||||||
}
|
|
||||||
|
|
||||||
const float scale = *(const float *) src1->data;
|
const float scale = *(const float *) src1->data;
|
||||||
|
|
||||||
[encoder setComputePipelineState:ctx->pipeline_scale];
|
[encoder setComputePipelineState:ctx->pipeline_scale];
|
||||||
@ -653,10 +645,6 @@ void ggml_metal_graph_compute(
|
|||||||
switch (ggml_get_unary_op(gf->nodes[i])) {
|
switch (ggml_get_unary_op(gf->nodes[i])) {
|
||||||
case GGML_UNARY_OP_SILU:
|
case GGML_UNARY_OP_SILU:
|
||||||
{
|
{
|
||||||
if (encoder == nil) {
|
|
||||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
|
||||||
}
|
|
||||||
|
|
||||||
[encoder setComputePipelineState:ctx->pipeline_silu];
|
[encoder setComputePipelineState:ctx->pipeline_silu];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
@ -667,10 +655,6 @@ void ggml_metal_graph_compute(
|
|||||||
} break;
|
} break;
|
||||||
case GGML_UNARY_OP_RELU:
|
case GGML_UNARY_OP_RELU:
|
||||||
{
|
{
|
||||||
if (encoder == nil) {
|
|
||||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
|
||||||
}
|
|
||||||
|
|
||||||
[encoder setComputePipelineState:ctx->pipeline_relu];
|
[encoder setComputePipelineState:ctx->pipeline_relu];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
@ -681,10 +665,6 @@ void ggml_metal_graph_compute(
|
|||||||
} break;
|
} break;
|
||||||
case GGML_UNARY_OP_GELU:
|
case GGML_UNARY_OP_GELU:
|
||||||
{
|
{
|
||||||
if (encoder == nil) {
|
|
||||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
|
||||||
}
|
|
||||||
|
|
||||||
[encoder setComputePipelineState:ctx->pipeline_gelu];
|
[encoder setComputePipelineState:ctx->pipeline_gelu];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
@ -701,10 +681,6 @@ void ggml_metal_graph_compute(
|
|||||||
} break;
|
} break;
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
{
|
{
|
||||||
if (encoder == nil) {
|
|
||||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
|
||||||
}
|
|
||||||
|
|
||||||
const int nth = 32;
|
const int nth = 32;
|
||||||
|
|
||||||
[encoder setComputePipelineState:ctx->pipeline_soft_max];
|
[encoder setComputePipelineState:ctx->pipeline_soft_max];
|
||||||
@ -719,10 +695,6 @@ void ggml_metal_graph_compute(
|
|||||||
} break;
|
} break;
|
||||||
case GGML_OP_DIAG_MASK_INF:
|
case GGML_OP_DIAG_MASK_INF:
|
||||||
{
|
{
|
||||||
if (encoder == nil) {
|
|
||||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
|
||||||
}
|
|
||||||
|
|
||||||
const int n_past = ((int32_t *)(dst->op_params))[0];
|
const int n_past = ((int32_t *)(dst->op_params))[0];
|
||||||
|
|
||||||
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
|
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
|
||||||
@ -740,53 +712,43 @@ void ggml_metal_graph_compute(
|
|||||||
|
|
||||||
GGML_ASSERT(ne00 == ne10);
|
GGML_ASSERT(ne00 == ne10);
|
||||||
// GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere
|
// GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere
|
||||||
|
uint gqa = ne12/ne02;
|
||||||
GGML_ASSERT(ne03 == ne13);
|
GGML_ASSERT(ne03 == ne13);
|
||||||
|
|
||||||
|
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
||||||
|
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
||||||
if (ggml_is_contiguous(src0) &&
|
if (ggml_is_contiguous(src0) &&
|
||||||
ggml_is_contiguous(src1) &&
|
ggml_is_contiguous(src1) &&
|
||||||
(src0t == GGML_TYPE_F32 || src0t == GGML_TYPE_F16) && ne11 > 1) {
|
src1t == GGML_TYPE_F32 &&
|
||||||
|
[ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
||||||
if (encoder != nil) {
|
ne00%32 == 0 &&
|
||||||
[encoder endEncoding];
|
ne11 > 1) {
|
||||||
encoder = nil;
|
switch (src0->type) {
|
||||||
|
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
|
||||||
|
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
|
||||||
|
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
|
||||||
|
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
|
||||||
|
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
|
||||||
|
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break;
|
||||||
|
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
|
||||||
|
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
|
||||||
|
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
|
||||||
}
|
}
|
||||||
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
MPSDataType src0dt = src0t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||||
MPSDataType src1dt = src1t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||||
|
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
||||||
// for F32 x F32 we use MPS
|
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
||||||
MPSMatrixDescriptor * desc0 = [MPSMatrixDescriptor
|
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
|
||||||
matrixDescriptorWithRows:ne01 columns:ne00 rowBytes:src0->nb[1] dataType:src0dt];
|
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
|
||||||
|
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
|
||||||
MPSMatrixDescriptor * desc1 = [MPSMatrixDescriptor
|
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8];
|
||||||
matrixDescriptorWithRows:ne11 columns:ne10 rowBytes:src1->nb[1] dataType:src1dt];
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9];
|
||||||
|
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:10];
|
||||||
MPSMatrixDescriptor * desc = [MPSMatrixDescriptor
|
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
||||||
matrixDescriptorWithRows:ne1 columns:ne0 rowBytes:dst->nb[1] dataType:MPSDataTypeFloat32];
|
[encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
||||||
|
|
||||||
MPSMatrixMultiplication * mul = [[MPSMatrixMultiplication alloc]
|
|
||||||
initWithDevice:ctx->device transposeLeft:false transposeRight:true
|
|
||||||
resultRows:ne11 resultColumns:ne01 interiorColumns:ne00 alpha:1.0 beta:0.0];
|
|
||||||
|
|
||||||
// we need to do ne12 multiplications
|
|
||||||
// TODO: is there a way to do this in parallel - currently very slow ..
|
|
||||||
// TODO: might be possible to offload part of the computation to ANE using Accelerate's CBLAS
|
|
||||||
for (int64_t i02 = 0; i02 < ne12; ++i02) {
|
|
||||||
size_t offs_src0_cur = offs_src0 + i02/(ne12/ne02)*nb02; // gqa not used for now
|
|
||||||
size_t offs_src1_cur = offs_src1 + i02*nb12;
|
|
||||||
size_t offs_dst_cur = offs_dst + i02*nb2;
|
|
||||||
|
|
||||||
MPSMatrix * mat_src0 = [[MPSMatrix alloc] initWithBuffer:id_src0 offset:offs_src0_cur descriptor:desc0];
|
|
||||||
MPSMatrix * mat_src1 = [[MPSMatrix alloc] initWithBuffer:id_src1 offset:offs_src1_cur descriptor:desc1];
|
|
||||||
MPSMatrix * mat_dst = [[MPSMatrix alloc] initWithBuffer:id_dst offset:offs_dst_cur descriptor:desc ];
|
|
||||||
|
|
||||||
[mul encodeToCommandBuffer:command_buffer leftMatrix:mat_src1 rightMatrix:mat_src0 resultMatrix:mat_dst];
|
|
||||||
}
|
}
|
||||||
} else {
|
else {
|
||||||
if (encoder == nil) {
|
|
||||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
|
||||||
}
|
|
||||||
|
|
||||||
int nth0 = 32;
|
int nth0 = 32;
|
||||||
int nth1 = 1;
|
int nth1 = 1;
|
||||||
|
|
||||||
@ -885,23 +847,24 @@ void ggml_metal_graph_compute(
|
|||||||
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
|
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
|
||||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
|
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
|
||||||
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
|
||||||
|
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
|
||||||
|
|
||||||
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
|
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
|
||||||
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
|
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
}
|
}
|
||||||
else if (src0t == GGML_TYPE_Q3_K) {
|
else if (src0t == GGML_TYPE_Q3_K) {
|
||||||
#ifdef GGML_QKK_64
|
#ifdef GGML_QKK_64
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
#else
|
#else
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
else if (src0t == GGML_TYPE_Q5_K) {
|
else if (src0t == GGML_TYPE_Q5_K) {
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
}
|
}
|
||||||
else if (src0t == GGML_TYPE_Q6_K) {
|
else if (src0t == GGML_TYPE_Q6_K) {
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
} else {
|
} else {
|
||||||
[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
|
[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
@ -910,10 +873,6 @@ void ggml_metal_graph_compute(
|
|||||||
} break;
|
} break;
|
||||||
case GGML_OP_GET_ROWS:
|
case GGML_OP_GET_ROWS:
|
||||||
{
|
{
|
||||||
if (encoder == nil) {
|
|
||||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
|
||||||
}
|
|
||||||
|
|
||||||
switch (src0->type) {
|
switch (src0->type) {
|
||||||
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
|
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
|
||||||
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
|
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
|
||||||
@ -939,10 +898,6 @@ void ggml_metal_graph_compute(
|
|||||||
} break;
|
} break;
|
||||||
case GGML_OP_RMS_NORM:
|
case GGML_OP_RMS_NORM:
|
||||||
{
|
{
|
||||||
if (encoder == nil) {
|
|
||||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
|
||||||
}
|
|
||||||
|
|
||||||
float eps;
|
float eps;
|
||||||
memcpy(&eps, dst->op_params, sizeof(float));
|
memcpy(&eps, dst->op_params, sizeof(float));
|
||||||
|
|
||||||
@ -962,10 +917,6 @@ void ggml_metal_graph_compute(
|
|||||||
} break;
|
} break;
|
||||||
case GGML_OP_NORM:
|
case GGML_OP_NORM:
|
||||||
{
|
{
|
||||||
if (encoder == nil) {
|
|
||||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
|
||||||
}
|
|
||||||
|
|
||||||
const float eps = 1e-5f;
|
const float eps = 1e-5f;
|
||||||
|
|
||||||
const int nth = 256;
|
const int nth = 256;
|
||||||
@ -984,10 +935,6 @@ void ggml_metal_graph_compute(
|
|||||||
} break;
|
} break;
|
||||||
case GGML_OP_ALIBI:
|
case GGML_OP_ALIBI:
|
||||||
{
|
{
|
||||||
if (encoder == nil) {
|
|
||||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
|
||||||
}
|
|
||||||
|
|
||||||
GGML_ASSERT((src0t == GGML_TYPE_F32));
|
GGML_ASSERT((src0t == GGML_TYPE_F32));
|
||||||
|
|
||||||
const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
|
const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
|
||||||
@ -1027,10 +974,6 @@ void ggml_metal_graph_compute(
|
|||||||
} break;
|
} break;
|
||||||
case GGML_OP_ROPE:
|
case GGML_OP_ROPE:
|
||||||
{
|
{
|
||||||
if (encoder == nil) {
|
|
||||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
|
||||||
}
|
|
||||||
|
|
||||||
const int n_past = ((int32_t *) dst->op_params)[0];
|
const int n_past = ((int32_t *) dst->op_params)[0];
|
||||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||||
const int mode = ((int32_t *) dst->op_params)[2];
|
const int mode = ((int32_t *) dst->op_params)[2];
|
||||||
@ -1071,10 +1014,6 @@ void ggml_metal_graph_compute(
|
|||||||
case GGML_OP_CPY:
|
case GGML_OP_CPY:
|
||||||
case GGML_OP_CONT:
|
case GGML_OP_CONT:
|
||||||
{
|
{
|
||||||
if (encoder == nil) {
|
|
||||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
|
||||||
}
|
|
||||||
|
|
||||||
const int nth = 32;
|
const int nth = 32;
|
||||||
|
|
||||||
switch (src0t) {
|
switch (src0t) {
|
||||||
|
969
ggml-metal.metal
969
ggml-metal.metal
File diff suppressed because it is too large
Load Diff
18
llama.cpp
18
llama.cpp
@ -1845,7 +1845,7 @@ static bool llama_eval_internal(
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef GGML_USE_METAL
|
#ifdef GGML_USE_METAL
|
||||||
if (lctx.ctx_metal && N == 1) {
|
if (lctx.ctx_metal) {
|
||||||
// TODO: disabled until #2413 is resolved
|
// TODO: disabled until #2413 is resolved
|
||||||
//if (!ggml_metal_if_optimized(lctx.ctx_metal)) {
|
//if (!ggml_metal_if_optimized(lctx.ctx_metal)) {
|
||||||
// ggml_metal_graph_find_concurrency(lctx.ctx_metal, gf);
|
// ggml_metal_graph_find_concurrency(lctx.ctx_metal, gf);
|
||||||
@ -1857,22 +1857,6 @@ static bool llama_eval_internal(
|
|||||||
ggml_metal_get_tensor(lctx.ctx_metal, embeddings);
|
ggml_metal_get_tensor(lctx.ctx_metal, embeddings);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// IMPORTANT:
|
|
||||||
// Since we don't have efficient Matrix x Matrix Metal multiplication yet, we fallback to vanilla
|
|
||||||
// ggml_graph_compute(). It uses Apple's Accelerate CBLAS API which takes advantage of the ANE or the AMX
|
|
||||||
// coprocessor.
|
|
||||||
//
|
|
||||||
// When we implement Matrix x Matrix Metal multiplication, we can avoid this branch.
|
|
||||||
// But for now, we have focused only on Matrix x Vector Metal multiplication.
|
|
||||||
//
|
|
||||||
// TODO: avoid these syncs via shared memory (ref #1696)
|
|
||||||
//
|
|
||||||
if (lctx.ctx_metal) {
|
|
||||||
// We need to sync the GPU KV cache with the CPU KV cache
|
|
||||||
ggml_metal_get_tensor(lctx.ctx_metal, kv_self.k);
|
|
||||||
ggml_metal_get_tensor(lctx.ctx_metal, kv_self.v);
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_graph_compute_helper(lctx.work_buffer, gf, n_threads);
|
ggml_graph_compute_helper(lctx.work_buffer, gf, n_threads);
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
|
Loading…
Reference in New Issue
Block a user