metal : GGML_OP_CONCAT

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-11-10 13:03:25 +02:00
parent 5d4cbc0845
commit e418ccf209
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
3 changed files with 75 additions and 67 deletions

View File

@ -419,6 +419,34 @@ typedef struct {
static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding"); static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding");
#if defined(GGML_COMMON_DECL_METAL_KARGS) #if defined(GGML_COMMON_DECL_METAL_KARGS)
typedef struct {
int32_t ne00;
int32_t ne01;
int32_t ne02;
int32_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne10;
int32_t ne11;
int32_t ne12;
int32_t ne13;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
int32_t ne0;
int32_t ne1;
int32_t ne2;
int32_t ne3;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
int32_t dim;
} ggml_metal_kargs_concat;
typedef struct { typedef struct {
int32_t ne00; int32_t ne00;
int32_t ne01; int32_t ne01;

View File

@ -1193,35 +1193,39 @@ static void ggml_metal_encode_node(
const int32_t dim = ((const int32_t *) dst->op_params)[0]; const int32_t dim = ((const int32_t *) dst->op_params)[0];
ggml_metal_kargs_concat args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne10 =*/ ne10,
/*.ne11 =*/ ne11,
/*.ne12 =*/ ne12,
/*.ne13 =*/ ne13,
/*.nb10 =*/ nb10,
/*.nb11 =*/ nb11,
/*.nb12 =*/ nb12,
/*.nb13 =*/ nb13,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.nb0 =*/ nb0,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
/*.dim =*/ dim,
};
[encoder setComputePipelineState:pipeline]; [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
[encoder setBytes:&dim length:sizeof(dim) atIndex:27];
const int nth = MIN(1024, ne0); const int nth = MIN(1024, ne0);

View File

@ -1893,7 +1893,7 @@ void kernel_mul_mv_impl(
float sumf = 0; float sumf = 0;
for (int i = tiisg; i < args.ne00/4; i += 32) { for (int i = tiisg; i < args.ne00/4; i += 32) {
sumf += dot((T14) x4[i], y4[i]); sumf += dot((float4) x4[i], (float4) y4[i]);
} }
float all_sum = simd_sum(sumf); float all_sum = simd_sum(sumf);
@ -3876,55 +3876,31 @@ kernel void kernel_cpy_f32_iq4_nl(
} }
kernel void kernel_concat( kernel void kernel_concat(
constant ggml_metal_kargs_concat & args,
device const char * src0, device const char * src0,
device const char * src1, device const char * src1,
device char * dst, device char * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant int64_t & ne13,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant uint64_t & nb13,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant uint64_t & nb0,
constant uint64_t & nb1,
constant uint64_t & nb2,
constant uint64_t & nb3,
constant int32_t & dim,
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]], ushort3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) { ushort3 ntg[[threads_per_threadgroup]]) {
const int64_t i3 = tgpig.z; const int i3 = tgpig.z;
const int64_t i2 = tgpig.y; const int i2 = tgpig.y;
const int64_t i1 = tgpig.x; const int i1 = tgpig.x;
int64_t o[4] = {0, 0, 0, 0}; int o[4] = {0, 0, 0, 0};
o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03)); o[args.dim] = args.dim == 0 ? args.ne00 : (args.dim == 1 ? args.ne01 : (args.dim == 2 ? args.ne02 : args.ne03));
device const float * x; device const float * x;
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { if (i0 < args.ne00 && i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
x = (device const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00); x = (device const float *)(src0 + (i3 )*args.nb03 + (i2 )*args.nb02 + (i1 )*args.nb01 + (i0 )*args.nb00);
} else { } else {
x = (device const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10); x = (device const float *)(src1 + (i3 - o[3])*args.nb13 + (i2 - o[2])*args.nb12 + (i1 - o[1])*args.nb11 + (i0 - o[0])*args.nb10);
} }
device float * y = (device float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); device float * y = (device float *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
*y = *x; *y = *x;
} }