mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-07 17:21:46 +00:00
cont : mul mm id
ggml-ci
This commit is contained in:
parent
eea1f7e532
commit
3855622da9
@ -509,6 +509,24 @@ typedef struct {
|
||||
int16_t r3;
|
||||
} ggml_metal_kargs_mul_mv;
|
||||
|
||||
typedef struct {
|
||||
int32_t nei0;
|
||||
int32_t nei1;
|
||||
uint64_t nbi1;
|
||||
int32_t ne00;
|
||||
int32_t ne02;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
int32_t ne11;
|
||||
int32_t ne12;
|
||||
int32_t ne13;
|
||||
uint64_t nb10;
|
||||
uint64_t nb11;
|
||||
uint64_t nb12;
|
||||
int32_t ne0;
|
||||
int32_t ne1;
|
||||
} ggml_metal_kargs_mul_mm_id;
|
||||
|
||||
typedef struct {
|
||||
int32_t nei0;
|
||||
int32_t nei1;
|
||||
|
@ -2297,27 +2297,30 @@ static void ggml_metal_encode_node(
|
||||
default: GGML_ABORT("MUL_MAT_ID not implemented");
|
||||
}
|
||||
|
||||
ggml_metal_kargs_mul_mm_id args = {
|
||||
/*.nei0 =*/ ne20,
|
||||
/*.nei1 =*/ ne21,
|
||||
/*.nbi1 =*/ nb21,
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.ne11 =*/ ne11,
|
||||
/*.ne12 =*/ ne12,
|
||||
/*.ne13 =*/ ne13,
|
||||
/*.nb10 =*/ nb10,
|
||||
/*.nb11 =*/ nb11,
|
||||
/*.nb12 =*/ nb12,
|
||||
/*.ne0 =*/ ne0,
|
||||
/*.ne1 =*/ ne1,
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
|
||||
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
|
||||
[encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
|
||||
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
|
||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7];
|
||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:8];
|
||||
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9];
|
||||
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10];
|
||||
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
|
||||
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
|
||||
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
|
||||
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
|
||||
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
|
||||
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
|
||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
|
||||
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18];
|
||||
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
||||
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:4];
|
||||
|
||||
[encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0];
|
||||
|
||||
|
@ -5755,31 +5755,32 @@ kernel void kernel_mul_mm(
|
||||
}
|
||||
|
||||
// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in rowids
|
||||
// TODO: this kernel needs to be reimplemented from scratch for better performance
|
||||
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
||||
void kernel_mul_mm_id_impl(
|
||||
device const uchar * src0,
|
||||
device const uchar * src1,
|
||||
int32_t ne00,
|
||||
int32_t ne02,
|
||||
uint64_t nb01,
|
||||
uint64_t nb02,
|
||||
int32_t ne11,
|
||||
int32_t ne12,
|
||||
uint64_t nb10,
|
||||
uint64_t nb11,
|
||||
uint64_t nb12,
|
||||
int32_t ne0,
|
||||
int32_t ne1,
|
||||
int64_t ne0ne1,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
threadgroup ushort2 * rowids,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne02,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant int64_t & ne11,
|
||||
constant int64_t & ne12,
|
||||
constant uint64_t & nb10,
|
||||
constant uint64_t & nb11,
|
||||
constant uint64_t & nb12,
|
||||
constant int64_t & ne0,
|
||||
int64_t ne1,
|
||||
int64_t ne0ne1,
|
||||
threadgroup uchar * shared_memory,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
device char * dst,
|
||||
threadgroup char * shmem,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiitg[[thread_index_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
threadgroup half * sa = (threadgroup half *)(shared_memory);
|
||||
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
|
||||
threadgroup half * sa = (threadgroup half *)(shmem);
|
||||
threadgroup float * sb = (threadgroup float *)(shmem + 4096);
|
||||
|
||||
const uint r0 = tgpig.y;
|
||||
const uint r1 = tgpig.x;
|
||||
@ -5796,9 +5797,9 @@ void kernel_mul_mm_id_impl(
|
||||
|
||||
simdgroup_half8x8 ma[4];
|
||||
simdgroup_float8x8 mb[2];
|
||||
simdgroup_float8x8 c_res[8];
|
||||
simdgroup_float8x8 mc[8];
|
||||
for (int i = 0; i < 8; i++){
|
||||
c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
|
||||
mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
|
||||
}
|
||||
short il = (tiitg % THREAD_PER_ROW);
|
||||
|
||||
@ -5836,11 +5837,14 @@ void kernel_mul_mm_id_impl(
|
||||
threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
|
||||
threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
|
||||
|
||||
#pragma unroll(BLOCK_SIZE_K/8)
|
||||
for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < 4; i++) {
|
||||
simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i);
|
||||
}
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
#pragma unroll(2)
|
||||
for (int i = 0; i < 2; i++) {
|
||||
simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i);
|
||||
}
|
||||
@ -5848,29 +5852,42 @@ void kernel_mul_mm_id_impl(
|
||||
lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
|
||||
lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
|
||||
|
||||
#pragma unroll(8)
|
||||
for (int i = 0; i < 8; i++){
|
||||
simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
|
||||
simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
|
||||
threadgroup float * temp_str = ((threadgroup float *) shmem) \
|
||||
+ 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
|
||||
for (int i = 0; i < 8; i++) {
|
||||
simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
|
||||
simdgroup_store(mc[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
device float * C = dst + (BLOCK_SIZE_M * r0);
|
||||
if (sgitg == 0) {
|
||||
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
|
||||
threadgroup const auto & jid = rowids[r1 * BLOCK_SIZE_N + j];
|
||||
int joff = jid[0] * ne0 + jid[1] * ne0ne1;
|
||||
for (int i = 0; i < n_rows; i++) {
|
||||
*(C + i + joff) = *(temp_str + i + j * BLOCK_SIZE_M);
|
||||
int64_t joff = jid[0]*ne0 + jid[1]*ne0ne1;
|
||||
|
||||
device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + joff;
|
||||
device float4 * D4 = (device float4 *) D;
|
||||
|
||||
threadgroup float * C = temp_str + (j*BLOCK_SIZE_M);
|
||||
threadgroup float4 * C4 = (threadgroup float4 *) C;
|
||||
|
||||
int i = 0;
|
||||
for (; i < n_rows/4; i++) {
|
||||
*(D4 + i) = *(C4 + i);
|
||||
}
|
||||
|
||||
i *= 4;
|
||||
for (; i < n_rows; i++) {
|
||||
*(D + i) = *(C + i);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -5879,48 +5896,34 @@ void kernel_mul_mm_id_impl(
|
||||
|
||||
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
||||
kernel void kernel_mul_mm_id(
|
||||
device const uchar * src0s,
|
||||
device const uchar * src1,
|
||||
device float * dst,
|
||||
device const uchar * ids,
|
||||
constant int64_t & nei0,
|
||||
constant int64_t & nei1,
|
||||
constant uint64_t & nbi1,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne02,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant int64_t & ne11,
|
||||
constant int64_t & ne12,
|
||||
constant int64_t & ne13,
|
||||
constant uint64_t & nb10,
|
||||
constant uint64_t & nb11,
|
||||
constant uint64_t & nb12,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant uint64_t & nb1,
|
||||
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
constant ggml_metal_kargs_mul_mm_id & args,
|
||||
device const char * src0s,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
device const char * ids,
|
||||
threadgroup char * shmem [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiitg[[thread_index_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
const int32_t i02 = tgpig.z;
|
||||
|
||||
tgpig.z = 0;
|
||||
|
||||
device const uchar * src0 = src0s + i02*nb02;
|
||||
device const char * src0 = src0s + i02*args.nb02;
|
||||
|
||||
// row indices
|
||||
threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192);
|
||||
threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shmem + 8192);
|
||||
|
||||
// TODO: parallelize this loop
|
||||
int64_t _ne1 = 0;
|
||||
for (ushort ii1 = 0; ii1 < nei1; ii1++) {
|
||||
for (ushort ii0 = 0; ii0 < nei0; ii0++) {
|
||||
int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0];
|
||||
for (ushort ii1 = 0; ii1 < args.nei1; ii1++) {
|
||||
for (ushort ii0 = 0; ii0 < args.nei0; ii0++) {
|
||||
int32_t id = ((device int32_t *) (ids + ii1*args.nbi1))[ii0];
|
||||
if (id == i02) {
|
||||
//if (tiitg == 0) {
|
||||
if (tiitg == 0) {
|
||||
rowids[_ne1] = ushort2(ii0, ii1);
|
||||
//}
|
||||
}
|
||||
_ne1++;
|
||||
}
|
||||
}
|
||||
@ -5929,23 +5932,23 @@ kernel void kernel_mul_mm_id(
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
|
||||
args.ne00,
|
||||
args.ne02,
|
||||
args.nb01,
|
||||
args.nb02,
|
||||
args.ne11,
|
||||
args.ne12,
|
||||
args.nb10,
|
||||
args.nb11,
|
||||
args.nb12,
|
||||
args.ne0,
|
||||
_ne1,
|
||||
(int64_t)args.ne0*args.ne1,
|
||||
src0,
|
||||
src1,
|
||||
rowids,
|
||||
dst,
|
||||
ne00,
|
||||
ne02,
|
||||
nb01,
|
||||
nb02,
|
||||
ne11,
|
||||
ne12,
|
||||
nb10,
|
||||
nb11,
|
||||
nb12,
|
||||
ne0,
|
||||
_ne1,
|
||||
ne0*ne1,
|
||||
shared_memory,
|
||||
shmem,
|
||||
tgpig,
|
||||
tiitg,
|
||||
sgitg);
|
||||
|
Loading…
Reference in New Issue
Block a user