mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 10:54:36 +00:00
4x4 -> 4x
This commit is contained in:
parent
bf3494345e
commit
dafedd33d2
@ -1962,7 +1962,7 @@ static void ggml_metal_encode_node(
|
|||||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32].pipeline;
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32].pipeline;
|
||||||
|
|
||||||
const int nsg = 2;
|
const int nsg = 2;
|
||||||
const int r0pt = 2;
|
const int r0pt = 1;
|
||||||
const int r1pt = 1;
|
const int r1pt = 1;
|
||||||
const int nxpsg = ne11 > 1 ? 8 : 32;
|
const int nxpsg = ne11 > 1 ? 8 : 32;
|
||||||
const int nypsg = 32/nxpsg;
|
const int nypsg = 32/nxpsg;
|
||||||
|
@ -170,6 +170,26 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg
|
|||||||
reg = (type4x4) reg_f;
|
reg = (type4x4) reg_f;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename type4>
|
||||||
|
void dequantize_q4_0x(device const block_q4_0 *xb, short il, thread type4 & reg) {
|
||||||
|
device const int8_t * qs = ((device const int8_t *)xb->qs);
|
||||||
|
const half d = xb->d;
|
||||||
|
|
||||||
|
for (int i = 0; i < 4; i++) {
|
||||||
|
reg[i] = qs[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename type4>
|
||||||
|
void dequantize_q8_0x(device const block_q8_0 *xb, short il, thread type4 & reg) {
|
||||||
|
device const int8_t * qs = ((device const int8_t *)xb->qs);
|
||||||
|
const half d = xb->d;
|
||||||
|
|
||||||
|
for (int i = 0; i < 4; i++) {
|
||||||
|
reg[i] = (qs[4*(il%4) + i + 16*(il/4)] * d);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename type4x4>
|
template <typename type4x4>
|
||||||
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
|
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
|
||||||
const float d = xb->d;
|
const float d = xb->d;
|
||||||
@ -1762,8 +1782,8 @@ void kernel_mul_mv_ext_q8_0_f32_impl(
|
|||||||
ushort3 ntg[[threads_per_threadgroup]],
|
ushort3 ntg[[threads_per_threadgroup]],
|
||||||
ushort tiisg[[thread_index_in_simdgroup]],
|
ushort tiisg[[thread_index_in_simdgroup]],
|
||||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
const short chpt = 1;
|
const short chpt = 4;
|
||||||
const short r0pt = 2;
|
const short r0pt = 1;
|
||||||
|
|
||||||
//const short nxpsg = (32);
|
//const short nxpsg = (32);
|
||||||
const short nypsg = (32/nxpsg)*r0pt;
|
const short nypsg = (32/nxpsg)*r0pt;
|
||||||
@ -1784,36 +1804,31 @@ void kernel_mul_mv_ext_q8_0_f32_impl(
|
|||||||
device const block_q8_0 * xq[r0pt];
|
device const block_q8_0 * xq[r0pt];
|
||||||
|
|
||||||
for (short ir0 = 0; ir0 < r0pt; ++ir0) {
|
for (short ir0 = 0; ir0 < r0pt; ++ir0) {
|
||||||
xq[ir0] = (i01 + ir0 < args.ne01) ? (device const block_q8_0 *) (src0 + offset0 + ir0*args.nb01) + (chpt*tx)/2 : (device const block_q8_0 *) src0;
|
//xq[ir0] = (i01 + ir0 < args.ne01) ? (device const block_q8_0 *) (src0 + offset0 + ir0*args.nb01) + (chpt*tx)/8 : (device const block_q8_0 *) src0;
|
||||||
|
xq[ir0] = (i01 + ir0 < args.ne01) ? (device const block_q8_0 *) (src0 + offset0 + ir0*args.nb01) + (tx)/8 : (device const block_q8_0 *) src0;
|
||||||
}
|
}
|
||||||
|
|
||||||
device const float4x4 * y4x4 = (device const float4x4 *) (src1 + offset1) + chpt*tx;
|
//device const float4 * y4 = (device const float4 *) (src1 + offset1) + chpt*tx;
|
||||||
|
device const float4 * y4 = (device const float4 *) (src1 + offset1) + tx;
|
||||||
|
|
||||||
float sumf[r0pt] = { [0 ... r0pt - 1] = 0.0f };
|
float sumf[r0pt] = { [0 ... r0pt - 1] = 0.0f };
|
||||||
|
|
||||||
for (int iib = 0; (16*chpt)*(iib*nxpsg + tx) < args.ne00; ++iib) {
|
for (int iib = 0; (4*chpt)*(iib*nxpsg + tx) < args.ne00; ++iib) {
|
||||||
float4x4 lx;
|
|
||||||
|
|
||||||
#pragma unroll(2)
|
|
||||||
for (short ir0 = 0; ir0 < r0pt; ++ir0) {
|
for (short ir0 = 0; ir0 < r0pt; ++ir0) {
|
||||||
#pragma unroll
|
#pragma unroll(4)
|
||||||
for (short ch = 0; ch < chpt; ++ch) {
|
for (short ch = 0; ch < chpt; ++ch) {
|
||||||
dequantize_q8_0(xq[ir0] + ch/2, (chpt*tx + ch)%2, lx);
|
float4 lx;
|
||||||
|
|
||||||
const float4x4 ly = y4x4[ch];
|
dequantize_q8_0x(xq[ir0] + (ch*nxpsg)/8, (tx)%8, lx);
|
||||||
|
|
||||||
sumf[ir0] +=
|
sumf[ir0] += dot(lx, y4[ch*nxpsg]);
|
||||||
dot(lx[0], ly[0]) +
|
|
||||||
dot(lx[1], ly[1]) +
|
|
||||||
dot(lx[2], ly[2]) +
|
|
||||||
dot(lx[3], ly[3]);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
y4x4 += ((16*chpt)*nxpsg)/16;
|
y4 += ((4*chpt)*nxpsg)/4;
|
||||||
|
|
||||||
for (short ir0 = 0; ir0 < r0pt; ++ir0) {
|
for (short ir0 = 0; ir0 < r0pt; ++ir0) {
|
||||||
xq[ir0] += ((16*chpt)*nxpsg)/32;
|
xq[ir0] += ((4*chpt)*nxpsg)/32;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user