mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-09 02:01:44 +00:00
cont : mul mm id
ggml-ci
This commit is contained in:
parent
eea1f7e532
commit
3855622da9
@ -509,6 +509,24 @@ typedef struct {
|
|||||||
int16_t r3;
|
int16_t r3;
|
||||||
} ggml_metal_kargs_mul_mv;
|
} ggml_metal_kargs_mul_mv;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
int32_t nei0;
|
||||||
|
int32_t nei1;
|
||||||
|
uint64_t nbi1;
|
||||||
|
int32_t ne00;
|
||||||
|
int32_t ne02;
|
||||||
|
uint64_t nb01;
|
||||||
|
uint64_t nb02;
|
||||||
|
int32_t ne11;
|
||||||
|
int32_t ne12;
|
||||||
|
int32_t ne13;
|
||||||
|
uint64_t nb10;
|
||||||
|
uint64_t nb11;
|
||||||
|
uint64_t nb12;
|
||||||
|
int32_t ne0;
|
||||||
|
int32_t ne1;
|
||||||
|
} ggml_metal_kargs_mul_mm_id;
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
int32_t nei0;
|
int32_t nei0;
|
||||||
int32_t nei1;
|
int32_t nei1;
|
||||||
|
@ -2297,27 +2297,30 @@ static void ggml_metal_encode_node(
|
|||||||
default: GGML_ABORT("MUL_MAT_ID not implemented");
|
default: GGML_ABORT("MUL_MAT_ID not implemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ggml_metal_kargs_mul_mm_id args = {
|
||||||
|
/*.nei0 =*/ ne20,
|
||||||
|
/*.nei1 =*/ ne21,
|
||||||
|
/*.nbi1 =*/ nb21,
|
||||||
|
/*.ne00 =*/ ne00,
|
||||||
|
/*.ne02 =*/ ne02,
|
||||||
|
/*.nb01 =*/ nb01,
|
||||||
|
/*.nb02 =*/ nb02,
|
||||||
|
/*.ne11 =*/ ne11,
|
||||||
|
/*.ne12 =*/ ne12,
|
||||||
|
/*.ne13 =*/ ne13,
|
||||||
|
/*.nb10 =*/ nb10,
|
||||||
|
/*.nb11 =*/ nb11,
|
||||||
|
/*.nb12 =*/ nb12,
|
||||||
|
/*.ne0 =*/ ne0,
|
||||||
|
/*.ne1 =*/ ne1,
|
||||||
|
};
|
||||||
|
|
||||||
[encoder setComputePipelineState:pipeline];
|
[encoder setComputePipelineState:pipeline];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
||||||
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
||||||
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
|
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:4];
|
||||||
[encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
|
|
||||||
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
|
|
||||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7];
|
|
||||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:8];
|
|
||||||
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9];
|
|
||||||
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10];
|
|
||||||
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
|
|
||||||
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
|
|
||||||
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
|
|
||||||
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
|
|
||||||
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
|
|
||||||
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
|
|
||||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
|
|
||||||
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18];
|
|
||||||
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
|
|
||||||
|
|
||||||
[encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0];
|
[encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0];
|
||||||
|
|
||||||
|
@ -5755,31 +5755,32 @@ kernel void kernel_mul_mm(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in rowids
|
// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in rowids
|
||||||
|
// TODO: this kernel needs to be reimplemented from scratch for better performance
|
||||||
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
||||||
void kernel_mul_mm_id_impl(
|
void kernel_mul_mm_id_impl(
|
||||||
device const uchar * src0,
|
int32_t ne00,
|
||||||
device const uchar * src1,
|
int32_t ne02,
|
||||||
threadgroup ushort2 * rowids,
|
uint64_t nb01,
|
||||||
device float * dst,
|
uint64_t nb02,
|
||||||
constant int64_t & ne00,
|
int32_t ne11,
|
||||||
constant int64_t & ne02,
|
int32_t ne12,
|
||||||
constant uint64_t & nb01,
|
uint64_t nb10,
|
||||||
constant uint64_t & nb02,
|
uint64_t nb11,
|
||||||
constant int64_t & ne11,
|
uint64_t nb12,
|
||||||
constant int64_t & ne12,
|
int32_t ne0,
|
||||||
constant uint64_t & nb10,
|
int32_t ne1,
|
||||||
constant uint64_t & nb11,
|
|
||||||
constant uint64_t & nb12,
|
|
||||||
constant int64_t & ne0,
|
|
||||||
int64_t ne1,
|
|
||||||
int64_t ne0ne1,
|
int64_t ne0ne1,
|
||||||
threadgroup uchar * shared_memory,
|
device const char * src0,
|
||||||
|
device const char * src1,
|
||||||
|
threadgroup ushort2 * rowids,
|
||||||
|
device char * dst,
|
||||||
|
threadgroup char * shmem,
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiitg[[thread_index_in_threadgroup]],
|
ushort tiitg[[thread_index_in_threadgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
|
|
||||||
threadgroup half * sa = (threadgroup half *)(shared_memory);
|
threadgroup half * sa = (threadgroup half *)(shmem);
|
||||||
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
|
threadgroup float * sb = (threadgroup float *)(shmem + 4096);
|
||||||
|
|
||||||
const uint r0 = tgpig.y;
|
const uint r0 = tgpig.y;
|
||||||
const uint r1 = tgpig.x;
|
const uint r1 = tgpig.x;
|
||||||
@ -5796,9 +5797,9 @@ void kernel_mul_mm_id_impl(
|
|||||||
|
|
||||||
simdgroup_half8x8 ma[4];
|
simdgroup_half8x8 ma[4];
|
||||||
simdgroup_float8x8 mb[2];
|
simdgroup_float8x8 mb[2];
|
||||||
simdgroup_float8x8 c_res[8];
|
simdgroup_float8x8 mc[8];
|
||||||
for (int i = 0; i < 8; i++){
|
for (int i = 0; i < 8; i++){
|
||||||
c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
|
mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
|
||||||
}
|
}
|
||||||
short il = (tiitg % THREAD_PER_ROW);
|
short il = (tiitg % THREAD_PER_ROW);
|
||||||
|
|
||||||
@ -5836,11 +5837,14 @@ void kernel_mul_mm_id_impl(
|
|||||||
threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
|
threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
|
||||||
threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
|
threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
|
||||||
|
|
||||||
|
#pragma unroll(BLOCK_SIZE_K/8)
|
||||||
for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
|
for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
|
||||||
|
#pragma unroll(4)
|
||||||
for (int i = 0; i < 4; i++) {
|
for (int i = 0; i < 4; i++) {
|
||||||
simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i);
|
simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i);
|
||||||
}
|
}
|
||||||
simdgroup_barrier(mem_flags::mem_none);
|
simdgroup_barrier(mem_flags::mem_none);
|
||||||
|
#pragma unroll(2)
|
||||||
for (int i = 0; i < 2; i++) {
|
for (int i = 0; i < 2; i++) {
|
||||||
simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i);
|
simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i);
|
||||||
}
|
}
|
||||||
@ -5848,29 +5852,42 @@ void kernel_mul_mm_id_impl(
|
|||||||
lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
|
lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
|
||||||
lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
|
lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
|
||||||
|
|
||||||
|
#pragma unroll(8)
|
||||||
for (int i = 0; i < 8; i++){
|
for (int i = 0; i < 8; i++){
|
||||||
simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
|
simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
|
threadgroup float * temp_str = ((threadgroup float *) shmem) \
|
||||||
+ 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
|
+ 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
|
||||||
for (int i = 0; i < 8; i++) {
|
for (int i = 0; i < 8; i++) {
|
||||||
simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
|
simdgroup_store(mc[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
|
||||||
}
|
}
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
device float * C = dst + (BLOCK_SIZE_M * r0);
|
|
||||||
if (sgitg == 0) {
|
if (sgitg == 0) {
|
||||||
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
|
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
|
||||||
threadgroup const auto & jid = rowids[r1 * BLOCK_SIZE_N + j];
|
threadgroup const auto & jid = rowids[r1 * BLOCK_SIZE_N + j];
|
||||||
int joff = jid[0] * ne0 + jid[1] * ne0ne1;
|
int64_t joff = jid[0]*ne0 + jid[1]*ne0ne1;
|
||||||
for (int i = 0; i < n_rows; i++) {
|
|
||||||
*(C + i + joff) = *(temp_str + i + j * BLOCK_SIZE_M);
|
device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + joff;
|
||||||
|
device float4 * D4 = (device float4 *) D;
|
||||||
|
|
||||||
|
threadgroup float * C = temp_str + (j*BLOCK_SIZE_M);
|
||||||
|
threadgroup float4 * C4 = (threadgroup float4 *) C;
|
||||||
|
|
||||||
|
int i = 0;
|
||||||
|
for (; i < n_rows/4; i++) {
|
||||||
|
*(D4 + i) = *(C4 + i);
|
||||||
|
}
|
||||||
|
|
||||||
|
i *= 4;
|
||||||
|
for (; i < n_rows; i++) {
|
||||||
|
*(D + i) = *(C + i);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -5879,48 +5896,34 @@ void kernel_mul_mm_id_impl(
|
|||||||
|
|
||||||
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
||||||
kernel void kernel_mul_mm_id(
|
kernel void kernel_mul_mm_id(
|
||||||
device const uchar * src0s,
|
constant ggml_metal_kargs_mul_mm_id & args,
|
||||||
device const uchar * src1,
|
device const char * src0s,
|
||||||
device float * dst,
|
device const char * src1,
|
||||||
device const uchar * ids,
|
device char * dst,
|
||||||
constant int64_t & nei0,
|
device const char * ids,
|
||||||
constant int64_t & nei1,
|
threadgroup char * shmem [[threadgroup(0)]],
|
||||||
constant uint64_t & nbi1,
|
|
||||||
constant int64_t & ne00,
|
|
||||||
constant int64_t & ne02,
|
|
||||||
constant uint64_t & nb01,
|
|
||||||
constant uint64_t & nb02,
|
|
||||||
constant int64_t & ne11,
|
|
||||||
constant int64_t & ne12,
|
|
||||||
constant int64_t & ne13,
|
|
||||||
constant uint64_t & nb10,
|
|
||||||
constant uint64_t & nb11,
|
|
||||||
constant uint64_t & nb12,
|
|
||||||
constant int64_t & ne0,
|
|
||||||
constant int64_t & ne1,
|
|
||||||
constant uint64_t & nb1,
|
|
||||||
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiitg[[thread_index_in_threadgroup]],
|
ushort tiitg[[thread_index_in_threadgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
|
|
||||||
const int32_t i02 = tgpig.z;
|
const int32_t i02 = tgpig.z;
|
||||||
|
|
||||||
tgpig.z = 0;
|
tgpig.z = 0;
|
||||||
|
|
||||||
device const uchar * src0 = src0s + i02*nb02;
|
device const char * src0 = src0s + i02*args.nb02;
|
||||||
|
|
||||||
// row indices
|
// row indices
|
||||||
threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192);
|
threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shmem + 8192);
|
||||||
|
|
||||||
// TODO: parallelize this loop
|
// TODO: parallelize this loop
|
||||||
int64_t _ne1 = 0;
|
int64_t _ne1 = 0;
|
||||||
for (ushort ii1 = 0; ii1 < nei1; ii1++) {
|
for (ushort ii1 = 0; ii1 < args.nei1; ii1++) {
|
||||||
for (ushort ii0 = 0; ii0 < nei0; ii0++) {
|
for (ushort ii0 = 0; ii0 < args.nei0; ii0++) {
|
||||||
int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0];
|
int32_t id = ((device int32_t *) (ids + ii1*args.nbi1))[ii0];
|
||||||
if (id == i02) {
|
if (id == i02) {
|
||||||
//if (tiitg == 0) {
|
if (tiitg == 0) {
|
||||||
rowids[_ne1] = ushort2(ii0, ii1);
|
rowids[_ne1] = ushort2(ii0, ii1);
|
||||||
//}
|
}
|
||||||
_ne1++;
|
_ne1++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -5929,23 +5932,23 @@ kernel void kernel_mul_mm_id(
|
|||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
|
kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
|
||||||
|
args.ne00,
|
||||||
|
args.ne02,
|
||||||
|
args.nb01,
|
||||||
|
args.nb02,
|
||||||
|
args.ne11,
|
||||||
|
args.ne12,
|
||||||
|
args.nb10,
|
||||||
|
args.nb11,
|
||||||
|
args.nb12,
|
||||||
|
args.ne0,
|
||||||
|
_ne1,
|
||||||
|
(int64_t)args.ne0*args.ne1,
|
||||||
src0,
|
src0,
|
||||||
src1,
|
src1,
|
||||||
rowids,
|
rowids,
|
||||||
dst,
|
dst,
|
||||||
ne00,
|
shmem,
|
||||||
ne02,
|
|
||||||
nb01,
|
|
||||||
nb02,
|
|
||||||
ne11,
|
|
||||||
ne12,
|
|
||||||
nb10,
|
|
||||||
nb11,
|
|
||||||
nb12,
|
|
||||||
ne0,
|
|
||||||
_ne1,
|
|
||||||
ne0*ne1,
|
|
||||||
shared_memory,
|
|
||||||
tgpig,
|
tgpig,
|
||||||
tiitg,
|
tiitg,
|
||||||
sgitg);
|
sgitg);
|
||||||
|
Loading…
Reference in New Issue
Block a user