mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-01 06:14:35 +00:00
metal : GGML_OP_ADD, GGML_OP_SUB, GGML_OP_MUL, GGML_OP_DIV
Some checks are pending
flake8 Lint / Lint (push) Waiting to run
Some checks are pending
flake8 Lint / Lint (push) Waiting to run
This commit is contained in:
parent
e418ccf209
commit
54c859e195
@ -447,6 +447,34 @@ typedef struct {
|
||||
int32_t dim;
|
||||
} ggml_metal_kargs_concat;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne01;
|
||||
int32_t ne02;
|
||||
int32_t ne03;
|
||||
uint64_t nb00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int32_t ne10;
|
||||
int32_t ne11;
|
||||
int32_t ne12;
|
||||
int32_t ne13;
|
||||
uint64_t nb10;
|
||||
uint64_t nb11;
|
||||
uint64_t nb12;
|
||||
uint64_t nb13;
|
||||
int32_t ne0;
|
||||
int32_t ne1;
|
||||
int32_t ne2;
|
||||
int32_t ne3;
|
||||
uint64_t nb0;
|
||||
uint64_t nb1;
|
||||
uint64_t nb2;
|
||||
uint64_t nb3;
|
||||
uint64_t offs;
|
||||
} ggml_metal_kargs_bin;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne01;
|
||||
|
@ -1243,8 +1243,6 @@ static void ggml_metal_encode_node(
|
||||
|
||||
bool bcast_row = false;
|
||||
|
||||
int64_t nb = ne00; // used by the "row" kernels
|
||||
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
|
||||
@ -1253,7 +1251,6 @@ static void ggml_metal_encode_node(
|
||||
// src1 is a row
|
||||
GGML_ASSERT(ne11 == 1);
|
||||
|
||||
nb = ne00 / 4;
|
||||
switch (dst->op) {
|
||||
case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
|
||||
case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline; break;
|
||||
@ -1273,36 +1270,39 @@ static void ggml_metal_encode_node(
|
||||
}
|
||||
}
|
||||
|
||||
ggml_metal_kargs_bin args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.ne03 =*/ ne03,
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.ne10 =*/ ne10,
|
||||
/*.ne11 =*/ ne11,
|
||||
/*.ne12 =*/ ne12,
|
||||
/*.ne13 =*/ ne13,
|
||||
/*.nb10 =*/ nb10,
|
||||
/*.nb11 =*/ nb11,
|
||||
/*.nb12 =*/ nb12,
|
||||
/*.nb13 =*/ nb13,
|
||||
/*.ne0 =*/ ne0,
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.ne2 =*/ ne2,
|
||||
/*.ne3 =*/ ne3,
|
||||
/*.nb0 =*/ nb0,
|
||||
/*.nb1 =*/ nb1,
|
||||
/*.nb2 =*/ nb2,
|
||||
/*.nb3 =*/ nb3,
|
||||
/*.offs =*/ offs,
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
||||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
||||
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
|
||||
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
|
||||
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
|
||||
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
|
||||
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
|
||||
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
|
||||
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
|
||||
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
|
||||
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
|
||||
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
|
||||
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
|
||||
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
|
||||
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
|
||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
|
||||
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
|
||||
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
|
||||
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
|
||||
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
|
||||
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
|
||||
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
|
||||
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
|
||||
[encoder setBytes:&offs length:sizeof(offs) atIndex:27];
|
||||
[encoder setBytes:&nb length:sizeof(nb) atIndex:28];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
||||
|
||||
if (bcast_row) {
|
||||
const int64_t n = ggml_nelements(dst)/4;
|
||||
@ -1400,35 +1400,39 @@ static void ggml_metal_encode_node(
|
||||
|
||||
const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline;
|
||||
|
||||
ggml_metal_kargs_bin args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.ne03 =*/ ne03,
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ pnb1,
|
||||
/*.nb02 =*/ pnb2,
|
||||
/*.nb03 =*/ pnb3,
|
||||
/*.ne10 =*/ ne10,
|
||||
/*.ne11 =*/ ne11,
|
||||
/*.ne12 =*/ ne12,
|
||||
/*.ne13 =*/ ne13,
|
||||
/*.nb10 =*/ nb10,
|
||||
/*.nb11 =*/ nb11,
|
||||
/*.nb12 =*/ nb12,
|
||||
/*.nb13 =*/ nb13,
|
||||
/*.ne0 =*/ ne0,
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.ne2 =*/ ne2,
|
||||
/*.ne3 =*/ ne3,
|
||||
/*.nb0 =*/ nb0,
|
||||
/*.nb1 =*/ pnb1,
|
||||
/*.nb2 =*/ pnb2,
|
||||
/*.nb3 =*/ pnb3,
|
||||
/*.offs =*/ offs,
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
||||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
||||
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
|
||||
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
|
||||
[encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8];
|
||||
[encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9];
|
||||
[encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10];
|
||||
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
|
||||
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
|
||||
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
|
||||
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
|
||||
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
|
||||
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
|
||||
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
|
||||
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
|
||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
|
||||
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
|
||||
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
|
||||
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
|
||||
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
|
||||
[encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24];
|
||||
[encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25];
|
||||
[encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
|
||||
[encoder setBytes:&offs length:sizeof(offs) atIndex:27];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
||||
|
||||
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
|
||||
|
||||
|
@ -493,200 +493,106 @@ enum ggml_sort_order {
|
||||
// pros: works for non-contiguous tensors, supports broadcast across all dims
|
||||
// cons: not very efficient
|
||||
kernel void kernel_add(
|
||||
constant ggml_metal_kargs_bin & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant int64_t & ne03,
|
||||
constant uint64_t & nb00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant uint64_t & nb03,
|
||||
constant int64_t & ne10,
|
||||
constant int64_t & ne11,
|
||||
constant int64_t & ne12,
|
||||
constant int64_t & ne13,
|
||||
constant uint64_t & nb10,
|
||||
constant uint64_t & nb11,
|
||||
constant uint64_t & nb12,
|
||||
constant uint64_t & nb13,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant int64_t & ne2,
|
||||
constant int64_t & ne3,
|
||||
constant uint64_t & nb0,
|
||||
constant uint64_t & nb1,
|
||||
constant uint64_t & nb2,
|
||||
constant uint64_t & nb3,
|
||||
constant int64_t & offs,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
const int64_t i03 = tgpig.z;
|
||||
const int64_t i02 = tgpig.y;
|
||||
const int64_t i01 = tgpig.x;
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const int i03 = tgpig.z;
|
||||
const int i02 = tgpig.y;
|
||||
const int i01 = tgpig.x;
|
||||
|
||||
const int64_t i13 = i03 % ne13;
|
||||
const int64_t i12 = i02 % ne12;
|
||||
const int64_t i11 = i01 % ne11;
|
||||
const int i13 = i03%args.ne13;
|
||||
const int i12 = i02%args.ne12;
|
||||
const int i11 = i01%args.ne11;
|
||||
|
||||
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs;
|
||||
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
|
||||
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs;
|
||||
device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
|
||||
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
|
||||
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
|
||||
|
||||
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
||||
const int i10 = i0 % ne10;
|
||||
*((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) + *((device float *)(src1_ptr + i10*nb10));
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
const int i10 = i0%args.ne10;
|
||||
*((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) + *((device float *)(src1_ptr + i10*args.nb10));
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_sub(
|
||||
constant ggml_metal_kargs_bin & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant int64_t & ne03,
|
||||
constant uint64_t & nb00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant uint64_t & nb03,
|
||||
constant int64_t & ne10,
|
||||
constant int64_t & ne11,
|
||||
constant int64_t & ne12,
|
||||
constant int64_t & ne13,
|
||||
constant uint64_t & nb10,
|
||||
constant uint64_t & nb11,
|
||||
constant uint64_t & nb12,
|
||||
constant uint64_t & nb13,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant int64_t & ne2,
|
||||
constant int64_t & ne3,
|
||||
constant uint64_t & nb0,
|
||||
constant uint64_t & nb1,
|
||||
constant uint64_t & nb2,
|
||||
constant uint64_t & nb3,
|
||||
constant int64_t & offs,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
const int64_t i03 = tgpig.z;
|
||||
const int64_t i02 = tgpig.y;
|
||||
const int64_t i01 = tgpig.x;
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const int i03 = tgpig.z;
|
||||
const int i02 = tgpig.y;
|
||||
const int i01 = tgpig.x;
|
||||
|
||||
const int64_t i13 = i03 % ne13;
|
||||
const int64_t i12 = i02 % ne12;
|
||||
const int64_t i11 = i01 % ne11;
|
||||
const int i13 = i03%args.ne13;
|
||||
const int i12 = i02%args.ne12;
|
||||
const int i11 = i01%args.ne11;
|
||||
|
||||
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs;
|
||||
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
|
||||
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs;
|
||||
device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
|
||||
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
|
||||
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
|
||||
|
||||
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
||||
const int i10 = i0 % ne10;
|
||||
*((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) - *((device float *)(src1_ptr + i10*nb10));
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
const int i10 = i0%args.ne10;
|
||||
*((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) - *((device float *)(src1_ptr + i10*args.nb10));
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_mul(
|
||||
constant ggml_metal_kargs_bin & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant int64_t & ne03,
|
||||
constant uint64_t & nb00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant uint64_t & nb03,
|
||||
constant int64_t & ne10,
|
||||
constant int64_t & ne11,
|
||||
constant int64_t & ne12,
|
||||
constant int64_t & ne13,
|
||||
constant uint64_t & nb10,
|
||||
constant uint64_t & nb11,
|
||||
constant uint64_t & nb12,
|
||||
constant uint64_t & nb13,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant int64_t & ne2,
|
||||
constant int64_t & ne3,
|
||||
constant uint64_t & nb0,
|
||||
constant uint64_t & nb1,
|
||||
constant uint64_t & nb2,
|
||||
constant uint64_t & nb3,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
const int64_t i03 = tgpig.z;
|
||||
const int64_t i02 = tgpig.y;
|
||||
const int64_t i01 = tgpig.x;
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const int i03 = tgpig.z;
|
||||
const int i02 = tgpig.y;
|
||||
const int i01 = tgpig.x;
|
||||
|
||||
const int64_t i13 = i03 % ne13;
|
||||
const int64_t i12 = i02 % ne12;
|
||||
const int64_t i11 = i01 % ne11;
|
||||
const int i13 = i03%args.ne13;
|
||||
const int i12 = i02%args.ne12;
|
||||
const int i11 = i01%args.ne11;
|
||||
|
||||
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
|
||||
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
|
||||
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
|
||||
device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
|
||||
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
|
||||
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1;
|
||||
|
||||
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
||||
const int i10 = i0 % ne10;
|
||||
*((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) * *((device float *)(src1_ptr + i10*nb10));
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
const int i10 = i0%args.ne10;
|
||||
*((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10));
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_div(
|
||||
constant ggml_metal_kargs_bin & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant int64_t & ne03,
|
||||
constant uint64_t & nb00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant uint64_t & nb03,
|
||||
constant int64_t & ne10,
|
||||
constant int64_t & ne11,
|
||||
constant int64_t & ne12,
|
||||
constant int64_t & ne13,
|
||||
constant uint64_t & nb10,
|
||||
constant uint64_t & nb11,
|
||||
constant uint64_t & nb12,
|
||||
constant uint64_t & nb13,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant int64_t & ne2,
|
||||
constant int64_t & ne3,
|
||||
constant uint64_t & nb0,
|
||||
constant uint64_t & nb1,
|
||||
constant uint64_t & nb2,
|
||||
constant uint64_t & nb3,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
const int64_t i03 = tgpig.z;
|
||||
const int64_t i02 = tgpig.y;
|
||||
const int64_t i01 = tgpig.x;
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const int i03 = tgpig.z;
|
||||
const int i02 = tgpig.y;
|
||||
const int i01 = tgpig.x;
|
||||
|
||||
const int64_t i13 = i03 % ne13;
|
||||
const int64_t i12 = i02 % ne12;
|
||||
const int64_t i11 = i01 % ne11;
|
||||
const int i13 = i03%args.ne13;
|
||||
const int i12 = i02%args.ne12;
|
||||
const int i11 = i01%args.ne11;
|
||||
|
||||
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
|
||||
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
|
||||
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
|
||||
device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
|
||||
device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
|
||||
device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1;
|
||||
|
||||
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
||||
const int i10 = i0 % ne10;
|
||||
*((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) / *((device float *)(src1_ptr + i10*nb10));
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
const int i10 = i0%args.ne10;
|
||||
*((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10));
|
||||
}
|
||||
}
|
||||
|
||||
@ -740,38 +646,42 @@ template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat
|
||||
// assumption: src1 is a row
|
||||
// broadcast src1 into src0
|
||||
kernel void kernel_add_row(
|
||||
constant ggml_metal_kargs_bin & args,
|
||||
device const float4 * src0,
|
||||
device const float4 * src1,
|
||||
device float4 * dst,
|
||||
constant uint64_t & nb [[buffer(28)]],
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
const uint nb = args.ne00/4;
|
||||
dst[tpig] = src0[tpig] + src1[tpig % nb];
|
||||
}
|
||||
|
||||
kernel void kernel_sub_row(
|
||||
constant ggml_metal_kargs_bin & args,
|
||||
device const float4 * src0,
|
||||
device const float4 * src1,
|
||||
device float4 * dst,
|
||||
constant uint64_t & nb [[buffer(28)]],
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
const uint nb = args.ne00/4;
|
||||
dst[tpig] = src0[tpig] - src1[tpig % nb];
|
||||
}
|
||||
|
||||
kernel void kernel_mul_row(
|
||||
constant ggml_metal_kargs_bin & args,
|
||||
device const float4 * src0,
|
||||
device const float4 * src1,
|
||||
device float4 * dst,
|
||||
constant uint64_t & nb [[buffer(28)]],
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
const uint nb = args.ne00/4;
|
||||
dst[tpig] = src0[tpig] * src1[tpig % nb];
|
||||
}
|
||||
|
||||
kernel void kernel_div_row(
|
||||
constant ggml_metal_kargs_bin & args,
|
||||
device const float4 * src0,
|
||||
device const float4 * src1,
|
||||
device float4 * dst,
|
||||
constant uint64_t & nb [[buffer(28)]],
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
const uint nb = args.ne00/4;
|
||||
dst[tpig] = src0[tpig] / src1[tpig % nb];
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user