metal : mul mat struct (wip)

This commit is contained in:
Georgi Gerganov 2024-11-09 17:54:40 +02:00
parent 593bc1aef5
commit 626e126e48
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
3 changed files with 113 additions and 102 deletions

View File

@ -470,6 +470,23 @@ typedef struct {
uint16_t n_head_log2;
float logit_softcap;
} ggml_metal_kargs_flash_attn_ext;
typedef struct {
int32_t ne00;
int32_t ne02;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne12;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
int32_t ne0;
int32_t ne1;
int16_t r2;
int16_t r3;
} ggml_metal_kargs_mul_mm;
#endif
#endif // GGML_COMMON_DECL

View File

@ -1959,24 +1959,29 @@ static void ggml_metal_encode_node(
default: GGML_ABORT("MUL MAT-MAT not implemented");
}
ggml_metal_kargs_mul_mm args = {
/*.ne00 =*/ ne00,
/*.ne02 =*/ ne02,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne12 =*/ ne12,
/*.nb10 =*/ nb10,
/*.nb11 =*/ nb11,
/*.nb12 =*/ nb12,
/*.nb13 =*/ nb13,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.r2 =*/ r2,
/*.r3 =*/ r3,
};
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:7];
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:9];
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:10];
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:11];
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:12];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
[encoder setBytes:&r2 length:sizeof(r2) atIndex:15];
[encoder setBytes:&r3 length:sizeof(r3) atIndex:16];
[encoder setBytes:&args length:sizeof(args) atIndex:3];
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
} else {
@ -2707,31 +2712,31 @@ static void ggml_metal_encode_node(
}
ggml_metal_kargs_rope args = {
.ne00 = ne00,
.ne01 = ne01,
.ne02 = ne02,
.ne03 = ne03,
.nb00 = nb00,
.nb01 = nb01,
.nb02 = nb02,
.nb03 = nb03,
.ne0 = ne0,
.ne1 = ne1,
.ne2 = ne2,
.ne3 = ne3,
.nb0 = nb0,
.nb1 = nb1,
.nb2 = nb2,
.nb3 = nb3,
.n_past = n_past,
.n_dims = n_dims,
.n_ctx_orig = n_ctx_orig,
.freq_base = freq_base,
.freq_scale = freq_scale,
.ext_factor = ext_factor,
.attn_factor = attn_factor,
.beta_fast = beta_fast,
.beta_slow = beta_slow,
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.nb0 =*/ nb0,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
/*.n_past =*/ n_past,
/*.n_dims =*/ n_dims,
/*.n_ctx_orig =*/ n_ctx_orig,
/*.freq_base =*/ freq_base,
/*.freq_scale =*/ freq_scale,
/*.ext_factor =*/ ext_factor,
/*.attn_factor =*/ attn_factor,
/*.beta_fast =*/ beta_fast,
/*.beta_slow =*/ beta_slow,
};
[encoder setComputePipelineState:pipeline];
@ -3229,27 +3234,27 @@ static void ggml_metal_encode_node(
}
ggml_metal_kargs_flash_attn_ext args = {
.ne01 = ne01,
.ne02 = ne02,
.ne03 = ne03,
.nb01 = nb01,
.nb02 = nb02,
.nb03 = nb03,
.ne11 = ne11,
.ne_12_2 = ne12,
.ne_12_3 = ne13,
.nb_12_1 = nb11,
.nb_12_2 = nb12,
.nb_12_3 = nb13,
.nb31 = nb31,
.ne1 = ne1,
.ne2 = ne2,
.scale = scale,
.max_bias = max_bias,
.m0 = m0,
.m1 = m1,
.n_head_log2 = n_head_log2,
.logit_softcap = logit_softcap,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne11 =*/ ne11,
/*.ne_12_2 =*/ ne12,
/*.ne_12_3 =*/ ne13,
/*.nb_12_1 =*/ nb11,
/*.nb_12_2 =*/ nb12,
/*.nb_12_3 =*/ nb13,
/*.nb31 =*/ nb31,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.scale =*/ scale,
/*.max_bias =*/ max_bias,
/*.m0 =*/ m0,
/*.m1 =*/ m1,
/*.n_head_log2 =*/ n_head_log2,
/*.logit_softcap =*/ logit_softcap,
};
[encoder setComputePipelineState:pipeline];

View File

@ -3256,7 +3256,6 @@ kernel void kernel_flash_attn_ext_vec(
constant ggml_metal_kargs_flash_attn_ext & args,
threadgroup half * shared [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
@ -6201,38 +6200,26 @@ kernel void kernel_get_rows_i32(
// each block_q contains 16*nl weights
template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
kernel void kernel_mul_mm(device const uchar * src0,
device const uchar * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne02,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant uint64_t & nb13,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
threadgroup uchar * shared_memory [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
kernel void kernel_mul_mm(
device const char * src0,
device const char * src1,
device char * dst,
constant ggml_metal_kargs_mul_mm & args,
threadgroup char * shared_memory [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiitg[[thread_index_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
threadgroup T * sa = (threadgroup T *)(shared_memory);
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
const uint r0 = tgpig.y;
const uint r1 = tgpig.x;
const uint im = tgpig.z;
const int r0 = tgpig.y;
const int r1 = tgpig.x;
const int im = tgpig.z;
// if this block is of 64x32 shape or smaller
short n_rows = (ne0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M;
short n_cols = (ne1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N;
short n_rows = (args.ne0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.ne0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M;
short n_cols = (args.ne1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? (args.ne1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N;
// a thread shouldn't load data outside of the matrix
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
@ -6248,20 +6235,20 @@ kernel void kernel_mul_mm(device const uchar * src0,
short il = (tiitg % THREAD_PER_ROW);
const uint i12 = im%ne12;
const uint i13 = im/ne12;
const int i12 = im%args.ne12;
const int i13 = im/args.ne12;
uint offset0 = (i12/r2)*nb02 + (i13/r3)*nb03;
ushort offset1 = il/nl;
int offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
short offset1 = il/nl;
device const block_q * x = (device const block_q *)(src0 + (r0*BLOCK_SIZE_M + thread_row)*nb01 + offset0) + offset1;
device const block_q * x = (device const block_q *)(src0 + (r0*BLOCK_SIZE_M + thread_row)*args.nb01 + offset0) + offset1;
device const float * y = (device const float *)(src1
+ nb13 * i13
+ nb12 * i12
+ nb11 * (r1 * BLOCK_SIZE_N + thread_col)
+ nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
+ args.nb13*i13
+ args.nb12*i12
+ args.nb11*(r1 * BLOCK_SIZE_N + thread_col)
+ args.nb10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) {
// load data and store to threadgroup memory
T4x4 temp_a;
dequantize_func(x, il, temp_a);
@ -6308,11 +6295,13 @@ kernel void kernel_mul_mm(device const uchar * src0,
}
}
if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \
+ (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0;
if ((r0 + 1) * BLOCK_SIZE_M <= args.ne0 && (r1 + 1) * BLOCK_SIZE_N <= args.ne1) {
device float * C = (device float *) dst +
(BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) + \
(BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0;
for (short i = 0; i < 8; i++) {
simdgroup_store(mc[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.ne0 * (i/4), args.ne0);
}
} else {
// block is smaller than 64x32, we should avoid writing data outside of the matrix
@ -6327,7 +6316,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
if (sgitg == 0) {
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
device float * D = dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*ne0 + im*ne1*ne0;
device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*args.ne0 + im*args.ne1*args.ne0;
device float4 * D4 = (device float4 *) D;
threadgroup float * C = temp_str + (j*BLOCK_SIZE_M);