mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-28 12:24:35 +00:00
multi-thread across src1 rows
This commit is contained in:
parent
0e15a0863f
commit
58ce4e2846
@ -1956,7 +1956,7 @@ static void ggml_metal_encode_node(
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
if (src0t == GGML_TYPE_Q8_0 && (ne00%16 == 0) && (ne11 >= 4 && ne11 < 32)) {
|
if (src0t == GGML_TYPE_Q8_0 && (ne00%16 == 0) && (ne11 >= 2 && ne11 < 32)) {
|
||||||
//if (false) {
|
//if (false) {
|
||||||
id<MTLComputePipelineState> pipeline = nil;
|
id<MTLComputePipelineState> pipeline = nil;
|
||||||
|
|
||||||
@ -1964,7 +1964,7 @@ static void ggml_metal_encode_node(
|
|||||||
|
|
||||||
const int nsg = 2;
|
const int nsg = 2;
|
||||||
const int r0pt = 1;
|
const int r0pt = 1;
|
||||||
const int r1pt = 1;
|
const int r1pt = 4;
|
||||||
const int nxpsg = ne11 > 1 ? 8 : 32;
|
const int nxpsg = ne11 > 1 ? 8 : 32;
|
||||||
const int nypsg = 32/nxpsg;
|
const int nypsg = 32/nxpsg;
|
||||||
const int nr0ptg = nypsg*r0pt*nsg;
|
const int nr0ptg = nypsg*r0pt*nsg;
|
||||||
|
@ -1784,6 +1784,7 @@ void kernel_mul_mv_ext_q8_0_f32_impl(
|
|||||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
const short chpt = 4;
|
const short chpt = 4;
|
||||||
const short r0pt = 1;
|
const short r0pt = 1;
|
||||||
|
const short r1pt = 4;
|
||||||
|
|
||||||
//const short nxpsg = (32);
|
//const short nxpsg = (32);
|
||||||
const short nypsg = (32/nxpsg)*r0pt;
|
const short nypsg = (32/nxpsg)*r0pt;
|
||||||
@ -1792,7 +1793,7 @@ void kernel_mul_mv_ext_q8_0_f32_impl(
|
|||||||
const short ty = tiisg/nxpsg;
|
const short ty = tiisg/nxpsg;
|
||||||
|
|
||||||
const int i01 = tgpig.x*(nypsg*nsg) + nypsg*sgitg + ty*r0pt;
|
const int i01 = tgpig.x*(nypsg*nsg) + nypsg*sgitg + ty*r0pt;
|
||||||
const int i11 = tgpig.y;
|
const int i11 = tgpig.y*r1pt;
|
||||||
const int i1m = tgpig.z;
|
const int i1m = tgpig.z;
|
||||||
|
|
||||||
const int i12 = i1m%args.ne12;
|
const int i12 = i1m%args.ne12;
|
||||||
@ -1801,6 +1802,9 @@ void kernel_mul_mv_ext_q8_0_f32_impl(
|
|||||||
const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||||
const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
||||||
|
|
||||||
|
//device const float4 * y4 = (device const float4 *) (src1 + offset1) + chpt*tx;
|
||||||
|
//device const float4 * y4 = (device const float4 *) (src1 + offset1) + tx;
|
||||||
|
|
||||||
device const block_q8_0 * xq[r0pt];
|
device const block_q8_0 * xq[r0pt];
|
||||||
|
|
||||||
for (short ir0 = 0; ir0 < r0pt; ++ir0) {
|
for (short ir0 = 0; ir0 < r0pt; ++ir0) {
|
||||||
@ -1808,10 +1812,14 @@ void kernel_mul_mv_ext_q8_0_f32_impl(
|
|||||||
xq[ir0] = (i01 + ir0 < args.ne01) ? (device const block_q8_0 *) (src0 + offset0 + ir0*args.nb01) + (tx)/8 : (device const block_q8_0 *) src0;
|
xq[ir0] = (i01 + ir0 < args.ne01) ? (device const block_q8_0 *) (src0 + offset0 + ir0*args.nb01) + (tx)/8 : (device const block_q8_0 *) src0;
|
||||||
}
|
}
|
||||||
|
|
||||||
//device const float4 * y4 = (device const float4 *) (src1 + offset1) + chpt*tx;
|
device const float4 * y4[r1pt];
|
||||||
device const float4 * y4 = (device const float4 *) (src1 + offset1) + tx;
|
for (int ir1 = 0; ir1 < r1pt; ++ir1) {
|
||||||
|
//y4[ir1] = (device const float4 *) (src1 + offset1 + ir1*args.nb11) + tx;
|
||||||
|
y4[ir1] = (i11 + ir1 < args.ne11) ? (device const float4 *) (src1 + offset1 + ir1*args.nb11) + tx : (device const float4 *) src1;
|
||||||
|
}
|
||||||
|
|
||||||
float sumf[r0pt] = { [0 ... r0pt - 1] = 0.0f };
|
//float sumf[r0pt] = { [0 ... r0pt - 1] = 0.0f };
|
||||||
|
float sumf[r1pt][r0pt] = { [ 0 ... r1pt - 1 ] = { [0 ... r0pt - 1] = 0.0f } };
|
||||||
|
|
||||||
for (int iib = 0; (4*chpt)*(iib*nxpsg + tx) < args.ne00; ++iib) {
|
for (int iib = 0; (4*chpt)*(iib*nxpsg + tx) < args.ne00; ++iib) {
|
||||||
for (short ir0 = 0; ir0 < r0pt; ++ir0) {
|
for (short ir0 = 0; ir0 < r0pt; ++ir0) {
|
||||||
@ -1821,42 +1829,53 @@ void kernel_mul_mv_ext_q8_0_f32_impl(
|
|||||||
|
|
||||||
dequantize_q8_0x(xq[ir0] + (ch*nxpsg)/8, (tx)%8, lx);
|
dequantize_q8_0x(xq[ir0] + (ch*nxpsg)/8, (tx)%8, lx);
|
||||||
|
|
||||||
sumf[ir0] += dot(lx, y4[ch*nxpsg]);
|
#pragma unroll(4)
|
||||||
|
for (short ir1 = 0; ir1 < r1pt; ++ir1) {
|
||||||
|
sumf[ir1][ir0] += dot(lx, y4[ir1][ch*nxpsg]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
y4 += ((4*chpt)*nxpsg)/4;
|
for (short ir1 = 0; ir1 < r1pt; ++ir1) {
|
||||||
|
y4[ir1] += ((4*chpt)*nxpsg)/4;
|
||||||
|
}
|
||||||
|
|
||||||
for (short ir0 = 0; ir0 < r0pt; ++ir0) {
|
for (short ir0 = 0; ir0 < r0pt; ++ir0) {
|
||||||
xq[ir0] += ((4*chpt)*nxpsg)/32;
|
xq[ir0] += ((4*chpt)*nxpsg)/32;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (short ir0 = 0; ir0 < r0pt; ++ir0) {
|
for (short ir1 = 0; ir1 < r1pt; ++ir1) {
|
||||||
if (nxpsg >= 32) {
|
for (short ir0 = 0; ir0 < r0pt; ++ir0) {
|
||||||
sumf[ir0] += simd_shuffle_down(sumf[ir0], 16);
|
if (nxpsg >= 32) {
|
||||||
}
|
sumf[ir1][ir0] += simd_shuffle_down(sumf[ir1][ir0], 16);
|
||||||
if (nxpsg >= 16) {
|
}
|
||||||
sumf[ir0] += simd_shuffle_down(sumf[ir0], 8);
|
if (nxpsg >= 16) {
|
||||||
}
|
sumf[ir1][ir0] += simd_shuffle_down(sumf[ir1][ir0], 8);
|
||||||
if (nxpsg >= 8) {
|
}
|
||||||
sumf[ir0] += simd_shuffle_down(sumf[ir0], 4);
|
if (nxpsg >= 8) {
|
||||||
}
|
sumf[ir1][ir0] += simd_shuffle_down(sumf[ir1][ir0], 4);
|
||||||
if (nxpsg >= 4) {
|
}
|
||||||
sumf[ir0] += simd_shuffle_down(sumf[ir0], 2);
|
if (nxpsg >= 4) {
|
||||||
}
|
sumf[ir1][ir0] += simd_shuffle_down(sumf[ir1][ir0], 2);
|
||||||
if (nxpsg >= 2) {
|
}
|
||||||
sumf[ir0] += simd_shuffle_down(sumf[ir0], 1);
|
if (nxpsg >= 2) {
|
||||||
}
|
sumf[ir1][ir0] += simd_shuffle_down(sumf[ir1][ir0], 1);
|
||||||
|
}
|
||||||
|
|
||||||
//sumf[ir0] = simd_sum(sumf[ir0]);
|
//sumf[ir1][ir0] = simd_sum(sumf[ir1][ir0]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
device float * dst_f32 = (device float *) dst + (uint64_t)i1m*args.ne0*args.ne1 + (uint64_t)i11*args.ne0;
|
//device float * dst_f32 = (device float *) dst + (uint64_t)i1m*args.ne0*args.ne1 + (uint64_t)i11*args.ne0;
|
||||||
|
|
||||||
if (tx == 0) {
|
if (tx == 0) {
|
||||||
for (short ir0 = 0; ir0 < r0pt && i01 + ir0 < args.ne01; ++ir0) {
|
for (short ir1 = 0; ir1 < r1pt && i11 + ir1 < args.ne11; ++ir1) {
|
||||||
dst_f32[i01 + ir0] = sumf[ir0];
|
device float * dst_f32 = (device float *) dst + (uint64_t)i1m*args.ne0*args.ne1 + (uint64_t)(i11 + ir1)*args.ne0;
|
||||||
|
|
||||||
|
for (short ir0 = 0; ir0 < r0pt && i01 + ir0 < args.ne01; ++ir0) {
|
||||||
|
dst_f32[i01 + ir0] = sumf[ir1][ir0];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -3570,6 +3570,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||||||
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 4));
|
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 4));
|
||||||
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 128, 4));
|
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 128, 4));
|
||||||
|
|
||||||
|
for (int i = 1; i < 64; ++i) {
|
||||||
|
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q8_0, GGML_TYPE_F32, 64, i, 256, { 1, 1}, {1, 1}));
|
||||||
|
}
|
||||||
|
|
||||||
#if 1
|
#if 1
|
||||||
for (ggml_type type_a : base_types) {
|
for (ggml_type type_a : base_types) {
|
||||||
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
||||||
|
Loading…
Reference in New Issue
Block a user