4x4 -> 4x
Some checks failed
Python check requirements.txt / check-requirements (push) Has been cancelled
flake8 Lint / Lint (push) Has been cancelled
Python Type-Check / pyright type-check (push) Has been cancelled

This commit is contained in:
Georgi Gerganov 2024-11-12 14:47:04 +02:00
parent bf3494345e
commit dafedd33d2
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
2 changed files with 34 additions and 19 deletions

View File

@ -1962,7 +1962,7 @@ static void ggml_metal_encode_node(
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32].pipeline; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32].pipeline;
const int nsg = 2; const int nsg = 2;
const int r0pt = 2; const int r0pt = 1;
const int r1pt = 1; const int r1pt = 1;
const int nxpsg = ne11 > 1 ? 8 : 32; const int nxpsg = ne11 > 1 ? 8 : 32;
const int nypsg = 32/nxpsg; const int nypsg = 32/nxpsg;

View File

@ -170,6 +170,26 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg
reg = (type4x4) reg_f; reg = (type4x4) reg_f;
} }
template <typename type4>
void dequantize_q4_0x(device const block_q4_0 *xb, short il, thread type4 & reg) {
device const int8_t * qs = ((device const int8_t *)xb->qs);
const half d = xb->d;
for (int i = 0; i < 4; i++) {
reg[i] = qs[0];
}
}
template <typename type4>
void dequantize_q8_0x(device const block_q8_0 *xb, short il, thread type4 & reg) {
device const int8_t * qs = ((device const int8_t *)xb->qs);
const half d = xb->d;
for (int i = 0; i < 4; i++) {
reg[i] = (qs[4*(il%4) + i + 16*(il/4)] * d);
}
}
template <typename type4x4> template <typename type4x4>
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) { void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
const float d = xb->d; const float d = xb->d;
@ -1762,8 +1782,8 @@ void kernel_mul_mv_ext_q8_0_f32_impl(
ushort3 ntg[[threads_per_threadgroup]], ushort3 ntg[[threads_per_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]], ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) { ushort sgitg[[simdgroup_index_in_threadgroup]]) {
const short chpt = 1; const short chpt = 4;
const short r0pt = 2; const short r0pt = 1;
//const short nxpsg = (32); //const short nxpsg = (32);
const short nypsg = (32/nxpsg)*r0pt; const short nypsg = (32/nxpsg)*r0pt;
@ -1784,36 +1804,31 @@ void kernel_mul_mv_ext_q8_0_f32_impl(
device const block_q8_0 * xq[r0pt]; device const block_q8_0 * xq[r0pt];
for (short ir0 = 0; ir0 < r0pt; ++ir0) { for (short ir0 = 0; ir0 < r0pt; ++ir0) {
xq[ir0] = (i01 + ir0 < args.ne01) ? (device const block_q8_0 *) (src0 + offset0 + ir0*args.nb01) + (chpt*tx)/2 : (device const block_q8_0 *) src0; //xq[ir0] = (i01 + ir0 < args.ne01) ? (device const block_q8_0 *) (src0 + offset0 + ir0*args.nb01) + (chpt*tx)/8 : (device const block_q8_0 *) src0;
xq[ir0] = (i01 + ir0 < args.ne01) ? (device const block_q8_0 *) (src0 + offset0 + ir0*args.nb01) + (tx)/8 : (device const block_q8_0 *) src0;
} }
device const float4x4 * y4x4 = (device const float4x4 *) (src1 + offset1) + chpt*tx; //device const float4 * y4 = (device const float4 *) (src1 + offset1) + chpt*tx;
device const float4 * y4 = (device const float4 *) (src1 + offset1) + tx;
float sumf[r0pt] = { [0 ... r0pt - 1] = 0.0f }; float sumf[r0pt] = { [0 ... r0pt - 1] = 0.0f };
for (int iib = 0; (16*chpt)*(iib*nxpsg + tx) < args.ne00; ++iib) { for (int iib = 0; (4*chpt)*(iib*nxpsg + tx) < args.ne00; ++iib) {
float4x4 lx;
#pragma unroll(2)
for (short ir0 = 0; ir0 < r0pt; ++ir0) { for (short ir0 = 0; ir0 < r0pt; ++ir0) {
#pragma unroll #pragma unroll(4)
for (short ch = 0; ch < chpt; ++ch) { for (short ch = 0; ch < chpt; ++ch) {
dequantize_q8_0(xq[ir0] + ch/2, (chpt*tx + ch)%2, lx); float4 lx;
const float4x4 ly = y4x4[ch]; dequantize_q8_0x(xq[ir0] + (ch*nxpsg)/8, (tx)%8, lx);
sumf[ir0] += sumf[ir0] += dot(lx, y4[ch*nxpsg]);
dot(lx[0], ly[0]) +
dot(lx[1], ly[1]) +
dot(lx[2], ly[2]) +
dot(lx[3], ly[3]);
} }
} }
y4x4 += ((16*chpt)*nxpsg)/16; y4 += ((4*chpt)*nxpsg)/4;
for (short ir0 = 0; ir0 < r0pt; ++ir0) { for (short ir0 = 0; ir0 < r0pt; ++ir0) {
xq[ir0] += ((16*chpt)*nxpsg)/32; xq[ir0] += ((4*chpt)*nxpsg)/32;
} }
} }