mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 03:44:35 +00:00
wip
This commit is contained in:
parent
58ce4e2846
commit
c46e9e488b
@ -213,6 +213,7 @@ typedef struct {
|
||||
int16_t r3;
|
||||
int16_t nsg;
|
||||
int16_t nxpsg;
|
||||
int16_t r1pt;
|
||||
} ggml_metal_kargs_mul_mv_ext;
|
||||
|
||||
typedef struct {
|
||||
|
@ -23,6 +23,14 @@
|
||||
|
||||
#define UNUSED(x) (void)(x)
|
||||
|
||||
static int up2(int x) {
|
||||
int r = 1;
|
||||
while (r < x) {
|
||||
r *= 2;
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
// globals
|
||||
|
||||
// overload of MTLGPUFamilyMetal3 (not available in some environments)
|
||||
@ -1956,7 +1964,7 @@ static void ggml_metal_encode_node(
|
||||
}
|
||||
#endif
|
||||
|
||||
if (src0t == GGML_TYPE_Q8_0 && (ne00%16 == 0) && (ne11 >= 2 && ne11 < 32)) {
|
||||
if (src0t == GGML_TYPE_Q8_0 && (ne00%256 == 0) && (ne11 >= 2 && ne11 < 16)) {
|
||||
//if (false) {
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
@ -1964,8 +1972,8 @@ static void ggml_metal_encode_node(
|
||||
|
||||
const int nsg = 2;
|
||||
const int r0pt = 1;
|
||||
const int r1pt = 4;
|
||||
const int nxpsg = ne11 > 1 ? 8 : 32;
|
||||
const int r1pt = ne11 < 3 ? 2 : 4;
|
||||
const int nxpsg = ne11 < 3 ? 16 : 8;
|
||||
const int nypsg = 32/nxpsg;
|
||||
const int nr0ptg = nypsg*r0pt*nsg;
|
||||
|
||||
@ -1994,6 +2002,7 @@ static void ggml_metal_encode_node(
|
||||
/*.r3 =*/ r3,
|
||||
/*.nsg =*/ nsg,
|
||||
/*.nxpsg =*/ nxpsg,
|
||||
/*.r1pt =*/ r1pt,
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
|
@ -1772,19 +1772,17 @@ kernel void kernel_mul_mv_q8_0_f32(
|
||||
kernel_mul_mv_q8_0_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
template<short nsg, short nxpsg>
|
||||
template<short nsg, short nxpsg, short r1pt>
|
||||
void kernel_mul_mv_ext_q8_0_f32_impl(
|
||||
constant ggml_metal_kargs_mul_mv_ext & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 ntg[[threads_per_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
const short chpt = 4;
|
||||
const short r0pt = 1;
|
||||
const short r1pt = 4;
|
||||
|
||||
//const short nxpsg = (32);
|
||||
const short nypsg = (32/nxpsg)*r0pt;
|
||||
@ -1802,47 +1800,42 @@ void kernel_mul_mv_ext_q8_0_f32_impl(
|
||||
const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||
const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
||||
|
||||
//device const float4 * y4 = (device const float4 *) (src1 + offset1) + chpt*tx;
|
||||
//device const float4 * y4 = (device const float4 *) (src1 + offset1) + tx;
|
||||
|
||||
device const block_q8_0 * xq[r0pt];
|
||||
|
||||
for (short ir0 = 0; ir0 < r0pt; ++ir0) {
|
||||
//xq[ir0] = (i01 + ir0 < args.ne01) ? (device const block_q8_0 *) (src0 + offset0 + ir0*args.nb01) + (chpt*tx)/8 : (device const block_q8_0 *) src0;
|
||||
xq[ir0] = (i01 + ir0 < args.ne01) ? (device const block_q8_0 *) (src0 + offset0 + ir0*args.nb01) + (tx)/8 : (device const block_q8_0 *) src0;
|
||||
}
|
||||
|
||||
device const float4 * y4[r1pt];
|
||||
device const float4 * y4[r1pt];
|
||||
|
||||
for (int ir1 = 0; ir1 < r1pt; ++ir1) {
|
||||
//y4[ir1] = (device const float4 *) (src1 + offset1 + ir1*args.nb11) + tx;
|
||||
y4[ir1] = (i11 + ir1 < args.ne11) ? (device const float4 *) (src1 + offset1 + ir1*args.nb11) + tx : (device const float4 *) src1;
|
||||
y4[ir1] = (i11 + ir1 < args.ne11) ? (device const float4 *) (src1 + offset1 + ir1*args.nb11) + tx : (device const float4 *) src1;
|
||||
}
|
||||
|
||||
//float sumf[r0pt] = { [0 ... r0pt - 1] = 0.0f };
|
||||
float sumf[r1pt][r0pt] = { [ 0 ... r1pt - 1 ] = { [0 ... r0pt - 1] = 0.0f } };
|
||||
|
||||
for (int iib = 0; (4*chpt)*(iib*nxpsg + tx) < args.ne00; ++iib) {
|
||||
for (short ir0 = 0; ir0 < r0pt; ++ir0) {
|
||||
#pragma unroll(4)
|
||||
#pragma unroll(chpt)
|
||||
for (short ch = 0; ch < chpt; ++ch) {
|
||||
float4 lx;
|
||||
|
||||
dequantize_q8_0x(xq[ir0] + (ch*nxpsg)/8, (tx)%8, lx);
|
||||
|
||||
#pragma unroll(4)
|
||||
#pragma unroll(r1pt)
|
||||
for (short ir1 = 0; ir1 < r1pt; ++ir1) {
|
||||
sumf[ir1][ir0] += dot(lx, y4[ir1][ch*nxpsg]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (short ir1 = 0; ir1 < r1pt; ++ir1) {
|
||||
y4[ir1] += ((4*chpt)*nxpsg)/4;
|
||||
}
|
||||
|
||||
for (short ir0 = 0; ir0 < r0pt; ++ir0) {
|
||||
xq[ir0] += ((4*chpt)*nxpsg)/32;
|
||||
}
|
||||
|
||||
for (short ir1 = 0; ir1 < r1pt; ++ir1) {
|
||||
y4[ir1] += ((4*chpt)*nxpsg)/4;
|
||||
}
|
||||
}
|
||||
|
||||
for (short ir1 = 0; ir1 < r1pt; ++ir1) {
|
||||
@ -1867,8 +1860,6 @@ void kernel_mul_mv_ext_q8_0_f32_impl(
|
||||
}
|
||||
}
|
||||
|
||||
//device float * dst_f32 = (device float *) dst + (uint64_t)i1m*args.ne0*args.ne1 + (uint64_t)i11*args.ne0;
|
||||
|
||||
if (tx == 0) {
|
||||
for (short ir1 = 0; ir1 < r1pt && i11 + ir1 < args.ne11; ++ir1) {
|
||||
device float * dst_f32 = (device float *) dst + (uint64_t)i1m*args.ne0*args.ne1 + (uint64_t)(i11 + ir1)*args.ne0;
|
||||
@ -1887,32 +1878,37 @@ kernel void kernel_mul_mv_ext_q8_0_f32(
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 ntg[[threads_per_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
switch (args.nsg) {
|
||||
case 1:
|
||||
switch (args.nxpsg) {
|
||||
case 4: kernel_mul_mv_ext_q8_0_f32_impl<1, 4> (args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break;
|
||||
case 8: kernel_mul_mv_ext_q8_0_f32_impl<1, 8> (args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break;
|
||||
case 16: kernel_mul_mv_ext_q8_0_f32_impl<1, 16>(args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break;
|
||||
case 32: kernel_mul_mv_ext_q8_0_f32_impl<1, 32>(args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break;
|
||||
} break;
|
||||
case 2:
|
||||
switch (args.nxpsg) {
|
||||
case 4: kernel_mul_mv_ext_q8_0_f32_impl<2, 4> (args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break;
|
||||
case 8: kernel_mul_mv_ext_q8_0_f32_impl<2, 8> (args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break;
|
||||
case 16: kernel_mul_mv_ext_q8_0_f32_impl<2, 16>(args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break;
|
||||
case 32: kernel_mul_mv_ext_q8_0_f32_impl<2, 32>(args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break;
|
||||
} break;
|
||||
case 4:
|
||||
switch (args.nxpsg) {
|
||||
case 4: kernel_mul_mv_ext_q8_0_f32_impl<4, 4> (args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break;
|
||||
case 8: kernel_mul_mv_ext_q8_0_f32_impl<4, 8> (args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break;
|
||||
case 16: kernel_mul_mv_ext_q8_0_f32_impl<4, 16>(args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break;
|
||||
case 32: kernel_mul_mv_ext_q8_0_f32_impl<4, 32>(args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break;
|
||||
} break;
|
||||
#define CASE_R1PT(r1pt) \
|
||||
switch (args.nsg) { \
|
||||
case 1: \
|
||||
switch (args.nxpsg) { \
|
||||
case 4: kernel_mul_mv_ext_q8_0_f32_impl<1, 4, (r1pt)>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; \
|
||||
case 8: kernel_mul_mv_ext_q8_0_f32_impl<1, 8, (r1pt)>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; \
|
||||
case 16: kernel_mul_mv_ext_q8_0_f32_impl<1, 16, (r1pt)>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; \
|
||||
case 32: kernel_mul_mv_ext_q8_0_f32_impl<1, 32, (r1pt)>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; \
|
||||
} break; \
|
||||
case 2: \
|
||||
switch (args.nxpsg) { \
|
||||
case 4: kernel_mul_mv_ext_q8_0_f32_impl<2, 4, (r1pt)>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; \
|
||||
case 8: kernel_mul_mv_ext_q8_0_f32_impl<2, 8, (r1pt)>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; \
|
||||
case 16: kernel_mul_mv_ext_q8_0_f32_impl<2, 16, (r1pt)>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; \
|
||||
case 32: kernel_mul_mv_ext_q8_0_f32_impl<2, 32, (r1pt)>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; \
|
||||
} break; \
|
||||
case 4: \
|
||||
switch (args.nxpsg) { \
|
||||
case 4: kernel_mul_mv_ext_q8_0_f32_impl<4, 4, (r1pt)>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; \
|
||||
case 8: kernel_mul_mv_ext_q8_0_f32_impl<4, 8, (r1pt)>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; \
|
||||
case 16: kernel_mul_mv_ext_q8_0_f32_impl<4, 16, (r1pt)>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; \
|
||||
case 32: kernel_mul_mv_ext_q8_0_f32_impl<4, 32, (r1pt)>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; \
|
||||
} break; \
|
||||
}
|
||||
|
||||
switch (args.r1pt) {
|
||||
case 2: CASE_R1PT( 2); break;
|
||||
case 4: CASE_R1PT( 4); break;
|
||||
};
|
||||
}
|
||||
|
||||
#define N_MV_T_T 4
|
||||
|
@ -3571,7 +3571,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 128, 4));
|
||||
|
||||
for (int i = 1; i < 64; ++i) {
|
||||
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q8_0, GGML_TYPE_F32, 64, i, 256, { 1, 1}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q8_0, GGML_TYPE_F32, 64, i, 512, { 1, 1}, {1, 1}));
|
||||
}
|
||||
|
||||
#if 1
|
||||
|
Loading…
Reference in New Issue
Block a user