Try fix quantized k-cache on ROCm

This commit is contained in:
Iwan Kawrakow 2024-03-21 20:18:50 +02:00
parent f372c49ccd
commit a710d58d88

View File

@ -6684,7 +6684,7 @@ static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
const float d = amax / ((1 << 7) - 1); const float d = amax / ((1 << 7) - 1);
const float id = d ? 1.0f/d : 0.0f; const float id = d ? 1.0f/d : 0.0f;
dsti->d = d; dsti->d = __float2half(d);
for (int j = 0; j < QK8_0; ++j) { for (int j = 0; j < QK8_0; ++j) {
const float x0 = xi[j]*id; const float x0 = xi[j]*id;
@ -6711,7 +6711,7 @@ static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
const float d = vmax / -8; const float d = vmax / -8;
const float id = d ? 1.0f/d : 0.0f; const float id = d ? 1.0f/d : 0.0f;
dsti->d = d; dsti->d = __float2half(d);
for (int j = 0; j < QK4_0/2; ++j) { for (int j = 0; j < QK4_0/2; ++j) {
const float x0 = xi[0 + j]*id; const float x0 = xi[0 + j]*id;
@ -6742,8 +6742,8 @@ static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
const float d = (vmax - vmin) / ((1 << 4) - 1); const float d = (vmax - vmin) / ((1 << 4) - 1);
const float id = d ? 1.0f/d : 0.0f; const float id = d ? 1.0f/d : 0.0f;
dsti->dm.x = d; dsti->dm.x = __float2half(d);
dsti->dm.y = vmin; dsti->dm.y = __float2half(vmin);
for (int j = 0; j < QK4_1/2; ++j) { for (int j = 0; j < QK4_1/2; ++j) {
const float x0 = (xi[0 + j] - vmin)*id; const float x0 = (xi[0 + j] - vmin)*id;
@ -6775,7 +6775,7 @@ static __device__ void cpy_blck_f32_q5_0(const char * cxi, char * cdsti) {
const float d = vmax / -16; const float d = vmax / -16;
const float id = d ? 1.0f/d : 0.0f; const float id = d ? 1.0f/d : 0.0f;
dsti->d = d; dsti->d = __float2half(d);
uint32_t qh = 0; uint32_t qh = 0;
for (int j = 0; j < QK5_0/2; ++j) { for (int j = 0; j < QK5_0/2; ++j) {
@ -6808,8 +6808,8 @@ static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
const float d = (max - min) / 31; const float d = (max - min) / 31;
const float id = d ? 1.0f/d : 0.0f; const float id = d ? 1.0f/d : 0.0f;
dsti->dm.x = d; dsti->dm.x = __float2half(d);
dsti->dm.y = min; dsti->dm.y = __float2half(min);
uint32_t qh = 0; uint32_t qh = 0;
for (int j = 0; j < QK5_1/2; ++j) { for (int j = 0; j < QK5_1/2; ++j) {
@ -6870,7 +6870,7 @@ static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
sumq2 += w0*v0*v0 + w1*v1*v1; sumq2 += w0*v0*v0 + w1*v1*v1;
} }
dsti->d = sumq2 > 0 ? sumqx/sumq2 : d; dsti->d = __float2half(sumq2 > 0 ? sumqx/sumq2 : d);
} }