From 4cd885beb5d1efc1b01bd182e3536ffe35d1371c Mon Sep 17 00:00:00 2001 From: Amy Date: Tue, 13 Jun 2023 08:59:03 +0100 Subject: [PATCH] added comments and scalar implementation for vec_dot_qx --- ggml.c | 216 ++++++++++++++++++++++++++++++++++-------------------- llama.cpp | 6 -- 2 files changed, 136 insertions(+), 86 deletions(-) diff --git a/ggml.c b/ggml.c index 7ad696ea8..dd973d6f9 100644 --- a/ggml.c +++ b/ggml.c @@ -3191,20 +3191,24 @@ __attribute__((optimize("unroll-loops"))) static void ggml_vec_dot_qx_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { uint32_t nb = n / QKX_0; GGML_ASSERT(QKX_0 % QK8_0 == 0); + *s = 0; uint8_t * quant_row = (uint8_t *) vx; - - // row_data stores dequantized values of the current block - float f32_row_data[QKX_0]; - const block_q8_0 * restrict column = vy; - - uint32_t column_idx = 0; + uint32_t column_i = 0; // current index in column - __m256 rolling_sum = _mm256_setzero_ps(); + // row_data is a buffer which stores dequantized float values for a current block + float f32_row_data[QKX_0]; - // IMPORTANT, Quantized weights should be kept <= 4bits. Change this number for higher values + // __AVX2__ doesn't seem to actually make much of a difference, + // a lot of optimizing could possibly be done, including possibly using AVX2 + // for dequantization...? + + #if defined(__AVX2__) + __m256 rolling_sum = _mm256_setzero_ps(); + #endif + float qvals[1 << 4]; for (int b = 0; b < nb; b++) { @@ -3218,22 +3222,31 @@ static void ggml_vec_dot_qx_0_q8_0(const int n, float * restrict s, const void * const uint8_t qbits = *((uint8_t *) data_start); data_start = (uint16_t*) ((uint8_t*) data_start + 1); - mult_value /= ((1 << qbits) - 1); quant_row = (uint8_t * ) data_start; + // Any qbits are supported, but the size of qvals needs to be changed to 1 << max_expected_qbits. + // So if you have at most 7bit values, you can change qvals's declaration to qvals[1 << 7]. + // Additionally, the "fp_chooser == 0" optimized branch only works if qbits is "3" or a power of 2, + // so feel free to disable it entirely and run the slower "else" statement which works for pretty much + // any qbit value. + GGML_ASSERT(qbits <= 4); + uint32_t offset = 0; uint8_t data_offset = 0; + // Cache quantized values for (int i = 0; i < (1 << qbits); i++) { qvals[i] = min_value + mult_value * i; } - // 64 is the size in bits of uint64_t + // Parse in sub-blocks of 64 since they are managed by a single uint64_t which decides if a given weight + // is on 16bit or quantized. This means that we can do a fast fp16_indicator == 0 check (i.e. all weights are quantized) + // to speed up peformance for (int jb = 0; jb < QKX_0 / 64; jb++) { - uint64_t fp16_chooser = block_start[jb]; + uint64_t fp16_indicator = block_start[jb]; // all weights are quantized in this section; ALSO this ONLY works when qbits is <= 4, since (qbits != 3) simply checks if qbits is a power of 2 - if (fp16_chooser == 0) { + if (fp16_indicator == 0) { if (qbits == 3) { // same principle as on the regular data_offset branch, but this time the qbits cross byte boundaries, so we need to manage it by hand for (int i = 0; i < 5; i++) { @@ -3242,8 +3255,8 @@ static void ggml_vec_dot_qx_0_q8_0(const int n, float * restrict s, const void * row_ptr[i * 11 + k] = qvals[((((uint64_t *) data_start)[0] >> (data_offset + k * qbits)) & ((1 << qbits) - 1))]; } - data_start += 2; // this is the same event as in if (data_start >= 16), but happening twice, stealthily - data_offset += 1; // it's actually +33, but we are rounding + data_start += 2; // this is the same event as in if (data_start >= 16), but happening twice + data_offset += 1; // it's actually +33, but the "+32" is represented in data_start above, so the remainder is simply +1 } for (int k = 0; k < 9; k ++) { @@ -3292,24 +3305,17 @@ static void ggml_vec_dot_qx_0_q8_0(const int n, float * restrict s, const void * offset += qbits * 64; } else { for (int i = 0; i < 64; i++) { - // next 32 values are free - if (fp16_chooser & 1) { + if (fp16_indicator & 1) { + // Current weight is fp16 offset += 16; row_ptr[i] = GGML_FP16_TO_FP32((((uint32_t *) data_start)[0] >> data_offset) & ((1 << 16) - 1)); - #ifdef ame_debug - printf("%f (16bit)\n", row_ptr[i]); - #endif - data_start += 1; } else { + // Current weight is quantized offset += qbits; row_ptr[i] = qvals[((((uint32_t *) data_start)[0] >> data_offset) & ((1 << qbits) - 1))]; - #ifdef ame_debug - printf("%ld -> %f (%dbit)\n", ((((uint32_t *) data_start)[0] >> data_offset) & ((1 << qbits) - 1)), row_ptr[i], qbits); - #endif - data_offset += qbits; if (data_offset >= 16) { @@ -3318,26 +3324,17 @@ static void ggml_vec_dot_qx_0_q8_0(const int n, float * restrict s, const void * } } - fp16_chooser >>= 1; - - // uint8_t sz = qbits << ((fp16_chooser & 1) << 1); - - // get_bits(data_start, offset, &w, sz); - // offset += sz; - - // if (sz == qbits) { - // row_save[i] = qvals_f16[w]; - // } else { - // row_save[i] = w; - // } + // Shift the fp16 indicator to the right, to move to the next weight + fp16_indicator >>= 1; } } for (int jb = 0; jb < 64 / QK8_0; jb++) { - __m256 column_multiplier = _mm256_set1_ps(GGML_FP16_TO_FP32(column[column_idx].d)); + #if defined(__AVX2__) + __m256 column_multiplier = _mm256_set1_ps(GGML_FP16_TO_FP32(column[column_i].d)); for (int i = 0; i < QK8_0/8; i++) { - __m128i test = _mm_loadu_si128((const __m128i *) (column[column_idx].qs + i * 8)); + __m128i test = _mm_loadu_si128((const __m128i *) (column[column_i].qs + i * 8)); __m256i work = _mm256_cvtepi8_epi32(test); __m256 workf = _mm256_cvtepi32_ps(work); @@ -3347,30 +3344,38 @@ static void ggml_vec_dot_qx_0_q8_0(const int n, float * restrict s, const void * workf = _mm256_mul_ps(workf, row); rolling_sum = _mm256_fmadd_ps(workf, column_multiplier, rolling_sum); } + + #else + // scalar + float sub_sum = 0; - column_idx += 1; + for (int i = 0; i < QK8_0; i++) { + sub_sum += row_ptr[jb * QK8_0 + i] * column[column_i].qs[i]; + } + + sub_sum *= GGML_FP16_TO_FP32(column[column_i].d); + *s += sub_sum; + + #endif + + column_i += 1; } - // horrible manual loop unroll for testing, 1 iteration only - // int i = 0; - // uint16_t w = 0; - - // if (unlikely(fp16_chooser & 1)) { get_bits(data_start, offset, &w, 16); offset += 16; row_save[i] = w; } else { get_bits(data_start, offset, &w, qbits); offset += qbits; row_save[i] = qvals_f16[w]; } fp16_chooser >>= 1; i++; - row_ptr += 64; } - //printf("offset: %d\n", offset); GGML_ASSERT(offset % 8 == 0); quant_row += offset / 8; } + #if defined(__AVX2__) float rolling_sum_vec[8]; _mm256_store_ps(rolling_sum_vec, rolling_sum); for (int i = 0; i < 8; i++) { *s += rolling_sum_vec[i]; } + #endif } static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { @@ -16524,31 +16529,68 @@ size_t ggml_quantize_qx_0(const float * src, void * dst, int n, int64_t * hist, const uint8_t * dst_8 = dst; uint64_t dst_offset = 0; + // define max quantization errors for every bit precision + // i.e max_quantization_errors[1] holds max error for 1bit quantized weights + // max_quantization_errors[2] holds max error for 2bit quantized weights + // max_quantization_errors[3] holds max error for 3bit quantized weights + // etc. + // + // max quantization error here means that every single quantized weight is within + // said value (e.g. 0.004) from its original value + // + // this can be replaced with a max allowed RMSE, a set percentage of weights being within + // a certain range, etc... The current implementation here is pretty much just an example double max_quantization_errors[5] = {0, 0.004, 0.004, 0, 0.004}; + + + // How maximum quantization error is implemented here: + // + // Each block holds both fp16 and "qbit" quantized weights mixed together arbitrarily. + // This mixing is handled by a few numbers at the start of each block, the bit of each number + // indicating if a given weight (corresponding to that bit) is stored on 16bit or is quantized. + // + // There is a metadata byte which indicates the qbit precision of the current block, and + // its values are in [1,2,3,4], but this can easily be extended to allow any other bit precisions, + // such as 5, 6, 9, 13 bits or anything else. + // + // To guarantee that each weight is within max_quantization_error, we first need to look at what range + // of values this allows us to have. Since we have "qbits" bits, then we have (1 << qbits) possible values + // the quantized weights can take. The maximum distance between two quantized points can be "2 * max_quantization_error" + // since any weight situated within these two points will be <= max_quantization_error of its closest point. + // + // A visual 2bit example would be: -->|<---->|<---->|<---->|<-- + // Where "|" are the quantized points, and "-->" represents max_quantization_error on the number line. + // + // Any value outside this range will have to be kept on 16bit, since it cannot be within max_quantization_error + // of its quantized point. + // + // + // Note: Each block is kept byte-aligned for simplicity, which means that the number of 16bit weights and qbit weights + // in the bitstream has to be balanced such that the total number of bits is divisible by 8. + // e.g. If we have 3 4bit values and 253 16bit values, we will need to revert a 4bit value to 16bit in order + // to keep the total number of bits divisble by 8. If we were to quantize a weight instead, we would lose + // the "max_quantization_error" guarantee. However, each block doesn't need to remain byte-aligned, the requirement + // only holds for each row, so a big potential improvement could be made here, since we have quite a few unnecessary + // 16bit weights. for (int i = 0; i < nb; i++) { - // each 64bit TODO - uint64_t fp16s[QKX_0 / 64]; - - memset(fp16s, 0, sizeof(uint64_t) * (QKX_0 / 64)); + // each 64bit value holds binary data of whether the current weight (corresponding to a specific bit) + // is stored on 16bit or is quantized. "QKX_0 / 64" is here since we need multiple 64bit numbers if + // the QX_0 block is larger than 64 weights. + uint64_t fp16_indicators[QKX_0 / 64]; + memset(fp16_indicators, 0, sizeof(uint64_t) * (QKX_0 / 64)); uint8_t qbits = QX_0_STARTING_QBITS; float thresh = max_quantization_errors[qbits] * (1 << qbits); int fp16_count = 0; - // max_quantization_error indicates that no value should be >= max_quantization_error away from - // its quantized value; - // that means, the total range for the quantized values will be max_quantization_error * 2 * (1 << qbits) (here, 16) - // for simplicty, we are going to center on 0, meaning that our fp16 threshold will be max_quantization_error * 16 values to the left and right - // -->|<---->|<---->|<---->|<-- 4bit example, --> = max_quant_error; we have "--> * 2 * 3 + --> + -->" positions where quantized values can be, == "--> * 4" - for (int j = 0; j < QKX_0; j++) { float x = src[i * QKX_0 + j]; if (fabsf(x) > thresh) { - // deactivate quant - fp16s[j / 64] |= (uint64_t) 1 << (j % 64); + // store this value on 16bits + fp16_indicators[j / 64] |= (uint64_t) 1 << (j % 64); fp16_count += 1; } } @@ -16556,30 +16598,31 @@ size_t ggml_quantize_qx_0(const float * src, void * dst, int n, int64_t * hist, uint16_t total_bits = fp16_count * 16 + (QKX_0 - fp16_count) * qbits; while ((total_bits % 8) != 0) { - total_bits += 16 - qbits; // simulate the replacement of a quantized weight with a 16bit one + total_bits += 16 - qbits; // simulate the replacement of a quantized weight with a 16bit one (needed for a block's byte alignment) } float min_value = -(max_quantization_errors[qbits] * ((1 << qbits) - 1)); - float mult_range = 2 * max_quantization_errors[qbits] * ((1 << qbits) - 1); + float mult_range = 2 * max_quantization_errors[qbits]; for (uint8_t test_qbit = QX_0_STARTING_QBITS_DOWNSCALING; test_qbit >= 1; test_qbit--) { + // calculate the mean of non-fp16 values and define that as the center of the quantization range double mean = 0; for (int j = 0; j < QKX_0; j++) { - if ((fp16s[j / 64] & ((uint64_t) 1 << (j % 64))) == 0) { + if ((fp16_indicators[j / 64] & ((uint64_t) 1 << (j % 64))) == 0) { float x_fp32 = src[i * QKX_0 + j]; mean += x_fp32; } } - mean /= (QKX_0 - fp16_count); // see where weights are centered + mean /= (QKX_0 - fp16_count); uint16_t total_fp16s_in_test_qbit = 0; thresh = max_quantization_errors[test_qbit] * (1 << test_qbit); for (int j = 0; j < QKX_0; j++) { - if ((fp16s[j / 64] & ((uint64_t) 1 << (j % 64))) == 0) { + if ((fp16_indicators[j / 64] & ((uint64_t) 1 << (j % 64))) == 0) { float x = src[i * QKX_0 + j]; - // this weight would need to be put on 16bit + // new outlier found for our current qbit if (x < mean - thresh || x > mean + thresh) { total_fp16s_in_test_qbit += 1; } @@ -16590,25 +16633,23 @@ size_t ggml_quantize_qx_0(const float * src, void * dst, int n, int64_t * hist, uint16_t total_bits_in_test_qbit = total_fp16s_in_test_qbit * 16 + test_qbit * (QKX_0 - total_fp16s_in_test_qbit); while ((total_bits_in_test_qbit % 8) != 0) { - total_bits_in_test_qbit += 16 - test_qbit; // simulate the replacement of a 3bit weight with a 16bit one + total_bits_in_test_qbit += 16 - test_qbit; // simulate the replacement of a qbit weight with a 16bit one } if (total_bits_in_test_qbit < total_bits) { - //printf("switching to %dbit! %d vs %d\n", test_qbit, total_bits, total_bits_in_test_qbit); - total_bits = total_bits_in_test_qbit; qbits = test_qbit; min_value = mean - (max_quantization_errors[test_qbit] * ((1 << qbits) - 1)); - mult_range = 2 * max_quantization_errors[test_qbit] * ((1 << qbits) - 1); + mult_range = 2 * max_quantization_errors[test_qbit]; for (int j = 0; j < QKX_0; j++) { - if ((fp16s[j / 64] & ((uint64_t) 1 << (j % 64))) == 0) { + if ((fp16_indicators[j / 64] & ((uint64_t) 1 << (j % 64))) == 0) { float x = src[i * QKX_0 + j]; - // this weight would need to be put on 16bit + // mark outlier as stored on 16bit if (x < mean - thresh || x > mean + thresh) { - fp16s[j / 64] |= (uint64_t) 1 << (j % 64); + fp16_indicators[j / 64] |= (uint64_t) 1 << (j % 64); fp16_count += 1; } } @@ -16616,13 +16657,14 @@ size_t ggml_quantize_qx_0(const float * src, void * dst, int n, int64_t * hist, uint16_t total_test_bits = fp16_count * 16 + (QKX_0 - fp16_count) * qbits; while ((total_test_bits % 8) != 0) { - total_test_bits += 16 - test_qbit; // simulate the replacement of a 3bit weight with a 16bit one + total_test_bits += 16 - test_qbit; // simulate the replacement of a qbit weight with a 16bit one } GGML_ASSERT(total_bits == total_test_bits); } } + // keep converting the largest qbit values to fp16 until the block is byte-aligned while (((QKX_0 - fp16_count) * qbits) % 8 != 0) { float maxi = 0; int target = -1; @@ -16631,7 +16673,7 @@ size_t ggml_quantize_qx_0(const float * src, void * dst, int n, int64_t * hist, float x = src[i * QKX_0 + j]; // weight is not on 16bit - if ((fp16s[j / 64] & ((uint64_t) 1 << (j % 64))) == 0) { + if ((fp16_indicators[j / 64] & ((uint64_t) 1 << (j % 64))) == 0) { float diff = fabsf(x); if (diff > maxi || target == -1) { maxi = diff; @@ -16641,38 +16683,46 @@ size_t ggml_quantize_qx_0(const float * src, void * dst, int n, int64_t * hist, } GGML_ASSERT(target != -1); - fp16s[target / 64] |= (uint64_t) 1 << (target % 64); + fp16_indicators[target / 64] |= (uint64_t) 1 << (target % 64); fp16_count += 1; } + // store the current byte-offset of the current row, if "i" indicates that this is the first + // block of a row if (((i * QKX_0) % tensor_width == 0) && i != 0) { uint32_t row = (i * QKX_0) / tensor_width; extra_data[row - 1] = dst_offset; } - uint64_t * fp16_data = (uint64_t *) (dst_8 + dst_offset); + // write the fp16 indicators to dst + uint64_t * stored_fp16_indicators = (uint64_t *) (dst_8 + dst_offset); - // write the data for (int j = 0; j < QKX_0 / 64; j++) { - fp16_data[j] = fp16s[j]; + stored_fp16_indicators[j] = fp16_indicators[j]; } dst_offset += (QKX_0 / 64) * sizeof(uint64_t); - // write min value and multiplier (min_value + mult * quant_number, result should be divided by (1 << QBits) during multplication) + // Each weight is stored as min_value + mult * quantized_weight + // Similar to Zero-point quantization, or Q4_1 + + // Write min value and multiplier to dst *((uint16_t*) (dst_8 + dst_offset)) = ggml_fp32_to_fp16(min_value); dst_offset += sizeof(uint16_t); *((uint16_t*) (dst_8 + dst_offset)) = ggml_fp32_to_fp16(mult_range); dst_offset += sizeof(uint16_t); + // Store the "metadata" byte (for now it's just "qbits") *((uint8_t*) (dst_8 + dst_offset)) = qbits; dst_offset += sizeof(uint8_t); + + // Store the quantization pivots / points float qvals[1 << qbits]; for (int i = 0; i < (1 << qbits); i++) { - qvals[i] = min_value + (mult_range * i) / ((1 << qbits) - 1); + qvals[i] = min_value + (mult_range * i); } uint64_t bit_offset = 0; @@ -16683,9 +16733,10 @@ size_t ggml_quantize_qx_0(const float * src, void * dst, int n, int64_t * hist, for (int j = 0; j < QKX_0; j++) { float x = src[i * QKX_0 + j]; - if (fp16s[j / 64] & ((uint64_t) 1 << (j % 64))) { + if (fp16_indicators[j / 64] & ((uint64_t) 1 << (j % 64))) { ggml_fp16_t x_f16 = ggml_fp32_to_fp16(x); + // store the full fp16 weight write_bits(data, bit_offset, x_f16, 16); bit_offset += 16; fp16_count_chk += 1; @@ -16693,6 +16744,7 @@ size_t ggml_quantize_qx_0(const float * src, void * dst, int n, int64_t * hist, uint8_t q = 0; float min_dist = fabsf(x - qvals[0]); + // find closest quantization point for (int iv = 0; iv < (1 << qbits); iv++) { float dist = fabsf(x - qvals[iv]); if (dist < min_dist) { @@ -16706,13 +16758,17 @@ size_t ggml_quantize_qx_0(const float * src, void * dst, int n, int64_t * hist, } } + // check that the reported fp16_count is coherent with the bits stored in fp16_indicators GGML_ASSERT(fp16_count == fp16_count_chk); + + // check that the number of bits from quantized values is divisible by 8 GGML_ASSERT((((QKX_0 - fp16_count) * qbits) % 8) == 0); dst_offset += ((QKX_0 - fp16_count) * qbits) / 8; dst_offset += fp16_count * 2; } + // store the total size of the tensor as the last element of extra_data extra_data[n / tensor_width - 1] = dst_offset; return dst_offset; diff --git a/llama.cpp b/llama.cpp index e0ff34861..158d9ebe8 100644 --- a/llama.cpp +++ b/llama.cpp @@ -883,12 +883,6 @@ struct llama_model_loader { if (lt.shards.at(0).extra_data_file_off != 0) { lt.extra_data = (uint64_t *) ((uint8_t *) mapping->addr + lt.shards.at(0).extra_data_file_off); } - printf("load data for %s\n", lt.name.c_str()); - - if (lt.extra_data != NULL) { - printf("extra_data_file_off: %zu, data: %p, extra_data: %p\n", lt.shards.at(0).extra_data_file_off, lt.data, lt.extra_data); - printf("extra_data for %s: %lu %lu ... %lu\n", lt.name.c_str(), lt.extra_data[0], lt.extra_data[1], lt.extra_data[lt.ne[1] - 1]); - } } else if (lt.split_type == SPLIT_NONE) { llama_file & file = file_loaders.at(lt.shards.at(0).file_idx)->file;