mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 20:04:35 +00:00
minor
This commit is contained in:
parent
8f6ad68427
commit
545b03491c
14
ggml-metal.m
14
ggml-metal.m
@ -994,7 +994,7 @@ void ggml_metal_graph_compute(
|
|||||||
GGML_ASSERT(ne03 == ne13);
|
GGML_ASSERT(ne03 == ne13);
|
||||||
|
|
||||||
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
||||||
// to the matrix-vector kernel. the numbers below are measure on M2 Ultra
|
// to the matrix-vector kernel. the numbers below are measured on M2 Ultra
|
||||||
// not sure if this translates across all chips
|
// not sure if this translates across all chips
|
||||||
int ne11_mm_min = 1;
|
int ne11_mm_min = 1;
|
||||||
|
|
||||||
@ -1015,12 +1015,13 @@ void ggml_metal_graph_compute(
|
|||||||
|
|
||||||
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
||||||
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
||||||
if (!ggml_is_transposed(src0) &&
|
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
||||||
|
!ggml_is_transposed(src0) &&
|
||||||
!ggml_is_transposed(src1) &&
|
!ggml_is_transposed(src1) &&
|
||||||
src1t == GGML_TYPE_F32 &&
|
src1t == GGML_TYPE_F32 &&
|
||||||
[ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
ne00 % 32 == 0 &&
|
||||||
ne00%32 == 0 &&
|
|
||||||
ne11 > ne11_mm_min) {
|
ne11 > ne11_mm_min) {
|
||||||
|
//printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
||||||
switch (src0->type) {
|
switch (src0->type) {
|
||||||
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
|
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
|
||||||
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
|
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
|
||||||
@ -1049,11 +1050,12 @@ void ggml_metal_graph_compute(
|
|||||||
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
|
||||||
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:13];
|
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:13];
|
||||||
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
||||||
} else {
|
} else {
|
||||||
int nth0 = 32;
|
int nth0 = 32;
|
||||||
int nth1 = 1;
|
int nth1 = 1;
|
||||||
int nrows = 1;
|
int nrows = 1;
|
||||||
|
//printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
||||||
|
|
||||||
// use custom matrix x vector kernel
|
// use custom matrix x vector kernel
|
||||||
switch (src0t) {
|
switch (src0t) {
|
||||||
@ -1175,7 +1177,7 @@ void ggml_metal_graph_compute(
|
|||||||
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
|
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
|
||||||
|
|
||||||
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
|
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
|
||||||
src0t == GGML_TYPE_Q2_K) {// || src0t == GGML_TYPE_Q4_K) {
|
src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
}
|
}
|
||||||
else if (src0t == GGML_TYPE_Q4_K) {
|
else if (src0t == GGML_TYPE_Q4_K) {
|
||||||
|
@ -13,8 +13,8 @@ typedef struct {
|
|||||||
|
|
||||||
#define QK4_1 32
|
#define QK4_1 32
|
||||||
typedef struct {
|
typedef struct {
|
||||||
half d; // delta
|
half d; // delta
|
||||||
half m; // min
|
half m; // min
|
||||||
uint8_t qs[QK4_1 / 2]; // nibbles / quants
|
uint8_t qs[QK4_1 / 2]; // nibbles / quants
|
||||||
} block_q4_1;
|
} block_q4_1;
|
||||||
|
|
||||||
@ -2397,7 +2397,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|||||||
+ nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
|
+ nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
|
||||||
|
|
||||||
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
|
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
|
||||||
//load data and store to threadgroup memory
|
// load data and store to threadgroup memory
|
||||||
half4x4 temp_a;
|
half4x4 temp_a;
|
||||||
dequantize_func(x, il, temp_a);
|
dequantize_func(x, il, temp_a);
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
@ -2417,7 +2417,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
//load matrices from threadgroup memory and conduct outer products
|
// load matrices from threadgroup memory and conduct outer products
|
||||||
threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
|
threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
|
||||||
threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
|
threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
|
||||||
|
|
||||||
@ -2444,25 +2444,25 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
|
if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
|
||||||
device float *C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \
|
device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \
|
||||||
+ (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0;
|
+ (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0;
|
||||||
for (int i = 0; i < 8; i++) {
|
for (int i = 0; i < 8; i++) {
|
||||||
simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
|
simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// block is smaller than 64x32, we should avoid writing data outside of the matrix
|
// block is smaller than 64x32, we should avoid writing data outside of the matrix
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
threadgroup float *temp_str = ((threadgroup float *)shared_memory) \
|
threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
|
||||||
+ 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
|
+ 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
|
||||||
for (int i = 0; i < 8; i++) {
|
for (int i = 0; i < 8; i++) {
|
||||||
simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
|
simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
|
||||||
}
|
}
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
|
device float * C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
|
||||||
if (sgitg==0) {
|
if (sgitg == 0) {
|
||||||
for (int i = 0; i < n_rows; i++) {
|
for (int i = 0; i < n_rows; i++) {
|
||||||
for (int j = tiitg; j< n_cols; j += BLOCK_SIZE_N) {
|
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
|
||||||
*(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
|
*(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user