mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-10 10:41:47 +00:00
metal : some mul_mv experiments
This commit is contained in:
parent
ab6a3b7c36
commit
59b33b9c3a
@ -192,6 +192,29 @@ typedef struct {
|
|||||||
int16_t r3;
|
int16_t r3;
|
||||||
} ggml_metal_kargs_mul_mv;
|
} ggml_metal_kargs_mul_mv;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
int32_t ne00;
|
||||||
|
int32_t ne01;
|
||||||
|
int32_t ne02;
|
||||||
|
uint64_t nb00;
|
||||||
|
uint64_t nb01;
|
||||||
|
uint64_t nb02;
|
||||||
|
uint64_t nb03;
|
||||||
|
int32_t ne10;
|
||||||
|
int32_t ne11;
|
||||||
|
int32_t ne12;
|
||||||
|
uint64_t nb10;
|
||||||
|
uint64_t nb11;
|
||||||
|
uint64_t nb12;
|
||||||
|
uint64_t nb13;
|
||||||
|
int32_t ne0;
|
||||||
|
int32_t ne1;
|
||||||
|
int16_t r2;
|
||||||
|
int16_t r3;
|
||||||
|
int16_t nsg;
|
||||||
|
int16_t nxpsg;
|
||||||
|
} ggml_metal_kargs_mul_mv_ext;
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
int32_t nei0;
|
int32_t nei0;
|
||||||
int32_t nei1;
|
int32_t nei1;
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
#import "ggml-metal.h"
|
#import "ggml-metal.h"
|
||||||
|
|
||||||
#import "ggml-impl.h"
|
#import "ggml-impl.h"
|
||||||
|
#define GGML_COMMON_DECL_C
|
||||||
|
#import "ggml-common.h"
|
||||||
#import "ggml-backend-impl.h"
|
#import "ggml-backend-impl.h"
|
||||||
#import "ggml-metal-impl.h"
|
#import "ggml-metal-impl.h"
|
||||||
|
|
||||||
@ -174,6 +176,7 @@ enum ggml_metal_kernel_type {
|
|||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32,
|
||||||
@ -693,6 +696,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32, mul_mv_ext_q8_0_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, has_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, has_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, has_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, has_simdgroup_reduction);
|
||||||
@ -1908,7 +1912,7 @@ static void ggml_metal_encode_node(
|
|||||||
|
|
||||||
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
||||||
// to the matrix-vector kernel
|
// to the matrix-vector kernel
|
||||||
int ne11_mm_min = 1;
|
int ne11_mm_min = 16;
|
||||||
|
|
||||||
#if 0
|
#if 0
|
||||||
// the numbers below are measured on M2 Ultra for 7B and 13B models
|
// the numbers below are measured on M2 Ultra for 7B and 13B models
|
||||||
@ -1932,6 +1936,55 @@ static void ggml_metal_encode_node(
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
if (src0t == GGML_TYPE_Q8_0 && (ne00%16 == 0)) {
|
||||||
|
//if (false) {
|
||||||
|
id<MTLComputePipelineState> pipeline = nil;
|
||||||
|
|
||||||
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32].pipeline;
|
||||||
|
|
||||||
|
const int nsg = 2;
|
||||||
|
const int r0pt = 2;
|
||||||
|
const int r1pt = 1;
|
||||||
|
const int nxpsg = ne11 > 1 ? 8 : 32;
|
||||||
|
const int nypsg = 32/nxpsg;
|
||||||
|
const int nr0ptg = nypsg*r0pt*nsg;
|
||||||
|
|
||||||
|
//GGML_ASSERT(ne00%1024 == 0);
|
||||||
|
//GGML_ASSERT(ne01%nr0ptg == 0);
|
||||||
|
//printf("ne01 = %lld, nr0ptg = %d, ne00 = %lld\n", ne01, nr0ptg, ne00);
|
||||||
|
|
||||||
|
ggml_metal_kargs_mul_mv_ext args = {
|
||||||
|
/*.ne00 =*/ ne00,
|
||||||
|
/*.ne01 =*/ ne01,
|
||||||
|
/*.ne02 =*/ ne02,
|
||||||
|
/*.nb00 =*/ nb00,
|
||||||
|
/*.nb01 =*/ nb01,
|
||||||
|
/*.nb02 =*/ nb02,
|
||||||
|
/*.nb03 =*/ nb03,
|
||||||
|
/*.ne10 =*/ ne10,
|
||||||
|
/*.ne11 =*/ ne11,
|
||||||
|
/*.ne12 =*/ ne12,
|
||||||
|
/*.nb10 =*/ nb10,
|
||||||
|
/*.nb11 =*/ nb11,
|
||||||
|
/*.nb12 =*/ nb12,
|
||||||
|
/*.nb13 =*/ nb13,
|
||||||
|
/*.ne0 =*/ ne0,
|
||||||
|
/*.ne1 =*/ ne1,
|
||||||
|
/*.r2 =*/ r2,
|
||||||
|
/*.r3 =*/ r3,
|
||||||
|
/*.nsg =*/ nsg,
|
||||||
|
/*.nxpsg =*/ nxpsg,
|
||||||
|
};
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
|
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||||
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||||
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
||||||
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
||||||
|
|
||||||
|
//printf("ne01 = %lld nr0ptg = %d\n", ne01, nr0ptg);
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0ptg - 1)/nr0ptg, (ne11 + r1pt - 1)/r1pt, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
||||||
|
} else
|
||||||
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
||||||
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
||||||
if ([device supportsFamily:MTLGPUFamilyApple7] &&
|
if ([device supportsFamily:MTLGPUFamilyApple7] &&
|
||||||
|
@ -1739,6 +1739,135 @@ kernel void kernel_mul_mv_q8_0_f32(
|
|||||||
kernel_mul_mv_q8_0_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
kernel_mul_mv_q8_0_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<short nsg, short nxpsg>
|
||||||
|
void kernel_mul_mv_ext_q8_0_f32_impl(
|
||||||
|
constant ggml_metal_kargs_mul_mv_ext & args,
|
||||||
|
device const char * src0,
|
||||||
|
device const char * src1,
|
||||||
|
device char * dst,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
ushort3 ntg[[threads_per_threadgroup]],
|
||||||
|
ushort tiisg[[thread_index_in_simdgroup]],
|
||||||
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
|
const short chpt = 1;
|
||||||
|
const short r0pt = 2;
|
||||||
|
|
||||||
|
//const short nxpsg = (32);
|
||||||
|
const short nypsg = (32/nxpsg)*r0pt;
|
||||||
|
|
||||||
|
const short tx = tiisg%nxpsg;
|
||||||
|
const short ty = tiisg/nxpsg;
|
||||||
|
|
||||||
|
const int i01 = tgpig.x*(nypsg*nsg) + nypsg*sgitg + ty*r0pt;
|
||||||
|
const int i11 = tgpig.y;
|
||||||
|
const int i1m = tgpig.z;
|
||||||
|
|
||||||
|
const int i12 = i1m%args.ne12;
|
||||||
|
const int i13 = i1m/args.ne12;
|
||||||
|
|
||||||
|
const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||||
|
const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
||||||
|
|
||||||
|
device const block_q8_0 * xq[r0pt];
|
||||||
|
|
||||||
|
for (short ir0 = 0; ir0 < r0pt; ++ir0) {
|
||||||
|
xq[ir0] = (i01 + ir0 < args.ne01) ? (device const block_q8_0 *) (src0 + offset0 + ir0*args.nb01) + (chpt*tx)/2 : (device const block_q8_0 *) src0;
|
||||||
|
}
|
||||||
|
|
||||||
|
device const float4x4 * y4x4 = (device const float4x4 *) (src1 + offset1) + chpt*tx;
|
||||||
|
|
||||||
|
float sumf[r0pt] = { [0 ... r0pt - 1] = 0.0f };
|
||||||
|
|
||||||
|
for (int iib = 0; (16*chpt)*(iib*nxpsg + tx) < args.ne00; ++iib) {
|
||||||
|
float4x4 lx;
|
||||||
|
|
||||||
|
#pragma unroll(2)
|
||||||
|
for (short ir0 = 0; ir0 < r0pt; ++ir0) {
|
||||||
|
#pragma unroll
|
||||||
|
for (short ch = 0; ch < chpt; ++ch) {
|
||||||
|
dequantize_q8_0(xq[ir0] + ch/2, (chpt*tx + ch)%2, lx);
|
||||||
|
|
||||||
|
const float4x4 ly = y4x4[ch];
|
||||||
|
|
||||||
|
sumf[ir0] +=
|
||||||
|
dot(lx[0], ly[0]) +
|
||||||
|
dot(lx[1], ly[1]) +
|
||||||
|
dot(lx[2], ly[2]) +
|
||||||
|
dot(lx[3], ly[3]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
y4x4 += ((16*chpt)*nxpsg)/16;
|
||||||
|
|
||||||
|
for (short ir0 = 0; ir0 < r0pt; ++ir0) {
|
||||||
|
xq[ir0] += ((16*chpt)*nxpsg)/32;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (short ir0 = 0; ir0 < r0pt; ++ir0) {
|
||||||
|
if (nxpsg >= 32) {
|
||||||
|
sumf[ir0] += simd_shuffle_down(sumf[ir0], 16);
|
||||||
|
}
|
||||||
|
if (nxpsg >= 16) {
|
||||||
|
sumf[ir0] += simd_shuffle_down(sumf[ir0], 8);
|
||||||
|
}
|
||||||
|
if (nxpsg >= 8) {
|
||||||
|
sumf[ir0] += simd_shuffle_down(sumf[ir0], 4);
|
||||||
|
}
|
||||||
|
if (nxpsg >= 4) {
|
||||||
|
sumf[ir0] += simd_shuffle_down(sumf[ir0], 2);
|
||||||
|
}
|
||||||
|
if (nxpsg >= 2) {
|
||||||
|
sumf[ir0] += simd_shuffle_down(sumf[ir0], 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
//sumf[ir0] = simd_sum(sumf[ir0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
device float * dst_f32 = (device float *) dst + (uint64_t)i1m*args.ne0*args.ne1 + (uint64_t)i11*args.ne0;
|
||||||
|
|
||||||
|
if (tx == 0) {
|
||||||
|
for (short ir0 = 0; ir0 < r0pt && i01 + ir0 < args.ne01; ++ir0) {
|
||||||
|
dst_f32[i01 + ir0] = sumf[ir0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[[host_name("kernel_mul_mv_ext_q8_0_f32")]]
|
||||||
|
kernel void kernel_mul_mv_ext_q8_0_f32(
|
||||||
|
constant ggml_metal_kargs_mul_mv_ext & args,
|
||||||
|
device const char * src0,
|
||||||
|
device const char * src1,
|
||||||
|
device char * dst,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
ushort3 ntg[[threads_per_threadgroup]],
|
||||||
|
ushort tiisg[[thread_index_in_simdgroup]],
|
||||||
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
|
switch (args.nsg) {
|
||||||
|
case 1:
|
||||||
|
switch (args.nxpsg) {
|
||||||
|
case 4: kernel_mul_mv_ext_q8_0_f32_impl<1, 4> (args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break;
|
||||||
|
case 8: kernel_mul_mv_ext_q8_0_f32_impl<1, 8> (args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break;
|
||||||
|
case 16: kernel_mul_mv_ext_q8_0_f32_impl<1, 16>(args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break;
|
||||||
|
case 32: kernel_mul_mv_ext_q8_0_f32_impl<1, 32>(args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break;
|
||||||
|
} break;
|
||||||
|
case 2:
|
||||||
|
switch (args.nxpsg) {
|
||||||
|
case 4: kernel_mul_mv_ext_q8_0_f32_impl<2, 4> (args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break;
|
||||||
|
case 8: kernel_mul_mv_ext_q8_0_f32_impl<2, 8> (args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break;
|
||||||
|
case 16: kernel_mul_mv_ext_q8_0_f32_impl<2, 16>(args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break;
|
||||||
|
case 32: kernel_mul_mv_ext_q8_0_f32_impl<2, 32>(args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break;
|
||||||
|
} break;
|
||||||
|
case 4:
|
||||||
|
switch (args.nxpsg) {
|
||||||
|
case 4: kernel_mul_mv_ext_q8_0_f32_impl<4, 4> (args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break;
|
||||||
|
case 8: kernel_mul_mv_ext_q8_0_f32_impl<4, 8> (args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break;
|
||||||
|
case 16: kernel_mul_mv_ext_q8_0_f32_impl<4, 16>(args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break;
|
||||||
|
case 32: kernel_mul_mv_ext_q8_0_f32_impl<4, 32>(args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break;
|
||||||
|
} break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#define N_MV_T_T 4
|
#define N_MV_T_T 4
|
||||||
|
|
||||||
template<typename T0, typename T04, typename T1, typename T14, typename args_t>
|
template<typename T0, typename T04, typename T1, typename T14, typename args_t>
|
||||||
|
Loading…
Reference in New Issue
Block a user