mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-30 21:34:36 +00:00
Make quantize_row_iq4_nl do the same thing is quantization on CUDA
This commit is contained in:
parent
76aa30a263
commit
cd4a7c4cb4
@ -11705,9 +11705,8 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block
|
|||||||
ggml_fp16_t * dh, uint8_t * q4, uint16_t * scales_h, uint8_t * scales_l,
|
ggml_fp16_t * dh, uint8_t * q4, uint16_t * scales_h, uint8_t * scales_l,
|
||||||
float * scales, float * weight, uint8_t * L,
|
float * scales, float * weight, uint8_t * L,
|
||||||
const int8_t * values,
|
const int8_t * values,
|
||||||
const float * quant_weights) {
|
const float * quant_weights,
|
||||||
|
const int ntry) {
|
||||||
const int ntry = 7;
|
|
||||||
|
|
||||||
float sigma2 = 0;
|
float sigma2 = 0;
|
||||||
for (int j = 0; j < super_block_size; ++j) sigma2 += x[j]*x[j];
|
for (int j = 0; j < super_block_size; ++j) sigma2 += x[j]*x[j];
|
||||||
@ -11823,7 +11822,7 @@ size_t quantize_iq4_nl(const float * restrict src, void * restrict dst, int nrow
|
|||||||
for (int ibl = 0; ibl < nblock; ++ibl) {
|
for (int ibl = 0; ibl < nblock; ++ibl) {
|
||||||
const float * qw = quant_weights ? quant_weights + QK4_NL*ibl : NULL;
|
const float * qw = quant_weights ? quant_weights + QK4_NL*ibl : NULL;
|
||||||
quantize_row_iq4_nl_impl(QK4_NL, 32, src + QK4_NL*ibl, &iq4[ibl].d, iq4[ibl].qs, &unused_h, unused_l,
|
quantize_row_iq4_nl_impl(QK4_NL, 32, src + QK4_NL*ibl, &iq4[ibl].d, iq4[ibl].qs, &unused_h, unused_l,
|
||||||
&scale, weight, L, kvalues_iq4nl, qw);
|
&scale, weight, L, kvalues_iq4nl, qw, 7);
|
||||||
}
|
}
|
||||||
src += n_per_row;
|
src += n_per_row;
|
||||||
qrow += nblock*sizeof(block_iq4_nl);
|
qrow += nblock*sizeof(block_iq4_nl);
|
||||||
@ -11832,9 +11831,21 @@ size_t quantize_iq4_nl(const float * restrict src, void * restrict dst, int nrow
|
|||||||
}
|
}
|
||||||
|
|
||||||
void quantize_row_iq4_nl(const float * restrict x, void * restrict vy, int k) {
|
void quantize_row_iq4_nl(const float * restrict x, void * restrict vy, int k) {
|
||||||
assert(k % QK4_NL == 0);
|
GGML_ASSERT(k%QK4_NL == 0);
|
||||||
block_iq4_nl * restrict y = vy;
|
int nblock = k/QK4_NL;
|
||||||
quantize_row_iq4_nl_reference(x, y, k);
|
uint8_t L[QK4_NL];
|
||||||
|
float weight[QK4_NL];
|
||||||
|
uint16_t unused_h;
|
||||||
|
uint8_t * unused_l = NULL;
|
||||||
|
float scale;
|
||||||
|
block_iq4_nl * iq4 = (block_iq4_nl *)vy;
|
||||||
|
for (int ibl = 0; ibl < nblock; ++ibl) {
|
||||||
|
quantize_row_iq4_nl_impl(QK4_NL, 32, x + QK4_NL*ibl, &iq4[ibl].d, iq4[ibl].qs, &unused_h, unused_l,
|
||||||
|
&scale, weight, L, kvalues_iq4nl, NULL, -1);
|
||||||
|
}
|
||||||
|
//assert(k % QK4_NL == 0);
|
||||||
|
//block_iq4_nl * restrict y = vy;
|
||||||
|
//quantize_row_iq4_nl_reference(x, y, k);
|
||||||
}
|
}
|
||||||
|
|
||||||
void quantize_row_iq4_nl_reference(const float * restrict x, block_iq4_nl * restrict y, int k) {
|
void quantize_row_iq4_nl_reference(const float * restrict x, block_iq4_nl * restrict y, int k) {
|
||||||
@ -11857,7 +11868,7 @@ size_t quantize_iq4_xs(const float * restrict src, void * restrict dst, int nrow
|
|||||||
for (int ibl = 0; ibl < nblock; ++ibl) {
|
for (int ibl = 0; ibl < nblock; ++ibl) {
|
||||||
const float * qw = quant_weights ? quant_weights + QK_K*ibl : NULL;
|
const float * qw = quant_weights ? quant_weights + QK_K*ibl : NULL;
|
||||||
quantize_row_iq4_nl_impl(QK_K, 32, src + QK_K*ibl, &iq4[ibl].d, iq4[ibl].qs, &iq4[ibl].scales_h, iq4[ibl].scales_l,
|
quantize_row_iq4_nl_impl(QK_K, 32, src + QK_K*ibl, &iq4[ibl].d, iq4[ibl].qs, &iq4[ibl].scales_h, iq4[ibl].scales_l,
|
||||||
scales, weight, L, kvalues_iq4nl, qw);
|
scales, weight, L, kvalues_iq4nl, qw, 7);
|
||||||
}
|
}
|
||||||
src += n_per_row;
|
src += n_per_row;
|
||||||
qrow += nblock*sizeof(block_iq4_xs);
|
qrow += nblock*sizeof(block_iq4_xs);
|
||||||
|
Loading…
Reference in New Issue
Block a user