added comments and scalar implementation for vec_dot_qx

This commit is contained in:
Amy 2023-06-13 08:59:03 +01:00
parent e5274378f7
commit 4cd885beb5
2 changed files with 136 additions and 86 deletions

210
ggml.c
View File

@ -3191,20 +3191,24 @@ __attribute__((optimize("unroll-loops")))
static void ggml_vec_dot_qx_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
uint32_t nb = n / QKX_0;
GGML_ASSERT(QKX_0 % QK8_0 == 0);
*s = 0;
uint8_t * quant_row = (uint8_t *) vx;
const block_q8_0 * restrict column = vy;
uint32_t column_i = 0; // current index in column
// row_data stores dequantized values of the current block
// row_data is a buffer which stores dequantized float values for a current block
float f32_row_data[QKX_0];
const block_q8_0 * restrict column = vy;
uint32_t column_idx = 0;
// __AVX2__ doesn't seem to actually make much of a difference,
// a lot of optimizing could possibly be done, including possibly using AVX2
// for dequantization...?
#if defined(__AVX2__)
__m256 rolling_sum = _mm256_setzero_ps();
#endif
// IMPORTANT, Quantized weights should be kept <= 4bits. Change this number for higher values
float qvals[1 << 4];
for (int b = 0; b < nb; b++) {
@ -3218,22 +3222,31 @@ static void ggml_vec_dot_qx_0_q8_0(const int n, float * restrict s, const void *
const uint8_t qbits = *((uint8_t *) data_start);
data_start = (uint16_t*) ((uint8_t*) data_start + 1);
mult_value /= ((1 << qbits) - 1);
quant_row = (uint8_t * ) data_start;
// Any qbits are supported, but the size of qvals needs to be changed to 1 << max_expected_qbits.
// So if you have at most 7bit values, you can change qvals's declaration to qvals[1 << 7].
// Additionally, the "fp_chooser == 0" optimized branch only works if qbits is "3" or a power of 2,
// so feel free to disable it entirely and run the slower "else" statement which works for pretty much
// any qbit value.
GGML_ASSERT(qbits <= 4);
uint32_t offset = 0;
uint8_t data_offset = 0;
// Cache quantized values
for (int i = 0; i < (1 << qbits); i++) {
qvals[i] = min_value + mult_value * i;
}
// 64 is the size in bits of uint64_t
// Parse in sub-blocks of 64 since they are managed by a single uint64_t which decides if a given weight
// is on 16bit or quantized. This means that we can do a fast fp16_indicator == 0 check (i.e. all weights are quantized)
// to speed up peformance
for (int jb = 0; jb < QKX_0 / 64; jb++) {
uint64_t fp16_chooser = block_start[jb];
uint64_t fp16_indicator = block_start[jb];
// all weights are quantized in this section; ALSO this ONLY works when qbits is <= 4, since (qbits != 3) simply checks if qbits is a power of 2
if (fp16_chooser == 0) {
if (fp16_indicator == 0) {
if (qbits == 3) {
// same principle as on the regular data_offset branch, but this time the qbits cross byte boundaries, so we need to manage it by hand
for (int i = 0; i < 5; i++) {
@ -3242,8 +3255,8 @@ static void ggml_vec_dot_qx_0_q8_0(const int n, float * restrict s, const void *
row_ptr[i * 11 + k] = qvals[((((uint64_t *) data_start)[0] >> (data_offset + k * qbits)) & ((1 << qbits) - 1))];
}
data_start += 2; // this is the same event as in if (data_start >= 16), but happening twice, stealthily
data_offset += 1; // it's actually +33, but we are rounding
data_start += 2; // this is the same event as in if (data_start >= 16), but happening twice
data_offset += 1; // it's actually +33, but the "+32" is represented in data_start above, so the remainder is simply +1
}
for (int k = 0; k < 9; k ++) {
@ -3292,24 +3305,17 @@ static void ggml_vec_dot_qx_0_q8_0(const int n, float * restrict s, const void *
offset += qbits * 64;
} else {
for (int i = 0; i < 64; i++) {
// next 32 values are free
if (fp16_chooser & 1) {
if (fp16_indicator & 1) {
// Current weight is fp16
offset += 16;
row_ptr[i] = GGML_FP16_TO_FP32((((uint32_t *) data_start)[0] >> data_offset) & ((1 << 16) - 1));
#ifdef ame_debug
printf("%f (16bit)\n", row_ptr[i]);
#endif
data_start += 1;
} else {
// Current weight is quantized
offset += qbits;
row_ptr[i] = qvals[((((uint32_t *) data_start)[0] >> data_offset) & ((1 << qbits) - 1))];
#ifdef ame_debug
printf("%ld -> %f (%dbit)\n", ((((uint32_t *) data_start)[0] >> data_offset) & ((1 << qbits) - 1)), row_ptr[i], qbits);
#endif
data_offset += qbits;
if (data_offset >= 16) {
@ -3318,26 +3324,17 @@ static void ggml_vec_dot_qx_0_q8_0(const int n, float * restrict s, const void *
}
}
fp16_chooser >>= 1;
// uint8_t sz = qbits << ((fp16_chooser & 1) << 1);
// get_bits(data_start, offset, &w, sz);
// offset += sz;
// if (sz == qbits) {
// row_save[i] = qvals_f16[w];
// } else {
// row_save[i] = w;
// }
// Shift the fp16 indicator to the right, to move to the next weight
fp16_indicator >>= 1;
}
}
for (int jb = 0; jb < 64 / QK8_0; jb++) {
__m256 column_multiplier = _mm256_set1_ps(GGML_FP16_TO_FP32(column[column_idx].d));
#if defined(__AVX2__)
__m256 column_multiplier = _mm256_set1_ps(GGML_FP16_TO_FP32(column[column_i].d));
for (int i = 0; i < QK8_0/8; i++) {
__m128i test = _mm_loadu_si128((const __m128i *) (column[column_idx].qs + i * 8));
__m128i test = _mm_loadu_si128((const __m128i *) (column[column_i].qs + i * 8));
__m256i work = _mm256_cvtepi8_epi32(test);
__m256 workf = _mm256_cvtepi32_ps(work);
@ -3348,29 +3345,37 @@ static void ggml_vec_dot_qx_0_q8_0(const int n, float * restrict s, const void *
rolling_sum = _mm256_fmadd_ps(workf, column_multiplier, rolling_sum);
}
column_idx += 1;
#else
// scalar
float sub_sum = 0;
for (int i = 0; i < QK8_0; i++) {
sub_sum += row_ptr[jb * QK8_0 + i] * column[column_i].qs[i];
}
sub_sum *= GGML_FP16_TO_FP32(column[column_i].d);
*s += sub_sum;
#endif
column_i += 1;
}
// horrible manual loop unroll for testing, 1 iteration only
// int i = 0;
// uint16_t w = 0;
// if (unlikely(fp16_chooser & 1)) { get_bits(data_start, offset, &w, 16); offset += 16; row_save[i] = w; } else { get_bits(data_start, offset, &w, qbits); offset += qbits; row_save[i] = qvals_f16[w]; } fp16_chooser >>= 1; i++;
row_ptr += 64;
}
//printf("offset: %d\n", offset);
GGML_ASSERT(offset % 8 == 0);
quant_row += offset / 8;
}
#if defined(__AVX2__)
float rolling_sum_vec[8];
_mm256_store_ps(rolling_sum_vec, rolling_sum);
for (int i = 0; i < 8; i++) {
*s += rolling_sum_vec[i];
}
#endif
}
static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
@ -16524,31 +16529,68 @@ size_t ggml_quantize_qx_0(const float * src, void * dst, int n, int64_t * hist,
const uint8_t * dst_8 = dst;
uint64_t dst_offset = 0;
// define max quantization errors for every bit precision
// i.e max_quantization_errors[1] holds max error for 1bit quantized weights
// max_quantization_errors[2] holds max error for 2bit quantized weights
// max_quantization_errors[3] holds max error for 3bit quantized weights
// etc.
//
// max quantization error here means that every single quantized weight is within
// said value (e.g. 0.004) from its original value
//
// this can be replaced with a max allowed RMSE, a set percentage of weights being within
// a certain range, etc... The current implementation here is pretty much just an example
double max_quantization_errors[5] = {0, 0.004, 0.004, 0, 0.004};
for (int i = 0; i < nb; i++) {
// each 64bit TODO
uint64_t fp16s[QKX_0 / 64];
memset(fp16s, 0, sizeof(uint64_t) * (QKX_0 / 64));
// How maximum quantization error is implemented here:
//
// Each block holds both fp16 and "qbit" quantized weights mixed together arbitrarily.
// This mixing is handled by a few numbers at the start of each block, the bit of each number
// indicating if a given weight (corresponding to that bit) is stored on 16bit or is quantized.
//
// There is a metadata byte which indicates the qbit precision of the current block, and
// its values are in [1,2,3,4], but this can easily be extended to allow any other bit precisions,
// such as 5, 6, 9, 13 bits or anything else.
//
// To guarantee that each weight is within max_quantization_error, we first need to look at what range
// of values this allows us to have. Since we have "qbits" bits, then we have (1 << qbits) possible values
// the quantized weights can take. The maximum distance between two quantized points can be "2 * max_quantization_error"
// since any weight situated within these two points will be <= max_quantization_error of its closest point.
//
// A visual 2bit example would be: -->|<---->|<---->|<---->|<--
// Where "|" are the quantized points, and "-->" represents max_quantization_error on the number line.
//
// Any value outside this range will have to be kept on 16bit, since it cannot be within max_quantization_error
// of its quantized point.
//
//
// Note: Each block is kept byte-aligned for simplicity, which means that the number of 16bit weights and qbit weights
// in the bitstream has to be balanced such that the total number of bits is divisible by 8.
// e.g. If we have 3 4bit values and 253 16bit values, we will need to revert a 4bit value to 16bit in order
// to keep the total number of bits divisble by 8. If we were to quantize a weight instead, we would lose
// the "max_quantization_error" guarantee. However, each block doesn't need to remain byte-aligned, the requirement
// only holds for each row, so a big potential improvement could be made here, since we have quite a few unnecessary
// 16bit weights.
for (int i = 0; i < nb; i++) {
// each 64bit value holds binary data of whether the current weight (corresponding to a specific bit)
// is stored on 16bit or is quantized. "QKX_0 / 64" is here since we need multiple 64bit numbers if
// the QX_0 block is larger than 64 weights.
uint64_t fp16_indicators[QKX_0 / 64];
memset(fp16_indicators, 0, sizeof(uint64_t) * (QKX_0 / 64));
uint8_t qbits = QX_0_STARTING_QBITS;
float thresh = max_quantization_errors[qbits] * (1 << qbits);
int fp16_count = 0;
// max_quantization_error indicates that no value should be >= max_quantization_error away from
// its quantized value;
// that means, the total range for the quantized values will be max_quantization_error * 2 * (1 << qbits) (here, 16)
// for simplicty, we are going to center on 0, meaning that our fp16 threshold will be max_quantization_error * 16 values to the left and right
// -->|<---->|<---->|<---->|<-- 4bit example, --> = max_quant_error; we have "--> * 2 * 3 + --> + -->" positions where quantized values can be, == "--> * 4"
for (int j = 0; j < QKX_0; j++) {
float x = src[i * QKX_0 + j];
if (fabsf(x) > thresh) {
// deactivate quant
fp16s[j / 64] |= (uint64_t) 1 << (j % 64);
// store this value on 16bits
fp16_indicators[j / 64] |= (uint64_t) 1 << (j % 64);
fp16_count += 1;
}
}
@ -16556,30 +16598,31 @@ size_t ggml_quantize_qx_0(const float * src, void * dst, int n, int64_t * hist,
uint16_t total_bits = fp16_count * 16 + (QKX_0 - fp16_count) * qbits;
while ((total_bits % 8) != 0) {
total_bits += 16 - qbits; // simulate the replacement of a quantized weight with a 16bit one
total_bits += 16 - qbits; // simulate the replacement of a quantized weight with a 16bit one (needed for a block's byte alignment)
}
float min_value = -(max_quantization_errors[qbits] * ((1 << qbits) - 1));
float mult_range = 2 * max_quantization_errors[qbits] * ((1 << qbits) - 1);
float mult_range = 2 * max_quantization_errors[qbits];
for (uint8_t test_qbit = QX_0_STARTING_QBITS_DOWNSCALING; test_qbit >= 1; test_qbit--) {
// calculate the mean of non-fp16 values and define that as the center of the quantization range
double mean = 0;
for (int j = 0; j < QKX_0; j++) {
if ((fp16s[j / 64] & ((uint64_t) 1 << (j % 64))) == 0) {
if ((fp16_indicators[j / 64] & ((uint64_t) 1 << (j % 64))) == 0) {
float x_fp32 = src[i * QKX_0 + j];
mean += x_fp32;
}
}
mean /= (QKX_0 - fp16_count); // see where weights are centered
mean /= (QKX_0 - fp16_count);
uint16_t total_fp16s_in_test_qbit = 0;
thresh = max_quantization_errors[test_qbit] * (1 << test_qbit);
for (int j = 0; j < QKX_0; j++) {
if ((fp16s[j / 64] & ((uint64_t) 1 << (j % 64))) == 0) {
if ((fp16_indicators[j / 64] & ((uint64_t) 1 << (j % 64))) == 0) {
float x = src[i * QKX_0 + j];
// this weight would need to be put on 16bit
// new outlier found for our current qbit
if (x < mean - thresh || x > mean + thresh) {
total_fp16s_in_test_qbit += 1;
}
@ -16590,25 +16633,23 @@ size_t ggml_quantize_qx_0(const float * src, void * dst, int n, int64_t * hist,
uint16_t total_bits_in_test_qbit = total_fp16s_in_test_qbit * 16 + test_qbit * (QKX_0 - total_fp16s_in_test_qbit);
while ((total_bits_in_test_qbit % 8) != 0) {
total_bits_in_test_qbit += 16 - test_qbit; // simulate the replacement of a 3bit weight with a 16bit one
total_bits_in_test_qbit += 16 - test_qbit; // simulate the replacement of a qbit weight with a 16bit one
}
if (total_bits_in_test_qbit < total_bits) {
//printf("switching to %dbit! %d vs %d\n", test_qbit, total_bits, total_bits_in_test_qbit);
total_bits = total_bits_in_test_qbit;
qbits = test_qbit;
min_value = mean - (max_quantization_errors[test_qbit] * ((1 << qbits) - 1));
mult_range = 2 * max_quantization_errors[test_qbit] * ((1 << qbits) - 1);
mult_range = 2 * max_quantization_errors[test_qbit];
for (int j = 0; j < QKX_0; j++) {
if ((fp16s[j / 64] & ((uint64_t) 1 << (j % 64))) == 0) {
if ((fp16_indicators[j / 64] & ((uint64_t) 1 << (j % 64))) == 0) {
float x = src[i * QKX_0 + j];
// this weight would need to be put on 16bit
// mark outlier as stored on 16bit
if (x < mean - thresh || x > mean + thresh) {
fp16s[j / 64] |= (uint64_t) 1 << (j % 64);
fp16_indicators[j / 64] |= (uint64_t) 1 << (j % 64);
fp16_count += 1;
}
}
@ -16616,13 +16657,14 @@ size_t ggml_quantize_qx_0(const float * src, void * dst, int n, int64_t * hist,
uint16_t total_test_bits = fp16_count * 16 + (QKX_0 - fp16_count) * qbits;
while ((total_test_bits % 8) != 0) {
total_test_bits += 16 - test_qbit; // simulate the replacement of a 3bit weight with a 16bit one
total_test_bits += 16 - test_qbit; // simulate the replacement of a qbit weight with a 16bit one
}
GGML_ASSERT(total_bits == total_test_bits);
}
}
// keep converting the largest qbit values to fp16 until the block is byte-aligned
while (((QKX_0 - fp16_count) * qbits) % 8 != 0) {
float maxi = 0;
int target = -1;
@ -16631,7 +16673,7 @@ size_t ggml_quantize_qx_0(const float * src, void * dst, int n, int64_t * hist,
float x = src[i * QKX_0 + j];
// weight is not on 16bit
if ((fp16s[j / 64] & ((uint64_t) 1 << (j % 64))) == 0) {
if ((fp16_indicators[j / 64] & ((uint64_t) 1 << (j % 64))) == 0) {
float diff = fabsf(x);
if (diff > maxi || target == -1) {
maxi = diff;
@ -16641,38 +16683,46 @@ size_t ggml_quantize_qx_0(const float * src, void * dst, int n, int64_t * hist,
}
GGML_ASSERT(target != -1);
fp16s[target / 64] |= (uint64_t) 1 << (target % 64);
fp16_indicators[target / 64] |= (uint64_t) 1 << (target % 64);
fp16_count += 1;
}
// store the current byte-offset of the current row, if "i" indicates that this is the first
// block of a row
if (((i * QKX_0) % tensor_width == 0) && i != 0) {
uint32_t row = (i * QKX_0) / tensor_width;
extra_data[row - 1] = dst_offset;
}
uint64_t * fp16_data = (uint64_t *) (dst_8 + dst_offset);
// write the fp16 indicators to dst
uint64_t * stored_fp16_indicators = (uint64_t *) (dst_8 + dst_offset);
// write the data
for (int j = 0; j < QKX_0 / 64; j++) {
fp16_data[j] = fp16s[j];
stored_fp16_indicators[j] = fp16_indicators[j];
}
dst_offset += (QKX_0 / 64) * sizeof(uint64_t);
// write min value and multiplier (min_value + mult * quant_number, result should be divided by (1 << QBits) during multplication)
// Each weight is stored as min_value + mult * quantized_weight
// Similar to Zero-point quantization, or Q4_1
// Write min value and multiplier to dst
*((uint16_t*) (dst_8 + dst_offset)) = ggml_fp32_to_fp16(min_value);
dst_offset += sizeof(uint16_t);
*((uint16_t*) (dst_8 + dst_offset)) = ggml_fp32_to_fp16(mult_range);
dst_offset += sizeof(uint16_t);
// Store the "metadata" byte (for now it's just "qbits")
*((uint8_t*) (dst_8 + dst_offset)) = qbits;
dst_offset += sizeof(uint8_t);
// Store the quantization pivots / points
float qvals[1 << qbits];
for (int i = 0; i < (1 << qbits); i++) {
qvals[i] = min_value + (mult_range * i) / ((1 << qbits) - 1);
qvals[i] = min_value + (mult_range * i);
}
uint64_t bit_offset = 0;
@ -16683,9 +16733,10 @@ size_t ggml_quantize_qx_0(const float * src, void * dst, int n, int64_t * hist,
for (int j = 0; j < QKX_0; j++) {
float x = src[i * QKX_0 + j];
if (fp16s[j / 64] & ((uint64_t) 1 << (j % 64))) {
if (fp16_indicators[j / 64] & ((uint64_t) 1 << (j % 64))) {
ggml_fp16_t x_f16 = ggml_fp32_to_fp16(x);
// store the full fp16 weight
write_bits(data, bit_offset, x_f16, 16);
bit_offset += 16;
fp16_count_chk += 1;
@ -16693,6 +16744,7 @@ size_t ggml_quantize_qx_0(const float * src, void * dst, int n, int64_t * hist,
uint8_t q = 0;
float min_dist = fabsf(x - qvals[0]);
// find closest quantization point
for (int iv = 0; iv < (1 << qbits); iv++) {
float dist = fabsf(x - qvals[iv]);
if (dist < min_dist) {
@ -16706,13 +16758,17 @@ size_t ggml_quantize_qx_0(const float * src, void * dst, int n, int64_t * hist,
}
}
// check that the reported fp16_count is coherent with the bits stored in fp16_indicators
GGML_ASSERT(fp16_count == fp16_count_chk);
// check that the number of bits from quantized values is divisible by 8
GGML_ASSERT((((QKX_0 - fp16_count) * qbits) % 8) == 0);
dst_offset += ((QKX_0 - fp16_count) * qbits) / 8;
dst_offset += fp16_count * 2;
}
// store the total size of the tensor as the last element of extra_data
extra_data[n / tensor_width - 1] = dst_offset;
return dst_offset;

View File

@ -883,12 +883,6 @@ struct llama_model_loader {
if (lt.shards.at(0).extra_data_file_off != 0) {
lt.extra_data = (uint64_t *) ((uint8_t *) mapping->addr + lt.shards.at(0).extra_data_file_off);
}
printf("load data for %s\n", lt.name.c_str());
if (lt.extra_data != NULL) {
printf("extra_data_file_off: %zu, data: %p, extra_data: %p\n", lt.shards.at(0).extra_data_file_off, lt.data, lt.extra_data);
printf("extra_data for %s: %lu %lu ... %lu\n", lt.name.c_str(), lt.extra_data[0], lt.extra_data[1], lt.extra_data[lt.ne[1] - 1]);
}
} else if (lt.split_type == SPLIT_NONE) {
llama_file & file = file_loaders.at(lt.shards.at(0).file_idx)->file;