mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-07 17:21:46 +00:00
cont : int safety + register optimizations
ggml-ci
This commit is contained in:
parent
3855622da9
commit
5d4cbc0845
@ -1645,8 +1645,8 @@ void mul_vec_q_n_f32_impl(
|
|||||||
const uint i12 = im%args.ne12;
|
const uint i12 = im%args.ne12;
|
||||||
const uint i13 = im/args.ne12;
|
const uint i13 = im/args.ne12;
|
||||||
|
|
||||||
//const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
//const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||||
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
||||||
|
|
||||||
//device const block_q_type * x = (device const block_q_type *) (src0 + offset0);
|
//device const block_q_type * x = (device const block_q_type *) (src0 + offset0);
|
||||||
device const float * y = (device const float *) (src1 + offset1);
|
device const float * y = (device const float *) (src1 + offset1);
|
||||||
@ -1654,7 +1654,7 @@ void mul_vec_q_n_f32_impl(
|
|||||||
// pointers to src0 rows
|
// pointers to src0 rows
|
||||||
device const block_q_type * ax[nr];
|
device const block_q_type * ax[nr];
|
||||||
for (int row = 0; row < nr; ++row) {
|
for (int row = 0; row < nr; ++row) {
|
||||||
const uint offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||||
|
|
||||||
ax[row] = (device const block_q_type *) ((device char *) src0 + offset0);
|
ax[row] = (device const block_q_type *) ((device char *) src0 + offset0);
|
||||||
}
|
}
|
||||||
@ -1662,10 +1662,10 @@ void mul_vec_q_n_f32_impl(
|
|||||||
float yl[16]; // src1 vector cache
|
float yl[16]; // src1 vector cache
|
||||||
float sumf[nr] = {0.f};
|
float sumf[nr] = {0.f};
|
||||||
|
|
||||||
const int ix = (tiisg/2);
|
const short ix = (tiisg/2);
|
||||||
const int il = (tiisg%2)*8;
|
const short il = (tiisg%2)*8;
|
||||||
|
|
||||||
device const float * yb = y + ix * QK4_0 + il;
|
device const float * yb = y + ix*QK4_0 + il;
|
||||||
|
|
||||||
// each thread in a SIMD group deals with half a block.
|
// each thread in a SIMD group deals with half a block.
|
||||||
for (int ib = ix; ib < nb; ib += nw/2) {
|
for (int ib = ix; ib < nb; ib += nw/2) {
|
||||||
@ -1771,8 +1771,8 @@ void kernel_mul_mv_q8_0_f32_impl(
|
|||||||
const uint i12 = im%args.ne12;
|
const uint i12 = im%args.ne12;
|
||||||
const uint i13 = im/args.ne12;
|
const uint i13 = im/args.ne12;
|
||||||
|
|
||||||
//const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
//const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||||
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
||||||
|
|
||||||
//device const block_q8_0 * x = (device const block_q8_0 *) (src0 + offset0);
|
//device const block_q8_0 * x = (device const block_q8_0 *) (src0 + offset0);
|
||||||
device const float * y = (device const float *) (src1 + offset1);
|
device const float * y = (device const float *) (src1 + offset1);
|
||||||
@ -1780,7 +1780,7 @@ void kernel_mul_mv_q8_0_f32_impl(
|
|||||||
// pointers to src0 rows
|
// pointers to src0 rows
|
||||||
device const block_q8_0 * ax[nr];
|
device const block_q8_0 * ax[nr];
|
||||||
for (int row = 0; row < nr; ++row) {
|
for (int row = 0; row < nr; ++row) {
|
||||||
const uint offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||||
|
|
||||||
ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0);
|
ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0);
|
||||||
}
|
}
|
||||||
@ -1788,21 +1788,21 @@ void kernel_mul_mv_q8_0_f32_impl(
|
|||||||
float yl[NB_Q8_0];
|
float yl[NB_Q8_0];
|
||||||
float sumf[nr] = { 0.f };
|
float sumf[nr] = { 0.f };
|
||||||
|
|
||||||
const int ix = tiisg/4;
|
const short ix = tiisg/4;
|
||||||
const int il = tiisg%4;
|
const short il = tiisg%4;
|
||||||
|
|
||||||
device const float * yb = y + ix*QK8_0 + il*NB_Q8_0;
|
device const float * yb = y + ix*QK8_0 + il*NB_Q8_0;
|
||||||
|
|
||||||
// each thread in a SIMD group deals with NB_Q8_0 quants at a time
|
// each thread in a SIMD group deals with NB_Q8_0 quants at a time
|
||||||
for (int ib = ix; ib < nb; ib += nw/4) {
|
for (int ib = ix; ib < nb; ib += nw/4) {
|
||||||
for (int i = 0; i < NB_Q8_0; ++i) {
|
for (short i = 0; i < NB_Q8_0; ++i) {
|
||||||
yl[i] = yb[i];
|
yl[i] = yb[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int row = 0; row < nr; row++) {
|
for (int row = 0; row < nr; row++) {
|
||||||
device const int8_t * qs = ax[row][ib].qs + il*NB_Q8_0;
|
device const int8_t * qs = ax[row][ib].qs + il*NB_Q8_0;
|
||||||
float sumq = 0.f;
|
float sumq = 0.f;
|
||||||
for (int iq = 0; iq < NB_Q8_0; ++iq) {
|
for (short iq = 0; iq < NB_Q8_0; ++iq) {
|
||||||
sumq += qs[iq] * yl[iq];
|
sumq += qs[iq] * yl[iq];
|
||||||
}
|
}
|
||||||
sumf[row] += sumq*ax[row][ib].d;
|
sumf[row] += sumq*ax[row][ib].d;
|
||||||
@ -1811,7 +1811,7 @@ void kernel_mul_mv_q8_0_f32_impl(
|
|||||||
yb += nw*NB_Q8_0;
|
yb += nw*NB_Q8_0;
|
||||||
}
|
}
|
||||||
|
|
||||||
device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0;
|
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
||||||
|
|
||||||
for (int row = 0; row < nr; ++row) {
|
for (int row = 0; row < nr; ++row) {
|
||||||
const float tot = simd_sum(sumf[row]);
|
const float tot = simd_sum(sumf[row]);
|
||||||
@ -1844,18 +1844,18 @@ void kernel_mul_mv_impl(
|
|||||||
device char * dst,
|
device char * dst,
|
||||||
uint3 tgpig,
|
uint3 tgpig,
|
||||||
ushort tiisg) {
|
ushort tiisg) {
|
||||||
const int64_t r0 = tgpig.x;
|
const int r0 = tgpig.x;
|
||||||
const int64_t rb = tgpig.y*N_MV_T_T;
|
const int rb = tgpig.y*N_MV_T_T;
|
||||||
const int64_t im = tgpig.z;
|
const int im = tgpig.z;
|
||||||
|
|
||||||
const uint i12 = im%args.ne12;
|
const uint i12 = im%args.ne12;
|
||||||
const uint i13 = im/args.ne12;
|
const uint i13 = im/args.ne12;
|
||||||
|
|
||||||
const uint offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||||
|
|
||||||
device const T0 * x = (device const T0 *) (src0 + offset0);
|
device const T0 * x = (device const T0 *) (src0 + offset0);
|
||||||
|
|
||||||
device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1;
|
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1;
|
||||||
|
|
||||||
if (args.ne00 < 128) {
|
if (args.ne00 < 128) {
|
||||||
for (int row = 0; row < N_MV_T_T; ++row) {
|
for (int row = 0; row < N_MV_T_T; ++row) {
|
||||||
@ -1864,7 +1864,7 @@ void kernel_mul_mv_impl(
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
||||||
|
|
||||||
device const T1 * y = (device const T1 *) (src1 + offset1);
|
device const T1 * y = (device const T1 *) (src1 + offset1);
|
||||||
|
|
||||||
@ -1875,7 +1875,7 @@ void kernel_mul_mv_impl(
|
|||||||
|
|
||||||
float all_sum = simd_sum(sumf);
|
float all_sum = simd_sum(sumf);
|
||||||
if (tiisg == 0) {
|
if (tiisg == 0) {
|
||||||
dst_f32[r1*args.ne0 + r0] = all_sum;
|
dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -1886,20 +1886,20 @@ void kernel_mul_mv_impl(
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
||||||
|
|
||||||
device const T1 * y = (device const T1 *) (src1 + offset1);
|
device const T1 * y = (device const T1 *) (src1 + offset1);
|
||||||
device const T14 * y4 = (device const T14 *) y;
|
device const T14 * y4 = (device const T14 *) y;
|
||||||
|
|
||||||
float sumf = 0;
|
float sumf = 0;
|
||||||
for (int i = tiisg; i < args.ne00/4; i += 32) {
|
for (int i = tiisg; i < args.ne00/4; i += 32) {
|
||||||
for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]);
|
sumf += dot((T14) x4[i], y4[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
float all_sum = simd_sum(sumf);
|
float all_sum = simd_sum(sumf);
|
||||||
if (tiisg == 0) {
|
if (tiisg == 0) {
|
||||||
for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]);
|
for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]);
|
||||||
dst_f32[r1*args.ne0 + r0] = all_sum;
|
dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1935,25 +1935,27 @@ template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv<
|
|||||||
template<typename T, typename T4>
|
template<typename T, typename T4>
|
||||||
kernel void kernel_mul_mv_1row(
|
kernel void kernel_mul_mv_1row(
|
||||||
constant ggml_metal_kargs_mul_mv & args,
|
constant ggml_metal_kargs_mul_mv & args,
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
device float * dst,
|
device char * dst,
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
ushort tiisg[[thread_index_in_simdgroup]]) {
|
ushort tiisg[[thread_index_in_simdgroup]]) {
|
||||||
|
|
||||||
const int64_t r0 = tgpig.x;
|
const int r0 = tgpig.x;
|
||||||
const int64_t r1 = tgpig.y;
|
const int r1 = tgpig.y;
|
||||||
const int64_t im = tgpig.z;
|
const int im = tgpig.z;
|
||||||
|
|
||||||
const uint i12 = im%args.ne12;
|
const uint i12 = im%args.ne12;
|
||||||
const uint i13 = im/args.ne12;
|
const uint i13 = im/args.ne12;
|
||||||
|
|
||||||
const uint offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||||
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
||||||
|
|
||||||
device const T * x = (device const T *) (src0 + offset0);
|
device const T * x = (device const T *) (src0 + offset0);
|
||||||
device const float * y = (device const float *) (src1 + offset1);
|
device const float * y = (device const float *) (src1 + offset1);
|
||||||
|
|
||||||
|
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
||||||
|
|
||||||
float sumf = 0;
|
float sumf = 0;
|
||||||
if (args.ne00 < 128) {
|
if (args.ne00 < 128) {
|
||||||
for (int i = tiisg; i < args.ne00; i += 32) {
|
for (int i = tiisg; i < args.ne00; i += 32) {
|
||||||
@ -1961,21 +1963,21 @@ kernel void kernel_mul_mv_1row(
|
|||||||
}
|
}
|
||||||
float all_sum = simd_sum(sumf);
|
float all_sum = simd_sum(sumf);
|
||||||
if (tiisg == 0) {
|
if (tiisg == 0) {
|
||||||
dst[im*args.ne1*args.ne0 + r1*args.ne0 + r0] = all_sum;
|
dst_f32[r0] = all_sum;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
device const T4 * x4 = (device const T4 *) x;
|
device const T4 * x4 = (device const T4 *) x;
|
||||||
device const float4 * y4 = (device const float4 *) y;
|
device const float4 * y4 = (device const float4 *) y;
|
||||||
|
|
||||||
for (int i = tiisg; i < args.ne00/4; i += 32) {
|
for (int i = tiisg; i < args.ne00/4; i += 32) {
|
||||||
for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]);
|
sumf += dot((float4) x4[i], y4[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
float all_sum = simd_sum(sumf);
|
float all_sum = simd_sum(sumf);
|
||||||
|
|
||||||
if (tiisg == 0) {
|
if (tiisg == 0) {
|
||||||
for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]);
|
for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]);
|
||||||
dst[im*args.ne1*args.ne0 + r1*args.ne0 + r0] = all_sum;
|
dst_f32[r0] = all_sum;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1991,36 +1993,38 @@ template [[host_name("kernel_mul_mv_bf16_f32_1row")]] kernel mul_mv_1row_t kerne
|
|||||||
template<typename T, typename T4>
|
template<typename T, typename T4>
|
||||||
kernel void kernel_mul_mv_l4(
|
kernel void kernel_mul_mv_l4(
|
||||||
constant ggml_metal_kargs_mul_mv & args,
|
constant ggml_metal_kargs_mul_mv & args,
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
device float * dst,
|
device char * dst,
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
ushort tiisg[[thread_index_in_simdgroup]]) {
|
ushort tiisg[[thread_index_in_simdgroup]]) {
|
||||||
|
|
||||||
const int nrows = args.ne11;
|
const int nrows = args.ne11;
|
||||||
const int64_t r0 = tgpig.x;
|
const int r0 = tgpig.x;
|
||||||
const int64_t im = tgpig.z;
|
const int im = tgpig.z;
|
||||||
|
|
||||||
const uint i12 = im%args.ne12;
|
const uint i12 = im%args.ne12;
|
||||||
const uint i13 = im/args.ne12;
|
const uint i13 = im/args.ne12;
|
||||||
|
|
||||||
const uint offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||||
|
|
||||||
device const T4 * x4 = (device const T4 *) (src0 + offset0);
|
device const T4 * x4 = (device const T4 *) (src0 + offset0);
|
||||||
|
|
||||||
|
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1;
|
||||||
|
|
||||||
for (int r1 = 0; r1 < nrows; ++r1) {
|
for (int r1 = 0; r1 < nrows; ++r1) {
|
||||||
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
||||||
|
|
||||||
device const float4 * y4 = (device const float4 *) (src1 + offset1);
|
device const float4 * y4 = (device const float4 *) (src1 + offset1);
|
||||||
|
|
||||||
float sumf = 0;
|
float sumf = 0;
|
||||||
for (int i = tiisg; i < args.ne00/4; i += 32) {
|
for (int i = tiisg; i < args.ne00/4; i += 32) {
|
||||||
for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]);
|
sumf += dot((float4) x4[i], y4[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
float all_sum = simd_sum(sumf);
|
float all_sum = simd_sum(sumf);
|
||||||
if (tiisg == 0) {
|
if (tiisg == 0) {
|
||||||
dst[im*args.ne1*args.ne0 + r1*args.ne0 + r0] = all_sum;
|
dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -2966,7 +2970,7 @@ kernel void kernel_flash_attn_ext(
|
|||||||
const float S = ss[j*TS + 0];
|
const float S = ss[j*TS + 0];
|
||||||
|
|
||||||
for (short i = tiisg; i < D4; i += NW) {
|
for (short i = tiisg; i < D4; i += NW) {
|
||||||
dst4[((int64_t)iq3*args.ne2*args.ne1 + iq2 + (int64_t)(iq1 + j)*args.ne1)*D4 + i] = (float4) so4[j*D4 + i]/S;
|
dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*D4 + i] = (float4) so4[j*D4 + i]/S;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -3352,7 +3356,7 @@ kernel void kernel_flash_attn_ext_vec(
|
|||||||
const float S = ss[0];
|
const float S = ss[0];
|
||||||
|
|
||||||
for (short i = tiisg; i < D16; i += NW) {
|
for (short i = tiisg; i < D16; i += NW) {
|
||||||
dst44[((int64_t)iq3*args.ne2*args.ne1 + iq2 + (int64_t)iq1*args.ne1)*D16 + i] = (float4x4) sr4x4[i]/S;
|
dst44[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)iq1*args.ne1)*D16 + i] = (float4x4) sr4x4[i]/S;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -3947,8 +3951,8 @@ void kernel_mul_mv_q2_K_f32_impl(
|
|||||||
const uint i12 = im%args.ne12;
|
const uint i12 = im%args.ne12;
|
||||||
const uint i13 = im/args.ne12;
|
const uint i13 = im/args.ne12;
|
||||||
|
|
||||||
const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||||
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
||||||
|
|
||||||
device const block_q2_K * x = (device const block_q2_K *) (src0 + offset0);
|
device const block_q2_K * x = (device const block_q2_K *) (src0 + offset0);
|
||||||
device const float * y = (device const float *) (src1 + offset1);
|
device const float * y = (device const float *) (src1 + offset1);
|
||||||
@ -4008,7 +4012,7 @@ void kernel_mul_mv_q2_K_f32_impl(
|
|||||||
y4 += 4 * QK_K;
|
y4 += 4 * QK_K;
|
||||||
}
|
}
|
||||||
|
|
||||||
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
|
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
||||||
|
|
||||||
for (int row = 0; row < N_DST; ++row) {
|
for (int row = 0; row < N_DST; ++row) {
|
||||||
all_sum = simd_sum(sumf[row]);
|
all_sum = simd_sum(sumf[row]);
|
||||||
@ -4044,17 +4048,17 @@ void kernel_mul_mv_q3_K_f32_impl(
|
|||||||
|
|
||||||
const int nb = args.ne00/QK_K;
|
const int nb = args.ne00/QK_K;
|
||||||
|
|
||||||
const int64_t r0 = tgpig.x;
|
const int r0 = tgpig.x;
|
||||||
const int64_t r1 = tgpig.y;
|
const int r1 = tgpig.y;
|
||||||
const int64_t im = tgpig.z;
|
const int im = tgpig.z;
|
||||||
|
|
||||||
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
|
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
|
||||||
|
|
||||||
const uint i12 = im%args.ne12;
|
const uint i12 = im%args.ne12;
|
||||||
const uint i13 = im/args.ne12;
|
const uint i13 = im/args.ne12;
|
||||||
|
|
||||||
const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||||
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
||||||
|
|
||||||
device const block_q3_K * x = (device const block_q3_K *) (src0 + offset0);
|
device const block_q3_K * x = (device const block_q3_K *) (src0 + offset0);
|
||||||
device const float * yy = (device const float *) (src1 + offset1);
|
device const float * yy = (device const float *) (src1 + offset1);
|
||||||
@ -4087,9 +4091,10 @@ void kernel_mul_mv_q3_K_f32_impl(
|
|||||||
|
|
||||||
const ushort4 hm = mm[2*ip + il/2];
|
const ushort4 hm = mm[2*ip + il/2];
|
||||||
|
|
||||||
const int shift = 2*il;
|
const short shift = 2*il;
|
||||||
const float v1 = il == 0 ? 4.f : 64.f;
|
|
||||||
const float v2 = 4.f * v1;
|
const float v1 = il == 0 ? 4.f : 64.f;
|
||||||
|
const float v2 = 4.f * v1;
|
||||||
|
|
||||||
const uint16_t s_shift1 = 4*ip;
|
const uint16_t s_shift1 = 4*ip;
|
||||||
const uint16_t s_shift2 = s_shift1 + il;
|
const uint16_t s_shift2 = s_shift1 + il;
|
||||||
@ -4172,7 +4177,7 @@ void kernel_mul_mv_q3_K_f32_impl(
|
|||||||
sumf1[row] = simd_sum(sumf);
|
sumf1[row] = simd_sum(sumf);
|
||||||
}
|
}
|
||||||
|
|
||||||
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
|
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
||||||
|
|
||||||
if (tiisg == 0) {
|
if (tiisg == 0) {
|
||||||
for (int row = 0; row < 2; ++row) {
|
for (int row = 0; row < 2; ++row) {
|
||||||
@ -4224,8 +4229,8 @@ void kernel_mul_mv_q4_K_f32_impl(
|
|||||||
const uint i12 = im%args.ne12;
|
const uint i12 = im%args.ne12;
|
||||||
const uint i13 = im/args.ne12;
|
const uint i13 = im/args.ne12;
|
||||||
|
|
||||||
const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||||
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
||||||
|
|
||||||
device const block_q4_K * x = (device const block_q4_K *) (src0 + offset0);
|
device const block_q4_K * x = (device const block_q4_K *) (src0 + offset0);
|
||||||
device const float * y = (device const float *) (src1 + offset1);
|
device const float * y = (device const float *) (src1 + offset1);
|
||||||
@ -4289,7 +4294,7 @@ void kernel_mul_mv_q4_K_f32_impl(
|
|||||||
y4 += 4 * QK_K;
|
y4 += 4 * QK_K;
|
||||||
}
|
}
|
||||||
|
|
||||||
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
|
device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0;
|
||||||
|
|
||||||
for (int row = 0; row < N_DST; ++row) {
|
for (int row = 0; row < N_DST; ++row) {
|
||||||
all_sum = simd_sum(sumf[row]);
|
all_sum = simd_sum(sumf[row]);
|
||||||
@ -4325,8 +4330,8 @@ void kernel_mul_mv_q5_K_f32_impl(
|
|||||||
|
|
||||||
const int nb = args.ne00/QK_K;
|
const int nb = args.ne00/QK_K;
|
||||||
|
|
||||||
const int64_t r0 = tgpig.x;
|
const int r0 = tgpig.x;
|
||||||
const int64_t r1 = tgpig.y;
|
const int r1 = tgpig.y;
|
||||||
const int im = tgpig.z;
|
const int im = tgpig.z;
|
||||||
|
|
||||||
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
|
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
|
||||||
@ -4334,8 +4339,8 @@ void kernel_mul_mv_q5_K_f32_impl(
|
|||||||
const uint i12 = im%args.ne12;
|
const uint i12 = im%args.ne12;
|
||||||
const uint i13 = im/args.ne12;
|
const uint i13 = im/args.ne12;
|
||||||
|
|
||||||
const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||||
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
||||||
|
|
||||||
device const block_q5_K * x = (device const block_q5_K *) (src0 + offset0);
|
device const block_q5_K * x = (device const block_q5_K *) (src0 + offset0);
|
||||||
device const float * yy = (device const float *) (src1 + offset1);
|
device const float * yy = (device const float *) (src1 + offset1);
|
||||||
@ -4421,7 +4426,7 @@ void kernel_mul_mv_q5_K_f32_impl(
|
|||||||
y1 += 4 * QK_K;
|
y1 += 4 * QK_K;
|
||||||
}
|
}
|
||||||
|
|
||||||
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
|
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
||||||
|
|
||||||
for (int row = 0; row < 2; ++row) {
|
for (int row = 0; row < 2; ++row) {
|
||||||
const float tot = simd_sum(sumf[row]);
|
const float tot = simd_sum(sumf[row]);
|
||||||
@ -4462,17 +4467,17 @@ void kernel_mul_mv_q6_K_f32_impl(
|
|||||||
|
|
||||||
const int nb = args.ne00/QK_K;
|
const int nb = args.ne00/QK_K;
|
||||||
|
|
||||||
const int64_t r0 = tgpig.x;
|
const int r0 = tgpig.x;
|
||||||
const int64_t r1 = tgpig.y;
|
const int r1 = tgpig.y;
|
||||||
const int im = tgpig.z;
|
const int im = tgpig.z;
|
||||||
|
|
||||||
const int row = 2 * r0 + sgitg;
|
const int row = 2*r0 + sgitg;
|
||||||
|
|
||||||
const uint i12 = im%args.ne12;
|
const uint i12 = im%args.ne12;
|
||||||
const uint i13 = im/args.ne12;
|
const uint i13 = im/args.ne12;
|
||||||
|
|
||||||
const uint offset0 = row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
const uint64_t offset0 = row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||||
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
||||||
|
|
||||||
device const block_q6_K * x = (device const block_q6_K *) (src0 + offset0);
|
device const block_q6_K * x = (device const block_q6_K *) (src0 + offset0);
|
||||||
device const float * yy = (device const float *) (src1 + offset1);
|
device const float * yy = (device const float *) (src1 + offset1);
|
||||||
@ -4492,7 +4497,6 @@ void kernel_mul_mv_q6_K_f32_impl(
|
|||||||
const int q_offset_h = 32*ip + l0;
|
const int q_offset_h = 32*ip + l0;
|
||||||
|
|
||||||
for (int i = ix; i < nb; i += 2) {
|
for (int i = ix; i < nb; i += 2) {
|
||||||
|
|
||||||
device const uint8_t * q1 = x[i].ql + q_offset_l;
|
device const uint8_t * q1 = x[i].ql + q_offset_l;
|
||||||
device const uint8_t * q2 = q1 + 32;
|
device const uint8_t * q2 = q1 + 32;
|
||||||
device const uint8_t * qh = x[i].qh + q_offset_h;
|
device const uint8_t * qh = x[i].qh + q_offset_h;
|
||||||
@ -4514,7 +4518,7 @@ void kernel_mul_mv_q6_K_f32_impl(
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
|
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
||||||
|
|
||||||
const float tot = simd_sum(sumf);
|
const float tot = simd_sum(sumf);
|
||||||
if (tiisg == 0) {
|
if (tiisg == 0) {
|
||||||
@ -4558,8 +4562,8 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
|
|||||||
const uint i12 = im%args.ne12;
|
const uint i12 = im%args.ne12;
|
||||||
const uint i13 = im/args.ne12;
|
const uint i13 = im/args.ne12;
|
||||||
|
|
||||||
const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||||
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
||||||
|
|
||||||
device const block_iq2_xxs * x = (device const block_iq2_xxs *) (src0 + offset0);
|
device const block_iq2_xxs * x = (device const block_iq2_xxs *) (src0 + offset0);
|
||||||
device const float * y = (device const float *) (src1 + offset1);
|
device const float * y = (device const float *) (src1 + offset1);
|
||||||
@ -4622,7 +4626,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
|
|||||||
y4 += 32 * 32;
|
y4 += 32 * 32;
|
||||||
}
|
}
|
||||||
|
|
||||||
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
|
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
||||||
|
|
||||||
for (int row = 0; row < N_DST; ++row) {
|
for (int row = 0; row < N_DST; ++row) {
|
||||||
all_sum = simd_sum(sumf[row]);
|
all_sum = simd_sum(sumf[row]);
|
||||||
@ -4666,8 +4670,8 @@ void kernel_mul_mv_iq2_xs_f32_impl(
|
|||||||
const uint i12 = im%args.ne12;
|
const uint i12 = im%args.ne12;
|
||||||
const uint i13 = im/args.ne12;
|
const uint i13 = im/args.ne12;
|
||||||
|
|
||||||
const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||||
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
||||||
|
|
||||||
device const block_iq2_xs * x = (device const block_iq2_xs *) (src0 + offset0);
|
device const block_iq2_xs * x = (device const block_iq2_xs *) (src0 + offset0);
|
||||||
device const float * y = (device const float *) (src1 + offset1);
|
device const float * y = (device const float *) (src1 + offset1);
|
||||||
@ -4740,7 +4744,7 @@ void kernel_mul_mv_iq2_xs_f32_impl(
|
|||||||
y4 += 32 * 32;
|
y4 += 32 * 32;
|
||||||
}
|
}
|
||||||
|
|
||||||
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
|
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
||||||
|
|
||||||
for (int row = 0; row < N_DST; ++row) {
|
for (int row = 0; row < N_DST; ++row) {
|
||||||
all_sum = simd_sum(sumf[row]);
|
all_sum = simd_sum(sumf[row]);
|
||||||
@ -4785,8 +4789,8 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
|
|||||||
const uint i12 = im%args.ne12;
|
const uint i12 = im%args.ne12;
|
||||||
const uint i13 = im/args.ne12;
|
const uint i13 = im/args.ne12;
|
||||||
|
|
||||||
const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||||
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
||||||
|
|
||||||
device const block_iq3_xxs * x = (device const block_iq3_xxs *) (src0 + offset0);
|
device const block_iq3_xxs * x = (device const block_iq3_xxs *) (src0 + offset0);
|
||||||
device const float * y = (device const float *) (src1 + offset1);
|
device const float * y = (device const float *) (src1 + offset1);
|
||||||
@ -4813,7 +4817,6 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
|
|||||||
device const float * y4 = y + 32 * ix;
|
device const float * y4 = y + 32 * ix;
|
||||||
|
|
||||||
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
||||||
|
|
||||||
for (int i = 0; i < 32; ++i) {
|
for (int i = 0; i < 32; ++i) {
|
||||||
yl[i] = y4[i];
|
yl[i] = y4[i];
|
||||||
}
|
}
|
||||||
@ -4827,7 +4830,6 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
|
|||||||
device const half * dh = &xr->d;
|
device const half * dh = &xr->d;
|
||||||
|
|
||||||
for (int row = 0; row < N_DST; row++) {
|
for (int row = 0; row < N_DST; row++) {
|
||||||
|
|
||||||
const float db = dh[0];
|
const float db = dh[0];
|
||||||
const uint32_t aux32 = gas[0] | (gas[1] << 16);
|
const uint32_t aux32 = gas[0] | (gas[1] << 16);
|
||||||
const float d = db * (0.5f + (aux32 >> 28));
|
const float d = db * (0.5f + (aux32 >> 28));
|
||||||
@ -4852,7 +4854,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
|
|||||||
y4 += 32 * 32;
|
y4 += 32 * 32;
|
||||||
}
|
}
|
||||||
|
|
||||||
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
|
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
||||||
|
|
||||||
for (int row = 0; row < N_DST; ++row) {
|
for (int row = 0; row < N_DST; ++row) {
|
||||||
all_sum = simd_sum(sumf[row]);
|
all_sum = simd_sum(sumf[row]);
|
||||||
@ -4897,8 +4899,8 @@ void kernel_mul_mv_iq3_s_f32_impl(
|
|||||||
const uint i12 = im%args.ne12;
|
const uint i12 = im%args.ne12;
|
||||||
const uint i13 = im/args.ne12;
|
const uint i13 = im/args.ne12;
|
||||||
|
|
||||||
const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||||
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
||||||
|
|
||||||
device const block_iq3_s * x = (device const block_iq3_s *) (src0 + offset0);
|
device const block_iq3_s * x = (device const block_iq3_s *) (src0 + offset0);
|
||||||
device const float * y = (device const float *) (src1 + offset1);
|
device const float * y = (device const float *) (src1 + offset1);
|
||||||
@ -4964,7 +4966,7 @@ void kernel_mul_mv_iq3_s_f32_impl(
|
|||||||
y4 += 32 * 32;
|
y4 += 32 * 32;
|
||||||
}
|
}
|
||||||
|
|
||||||
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
|
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
||||||
|
|
||||||
for (int row = 0; row < N_DST; ++row) {
|
for (int row = 0; row < N_DST; ++row) {
|
||||||
all_sum = simd_sum(sumf[row]);
|
all_sum = simd_sum(sumf[row]);
|
||||||
@ -5009,8 +5011,8 @@ void kernel_mul_mv_iq2_s_f32_impl(
|
|||||||
const uint i12 = im%args.ne12;
|
const uint i12 = im%args.ne12;
|
||||||
const uint i13 = im/args.ne12;
|
const uint i13 = im/args.ne12;
|
||||||
|
|
||||||
const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||||
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
||||||
|
|
||||||
device const block_iq2_s * x = (device const block_iq2_s *) (src0 + offset0);
|
device const block_iq2_s * x = (device const block_iq2_s *) (src0 + offset0);
|
||||||
device const float * y = (device const float *) (src1 + offset1);
|
device const float * y = (device const float *) (src1 + offset1);
|
||||||
@ -5077,7 +5079,7 @@ void kernel_mul_mv_iq2_s_f32_impl(
|
|||||||
y4 += 32 * 32;
|
y4 += 32 * 32;
|
||||||
}
|
}
|
||||||
|
|
||||||
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
|
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
||||||
|
|
||||||
for (int row = 0; row < N_DST; ++row) {
|
for (int row = 0; row < N_DST; ++row) {
|
||||||
all_sum = simd_sum(sumf[row]);
|
all_sum = simd_sum(sumf[row]);
|
||||||
@ -5122,8 +5124,8 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
|||||||
const uint i12 = im%args.ne12;
|
const uint i12 = im%args.ne12;
|
||||||
const uint i13 = im/args.ne12;
|
const uint i13 = im/args.ne12;
|
||||||
|
|
||||||
const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||||
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
||||||
|
|
||||||
device const block_iq1_s * x = (device const block_iq1_s *) (src0 + offset0);
|
device const block_iq1_s * x = (device const block_iq1_s *) (src0 + offset0);
|
||||||
device const float * y = (device const float *) (src1 + offset1);
|
device const float * y = (device const float *) (src1 + offset1);
|
||||||
@ -5177,7 +5179,7 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
|||||||
y4 += 32 * 32;
|
y4 += 32 * 32;
|
||||||
}
|
}
|
||||||
|
|
||||||
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
|
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
||||||
|
|
||||||
for (int row = 0; row < N_DST; ++row) {
|
for (int row = 0; row < N_DST; ++row) {
|
||||||
all_sum = simd_sum(sumf[row]);
|
all_sum = simd_sum(sumf[row]);
|
||||||
@ -5208,8 +5210,8 @@ void kernel_mul_mv_iq1_m_f32_impl(
|
|||||||
const uint i12 = im%args.ne12;
|
const uint i12 = im%args.ne12;
|
||||||
const uint i13 = im/args.ne12;
|
const uint i13 = im/args.ne12;
|
||||||
|
|
||||||
const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||||
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
||||||
|
|
||||||
device const block_iq1_m * x = (device const block_iq1_m *) (src0 + offset0);
|
device const block_iq1_m * x = (device const block_iq1_m *) (src0 + offset0);
|
||||||
device const float * y = (device const float *) (src1 + offset1);
|
device const float * y = (device const float *) (src1 + offset1);
|
||||||
@ -5272,7 +5274,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
|
|||||||
y4 += 32 * 32;
|
y4 += 32 * 32;
|
||||||
}
|
}
|
||||||
|
|
||||||
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
|
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
||||||
|
|
||||||
for (int row = 0; row < N_DST; ++row) {
|
for (int row = 0; row < N_DST; ++row) {
|
||||||
all_sum = simd_sum(sumf[row]);
|
all_sum = simd_sum(sumf[row]);
|
||||||
@ -5303,8 +5305,8 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
|||||||
const uint i12 = im%args.ne12;
|
const uint i12 = im%args.ne12;
|
||||||
const uint i13 = im/args.ne12;
|
const uint i13 = im/args.ne12;
|
||||||
|
|
||||||
const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||||
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
||||||
|
|
||||||
device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0);
|
device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0);
|
||||||
device const float * y = (device const float *) (src1 + offset1);
|
device const float * y = (device const float *) (src1 + offset1);
|
||||||
@ -5362,7 +5364,7 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
|||||||
yb += 16 * QK4_NL;
|
yb += 16 * QK4_NL;
|
||||||
}
|
}
|
||||||
|
|
||||||
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
|
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
||||||
|
|
||||||
for (int row = 0; row < 2 && first_row + row < args.ne01; ++row) {
|
for (int row = 0; row < 2 && first_row + row < args.ne01; ++row) {
|
||||||
all_sum = simd_sum(sumf[row]);
|
all_sum = simd_sum(sumf[row]);
|
||||||
@ -5393,8 +5395,8 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
|||||||
const uint i12 = im%args.ne12;
|
const uint i12 = im%args.ne12;
|
||||||
const uint i13 = im/args.ne12;
|
const uint i13 = im/args.ne12;
|
||||||
|
|
||||||
const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||||
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
||||||
|
|
||||||
device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0);
|
device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0);
|
||||||
device const float * y = (device const float *) (src1 + offset1);
|
device const float * y = (device const float *) (src1 + offset1);
|
||||||
@ -5418,25 +5420,23 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
|||||||
float4 qf1, qf2;
|
float4 qf1, qf2;
|
||||||
|
|
||||||
for (int ibl = ix; ibl < nb; ibl += 2) {
|
for (int ibl = ix; ibl < nb; ibl += 2) {
|
||||||
|
|
||||||
device const float4 * y4 = (device const float4 *)yb;
|
device const float4 * y4 = (device const float4 *)yb;
|
||||||
yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
|
yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
|
||||||
|
|
||||||
for (int row = 0; row < 2; ++row) {
|
for (int row = 0; row < 2; ++row) {
|
||||||
|
|
||||||
device const block_iq4_xs & xb = x[row*nb + ibl];
|
device const block_iq4_xs & xb = x[row*nb + ibl];
|
||||||
device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il);
|
device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il);
|
||||||
|
|
||||||
float4 acc1 = {0.f}, acc2 = {0.f};
|
float4 acc1 = {0.f}, acc2 = {0.f};
|
||||||
|
|
||||||
aux32[0] = q4[0] & 0x0f0f0f0f;
|
aux32[0] = (q4[0] ) & 0x0f0f0f0f;
|
||||||
aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f;
|
aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f;
|
||||||
qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]};
|
qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]};
|
||||||
qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]};
|
qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]};
|
||||||
acc1 += yl[0] * qf1;
|
acc1 += yl[0] * qf1;
|
||||||
acc2 += yl[1] * qf2;
|
acc2 += yl[1] * qf2;
|
||||||
|
|
||||||
aux32[0] = q4[1] & 0x0f0f0f0f;
|
aux32[0] = (q4[1] ) & 0x0f0f0f0f;
|
||||||
aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f;
|
aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f;
|
||||||
qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]};
|
qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]};
|
||||||
qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]};
|
qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]};
|
||||||
@ -5453,7 +5453,7 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
|||||||
yb += 2 * QK_K;
|
yb += 2 * QK_K;
|
||||||
}
|
}
|
||||||
|
|
||||||
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
|
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
||||||
|
|
||||||
for (int row = 0; row < 2; ++row) {
|
for (int row = 0; row < 2; ++row) {
|
||||||
all_sum = simd_sum(sumf[row]);
|
all_sum = simd_sum(sumf[row]);
|
||||||
@ -5656,8 +5656,8 @@ kernel void kernel_mul_mm(
|
|||||||
const int i12 = im%args.ne12;
|
const int i12 = im%args.ne12;
|
||||||
const int i13 = im/args.ne12;
|
const int i13 = im/args.ne12;
|
||||||
|
|
||||||
int offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||||
short offset1 = il/nl;
|
short offset1 = il/nl;
|
||||||
|
|
||||||
device const block_q * x = (device const block_q *)(src0 + (r0*BLOCK_SIZE_M + thread_row)*args.nb01 + offset0) + offset1;
|
device const block_q * x = (device const block_q *)(src0 + (r0*BLOCK_SIZE_M + thread_row)*args.nb01 + offset0) + offset1;
|
||||||
device const float * y = (device const float *)(src1
|
device const float * y = (device const float *)(src1
|
||||||
@ -5782,10 +5782,10 @@ void kernel_mul_mm_id_impl(
|
|||||||
threadgroup half * sa = (threadgroup half *)(shmem);
|
threadgroup half * sa = (threadgroup half *)(shmem);
|
||||||
threadgroup float * sb = (threadgroup float *)(shmem + 4096);
|
threadgroup float * sb = (threadgroup float *)(shmem + 4096);
|
||||||
|
|
||||||
const uint r0 = tgpig.y;
|
const int r0 = tgpig.y;
|
||||||
const uint r1 = tgpig.x;
|
const int r1 = tgpig.x;
|
||||||
|
|
||||||
if (r1 * BLOCK_SIZE_N >= ne1) return;
|
if (r1*BLOCK_SIZE_N >= ne1) return;
|
||||||
|
|
||||||
// if this block is of 64x32 shape or smaller
|
// if this block is of 64x32 shape or smaller
|
||||||
short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
|
short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
|
||||||
@ -5916,7 +5916,7 @@ kernel void kernel_mul_mm_id(
|
|||||||
threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shmem + 8192);
|
threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shmem + 8192);
|
||||||
|
|
||||||
// TODO: parallelize this loop
|
// TODO: parallelize this loop
|
||||||
int64_t _ne1 = 0;
|
int32_t _ne1 = 0;
|
||||||
for (ushort ii1 = 0; ii1 < args.nei1; ii1++) {
|
for (ushort ii1 = 0; ii1 < args.nei1; ii1++) {
|
||||||
for (ushort ii0 = 0; ii0 < args.nei0; ii0++) {
|
for (ushort ii0 = 0; ii0 < args.nei0; ii0++) {
|
||||||
int32_t id = ((device int32_t *) (ids + ii1*args.nbi1))[ii0];
|
int32_t id = ((device int32_t *) (ids + ii1*args.nbi1))[ii0];
|
||||||
|
Loading…
Reference in New Issue
Block a user