mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 20:04:35 +00:00
Try fix quantized k-cache on ROCm
This commit is contained in:
parent
f372c49ccd
commit
a710d58d88
16
ggml-cuda.cu
16
ggml-cuda.cu
@ -6684,7 +6684,7 @@ static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
|
|||||||
const float d = amax / ((1 << 7) - 1);
|
const float d = amax / ((1 << 7) - 1);
|
||||||
const float id = d ? 1.0f/d : 0.0f;
|
const float id = d ? 1.0f/d : 0.0f;
|
||||||
|
|
||||||
dsti->d = d;
|
dsti->d = __float2half(d);
|
||||||
|
|
||||||
for (int j = 0; j < QK8_0; ++j) {
|
for (int j = 0; j < QK8_0; ++j) {
|
||||||
const float x0 = xi[j]*id;
|
const float x0 = xi[j]*id;
|
||||||
@ -6711,7 +6711,7 @@ static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
|
|||||||
const float d = vmax / -8;
|
const float d = vmax / -8;
|
||||||
const float id = d ? 1.0f/d : 0.0f;
|
const float id = d ? 1.0f/d : 0.0f;
|
||||||
|
|
||||||
dsti->d = d;
|
dsti->d = __float2half(d);
|
||||||
|
|
||||||
for (int j = 0; j < QK4_0/2; ++j) {
|
for (int j = 0; j < QK4_0/2; ++j) {
|
||||||
const float x0 = xi[0 + j]*id;
|
const float x0 = xi[0 + j]*id;
|
||||||
@ -6742,8 +6742,8 @@ static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
|
|||||||
const float d = (vmax - vmin) / ((1 << 4) - 1);
|
const float d = (vmax - vmin) / ((1 << 4) - 1);
|
||||||
const float id = d ? 1.0f/d : 0.0f;
|
const float id = d ? 1.0f/d : 0.0f;
|
||||||
|
|
||||||
dsti->dm.x = d;
|
dsti->dm.x = __float2half(d);
|
||||||
dsti->dm.y = vmin;
|
dsti->dm.y = __float2half(vmin);
|
||||||
|
|
||||||
for (int j = 0; j < QK4_1/2; ++j) {
|
for (int j = 0; j < QK4_1/2; ++j) {
|
||||||
const float x0 = (xi[0 + j] - vmin)*id;
|
const float x0 = (xi[0 + j] - vmin)*id;
|
||||||
@ -6775,7 +6775,7 @@ static __device__ void cpy_blck_f32_q5_0(const char * cxi, char * cdsti) {
|
|||||||
const float d = vmax / -16;
|
const float d = vmax / -16;
|
||||||
const float id = d ? 1.0f/d : 0.0f;
|
const float id = d ? 1.0f/d : 0.0f;
|
||||||
|
|
||||||
dsti->d = d;
|
dsti->d = __float2half(d);
|
||||||
|
|
||||||
uint32_t qh = 0;
|
uint32_t qh = 0;
|
||||||
for (int j = 0; j < QK5_0/2; ++j) {
|
for (int j = 0; j < QK5_0/2; ++j) {
|
||||||
@ -6808,8 +6808,8 @@ static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
|
|||||||
const float d = (max - min) / 31;
|
const float d = (max - min) / 31;
|
||||||
const float id = d ? 1.0f/d : 0.0f;
|
const float id = d ? 1.0f/d : 0.0f;
|
||||||
|
|
||||||
dsti->dm.x = d;
|
dsti->dm.x = __float2half(d);
|
||||||
dsti->dm.y = min;
|
dsti->dm.y = __float2half(min);
|
||||||
|
|
||||||
uint32_t qh = 0;
|
uint32_t qh = 0;
|
||||||
for (int j = 0; j < QK5_1/2; ++j) {
|
for (int j = 0; j < QK5_1/2; ++j) {
|
||||||
@ -6870,7 +6870,7 @@ static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
|
|||||||
sumq2 += w0*v0*v0 + w1*v1*v1;
|
sumq2 += w0*v0*v0 + w1*v1*v1;
|
||||||
}
|
}
|
||||||
|
|
||||||
dsti->d = sumq2 > 0 ? sumqx/sumq2 : d;
|
dsti->d = __float2half(sumq2 > 0 ? sumqx/sumq2 : d);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user