cont : mul mm id

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-11-10 10:32:15 +02:00
parent eea1f7e532
commit 3855622da9
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
3 changed files with 118 additions and 94 deletions

View File

@ -509,6 +509,24 @@ typedef struct {
int16_t r3;
} ggml_metal_kargs_mul_mv;
typedef struct {
int32_t nei0;
int32_t nei1;
uint64_t nbi1;
int32_t ne00;
int32_t ne02;
uint64_t nb01;
uint64_t nb02;
int32_t ne11;
int32_t ne12;
int32_t ne13;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
int32_t ne0;
int32_t ne1;
} ggml_metal_kargs_mul_mm_id;
typedef struct {
int32_t nei0;
int32_t nei1;

View File

@ -2297,27 +2297,30 @@ static void ggml_metal_encode_node(
default: GGML_ABORT("MUL_MAT_ID not implemented");
}
ggml_metal_kargs_mul_mm_id args = {
/*.nei0 =*/ ne20,
/*.nei1 =*/ ne21,
/*.nbi1 =*/ nb21,
/*.ne00 =*/ ne00,
/*.ne02 =*/ ne02,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.ne11 =*/ ne11,
/*.ne12 =*/ ne12,
/*.ne13 =*/ ne13,
/*.nb10 =*/ nb10,
/*.nb11 =*/ nb11,
/*.nb12 =*/ nb12,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
};
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
[encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:8];
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9];
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10];
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18];
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
[encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:4];
[encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0];

View File

@ -5755,31 +5755,32 @@ kernel void kernel_mul_mm(
}
// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in rowids
// TODO: this kernel needs to be reimplemented from scratch for better performance
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
void kernel_mul_mm_id_impl(
device const uchar * src0,
device const uchar * src1,
int32_t ne00,
int32_t ne02,
uint64_t nb01,
uint64_t nb02,
int32_t ne11,
int32_t ne12,
uint64_t nb10,
uint64_t nb11,
uint64_t nb12,
int32_t ne0,
int32_t ne1,
int64_t ne0ne1,
device const char * src0,
device const char * src1,
threadgroup ushort2 * rowids,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne02,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne11,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
int64_t ne1,
int64_t ne0ne1,
threadgroup uchar * shared_memory,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
device char * dst,
threadgroup char * shmem,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiitg[[thread_index_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
threadgroup half * sa = (threadgroup half *)(shared_memory);
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
threadgroup half * sa = (threadgroup half *)(shmem);
threadgroup float * sb = (threadgroup float *)(shmem + 4096);
const uint r0 = tgpig.y;
const uint r1 = tgpig.x;
@ -5796,9 +5797,9 @@ void kernel_mul_mm_id_impl(
simdgroup_half8x8 ma[4];
simdgroup_float8x8 mb[2];
simdgroup_float8x8 c_res[8];
simdgroup_float8x8 mc[8];
for (int i = 0; i < 8; i++){
c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
}
short il = (tiitg % THREAD_PER_ROW);
@ -5836,11 +5837,14 @@ void kernel_mul_mm_id_impl(
threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
#pragma unroll(BLOCK_SIZE_K/8)
for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
#pragma unroll(4)
for (int i = 0; i < 4; i++) {
simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i);
}
simdgroup_barrier(mem_flags::mem_none);
#pragma unroll(2)
for (int i = 0; i < 2; i++) {
simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i);
}
@ -5848,29 +5852,42 @@ void kernel_mul_mm_id_impl(
lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
#pragma unroll(8)
for (int i = 0; i < 8; i++){
simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
}
}
}
{
threadgroup_barrier(mem_flags::mem_threadgroup);
threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
threadgroup float * temp_str = ((threadgroup float *) shmem) \
+ 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
for (int i = 0; i < 8; i++) {
simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
simdgroup_store(mc[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
device float * C = dst + (BLOCK_SIZE_M * r0);
if (sgitg == 0) {
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
threadgroup const auto & jid = rowids[r1 * BLOCK_SIZE_N + j];
int joff = jid[0] * ne0 + jid[1] * ne0ne1;
for (int i = 0; i < n_rows; i++) {
*(C + i + joff) = *(temp_str + i + j * BLOCK_SIZE_M);
int64_t joff = jid[0]*ne0 + jid[1]*ne0ne1;
device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + joff;
device float4 * D4 = (device float4 *) D;
threadgroup float * C = temp_str + (j*BLOCK_SIZE_M);
threadgroup float4 * C4 = (threadgroup float4 *) C;
int i = 0;
for (; i < n_rows/4; i++) {
*(D4 + i) = *(C4 + i);
}
i *= 4;
for (; i < n_rows; i++) {
*(D + i) = *(C + i);
}
}
}
@ -5879,48 +5896,34 @@ void kernel_mul_mm_id_impl(
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
kernel void kernel_mul_mm_id(
device const uchar * src0s,
device const uchar * src1,
device float * dst,
device const uchar * ids,
constant int64_t & nei0,
constant int64_t & nei1,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne02,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne11,
constant int64_t & ne12,
constant int64_t & ne13,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint64_t & nb1,
threadgroup uchar * shared_memory [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
constant ggml_metal_kargs_mul_mm_id & args,
device const char * src0s,
device const char * src1,
device char * dst,
device const char * ids,
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiitg[[thread_index_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
const int32_t i02 = tgpig.z;
tgpig.z = 0;
device const uchar * src0 = src0s + i02*nb02;
device const char * src0 = src0s + i02*args.nb02;
// row indices
threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192);
threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shmem + 8192);
// TODO: parallelize this loop
int64_t _ne1 = 0;
for (ushort ii1 = 0; ii1 < nei1; ii1++) {
for (ushort ii0 = 0; ii0 < nei0; ii0++) {
int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0];
for (ushort ii1 = 0; ii1 < args.nei1; ii1++) {
for (ushort ii0 = 0; ii0 < args.nei0; ii0++) {
int32_t id = ((device int32_t *) (ids + ii1*args.nbi1))[ii0];
if (id == i02) {
//if (tiitg == 0) {
if (tiitg == 0) {
rowids[_ne1] = ushort2(ii0, ii1);
//}
}
_ne1++;
}
}
@ -5929,23 +5932,23 @@ kernel void kernel_mul_mm_id(
threadgroup_barrier(mem_flags::mem_threadgroup);
kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
args.ne00,
args.ne02,
args.nb01,
args.nb02,
args.ne11,
args.ne12,
args.nb10,
args.nb11,
args.nb12,
args.ne0,
_ne1,
(int64_t)args.ne0*args.ne1,
src0,
src1,
rowids,
dst,
ne00,
ne02,
nb01,
nb02,
ne11,
ne12,
nb10,
nb11,
nb12,
ne0,
_ne1,
ne0*ne1,
shared_memory,
shmem,
tgpig,
tiitg,
sgitg);