mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 02:44:36 +00:00
4x4 -> 4x
This commit is contained in:
parent
bf3494345e
commit
dafedd33d2
@ -1962,7 +1962,7 @@ static void ggml_metal_encode_node(
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32].pipeline;
|
||||
|
||||
const int nsg = 2;
|
||||
const int r0pt = 2;
|
||||
const int r0pt = 1;
|
||||
const int r1pt = 1;
|
||||
const int nxpsg = ne11 > 1 ? 8 : 32;
|
||||
const int nypsg = 32/nxpsg;
|
||||
|
@ -170,6 +170,26 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg
|
||||
reg = (type4x4) reg_f;
|
||||
}
|
||||
|
||||
template <typename type4>
|
||||
void dequantize_q4_0x(device const block_q4_0 *xb, short il, thread type4 & reg) {
|
||||
device const int8_t * qs = ((device const int8_t *)xb->qs);
|
||||
const half d = xb->d;
|
||||
|
||||
for (int i = 0; i < 4; i++) {
|
||||
reg[i] = qs[0];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename type4>
|
||||
void dequantize_q8_0x(device const block_q8_0 *xb, short il, thread type4 & reg) {
|
||||
device const int8_t * qs = ((device const int8_t *)xb->qs);
|
||||
const half d = xb->d;
|
||||
|
||||
for (int i = 0; i < 4; i++) {
|
||||
reg[i] = (qs[4*(il%4) + i + 16*(il/4)] * d);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
|
||||
const float d = xb->d;
|
||||
@ -1762,8 +1782,8 @@ void kernel_mul_mv_ext_q8_0_f32_impl(
|
||||
ushort3 ntg[[threads_per_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
const short chpt = 1;
|
||||
const short r0pt = 2;
|
||||
const short chpt = 4;
|
||||
const short r0pt = 1;
|
||||
|
||||
//const short nxpsg = (32);
|
||||
const short nypsg = (32/nxpsg)*r0pt;
|
||||
@ -1784,36 +1804,31 @@ void kernel_mul_mv_ext_q8_0_f32_impl(
|
||||
device const block_q8_0 * xq[r0pt];
|
||||
|
||||
for (short ir0 = 0; ir0 < r0pt; ++ir0) {
|
||||
xq[ir0] = (i01 + ir0 < args.ne01) ? (device const block_q8_0 *) (src0 + offset0 + ir0*args.nb01) + (chpt*tx)/2 : (device const block_q8_0 *) src0;
|
||||
//xq[ir0] = (i01 + ir0 < args.ne01) ? (device const block_q8_0 *) (src0 + offset0 + ir0*args.nb01) + (chpt*tx)/8 : (device const block_q8_0 *) src0;
|
||||
xq[ir0] = (i01 + ir0 < args.ne01) ? (device const block_q8_0 *) (src0 + offset0 + ir0*args.nb01) + (tx)/8 : (device const block_q8_0 *) src0;
|
||||
}
|
||||
|
||||
device const float4x4 * y4x4 = (device const float4x4 *) (src1 + offset1) + chpt*tx;
|
||||
//device const float4 * y4 = (device const float4 *) (src1 + offset1) + chpt*tx;
|
||||
device const float4 * y4 = (device const float4 *) (src1 + offset1) + tx;
|
||||
|
||||
float sumf[r0pt] = { [0 ... r0pt - 1] = 0.0f };
|
||||
|
||||
for (int iib = 0; (16*chpt)*(iib*nxpsg + tx) < args.ne00; ++iib) {
|
||||
float4x4 lx;
|
||||
|
||||
#pragma unroll(2)
|
||||
for (int iib = 0; (4*chpt)*(iib*nxpsg + tx) < args.ne00; ++iib) {
|
||||
for (short ir0 = 0; ir0 < r0pt; ++ir0) {
|
||||
#pragma unroll
|
||||
#pragma unroll(4)
|
||||
for (short ch = 0; ch < chpt; ++ch) {
|
||||
dequantize_q8_0(xq[ir0] + ch/2, (chpt*tx + ch)%2, lx);
|
||||
float4 lx;
|
||||
|
||||
const float4x4 ly = y4x4[ch];
|
||||
dequantize_q8_0x(xq[ir0] + (ch*nxpsg)/8, (tx)%8, lx);
|
||||
|
||||
sumf[ir0] +=
|
||||
dot(lx[0], ly[0]) +
|
||||
dot(lx[1], ly[1]) +
|
||||
dot(lx[2], ly[2]) +
|
||||
dot(lx[3], ly[3]);
|
||||
sumf[ir0] += dot(lx, y4[ch*nxpsg]);
|
||||
}
|
||||
}
|
||||
|
||||
y4x4 += ((16*chpt)*nxpsg)/16;
|
||||
y4 += ((4*chpt)*nxpsg)/4;
|
||||
|
||||
for (short ir0 = 0; ir0 < r0pt; ++ir0) {
|
||||
xq[ir0] += ((16*chpt)*nxpsg)/32;
|
||||
xq[ir0] += ((4*chpt)*nxpsg)/32;
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user