From e5274378f7e280c8513496f9ca91cc4f49b44ac7 Mon Sep 17 00:00:00 2001 From: Amy Date: Tue, 13 Jun 2023 05:08:57 +0100 Subject: [PATCH] cleaned-up implementation of QX mixed quantization --- examples/quantize/quantize.cpp | 1 + ggml.c | 510 ++++++++++++++++++++++++++++++++- ggml.h | 5 +- llama.cpp | 133 ++++++++- llama.h | 1 + 5 files changed, 630 insertions(+), 20 deletions(-) diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index c6bf1b723..5f8c43944 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -25,6 +25,7 @@ static const std::map LLAMA_FTYPE_MAP = { {"q5_K_S", LLAMA_FTYPE_MOSTLY_Q5_K_S}, {"q5_K_M", LLAMA_FTYPE_MOSTLY_Q5_K_M}, {"q6_K", LLAMA_FTYPE_MOSTLY_Q6_K}, + {"qx_0", LLAMA_FTYPE_MOSTLY_QX_0}, }; bool try_parse_ftype(const std::string & ftype_str, llama_ftype & ftype, std::string & ftype_str_out) { diff --git a/ggml.c b/ggml.c index a13de5115..7ad696ea8 100644 --- a/ggml.c +++ b/ggml.c @@ -488,6 +488,44 @@ int64_t ggml_cycles_per_ms(void) { static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float); +// +// bit manipulation helpers +// + +// writes "bit_count" bits of "data" at a "bit_offset" offset in "dst" +// only used for data <= 16bits; only useful to quantize_qx_0 +inline static void write_bits(uint32_t * dst, uint32_t bit_offset, uint16_t data, uint16_t bit_count) { + const uint32_t chunk_size = (sizeof(uint32_t) * 8); + const uint32_t chunk_id = bit_offset / chunk_size; + + dst = dst + chunk_id; + bit_offset %= (sizeof(uint32_t) * 8); + + if (bit_offset + bit_count > chunk_size) { + // first fill the current chunk + uint16_t bitcount_1 = chunk_size - bit_offset; + + uint32_t bitmask = ((1 << bitcount_1) - 1) << (bit_offset); + *dst &= ~bitmask; + *dst |= data << bit_offset; + + // move onto the next chunk + data >>= bitcount_1; + + bit_count -= bitcount_1; + bit_offset = 0; + dst += 1; + + bitmask = ((1 << bit_count) - 1) << (bit_offset); + *dst &= ~bitmask; + *dst |= data << bit_offset; + } else { + uint32_t bitmask = ((1 << bit_count) - 1) << (bit_offset); + *dst &= ~bitmask; + *dst |= data << bit_offset; + } +} + // // quantization // @@ -835,6 +873,22 @@ typedef struct { } block_q8_1; static_assert(sizeof(block_q8_1) == 2*sizeof(float) + QK8_1, "wrong q8_1 block size/padding"); + +// max block size is 256 because some feed_forward tensors have a width of 11008 weights, which is not divisible by 512 +#define QKX_0 256 + +// there is no byte-exact C struct to represent a QX_0 block, but a high-level representation of a block is: +// ggml_fp16_t delta; +// ggml_fp16_t min; +// uint8_t block_metadata; +// [bitstream of weights] + +// quantization parameters for QX_0 (used only when running ./quantize, irrelevant during inference) +// TODO maybe move these to commandline arguments...? +#define QX_0_STARTING_QBITS 4 +#define QX_0_STARTING_QBITS_DOWNSCALING 2 + + // reference implementation for deterministic creation of model files static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) { static const int qk = QK4_0; @@ -1530,6 +1584,7 @@ static void dequantize_row_q8_0(const void * restrict vx, float * restrict y, in } } +static void ggml_vec_dot_qx_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); @@ -1627,6 +1682,16 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = { .vec_dot_type = GGML_TYPE_Q8_K, }, #endif + // GGML_TYPE_QX_0's quantize/dequantize functions aren't the same as other quantization methods' functions + // so we need to supply NULL instead and use if statements in the places where they are actually used + [GGML_TYPE_QX_0] = { + .dequantize_row_q = (dequantize_row_q_t) NULL, + .quantize_row_q = NULL, + .quantize_row_q_reference = (quantize_row_q_t) NULL, + .quantize_row_q_dot = quantize_row_q8_0, + .vec_dot_q = ggml_vec_dot_qx_0_q8_0, + .vec_dot_type = GGML_TYPE_Q8_0, + }, }; // For internal test use @@ -3122,6 +3187,192 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * #endif } +__attribute__((optimize("unroll-loops"))) +static void ggml_vec_dot_qx_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + uint32_t nb = n / QKX_0; + GGML_ASSERT(QKX_0 % QK8_0 == 0); + *s = 0; + + uint8_t * quant_row = (uint8_t *) vx; + + // row_data stores dequantized values of the current block + float f32_row_data[QKX_0]; + + const block_q8_0 * restrict column = vy; + + uint32_t column_idx = 0; + + __m256 rolling_sum = _mm256_setzero_ps(); + + // IMPORTANT, Quantized weights should be kept <= 4bits. Change this number for higher values + float qvals[1 << 4]; + + for (int b = 0; b < nb; b++) { + float * row_ptr = f32_row_data; + + const uint64_t * block_start = (uint64_t *) quant_row; + + const float min_value = GGML_FP16_TO_FP32(*((uint16_t *) (block_start + (QKX_0 / 64)))); + float mult_value = GGML_FP16_TO_FP32(*((uint16_t *) (block_start + (QKX_0 / 64)) + 1)); + const uint16_t * data_start = (uint16_t *) (block_start + (QKX_0 / 64)) + 2; + const uint8_t qbits = *((uint8_t *) data_start); + data_start = (uint16_t*) ((uint8_t*) data_start + 1); + + mult_value /= ((1 << qbits) - 1); + quant_row = (uint8_t * ) data_start; + + uint32_t offset = 0; + uint8_t data_offset = 0; + + for (int i = 0; i < (1 << qbits); i++) { + qvals[i] = min_value + mult_value * i; + } + + // 64 is the size in bits of uint64_t + for (int jb = 0; jb < QKX_0 / 64; jb++) { + uint64_t fp16_chooser = block_start[jb]; + + // all weights are quantized in this section; ALSO this ONLY works when qbits is <= 4, since (qbits != 3) simply checks if qbits is a power of 2 + if (fp16_chooser == 0) { + if (qbits == 3) { + // same principle as on the regular data_offset branch, but this time the qbits cross byte boundaries, so we need to manage it by hand + for (int i = 0; i < 5; i++) { + for (int k = 0; k < 11; k ++) { + // here we cast to 64bit, to make sure that we don't lose bits that are outside the u32 range + row_ptr[i * 11 + k] = qvals[((((uint64_t *) data_start)[0] >> (data_offset + k * qbits)) & ((1 << qbits) - 1))]; + } + + data_start += 2; // this is the same event as in if (data_start >= 16), but happening twice, stealthily + data_offset += 1; // it's actually +33, but we are rounding + } + + for (int k = 0; k < 9; k ++) { + // here we cast to 64bit, to make sure that we don't lose bits that are outside the u32 range + row_ptr[55 + k] = qvals[((((uint64_t *) data_start)[0] >> (data_offset + k * qbits)) & ((1 << qbits) - 1))]; + } + + data_start += 1; + data_offset += 9 * 3 - 16; + + if (data_offset >= 16) { + data_start += 1; + data_offset -= 16; + } + + } else if (data_offset == 0) { + // This only properly works for QBits = power of 2 + const uint8_t data_block_size = 64; + // we can take a full 64bit block + const uint8_t weights_per_u64_data_block = data_block_size / qbits; + const uint8_t num_of_data_blocks_needed = 64 / weights_per_u64_data_block; // because we have 64 qbit-sized weights here + + for (int i = 0; i < num_of_data_blocks_needed; i++) { + for (int k = 0; k < weights_per_u64_data_block; k ++) { + row_ptr[i * weights_per_u64_data_block + k] = qvals[(((uint64_t *) data_start)[0] >> (k * qbits)) & ((1 << qbits) - 1)]; + } + + data_start += (data_block_size / 8) / sizeof(uint16_t); + } + } else { + // We are doing u32 instead of a simple u64, since data_offset may not be 0 and we need to account for that + const uint8_t data_block_size = 32; + const uint8_t weights_per_u32_data_block = data_block_size / qbits; + const uint8_t num_of_data_blocks_needed = 64 / weights_per_u32_data_block; + + for (int i = 0; i < num_of_data_blocks_needed; i++) { + for (int k = 0; k < weights_per_u32_data_block; k ++) { + // here we cast to 64bit, to make sure that we don't lose bits that are outside the u32 range + row_ptr[i * weights_per_u32_data_block + k] = qvals[((((uint64_t *) data_start)[0] >> (data_offset + k * qbits)) & ((1 << qbits) - 1))]; + } + + data_start += (data_block_size / 8) / sizeof(uint16_t); + } + } + + offset += qbits * 64; + } else { + for (int i = 0; i < 64; i++) { + // next 32 values are free + if (fp16_chooser & 1) { + offset += 16; + row_ptr[i] = GGML_FP16_TO_FP32((((uint32_t *) data_start)[0] >> data_offset) & ((1 << 16) - 1)); + + #ifdef ame_debug + printf("%f (16bit)\n", row_ptr[i]); + #endif + + data_start += 1; + } else { + offset += qbits; + row_ptr[i] = qvals[((((uint32_t *) data_start)[0] >> data_offset) & ((1 << qbits) - 1))]; + + #ifdef ame_debug + printf("%ld -> %f (%dbit)\n", ((((uint32_t *) data_start)[0] >> data_offset) & ((1 << qbits) - 1)), row_ptr[i], qbits); + #endif + + data_offset += qbits; + + if (data_offset >= 16) { + data_start += 1; + data_offset -= 16; + } + } + + fp16_chooser >>= 1; + + // uint8_t sz = qbits << ((fp16_chooser & 1) << 1); + + // get_bits(data_start, offset, &w, sz); + // offset += sz; + + // if (sz == qbits) { + // row_save[i] = qvals_f16[w]; + // } else { + // row_save[i] = w; + // } + } + } + + for (int jb = 0; jb < 64 / QK8_0; jb++) { + __m256 column_multiplier = _mm256_set1_ps(GGML_FP16_TO_FP32(column[column_idx].d)); + + for (int i = 0; i < QK8_0/8; i++) { + __m128i test = _mm_loadu_si128((const __m128i *) (column[column_idx].qs + i * 8)); + __m256i work = _mm256_cvtepi8_epi32(test); + __m256 workf = _mm256_cvtepi32_ps(work); + + // multiply with our 8 parts of the row at row_data + __m256 row = _mm256_loadu_ps(row_ptr + jb * QK8_0 + i * 8); + + workf = _mm256_mul_ps(workf, row); + rolling_sum = _mm256_fmadd_ps(workf, column_multiplier, rolling_sum); + } + + column_idx += 1; + } + + // horrible manual loop unroll for testing, 1 iteration only + // int i = 0; + // uint16_t w = 0; + + // if (unlikely(fp16_chooser & 1)) { get_bits(data_start, offset, &w, 16); offset += 16; row_save[i] = w; } else { get_bits(data_start, offset, &w, qbits); offset += qbits; row_save[i] = qvals_f16[w]; } fp16_chooser >>= 1; i++; + + row_ptr += 64; + } + + //printf("offset: %d\n", offset); + GGML_ASSERT(offset % 8 == 0); + quant_row += offset / 8; + } + + float rolling_sum_vec[8]; + _mm256_store_ps(rolling_sum_vec, rolling_sum); + + for (int i = 0; i < 8; i++) { + *s += rolling_sum_vec[i]; + } +} + static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { const int qk = QK8_0; const int nb = n / qk; @@ -3514,11 +3765,12 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = { [GGML_TYPE_Q6_K] = QK_K, [GGML_TYPE_Q8_K] = QK_K, #endif + // [GGML_TYPE_QX_0], // QX_0 doesn't have a fixed block size [GGML_TYPE_I8] = 1, [GGML_TYPE_I16] = 1, [GGML_TYPE_I32] = 1, }; -static_assert(GGML_TYPE_COUNT == 19, "GGML_BLCK_SIZE is outdated"); +static_assert(GGML_TYPE_COUNT == 20, "GGML_BLCK_SIZE is outdated"); static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = { [GGML_TYPE_F32] = sizeof(float), @@ -3537,11 +3789,12 @@ static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = { [GGML_TYPE_Q6_K] = sizeof(block_q6_K), [GGML_TYPE_Q8_K] = sizeof(block_q8_K), #endif + // [GGML_TYPE_QX_0], // QX_0 doesn't have a fixed type size [GGML_TYPE_I8] = sizeof(int8_t), [GGML_TYPE_I16] = sizeof(int16_t), [GGML_TYPE_I32] = sizeof(int32_t), }; -static_assert(GGML_TYPE_COUNT == 19, "GGML_TYPE_SIZE is outdated"); +static_assert(GGML_TYPE_COUNT == 20, "GGML_TYPE_SIZE is outdated"); static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = { @@ -3559,11 +3812,12 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = { [GGML_TYPE_Q5_K] = "q5_K", [GGML_TYPE_Q6_K] = "q6_K", [GGML_TYPE_Q8_K] = "q8_K", + [GGML_TYPE_QX_0] = "qx_0", [GGML_TYPE_I8] = "i8", [GGML_TYPE_I16] = "i16", [GGML_TYPE_I32] = "i32", }; -static_assert(GGML_TYPE_COUNT == 19, "GGML_TYPE_NAME is outdated"); +static_assert(GGML_TYPE_COUNT == 20, "GGML_TYPE_NAME is outdated"); static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = { [GGML_TYPE_F32] = false, @@ -3580,11 +3834,12 @@ static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = { [GGML_TYPE_Q5_K] = true, [GGML_TYPE_Q6_K] = true, [GGML_TYPE_Q8_K] = true, + [GGML_TYPE_QX_0] = true, [GGML_TYPE_I8] = false, [GGML_TYPE_I16] = false, [GGML_TYPE_I32] = false, }; -static_assert(GGML_TYPE_COUNT == 19, "GGML_IS_QUANTIZED is outdated"); +static_assert(GGML_TYPE_COUNT == 20, "GGML_IS_QUANTIZED is outdated"); static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "NONE", @@ -3890,6 +4145,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { case GGML_FTYPE_MOSTLY_Q4_K: wtype = GGML_TYPE_Q4_K; break; case GGML_FTYPE_MOSTLY_Q5_K: wtype = GGML_TYPE_Q5_K; break; case GGML_FTYPE_MOSTLY_Q6_K: wtype = GGML_TYPE_Q6_K; break; + case GGML_FTYPE_MOSTLY_QX_0: wtype = GGML_TYPE_QX_0; break; case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break; case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break; } @@ -4266,7 +4522,14 @@ struct ggml_tensor * ggml_new_tensor_impl( } result->nb[0] = GGML_TYPE_SIZE[type]; - result->nb[1] = result->nb[0]*(result->ne[0]/GGML_BLCK_SIZE[type]); + + if (type == GGML_TYPE_QX_0) { + // QX_0 doesn't have a set stride size for a row; that value is stored in the "extra" part of the tensor + result->nb[1] = 0; + } else { + result->nb[1] = result->nb[0]*(result->ne[0]/GGML_BLCK_SIZE[type]); + } + for (int i = 2; i < GGML_MAX_DIMS; i++) { result->nb[i] = result->nb[i - 1]*result->ne[i - 1]; } @@ -7719,6 +7982,7 @@ static void ggml_compute_forward_add( case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_QX_0: { ggml_compute_forward_add_q_f32(params, src0, src1, dst); } break; @@ -8027,6 +8291,7 @@ static void ggml_compute_forward_add1( case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_QX_0: { ggml_compute_forward_add1_q_f32(params, src0, src1, dst); } break; @@ -8154,6 +8419,7 @@ static void ggml_compute_forward_acc( case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_QX_0: default: { GGML_ASSERT(false); @@ -10189,13 +10455,22 @@ static void ggml_compute_forward_mul_mat_q_f32( const int i2 = i02; const int i3 = i03; - void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)); + void * src0_row; + + if (type == GGML_TYPE_QX_0) { + if (ir > 0) { + src0_row = (void *) ((char *) src0->data + ((uint64_t *) src0->extra)[ir - 1]); + } else { + src0_row = (void *) ((char *) src0->data); + } + } else { + src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)); + assert(ne00 % 32 == 0); + } + char * src1_col = ((char *) wdata + ( (0 + i12*ne11 + i13*ne12*ne11)*row_size)); - float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3)); - - assert(ne00 % 32 == 0); - + for (int64_t ic = 0; ic < ne11; ++ic) { vec_dot_q(ne00, &dst_col[ic*ne0], src0_row, (void *) (src1_col + ic*row_size)); } @@ -10231,6 +10506,7 @@ static void ggml_compute_forward_mul_mat( case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_QX_0: { ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst); } break; @@ -10419,6 +10695,7 @@ static void ggml_compute_forward_set( case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_QX_0: default: { GGML_ASSERT(false); @@ -10589,6 +10866,7 @@ static void ggml_compute_forward_get_rows( case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_QX_0: { ggml_compute_forward_get_rows_q(params, src0, src1, dst); } break; @@ -11141,6 +11419,7 @@ static void ggml_compute_forward_alibi( case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: case GGML_TYPE_Q8_K: + case GGML_TYPE_QX_0: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -11218,6 +11497,7 @@ static void ggml_compute_forward_clamp( case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: case GGML_TYPE_Q8_K: + case GGML_TYPE_QX_0: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -16236,7 +16516,211 @@ size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * return (n/QK8_0*sizeof(block_q8_0)); } -size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist) { +size_t ggml_quantize_qx_0(const float * src, void * dst, int n, int64_t * hist, uint64_t * extra_data, uint32_t tensor_width) { + assert(n % QKX_0 == 0); + assert(tensor_width % QKX_0 == 0); + const int nb = n / QKX_0; + + const uint8_t * dst_8 = dst; + uint64_t dst_offset = 0; + + double max_quantization_errors[5] = {0, 0.004, 0.004, 0, 0.004}; + + for (int i = 0; i < nb; i++) { + // each 64bit TODO + uint64_t fp16s[QKX_0 / 64]; + + memset(fp16s, 0, sizeof(uint64_t) * (QKX_0 / 64)); + + uint8_t qbits = QX_0_STARTING_QBITS; + float thresh = max_quantization_errors[qbits] * (1 << qbits); + + int fp16_count = 0; + + // max_quantization_error indicates that no value should be >= max_quantization_error away from + // its quantized value; + // that means, the total range for the quantized values will be max_quantization_error * 2 * (1 << qbits) (here, 16) + // for simplicty, we are going to center on 0, meaning that our fp16 threshold will be max_quantization_error * 16 values to the left and right + // -->|<---->|<---->|<---->|<-- 4bit example, --> = max_quant_error; we have "--> * 2 * 3 + --> + -->" positions where quantized values can be, == "--> * 4" + + for (int j = 0; j < QKX_0; j++) { + float x = src[i * QKX_0 + j]; + + if (fabsf(x) > thresh) { + // deactivate quant + fp16s[j / 64] |= (uint64_t) 1 << (j % 64); + fp16_count += 1; + } + } + + uint16_t total_bits = fp16_count * 16 + (QKX_0 - fp16_count) * qbits; + + while ((total_bits % 8) != 0) { + total_bits += 16 - qbits; // simulate the replacement of a quantized weight with a 16bit one + } + + float min_value = -(max_quantization_errors[qbits] * ((1 << qbits) - 1)); + float mult_range = 2 * max_quantization_errors[qbits] * ((1 << qbits) - 1); + + for (uint8_t test_qbit = QX_0_STARTING_QBITS_DOWNSCALING; test_qbit >= 1; test_qbit--) { + double mean = 0; + for (int j = 0; j < QKX_0; j++) { + if ((fp16s[j / 64] & ((uint64_t) 1 << (j % 64))) == 0) { + float x_fp32 = src[i * QKX_0 + j]; + mean += x_fp32; + } + } + mean /= (QKX_0 - fp16_count); // see where weights are centered + + uint16_t total_fp16s_in_test_qbit = 0; + thresh = max_quantization_errors[test_qbit] * (1 << test_qbit); + + for (int j = 0; j < QKX_0; j++) { + if ((fp16s[j / 64] & ((uint64_t) 1 << (j % 64))) == 0) { + float x = src[i * QKX_0 + j]; + + // this weight would need to be put on 16bit + if (x < mean - thresh || x > mean + thresh) { + total_fp16s_in_test_qbit += 1; + } + } else { + total_fp16s_in_test_qbit += 1; + } + } + + uint16_t total_bits_in_test_qbit = total_fp16s_in_test_qbit * 16 + test_qbit * (QKX_0 - total_fp16s_in_test_qbit); + while ((total_bits_in_test_qbit % 8) != 0) { + total_bits_in_test_qbit += 16 - test_qbit; // simulate the replacement of a 3bit weight with a 16bit one + } + + if (total_bits_in_test_qbit < total_bits) { + //printf("switching to %dbit! %d vs %d\n", test_qbit, total_bits, total_bits_in_test_qbit); + + total_bits = total_bits_in_test_qbit; + qbits = test_qbit; + + min_value = mean - (max_quantization_errors[test_qbit] * ((1 << qbits) - 1)); + mult_range = 2 * max_quantization_errors[test_qbit] * ((1 << qbits) - 1); + + for (int j = 0; j < QKX_0; j++) { + if ((fp16s[j / 64] & ((uint64_t) 1 << (j % 64))) == 0) { + float x = src[i * QKX_0 + j]; + + // this weight would need to be put on 16bit + if (x < mean - thresh || x > mean + thresh) { + fp16s[j / 64] |= (uint64_t) 1 << (j % 64); + fp16_count += 1; + } + } + } + + uint16_t total_test_bits = fp16_count * 16 + (QKX_0 - fp16_count) * qbits; + while ((total_test_bits % 8) != 0) { + total_test_bits += 16 - test_qbit; // simulate the replacement of a 3bit weight with a 16bit one + } + + GGML_ASSERT(total_bits == total_test_bits); + } + } + + while (((QKX_0 - fp16_count) * qbits) % 8 != 0) { + float maxi = 0; + int target = -1; + + for (int j = 0; j < QKX_0; j++) { + float x = src[i * QKX_0 + j]; + + // weight is not on 16bit + if ((fp16s[j / 64] & ((uint64_t) 1 << (j % 64))) == 0) { + float diff = fabsf(x); + if (diff > maxi || target == -1) { + maxi = diff; + target = j; + } + } + } + + GGML_ASSERT(target != -1); + fp16s[target / 64] |= (uint64_t) 1 << (target % 64); + fp16_count += 1; + } + + if (((i * QKX_0) % tensor_width == 0) && i != 0) { + uint32_t row = (i * QKX_0) / tensor_width; + extra_data[row - 1] = dst_offset; + } + + uint64_t * fp16_data = (uint64_t *) (dst_8 + dst_offset); + + // write the data + for (int j = 0; j < QKX_0 / 64; j++) { + fp16_data[j] = fp16s[j]; + } + + dst_offset += (QKX_0 / 64) * sizeof(uint64_t); + + // write min value and multiplier (min_value + mult * quant_number, result should be divided by (1 << QBits) during multplication) + *((uint16_t*) (dst_8 + dst_offset)) = ggml_fp32_to_fp16(min_value); + dst_offset += sizeof(uint16_t); + + *((uint16_t*) (dst_8 + dst_offset)) = ggml_fp32_to_fp16(mult_range); + dst_offset += sizeof(uint16_t); + + *((uint8_t*) (dst_8 + dst_offset)) = qbits; + dst_offset += sizeof(uint8_t); + + float qvals[1 << qbits]; + + for (int i = 0; i < (1 << qbits); i++) { + qvals[i] = min_value + (mult_range * i) / ((1 << qbits) - 1); + } + + uint64_t bit_offset = 0; + uint32_t * data = (uint32_t*) (dst_8 + dst_offset); + + int fp16_count_chk = 0; + + for (int j = 0; j < QKX_0; j++) { + float x = src[i * QKX_0 + j]; + + if (fp16s[j / 64] & ((uint64_t) 1 << (j % 64))) { + ggml_fp16_t x_f16 = ggml_fp32_to_fp16(x); + + write_bits(data, bit_offset, x_f16, 16); + bit_offset += 16; + fp16_count_chk += 1; + } else { + uint8_t q = 0; + float min_dist = fabsf(x - qvals[0]); + + for (int iv = 0; iv < (1 << qbits); iv++) { + float dist = fabsf(x - qvals[iv]); + if (dist < min_dist) { + q = iv; + min_dist = dist; + } + } + + write_bits(data, bit_offset, q, qbits); + bit_offset += qbits; + } + } + + GGML_ASSERT(fp16_count == fp16_count_chk); + GGML_ASSERT((((QKX_0 - fp16_count) * qbits) % 8) == 0); + + dst_offset += ((QKX_0 - fp16_count) * qbits) / 8; + dst_offset += fp16_count * 2; + } + + extra_data[n / tensor_width - 1] = dst_offset; + + return dst_offset; +} + +// Pass in additional information such as the tensor's "extra_data" and width, since QX_0 needs this info. We can't pass in a pointer to +// a ggml_tensor (since none exists where quantize_chunk is created), nor to llama_load_tensor since ggml.c doesn't have access to the struct +size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist, uint64_t * extra_data, uint32_t tensor_width) { size_t result = 0; switch (type) { case GGML_TYPE_Q4_0: @@ -16301,6 +16785,10 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i result = ggml_quantize_q6_K(src + start, block, n, n, hist); } break; #endif + case GGML_TYPE_QX_0: + { + result = ggml_quantize_qx_0(src, dst, n, hist, extra_data, tensor_width); + } break; default: assert(false); } diff --git a/ggml.h b/ggml.h index 1b26da3ad..a474e3e8c 100644 --- a/ggml.h +++ b/ggml.h @@ -248,6 +248,7 @@ extern "C" { GGML_TYPE_Q5_K = 13, GGML_TYPE_Q6_K = 14, GGML_TYPE_Q8_K = 15, + GGML_TYPE_QX_0 = 16, GGML_TYPE_I8, GGML_TYPE_I16, GGML_TYPE_I32, @@ -276,6 +277,7 @@ extern "C" { GGML_FTYPE_MOSTLY_Q4_K = 12, // except 1d tensors GGML_FTYPE_MOSTLY_Q5_K = 13, // except 1d tensors GGML_FTYPE_MOSTLY_Q6_K = 14, // except 1d tensors + GGML_FTYPE_MOSTLY_QX_0 = 15, // except 1d tensors }; // available tensor operations: @@ -1135,13 +1137,14 @@ extern "C" { // quantization // + GGML_API size_t ggml_quantize_qx_0(const float * src, void * dst, int n, int64_t * hist, uint64_t * extra_data, uint32_t tensor_width); GGML_API size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist); GGML_API size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist); GGML_API size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist); GGML_API size_t ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist); GGML_API size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist); - GGML_API size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist); + GGML_API size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist, uint64_t * extra_data, uint32_t tensor_width); // // system info diff --git a/llama.cpp b/llama.cpp index f0f9124d8..e0ff34861 100644 --- a/llama.cpp +++ b/llama.cpp @@ -342,8 +342,11 @@ struct llama_load_tensor_shard { enum ggml_type type; size_t file_idx; size_t file_off; + size_t extra_data_file_off; void calc_size() { + // For QX_0, the size is manually written-in, since it comes from extra_data + GGML_ASSERT(type != GGML_TYPE_QX_0); size = llama_calc_tensor_size(ne, type); } }; @@ -364,6 +367,7 @@ struct llama_load_tensor { size_t size; struct ggml_tensor * ggml_tensor = NULL; uint8_t * data; + uint64_t * extra_data = NULL; llama_load_tensor(const std::string & name) : name(name) {} @@ -424,7 +428,18 @@ struct llama_load_tensor { } void calc_size() { - size = llama_calc_tensor_size(ne, type); + // For QX_0 the size comes from extra_data, but since extra_data might not be initialized here + // we can take it from the shard instead + if (type == GGML_TYPE_QX_0) { + GGML_ASSERT(shards.size() == 1); + GGML_ASSERT(ne.size() == 2); + + size = shards.at(0).size; + + GGML_ASSERT(size != 0); + } else { + size = llama_calc_tensor_size(ne, type); + } } }; @@ -520,6 +535,7 @@ struct llama_file_loader { shard.ne.resize(n_dims); file.read_raw(shard.ne.data(), sizeof(shard.ne[0]) * n_dims); std::string name = file.read_string(name_len); + if (n_dims < 1 || n_dims > 2) { throw std::runtime_error(format("llama.cpp: tensor '%s' should not be %u-dimensional", name.c_str(), n_dims)); } @@ -536,6 +552,7 @@ struct llama_file_loader { case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_QX_0: break; default: { throw std::runtime_error(format("unrecognized tensor type %u\n", shard.type)); @@ -546,12 +563,36 @@ struct llama_file_loader { // skip to the next multiple of 32 bytes file.seek(-static_cast(file.tell()) & 31, SEEK_CUR); } + + if (shard.type == GGML_TYPE_QX_0) { + shard.extra_data_file_off = file.tell(); + + uint64_t extra_data[shard.ne[1]]; + file.read_raw(extra_data, sizeof(uint64_t) * shard.ne[1]); + + // set the size of the tensor here + shard.size = extra_data[shard.ne[1] - 1]; + + // realign, just in case extra_data isn't a multiple of 32B + file.seek(-static_cast(file.tell()) & 31, SEEK_CUR); + } else { + shard.extra_data_file_off = 0; + } + shard.file_idx = file_idx; shard.file_off = file.tell(); - shard.calc_size(); + if (shard.type != GGML_TYPE_QX_0) { + shard.calc_size(); + } + file.seek(shard.size, SEEK_CUR); + // QX_0's data may not be 32-byte aligned + if (shard.type == GGML_TYPE_QX_0) { + file.seek(-static_cast(file.tell()) & 31, SEEK_CUR); + } + auto it = tensors_map.name_to_idx.find(name); size_t idx; if (it != tensors_map.name_to_idx.end()) { @@ -602,7 +643,9 @@ struct llama_file_saver { file.write_raw(&token_score.score, sizeof(token_score.score)); } } - void write_tensor(llama_load_tensor & tensor, enum ggml_type new_type, const void * new_data, size_t new_size) { + + // pass extra_data by reference to avoid excessive copying + void write_tensor(llama_load_tensor & tensor, enum ggml_type new_type, const void * new_data, size_t new_size, llama_buffer & extra_data) { switch (new_type) { case GGML_TYPE_F32: case GGML_TYPE_F16: @@ -616,6 +659,7 @@ struct llama_file_saver { case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_QX_0: break; default: LLAMA_ASSERT(false); } @@ -624,9 +668,29 @@ struct llama_file_saver { file.write_u32(new_type); file.write_raw(tensor.ne.data(), sizeof(tensor.ne[0]) * tensor.ne.size()); file.write_raw(tensor.name.data(), tensor.name.size()); + + size_t tensor_size; + file.seek(-static_cast(file.tell()) & 31, SEEK_CUR); - LLAMA_ASSERT(new_size == llama_calc_tensor_size(tensor.ne, new_type)); + + // The tensor's size for QX_0 is stored in the last element of extra_data + if (new_type == GGML_TYPE_QX_0) { + file.write_raw(extra_data.addr, sizeof(uint64_t) * tensor.ne[1]); + tensor_size = ((uint64_t *) extra_data.addr)[tensor.ne[1] - 1]; + + // realign, just in case extra_data isn't a multiple of 32B + file.seek(-static_cast(file.tell()) & 31, SEEK_CUR); + } else { + tensor_size = llama_calc_tensor_size(tensor.ne, new_type); + } + + LLAMA_ASSERT(new_size == tensor_size); file.write_raw(new_data, new_size); + + // QX_0 data may not be 32-byte aligned + if (new_type == GGML_TYPE_QX_0) { + file.seek(-static_cast(file.tell()) & 31, SEEK_CUR); + } } }; @@ -666,7 +730,7 @@ struct llama_model_loader { bool alignment_prevents_mmap() { for (const llama_load_tensor & lt : tensors_map.tensors) { for (const llama_load_tensor_shard & shard : lt.shards) { - if (shard.file_off & 3) { + if ((shard.file_off & 3)) { return true; } } @@ -725,6 +789,7 @@ struct llama_model_loader { tensor->backend = backend; lt.ggml_tensor = tensor; num_ggml_tensors_created++; + return tensor; } @@ -771,6 +836,13 @@ struct llama_model_loader { switch(lt.ggml_tensor->backend) { case GGML_BACKEND_CPU: lt.ggml_tensor->data = lt.data; + + if (lt.type == GGML_TYPE_QX_0) { + // QX_0 uses the extra field to store byte offsets in *data for each row except row 0 + // (so extra[0] stores where row 1 starts, extra[1] is for row 2, and the last element + // in extra stores the total tensor size) + lt.ggml_tensor->extra = lt.extra_data; + } if (use_mmap && lmlock) { lock_size += lt.size; lmlock->grow_to(lock_size); @@ -801,9 +873,23 @@ struct llama_model_loader { } void load_data_for(llama_load_tensor & lt) { + // QX_0 only supports mmap + GGML_ASSERT(use_mmap || lt.type != GGML_TYPE_QX_0); + if (use_mmap) { LLAMA_ASSERT(lt.shards.size() == 1); lt.data = (uint8_t *) mapping->addr + lt.shards.at(0).file_off; + + if (lt.shards.at(0).extra_data_file_off != 0) { + lt.extra_data = (uint64_t *) ((uint8_t *) mapping->addr + lt.shards.at(0).extra_data_file_off); + } + printf("load data for %s\n", lt.name.c_str()); + + if (lt.extra_data != NULL) { + printf("extra_data_file_off: %zu, data: %p, extra_data: %p\n", lt.shards.at(0).extra_data_file_off, lt.data, lt.extra_data); + printf("extra_data for %s: %lu %lu ... %lu\n", lt.name.c_str(), lt.extra_data[0], lt.extra_data[1], lt.extra_data[lt.ne[1] - 1]); + } + } else if (lt.split_type == SPLIT_NONE) { llama_file & file = file_loaders.at(lt.shards.at(0).file_idx)->file; file.seek(lt.shards.at(0).file_off, SEEK_SET); @@ -988,6 +1074,7 @@ static const char *llama_ftype_name(enum llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_Q5_K_S: return "mostly Q5_K - Small"; case LLAMA_FTYPE_MOSTLY_Q5_K_M: return "mostly Q5_K - Medium"; case LLAMA_FTYPE_MOSTLY_Q6_K: return "mostly Q6_K"; + case LLAMA_FTYPE_MOSTLY_QX_0: return "mostly QX_0"; default: return "unknown, may not work"; } } @@ -1665,6 +1752,8 @@ static bool llama_eval_internal( lctx.n_p_eval += N; } + fprintf(stderr, "\nmodel eval time: %ldms\n", (ggml_time_us() - t_start_us) / 1000); + fflush(stderr); return true; } @@ -2309,12 +2398,22 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_Q5_K_S: case LLAMA_FTYPE_MOSTLY_Q5_K_M: quantized_type = GGML_TYPE_Q5_K; break; case LLAMA_FTYPE_MOSTLY_Q6_K: quantized_type = GGML_TYPE_Q6_K; break; + case LLAMA_FTYPE_MOSTLY_QX_0: quantized_type = GGML_TYPE_QX_0; break; default: throw std::runtime_error(format("invalid output file type %d\n", ftype)); } if (nthread <= 0) { nthread = std::thread::hardware_concurrency(); } + + // multithreaded QX_0 quantization is not compatible with the current multithreaded quantization impl. + // because, since blocks have an unknown size in bytes, we cannot section the output data in exact + // chunks assigned to 1 thread. Multithreading would technically only be possible if we quantize + // multiple entire tensors at once, but the overall implementation doesn't seem to allow that to be done easily + if (quantized_type == GGML_TYPE_QX_0) { + nthread = 1; + printf("Setting nthread to 1 due to the implementation for QX_0 quantization being single-threaded.\n"); + } std::unique_ptr model_loader(new llama_model_loader(fname_inp, /*use_mmap*/ false, /*vocab_only*/ false)); @@ -2363,12 +2462,23 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s if (!params->quantize_output_tensor && tensor.name == "output.weight") { quantize = false; } + + // Allow only attention and FFN matrices to be quantized under QX_0, since they only require vec_dot + // to be implemented. Output weights and other matrices require more fuctions to be implemented, so + // for simplicity we'll only quantize attn and ffn for now. + if (quantized_type == GGML_TYPE_QX_0) { + if (tensor.name.find("attention") == std::string::npos && tensor.name.find("feed_forward") == std::string::npos) { + quantize = false; + } + } + quantize = quantize && quantized_type != tensor.type; enum ggml_type new_type; void * new_data; size_t new_size; llama_buffer work; + llama_buffer extra_data; if (!quantize) { new_type = tensor.type; @@ -2421,11 +2531,16 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s new_data = work.addr; std::vector hist_cur(1 << 4, 0); + if (new_type == GGML_TYPE_QX_0) { + extra_data.resize(sizeof(uint64_t) * tensor.ne[1]); + } + int chunk_size = 32 * 512; const int nchunk = (nelements + chunk_size - 1)/chunk_size; const int nthread_use = nthread > 1 ? std::max(1, std::min(nthread, nchunk)) : 1; + if (nthread_use < 2) { - new_size = ggml_quantize_chunk(new_type, f32_data, new_data, 0, nelements, hist_cur.data()); + new_size = ggml_quantize_chunk(new_type, f32_data, new_data, 0, nelements, hist_cur.data(), (uint64_t *) extra_data.addr, tensor.ne[0]); } else { size_t counter = 0; new_size = 0; @@ -2449,7 +2564,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s if (local_hist.empty()) { local_hist.resize(hist_cur.size(), 0); } - local_size += ggml_quantize_chunk(new_type, f32_data, new_data, first, last - first, local_hist.data()); + + // pass in NULL for extra_data, since it's only required for QX_0, which doesn't support quantized threading + local_size += ggml_quantize_chunk(new_type, f32_data, new_data, first, last - first, local_hist.data(), NULL, 0); } }; if ((int) workers.size() < nthread_use - 1) { @@ -2480,7 +2597,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s } total_size_org += tensor.size; total_size_new += new_size; - file_saver.write_tensor(tensor, new_type, new_data, new_size); + file_saver.write_tensor(tensor, new_type, new_data, new_size, extra_data); } printf("%s: model size = %8.2f MB\n", __func__, total_size_org/1024.0/1024.0); diff --git a/llama.h b/llama.h index 7c7fd481c..920779d02 100644 --- a/llama.h +++ b/llama.h @@ -113,6 +113,7 @@ extern "C" { LLAMA_FTYPE_MOSTLY_Q5_K_S = 16,// except 1d tensors LLAMA_FTYPE_MOSTLY_Q5_K_M = 17,// except 1d tensors LLAMA_FTYPE_MOSTLY_Q6_K = 18,// except 1d tensors + LLAMA_FTYPE_MOSTLY_QX_0 = 19, // except 1d tensors }; // model quantization parameters