This commit is contained in:
abhilash1910 2024-03-07 01:54:26 -08:00
parent c810047b53
commit 94f33d7ae3

View File

@ -8357,7 +8357,6 @@ vec_dot_iq1_s_q8_1(const void *__restrict__ vbq,
int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0; int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0;
const uint8_t h1 = bq1->scales[2*ib32+0]; const uint8_t h1 = bq1->scales[2*ib32+0];
const uint8_t h2 = bq1->scales[2*ib32+1]; const uint8_t h2 = bq1->scales[2*ib32+1];
#if DPCT_COMPATIBILITY_TEMP >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
const int * q8 = (const int *)bq8_1[ib32].qs; const int * q8 = (const int *)bq8_1[ib32].qs;
const int * grid1 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+0] | ((h1 & 0x08) << 5))); const int * grid1 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+0] | ((h1 & 0x08) << 5)));
const int * grid2 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+1] | ((h1 & 0x80) << 1))); const int * grid2 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+1] | ((h1 & 0x80) << 1)));
@ -8369,19 +8368,6 @@ vec_dot_iq1_s_q8_1(const void *__restrict__ vbq,
sumi3 = dpct::dp4a(q8[j+4], grid3[j], sumi3); sumi3 = dpct::dp4a(q8[j+4], grid3[j], sumi3);
sumi4 = dpct::dp4a(q8[j+6], grid4[j], sumi4); sumi4 = dpct::dp4a(q8[j+6], grid4[j], sumi4);
} }
#else
const int8_t * q8 = bq8_1[ib32].qs;
const int8_t * grid1 = (const int8_t *)(iq1s_grid + (bq1->qs[4*ib32+0] | ((h1 & 0x08) << 5)));
const int8_t * grid2 = (const int8_t *)(iq1s_grid + (bq1->qs[4*ib32+1] | ((h1 & 0x80) << 1)));
const int8_t * grid3 = (const int8_t *)(iq1s_grid + (bq1->qs[4*ib32+2] | ((h2 & 0x08) << 5)));
const int8_t * grid4 = (const int8_t *)(iq1s_grid + (bq1->qs[4*ib32+3] | ((h2 & 0x80) << 1)));
for (int j = 0; j < 8; ++j) {
sumi1 += q8[j+ 0] * grid1[j];
sumi2 += q8[j+ 8] * grid2[j];
sumi3 += q8[j+16] * grid3[j];
sumi4 += q8[j+24] * grid4[j];
}
#endif
const float d = (float)bq1->d * bq8_1[ib32].ds[0]; const float d = (float)bq1->d * bq8_1[ib32].ds[0];
return d * (sumi1 * (2*(h1 & 7) + 1) + sumi2 * (2*((h1 >> 4) & 7) + 1) + return d * (sumi1 * (2*(h1 & 7) + 1) + sumi2 * (2*((h1 >> 4) & 7) + 1) +
sumi3 * (2*(h2 & 7) + 1) + sumi4 * (2*((h2 >> 4) & 7) + 1)); sumi3 * (2*(h2 & 7) + 1) + sumi4 * (2*((h2 >> 4) & 7) + 1));