mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-10 18:51:45 +00:00
metal : GGML_OP_CPY
This commit is contained in:
parent
f46f710ca6
commit
647a7044f5
@ -494,6 +494,25 @@ typedef struct {
|
||||
uint64_t nb3;
|
||||
} ggml_metal_kargs_repeat;
|
||||
|
||||
typedef struct {
|
||||
int64_t ne00;
|
||||
int64_t ne01;
|
||||
int64_t ne02;
|
||||
int64_t ne03;
|
||||
uint64_t nb00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int64_t ne0;
|
||||
int64_t ne1;
|
||||
int64_t ne2;
|
||||
int64_t ne3;
|
||||
uint64_t nb0;
|
||||
uint64_t nb1;
|
||||
uint64_t nb2;
|
||||
uint64_t nb3;
|
||||
} ggml_metal_kargs_cpy;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne01;
|
||||
|
@ -1377,25 +1377,29 @@ static void ggml_metal_encode_node(
|
||||
|
||||
const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline;
|
||||
|
||||
ggml_metal_kargs_cpy args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.ne03 =*/ ne03,
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.ne0 =*/ ne0,
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.ne2 =*/ ne2,
|
||||
/*.ne3 =*/ ne3,
|
||||
/*.nb0 =*/ nb0,
|
||||
/*.nb1 =*/ nb1,
|
||||
/*.nb2 =*/ nb2,
|
||||
/*.nb3 =*/ nb3,
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
||||
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
||||
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
||||
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
||||
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
||||
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
||||
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
||||
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
||||
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
||||
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
||||
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
||||
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
||||
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
||||
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
||||
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
||||
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
|
||||
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
|
||||
|
||||
@ -3425,25 +3429,29 @@ static void ggml_metal_encode_node(
|
||||
default: GGML_ABORT("not implemented");
|
||||
}
|
||||
|
||||
ggml_metal_kargs_cpy args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.ne03 =*/ ne03,
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.ne0 =*/ ne0,
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.ne2 =*/ ne2,
|
||||
/*.ne3 =*/ ne3,
|
||||
/*.nb0 =*/ nb0,
|
||||
/*.nb1 =*/ nb1,
|
||||
/*.nb2 =*/ nb2,
|
||||
/*.nb3 =*/ nb3,
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
||||
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
||||
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
||||
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
||||
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
||||
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
||||
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
||||
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
||||
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
||||
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
||||
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
||||
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
||||
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
||||
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
||||
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
||||
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
} break;
|
||||
|
@ -3302,42 +3302,27 @@ template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_
|
||||
|
||||
template<typename T0, typename T1>
|
||||
kernel void kernel_cpy(
|
||||
device const void * src0,
|
||||
device void * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant int64_t & ne03,
|
||||
constant uint64_t & nb00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant uint64_t & nb03,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant int64_t & ne2,
|
||||
constant int64_t & ne3,
|
||||
constant uint64_t & nb0,
|
||||
constant uint64_t & nb1,
|
||||
constant uint64_t & nb2,
|
||||
constant uint64_t & nb3,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
const int64_t i03 = tgpig[2];
|
||||
const int64_t i02 = tgpig[1];
|
||||
const int64_t i01 = tgpig[0];
|
||||
constant ggml_metal_kargs_cpy & args,
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const int i03 = tgpig[2];
|
||||
const int i02 = tgpig[1];
|
||||
const int i01 = tgpig[0];
|
||||
|
||||
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
|
||||
|
||||
const int64_t i3 = n / (ne2*ne1*ne0);
|
||||
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
||||
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
||||
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
||||
const int64_t i3 = n/(args.ne2*args.ne1*args.ne0);
|
||||
const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
|
||||
const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
|
||||
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
|
||||
|
||||
device T1 * dst_data = (device T1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
||||
|
||||
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
||||
device const T0 * src = (device T0 *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
||||
for (int64_t i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
|
||||
device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
||||
dst_data[i00] = (T1) src[0];
|
||||
}
|
||||
}
|
||||
@ -3357,42 +3342,27 @@ template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy<bf
|
||||
#endif
|
||||
|
||||
kernel void kernel_cpy_f32_q8_0(
|
||||
device const float * src0,
|
||||
device void * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant int64_t & ne03,
|
||||
constant uint64_t & nb00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant uint64_t & nb03,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant int64_t & ne2,
|
||||
constant int64_t & ne3,
|
||||
constant uint64_t & nb0,
|
||||
constant uint64_t & nb1,
|
||||
constant uint64_t & nb2,
|
||||
constant uint64_t & nb3,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
const int64_t i03 = tgpig[2];
|
||||
const int64_t i02 = tgpig[1];
|
||||
const int64_t i01 = tgpig[0];
|
||||
constant ggml_metal_kargs_cpy & args,
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const int i03 = tgpig[2];
|
||||
const int i02 = tgpig[1];
|
||||
const int i01 = tgpig[0];
|
||||
|
||||
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
|
||||
|
||||
const int64_t i3 = n / (ne2*ne1*ne0);
|
||||
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
||||
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
||||
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK8_0;
|
||||
const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
|
||||
const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
|
||||
const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
|
||||
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK8_0;
|
||||
|
||||
device block_q8_0 * dst_data = (device block_q8_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
device block_q8_0 * dst_data = (device block_q8_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
||||
|
||||
for (int64_t i00 = tpitg.x*QK8_0; i00 < ne00; i00 += ntg.x*QK8_0) {
|
||||
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
||||
for (int64_t i00 = tpitg.x*QK8_0; i00 < args.ne00; i00 += ntg.x*QK8_0) {
|
||||
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
||||
|
||||
float amax = 0.0f; // absolute max
|
||||
|
||||
@ -3415,42 +3385,27 @@ kernel void kernel_cpy_f32_q8_0(
|
||||
}
|
||||
|
||||
kernel void kernel_cpy_f32_q4_0(
|
||||
device const float * src0,
|
||||
device void * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant int64_t & ne03,
|
||||
constant uint64_t & nb00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant uint64_t & nb03,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant int64_t & ne2,
|
||||
constant int64_t & ne3,
|
||||
constant uint64_t & nb0,
|
||||
constant uint64_t & nb1,
|
||||
constant uint64_t & nb2,
|
||||
constant uint64_t & nb3,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
const int64_t i03 = tgpig[2];
|
||||
const int64_t i02 = tgpig[1];
|
||||
const int64_t i01 = tgpig[0];
|
||||
constant ggml_metal_kargs_cpy & args,
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const int i03 = tgpig[2];
|
||||
const int i02 = tgpig[1];
|
||||
const int i01 = tgpig[0];
|
||||
|
||||
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
|
||||
|
||||
const int64_t i3 = n / (ne2*ne1*ne0);
|
||||
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
||||
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
||||
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_0;
|
||||
const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
|
||||
const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
|
||||
const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
|
||||
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_0;
|
||||
|
||||
device block_q4_0 * dst_data = (device block_q4_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
device block_q4_0 * dst_data = (device block_q4_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
||||
|
||||
for (int64_t i00 = tpitg.x*QK4_0; i00 < ne00; i00 += ntg.x*QK4_0) {
|
||||
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
||||
for (int64_t i00 = tpitg.x*QK4_0; i00 < args.ne00; i00 += ntg.x*QK4_0) {
|
||||
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
||||
|
||||
float amax = 0.0f; // absolute max
|
||||
float max = 0.0f;
|
||||
@ -3482,42 +3437,27 @@ kernel void kernel_cpy_f32_q4_0(
|
||||
}
|
||||
|
||||
kernel void kernel_cpy_f32_q4_1(
|
||||
device const float * src0,
|
||||
device void * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant int64_t & ne03,
|
||||
constant uint64_t & nb00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant uint64_t & nb03,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant int64_t & ne2,
|
||||
constant int64_t & ne3,
|
||||
constant uint64_t & nb0,
|
||||
constant uint64_t & nb1,
|
||||
constant uint64_t & nb2,
|
||||
constant uint64_t & nb3,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
const int64_t i03 = tgpig[2];
|
||||
const int64_t i02 = tgpig[1];
|
||||
const int64_t i01 = tgpig[0];
|
||||
constant ggml_metal_kargs_cpy & args,
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const int i03 = tgpig[2];
|
||||
const int i02 = tgpig[1];
|
||||
const int i01 = tgpig[0];
|
||||
|
||||
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
|
||||
|
||||
const int64_t i3 = n / (ne2*ne1*ne0);
|
||||
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
||||
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
||||
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_1;
|
||||
const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
|
||||
const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
|
||||
const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
|
||||
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_1;
|
||||
|
||||
device block_q4_1 * dst_data = (device block_q4_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
device block_q4_1 * dst_data = (device block_q4_1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
||||
|
||||
for (int64_t i00 = tpitg.x*QK4_1; i00 < ne00; i00 += ntg.x*QK4_1) {
|
||||
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
||||
for (int64_t i00 = tpitg.x*QK4_1; i00 < args.ne00; i00 += ntg.x*QK4_1) {
|
||||
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
||||
|
||||
float min = FLT_MAX;
|
||||
float max = -FLT_MAX;
|
||||
@ -3548,42 +3488,27 @@ kernel void kernel_cpy_f32_q4_1(
|
||||
}
|
||||
|
||||
kernel void kernel_cpy_f32_q5_0(
|
||||
device const float * src0,
|
||||
device void * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant int64_t & ne03,
|
||||
constant uint64_t & nb00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant uint64_t & nb03,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant int64_t & ne2,
|
||||
constant int64_t & ne3,
|
||||
constant uint64_t & nb0,
|
||||
constant uint64_t & nb1,
|
||||
constant uint64_t & nb2,
|
||||
constant uint64_t & nb3,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
const int64_t i03 = tgpig[2];
|
||||
const int64_t i02 = tgpig[1];
|
||||
const int64_t i01 = tgpig[0];
|
||||
constant ggml_metal_kargs_cpy & args,
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const int i03 = tgpig[2];
|
||||
const int i02 = tgpig[1];
|
||||
const int i01 = tgpig[0];
|
||||
|
||||
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
|
||||
|
||||
const int64_t i3 = n / (ne2*ne1*ne0);
|
||||
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
||||
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
||||
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_0;
|
||||
const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
|
||||
const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
|
||||
const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
|
||||
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK5_0;
|
||||
|
||||
device block_q5_0 * dst_data = (device block_q5_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
device block_q5_0 * dst_data = (device block_q5_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
||||
|
||||
for (int64_t i00 = tpitg.x*QK5_0; i00 < ne00; i00 += ntg.x*QK5_0) {
|
||||
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
||||
for (int64_t i00 = tpitg.x*QK5_0; i00 < args.ne00; i00 += ntg.x*QK5_0) {
|
||||
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
||||
|
||||
float amax = 0.0f; // absolute max
|
||||
float max = 0.0f;
|
||||
@ -3621,42 +3546,27 @@ kernel void kernel_cpy_f32_q5_0(
|
||||
}
|
||||
|
||||
kernel void kernel_cpy_f32_q5_1(
|
||||
device const float * src0,
|
||||
device void * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant int64_t & ne03,
|
||||
constant uint64_t & nb00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant uint64_t & nb03,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant int64_t & ne2,
|
||||
constant int64_t & ne3,
|
||||
constant uint64_t & nb0,
|
||||
constant uint64_t & nb1,
|
||||
constant uint64_t & nb2,
|
||||
constant uint64_t & nb3,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
const int64_t i03 = tgpig[2];
|
||||
const int64_t i02 = tgpig[1];
|
||||
const int64_t i01 = tgpig[0];
|
||||
constant ggml_metal_kargs_cpy & args,
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const int i03 = tgpig[2];
|
||||
const int i02 = tgpig[1];
|
||||
const int i01 = tgpig[0];
|
||||
|
||||
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
|
||||
|
||||
const int64_t i3 = n / (ne2*ne1*ne0);
|
||||
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
||||
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
||||
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_1;
|
||||
const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
|
||||
const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
|
||||
const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
|
||||
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK5_1;
|
||||
|
||||
device block_q5_1 * dst_data = (device block_q5_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
device block_q5_1 * dst_data = (device block_q5_1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
||||
|
||||
for (int64_t i00 = tpitg.x*QK5_1; i00 < ne00; i00 += ntg.x*QK5_1) {
|
||||
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
||||
for (int64_t i00 = tpitg.x*QK5_1; i00 < args.ne00; i00 += ntg.x*QK5_1) {
|
||||
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
||||
|
||||
float max = src[0];
|
||||
float min = src[0];
|
||||
@ -3704,42 +3614,27 @@ static inline int best_index_int8(int n, constant float * val, float x) {
|
||||
}
|
||||
|
||||
kernel void kernel_cpy_f32_iq4_nl(
|
||||
device const float * src0,
|
||||
device void * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant int64_t & ne03,
|
||||
constant uint64_t & nb00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant uint64_t & nb03,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant int64_t & ne2,
|
||||
constant int64_t & ne3,
|
||||
constant uint64_t & nb0,
|
||||
constant uint64_t & nb1,
|
||||
constant uint64_t & nb2,
|
||||
constant uint64_t & nb3,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
const int64_t i03 = tgpig[2];
|
||||
const int64_t i02 = tgpig[1];
|
||||
const int64_t i01 = tgpig[0];
|
||||
constant ggml_metal_kargs_cpy & args,
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const int i03 = tgpig[2];
|
||||
const int i02 = tgpig[1];
|
||||
const int i01 = tgpig[0];
|
||||
|
||||
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
|
||||
|
||||
const int64_t i3 = n / (ne2*ne1*ne0);
|
||||
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
||||
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
||||
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_NL;
|
||||
const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
|
||||
const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
|
||||
const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
|
||||
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_NL;
|
||||
|
||||
device block_iq4_nl * dst_data = (device block_iq4_nl *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
device block_iq4_nl * dst_data = (device block_iq4_nl *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
||||
|
||||
for (int64_t i00 = tpitg.x*QK4_NL; i00 < ne00; i00 += ntg.x*QK4_NL) {
|
||||
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
||||
for (int64_t i00 = tpitg.x*QK4_NL; i00 < args.ne00; i00 += ntg.x*QK4_NL) {
|
||||
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
||||
|
||||
float amax = 0.0f; // absolute max
|
||||
float max = 0.0f;
|
||||
@ -3775,7 +3670,6 @@ kernel void kernel_cpy_f32_iq4_nl(
|
||||
}
|
||||
|
||||
dst_data[i00/QK4_NL].d = sumq2 > 0 ? sumqx/sumq2 : d;
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user