mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-10 18:51:45 +00:00
metal : mul mat struct (wip)
This commit is contained in:
parent
4fd6fc5ab8
commit
7670809ad4
@ -470,6 +470,23 @@ typedef struct {
|
||||
uint16_t n_head_log2;
|
||||
float logit_softcap;
|
||||
} ggml_metal_kargs_flash_attn_ext;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne02;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int32_t ne12;
|
||||
uint64_t nb10;
|
||||
uint64_t nb11;
|
||||
uint64_t nb12;
|
||||
uint64_t nb13;
|
||||
int32_t ne0;
|
||||
int32_t ne1;
|
||||
int16_t r2;
|
||||
int16_t r3;
|
||||
} ggml_metal_kargs_mul_mm;
|
||||
#endif
|
||||
|
||||
#endif // GGML_COMMON_DECL
|
||||
|
@ -1959,24 +1959,29 @@ static void ggml_metal_encode_node(
|
||||
default: GGML_ABORT("MUL MAT-MAT not implemented");
|
||||
}
|
||||
|
||||
ggml_metal_kargs_mul_mm args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.ne12 =*/ ne12,
|
||||
/*.nb10 =*/ nb10,
|
||||
/*.nb11 =*/ nb11,
|
||||
/*.nb12 =*/ nb12,
|
||||
/*.nb13 =*/ nb13,
|
||||
/*.ne0 =*/ ne0,
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.r2 =*/ r2,
|
||||
/*.r3 =*/ r3,
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
||||
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
|
||||
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
|
||||
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:7];
|
||||
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
|
||||
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:9];
|
||||
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:10];
|
||||
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:11];
|
||||
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:12];
|
||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
|
||||
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
|
||||
[encoder setBytes:&r2 length:sizeof(r2) atIndex:15];
|
||||
[encoder setBytes:&r3 length:sizeof(r3) atIndex:16];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:3];
|
||||
|
||||
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
||||
} else {
|
||||
@ -2707,31 +2712,31 @@ static void ggml_metal_encode_node(
|
||||
}
|
||||
|
||||
ggml_metal_kargs_rope args = {
|
||||
.ne00 = ne00,
|
||||
.ne01 = ne01,
|
||||
.ne02 = ne02,
|
||||
.ne03 = ne03,
|
||||
.nb00 = nb00,
|
||||
.nb01 = nb01,
|
||||
.nb02 = nb02,
|
||||
.nb03 = nb03,
|
||||
.ne0 = ne0,
|
||||
.ne1 = ne1,
|
||||
.ne2 = ne2,
|
||||
.ne3 = ne3,
|
||||
.nb0 = nb0,
|
||||
.nb1 = nb1,
|
||||
.nb2 = nb2,
|
||||
.nb3 = nb3,
|
||||
.n_past = n_past,
|
||||
.n_dims = n_dims,
|
||||
.n_ctx_orig = n_ctx_orig,
|
||||
.freq_base = freq_base,
|
||||
.freq_scale = freq_scale,
|
||||
.ext_factor = ext_factor,
|
||||
.attn_factor = attn_factor,
|
||||
.beta_fast = beta_fast,
|
||||
.beta_slow = beta_slow,
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.ne03 =*/ ne03,
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.ne0 =*/ ne0,
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.ne2 =*/ ne2,
|
||||
/*.ne3 =*/ ne3,
|
||||
/*.nb0 =*/ nb0,
|
||||
/*.nb1 =*/ nb1,
|
||||
/*.nb2 =*/ nb2,
|
||||
/*.nb3 =*/ nb3,
|
||||
/*.n_past =*/ n_past,
|
||||
/*.n_dims =*/ n_dims,
|
||||
/*.n_ctx_orig =*/ n_ctx_orig,
|
||||
/*.freq_base =*/ freq_base,
|
||||
/*.freq_scale =*/ freq_scale,
|
||||
/*.ext_factor =*/ ext_factor,
|
||||
/*.attn_factor =*/ attn_factor,
|
||||
/*.beta_fast =*/ beta_fast,
|
||||
/*.beta_slow =*/ beta_slow,
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
@ -3229,27 +3234,27 @@ static void ggml_metal_encode_node(
|
||||
}
|
||||
|
||||
ggml_metal_kargs_flash_attn_ext args = {
|
||||
.ne01 = ne01,
|
||||
.ne02 = ne02,
|
||||
.ne03 = ne03,
|
||||
.nb01 = nb01,
|
||||
.nb02 = nb02,
|
||||
.nb03 = nb03,
|
||||
.ne11 = ne11,
|
||||
.ne_12_2 = ne12,
|
||||
.ne_12_3 = ne13,
|
||||
.nb_12_1 = nb11,
|
||||
.nb_12_2 = nb12,
|
||||
.nb_12_3 = nb13,
|
||||
.nb31 = nb31,
|
||||
.ne1 = ne1,
|
||||
.ne2 = ne2,
|
||||
.scale = scale,
|
||||
.max_bias = max_bias,
|
||||
.m0 = m0,
|
||||
.m1 = m1,
|
||||
.n_head_log2 = n_head_log2,
|
||||
.logit_softcap = logit_softcap,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.ne03 =*/ ne03,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.ne11 =*/ ne11,
|
||||
/*.ne_12_2 =*/ ne12,
|
||||
/*.ne_12_3 =*/ ne13,
|
||||
/*.nb_12_1 =*/ nb11,
|
||||
/*.nb_12_2 =*/ nb12,
|
||||
/*.nb_12_3 =*/ nb13,
|
||||
/*.nb31 =*/ nb31,
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.ne2 =*/ ne2,
|
||||
/*.scale =*/ scale,
|
||||
/*.max_bias =*/ max_bias,
|
||||
/*.m0 =*/ m0,
|
||||
/*.m1 =*/ m1,
|
||||
/*.n_head_log2 =*/ n_head_log2,
|
||||
/*.logit_softcap =*/ logit_softcap,
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
|
@ -3259,7 +3259,6 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
constant ggml_metal_kargs_flash_attn_ext & args,
|
||||
threadgroup half * shared [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
@ -6210,38 +6209,26 @@ kernel void kernel_get_rows_i32(
|
||||
|
||||
// each block_q contains 16*nl weights
|
||||
template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
|
||||
kernel void kernel_mul_mm(device const uchar * src0,
|
||||
device const uchar * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne02,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant uint64_t & nb03,
|
||||
constant int64_t & ne12,
|
||||
constant uint64_t & nb10,
|
||||
constant uint64_t & nb11,
|
||||
constant uint64_t & nb12,
|
||||
constant uint64_t & nb13,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant uint & r2,
|
||||
constant uint & r3,
|
||||
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
kernel void kernel_mul_mm(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
constant ggml_metal_kargs_mul_mm & args,
|
||||
threadgroup char * shared_memory [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiitg[[thread_index_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
threadgroup T * sa = (threadgroup T *)(shared_memory);
|
||||
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
|
||||
|
||||
const uint r0 = tgpig.y;
|
||||
const uint r1 = tgpig.x;
|
||||
const uint im = tgpig.z;
|
||||
const int r0 = tgpig.y;
|
||||
const int r1 = tgpig.x;
|
||||
const int im = tgpig.z;
|
||||
|
||||
// if this block is of 64x32 shape or smaller
|
||||
short n_rows = (ne0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M;
|
||||
short n_cols = (ne1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N;
|
||||
short n_rows = (args.ne0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.ne0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M;
|
||||
short n_cols = (args.ne1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? (args.ne1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N;
|
||||
|
||||
// a thread shouldn't load data outside of the matrix
|
||||
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
|
||||
@ -6257,20 +6244,20 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
||||
|
||||
short il = (tiitg % THREAD_PER_ROW);
|
||||
|
||||
const uint i12 = im%ne12;
|
||||
const uint i13 = im/ne12;
|
||||
const int i12 = im%args.ne12;
|
||||
const int i13 = im/args.ne12;
|
||||
|
||||
uint offset0 = (i12/r2)*nb02 + (i13/r3)*nb03;
|
||||
ushort offset1 = il/nl;
|
||||
int offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||
short offset1 = il/nl;
|
||||
|
||||
device const block_q * x = (device const block_q *)(src0 + (r0*BLOCK_SIZE_M + thread_row)*nb01 + offset0) + offset1;
|
||||
device const block_q * x = (device const block_q *)(src0 + (r0*BLOCK_SIZE_M + thread_row)*args.nb01 + offset0) + offset1;
|
||||
device const float * y = (device const float *)(src1
|
||||
+ nb13 * i13
|
||||
+ nb12 * i12
|
||||
+ nb11 * (r1 * BLOCK_SIZE_N + thread_col)
|
||||
+ nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
|
||||
+ args.nb13*i13
|
||||
+ args.nb12*i12
|
||||
+ args.nb11*(r1 * BLOCK_SIZE_N + thread_col)
|
||||
+ args.nb10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
|
||||
|
||||
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
|
||||
for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) {
|
||||
// load data and store to threadgroup memory
|
||||
T4x4 temp_a;
|
||||
dequantize_func(x, il, temp_a);
|
||||
@ -6317,11 +6304,13 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
||||
}
|
||||
}
|
||||
|
||||
if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
|
||||
device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \
|
||||
+ (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0;
|
||||
if ((r0 + 1) * BLOCK_SIZE_M <= args.ne0 && (r1 + 1) * BLOCK_SIZE_N <= args.ne1) {
|
||||
device float * C = (device float *) dst +
|
||||
(BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) + \
|
||||
(BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0;
|
||||
|
||||
for (short i = 0; i < 8; i++) {
|
||||
simdgroup_store(mc[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
|
||||
simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.ne0 * (i/4), args.ne0);
|
||||
}
|
||||
} else {
|
||||
// block is smaller than 64x32, we should avoid writing data outside of the matrix
|
||||
@ -6336,7 +6325,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
||||
|
||||
if (sgitg == 0) {
|
||||
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
|
||||
device float * D = dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*ne0 + im*ne1*ne0;
|
||||
device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*args.ne0 + im*args.ne1*args.ne0;
|
||||
device float4 * D4 = (device float4 *) D;
|
||||
|
||||
threadgroup float * C = temp_str + (j*BLOCK_SIZE_M);
|
||||
|
Loading…
Reference in New Issue
Block a user