Removed trailing whitespaces, removed variable-length arrays, removed debug print

This commit is contained in:
Amy 2023-06-13 10:39:04 +01:00
parent 124b4172ef
commit 1e06f12714
2 changed files with 44 additions and 40 deletions

56
ggml.c
View File

@ -3203,7 +3203,7 @@ static void ggml_vec_dot_qx_0_q8_0(const int n, float * restrict s, const void *
// row_data is a buffer which stores dequantized float values for a current block // row_data is a buffer which stores dequantized float values for a current block
float f32_row_data[QKX_0]; float f32_row_data[QKX_0];
// __AVX2__ doesn't seem to actually make much of a difference, // __AVX2__ doesn't seem to actually make much of a difference,
// a lot of optimizing could possibly be done, including possibly using AVX2 // a lot of optimizing could possibly be done, including possibly using AVX2
// for dequantization...? // for dequantization...?
@ -3280,8 +3280,8 @@ static void ggml_vec_dot_qx_0_q8_0(const int n, float * restrict s, const void *
const uint8_t data_block_size = 64; const uint8_t data_block_size = 64;
// we can take a full 64bit block // we can take a full 64bit block
const uint8_t weights_per_u64_data_block = data_block_size / qbits; const uint8_t weights_per_u64_data_block = data_block_size / qbits;
const uint8_t num_of_data_blocks_needed = 64 / weights_per_u64_data_block; // because we have 64 qbit-sized weights here const uint8_t num_of_data_blocks_needed = 64 / weights_per_u64_data_block; // because we have 64 qbit-sized weights here
for (int i = 0; i < num_of_data_blocks_needed; i++) { for (int i = 0; i < num_of_data_blocks_needed; i++) {
for (int k = 0; k < weights_per_u64_data_block; k ++) { for (int k = 0; k < weights_per_u64_data_block; k ++) {
row_ptr[i * weights_per_u64_data_block + k] = qvals[(((const uint64_t *) data_start)[0] >> (k * qbits)) & ((1 << qbits) - 1)]; row_ptr[i * weights_per_u64_data_block + k] = qvals[(((const uint64_t *) data_start)[0] >> (k * qbits)) & ((1 << qbits) - 1)];
@ -3340,14 +3340,14 @@ static void ggml_vec_dot_qx_0_q8_0(const int n, float * restrict s, const void *
__m128i test = _mm_loadu_si128((const __m128i *) (column[column_i].qs + i * 8)); __m128i test = _mm_loadu_si128((const __m128i *) (column[column_i].qs + i * 8));
__m256i work = _mm256_cvtepi8_epi32(test); __m256i work = _mm256_cvtepi8_epi32(test);
__m256 workf = _mm256_cvtepi32_ps(work); __m256 workf = _mm256_cvtepi32_ps(work);
// multiply with our 8 parts of the row at row_data // multiply with our 8 parts of the row at row_data
__m256 row = _mm256_loadu_ps(row_ptr + jb * QK8_0 + i * 8); __m256 row = _mm256_loadu_ps(row_ptr + jb * QK8_0 + i * 8);
workf = _mm256_mul_ps(workf, row); workf = _mm256_mul_ps(workf, row);
rolling_sum = _mm256_fmadd_ps(workf, column_multiplier, rolling_sum); rolling_sum = _mm256_fmadd_ps(workf, column_multiplier, rolling_sum);
} }
#else #else
// scalar // scalar
float sub_sum = 0; float sub_sum = 0;
@ -3370,11 +3370,11 @@ static void ggml_vec_dot_qx_0_q8_0(const int n, float * restrict s, const void *
GGML_ASSERT(offset % 8 == 0); GGML_ASSERT(offset % 8 == 0);
quant_row += offset / 8; quant_row += offset / 8;
} }
#if defined(__AVX2__) #if defined(__AVX2__)
float rolling_sum_vec[8]; float rolling_sum_vec[8];
_mm256_store_ps(rolling_sum_vec, rolling_sum); _mm256_store_ps(rolling_sum_vec, rolling_sum);
for (int i = 0; i < 8; i++) { for (int i = 0; i < 8; i++) {
*s += rolling_sum_vec[i]; *s += rolling_sum_vec[i];
} }
@ -4530,7 +4530,7 @@ struct ggml_tensor * ggml_new_tensor_impl(
} }
result->nb[0] = GGML_TYPE_SIZE[type]; result->nb[0] = GGML_TYPE_SIZE[type];
if (type == GGML_TYPE_QX_0) { if (type == GGML_TYPE_QX_0) {
// QX_0 doesn't have a set stride size for a row; that value is stored in the "extra" part of the tensor // QX_0 doesn't have a set stride size for a row; that value is stored in the "extra" part of the tensor
result->nb[1] = 0; result->nb[1] = 0;
@ -10464,7 +10464,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
const int i3 = i03; const int i3 = i03;
void * src0_row; void * src0_row;
if (type == GGML_TYPE_QX_0) { if (type == GGML_TYPE_QX_0) {
if (ir > 0) { if (ir > 0) {
src0_row = (void *) ((char *) src0->data + ((uint64_t *) src0->extra)[ir - 1]); src0_row = (void *) ((char *) src0->data + ((uint64_t *) src0->extra)[ir - 1]);
@ -10478,7 +10478,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
char * src1_col = ((char *) wdata + ( (0 + i12*ne11 + i13*ne12*ne11)*row_size)); char * src1_col = ((char *) wdata + ( (0 + i12*ne11 + i13*ne12*ne11)*row_size));
float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3)); float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
for (int64_t ic = 0; ic < ne11; ++ic) { for (int64_t ic = 0; ic < ne11; ++ic) {
vec_dot_q(ne00, &dst_col[ic*ne0], src0_row, (void *) (src1_col + ic*row_size)); vec_dot_q(ne00, &dst_col[ic*ne0], src0_row, (void *) (src1_col + ic*row_size));
} }
@ -16528,7 +16528,7 @@ size_t ggml_quantize_qx_0(const float * src, void * dst, int n, int64_t * hist,
assert(n % QKX_0 == 0); assert(n % QKX_0 == 0);
assert(tensor_width % QKX_0 == 0); assert(tensor_width % QKX_0 == 0);
const int nb = n / QKX_0; const int nb = n / QKX_0;
uint8_t * dst_8 = dst; uint8_t * dst_8 = dst;
uint64_t dst_offset = 0; uint64_t dst_offset = 0;
@ -16544,7 +16544,7 @@ size_t ggml_quantize_qx_0(const float * src, void * dst, int n, int64_t * hist,
// this can be replaced with a max allowed RMSE, a set percentage of weights being within // this can be replaced with a max allowed RMSE, a set percentage of weights being within
// a certain range, etc... The current implementation here is pretty much just an example // a certain range, etc... The current implementation here is pretty much just an example
float max_quantization_errors[5] = {0, 0.004, 0.004, 0, 0.004}; float max_quantization_errors[5] = {0, 0.004, 0.004, 0, 0.004};
// How maximum quantization error is implemented here: // How maximum quantization error is implemented here:
// //
@ -16599,14 +16599,14 @@ size_t ggml_quantize_qx_0(const float * src, void * dst, int n, int64_t * hist,
} }
uint16_t total_bits = fp16_count * 16 + (QKX_0 - fp16_count) * qbits; uint16_t total_bits = fp16_count * 16 + (QKX_0 - fp16_count) * qbits;
while ((total_bits % 8) != 0) { while ((total_bits % 8) != 0) {
total_bits += 16 - qbits; // simulate the replacement of a quantized weight with a 16bit one (needed for a block's byte alignment) total_bits += 16 - qbits; // simulate the replacement of a quantized weight with a 16bit one (needed for a block's byte alignment)
} }
float min_value = -(max_quantization_errors[qbits] * ((1 << qbits) - 1)); float min_value = -(max_quantization_errors[qbits] * ((1 << qbits) - 1));
float mult_range = 2 * max_quantization_errors[qbits]; float mult_range = 2 * max_quantization_errors[qbits];
// The quantizer starts at a QX_0_STARTING_QBITS quantized block (e.g. 4bits), but then // The quantizer starts at a QX_0_STARTING_QBITS quantized block (e.g. 4bits), but then
// attempts to move to a lower precision defined by QX_0_START_OF_ATTEMPTED_QBITS. // attempts to move to a lower precision defined by QX_0_START_OF_ATTEMPTED_QBITS.
// It keeps looking to see if 3, 2 or 1 bit precision leads to a smaller file size. // It keeps looking to see if 3, 2 or 1 bit precision leads to a smaller file size.
@ -16629,7 +16629,7 @@ size_t ggml_quantize_qx_0(const float * src, void * dst, int n, int64_t * hist,
} }
} }
mean /= (QKX_0 - fp16_count); mean /= (QKX_0 - fp16_count);
uint16_t total_fp16s_in_test_qbit = 0; uint16_t total_fp16s_in_test_qbit = 0;
thresh = max_quantization_errors[test_qbit] * (1 << test_qbit); thresh = max_quantization_errors[test_qbit] * (1 << test_qbit);
@ -16645,12 +16645,12 @@ size_t ggml_quantize_qx_0(const float * src, void * dst, int n, int64_t * hist,
total_fp16s_in_test_qbit += 1; total_fp16s_in_test_qbit += 1;
} }
} }
uint16_t total_bits_in_test_qbit = total_fp16s_in_test_qbit * 16 + test_qbit * (QKX_0 - total_fp16s_in_test_qbit); uint16_t total_bits_in_test_qbit = total_fp16s_in_test_qbit * 16 + test_qbit * (QKX_0 - total_fp16s_in_test_qbit);
while ((total_bits_in_test_qbit % 8) != 0) { while ((total_bits_in_test_qbit % 8) != 0) {
total_bits_in_test_qbit += 16 - test_qbit; // simulate the replacement of a qbit weight with a 16bit one total_bits_in_test_qbit += 16 - test_qbit; // simulate the replacement of a qbit weight with a 16bit one
} }
if (total_bits_in_test_qbit < total_bits) { if (total_bits_in_test_qbit < total_bits) {
total_bits = total_bits_in_test_qbit; total_bits = total_bits_in_test_qbit;
qbits = test_qbit; qbits = test_qbit;
@ -16686,7 +16686,7 @@ size_t ggml_quantize_qx_0(const float * src, void * dst, int n, int64_t * hist,
for (int j = 0; j < QKX_0; j++) { for (int j = 0; j < QKX_0; j++) {
float x = src[i * QKX_0 + j]; float x = src[i * QKX_0 + j];
// weight is not on 16bit // weight is not on 16bit
if ((fp16_indicators[j / 64] & ((uint64_t) 1 << (j % 64))) == 0) { if ((fp16_indicators[j / 64] & ((uint64_t) 1 << (j % 64))) == 0) {
float diff = fabsf(x); float diff = fabsf(x);
@ -16717,25 +16717,27 @@ size_t ggml_quantize_qx_0(const float * src, void * dst, int n, int64_t * hist,
} }
dst_offset += (QKX_0 / 64) * sizeof(uint64_t); dst_offset += (QKX_0 / 64) * sizeof(uint64_t);
// Each weight is stored as min_value + mult * quantized_weight // Each weight is stored as min_value + mult * quantized_weight
// Similar to Zero-point quantization, or Q4_1 // Similar to Zero-point quantization, or Q4_1
// Write min value and multiplier to dst // Write min value and multiplier to dst
*((uint16_t*) (dst_8 + dst_offset)) = ggml_fp32_to_fp16(min_value); *((uint16_t*) (dst_8 + dst_offset)) = ggml_fp32_to_fp16(min_value);
dst_offset += sizeof(uint16_t); dst_offset += sizeof(uint16_t);
*((uint16_t*) (dst_8 + dst_offset)) = ggml_fp32_to_fp16(mult_range); *((uint16_t*) (dst_8 + dst_offset)) = ggml_fp32_to_fp16(mult_range);
dst_offset += sizeof(uint16_t); dst_offset += sizeof(uint16_t);
// Store the "metadata" byte (for now it's just "qbits") // Store the "metadata" byte (for now it's just "qbits")
*((uint8_t*) (dst_8 + dst_offset)) = qbits; *((uint8_t*) (dst_8 + dst_offset)) = qbits;
dst_offset += sizeof(uint8_t); dst_offset += sizeof(uint8_t);
// Store the quantization pivots / points // Store the quantization pivots / points
float qvals[1 << qbits]; // IMPORTANT: Change qvals's size depending on the maximum qbits expected
GGML_ASSERT(qbits <= 8);
float qvals[1 << 8];
for (int j = 0; j < (1 << qbits); j++) { for (int j = 0; j < (1 << qbits); j++) {
qvals[j] = min_value + (mult_range * j); qvals[j] = min_value + (mult_range * j);
} }
@ -16744,7 +16746,7 @@ size_t ggml_quantize_qx_0(const float * src, void * dst, int n, int64_t * hist,
uint32_t * data = (uint32_t*) (dst_8 + dst_offset); uint32_t * data = (uint32_t*) (dst_8 + dst_offset);
int fp16_count_chk = 0; int fp16_count_chk = 0;
for (int j = 0; j < QKX_0; j++) { for (int j = 0; j < QKX_0; j++) {
float x = src[i * QKX_0 + j]; float x = src[i * QKX_0 + j];
@ -16772,17 +16774,17 @@ size_t ggml_quantize_qx_0(const float * src, void * dst, int n, int64_t * hist,
bit_offset += qbits; bit_offset += qbits;
} }
} }
// check that the reported fp16_count is coherent with the bits stored in fp16_indicators // check that the reported fp16_count is coherent with the bits stored in fp16_indicators
GGML_ASSERT(fp16_count == fp16_count_chk); GGML_ASSERT(fp16_count == fp16_count_chk);
// check that the number of bits from quantized values is divisible by 8 // check that the number of bits from quantized values is divisible by 8
GGML_ASSERT((((QKX_0 - fp16_count) * qbits) % 8) == 0); GGML_ASSERT((((QKX_0 - fp16_count) * qbits) % 8) == 0);
dst_offset += ((QKX_0 - fp16_count) * qbits) / 8; dst_offset += ((QKX_0 - fp16_count) * qbits) / 8;
dst_offset += fp16_count * 2; dst_offset += fp16_count * 2;
} }
// store the total size of the tensor as the last element of extra_data // store the total size of the tensor as the last element of extra_data
extra_data[n / tensor_width - 1] = dst_offset; extra_data[n / tensor_width - 1] = dst_offset;

View File

@ -435,7 +435,7 @@ struct llama_load_tensor {
GGML_ASSERT(ne.size() == 2); GGML_ASSERT(ne.size() == 2);
size = shards.at(0).size; size = shards.at(0).size;
GGML_ASSERT(size != 0); GGML_ASSERT(size != 0);
} else { } else {
size = llama_calc_tensor_size(ne, type); size = llama_calc_tensor_size(ne, type);
@ -566,12 +566,14 @@ struct llama_file_loader {
if (shard.type == GGML_TYPE_QX_0) { if (shard.type == GGML_TYPE_QX_0) {
shard.extra_data_file_off = file.tell(); shard.extra_data_file_off = file.tell();
uint64_t extra_data[shard.ne[1]];
file.read_raw(extra_data, sizeof(uint64_t) * shard.ne[1]);
// set the size of the tensor here // seek until before the last element of extra_data
shard.size = extra_data[shard.ne[1] - 1]; file.seek(sizeof(uint64_t) * (shard.ne[1] - 1), SEEK_CUR);
// get the tensor's size from here
uint64_t tensor_size = 0;
file.read_raw(&tensor_size, sizeof(uint64_t));
shard.size = tensor_size;
// realign, just in case extra_data isn't a multiple of 32B // realign, just in case extra_data isn't a multiple of 32B
file.seek(-static_cast<ptrdiff_t>(file.tell()) & 31, SEEK_CUR); file.seek(-static_cast<ptrdiff_t>(file.tell()) & 31, SEEK_CUR);
@ -686,7 +688,7 @@ struct llama_file_saver {
LLAMA_ASSERT(new_size == tensor_size); LLAMA_ASSERT(new_size == tensor_size);
file.write_raw(new_data, new_size); file.write_raw(new_data, new_size);
// QX_0 data may not be 32-byte aligned // QX_0 data may not be 32-byte aligned
if (new_type == GGML_TYPE_QX_0) { if (new_type == GGML_TYPE_QX_0) {
file.seek(-static_cast<ptrdiff_t>(file.tell()) & 31, SEEK_CUR); file.seek(-static_cast<ptrdiff_t>(file.tell()) & 31, SEEK_CUR);
@ -836,7 +838,7 @@ struct llama_model_loader {
switch(lt.ggml_tensor->backend) { switch(lt.ggml_tensor->backend) {
case GGML_BACKEND_CPU: case GGML_BACKEND_CPU:
lt.ggml_tensor->data = lt.data; lt.ggml_tensor->data = lt.data;
if (lt.type == GGML_TYPE_QX_0) { if (lt.type == GGML_TYPE_QX_0) {
// QX_0 uses the extra field to store byte offsets in *data for each row except row 0 // QX_0 uses the extra field to store byte offsets in *data for each row except row 0
// (so extra[0] stores where row 1 starts, extra[1] is for row 2, and the last element // (so extra[0] stores where row 1 starts, extra[1] is for row 2, and the last element
@ -883,7 +885,7 @@ struct llama_model_loader {
if (lt.shards.at(0).extra_data_file_off != 0) { if (lt.shards.at(0).extra_data_file_off != 0) {
lt.extra_data = (uint64_t *) ((uint8_t *) mapping->addr + lt.shards.at(0).extra_data_file_off); lt.extra_data = (uint64_t *) ((uint8_t *) mapping->addr + lt.shards.at(0).extra_data_file_off);
} }
} else if (lt.split_type == SPLIT_NONE) { } else if (lt.split_type == SPLIT_NONE) {
llama_file & file = file_loaders.at(lt.shards.at(0).file_idx)->file; llama_file & file = file_loaders.at(lt.shards.at(0).file_idx)->file;
file.seek(lt.shards.at(0).file_off, SEEK_SET); file.seek(lt.shards.at(0).file_off, SEEK_SET);
@ -1746,8 +1748,8 @@ static bool llama_eval_internal(
lctx.n_p_eval += N; lctx.n_p_eval += N;
} }
fprintf(stderr, "\nmodel eval time: %ldms\n", (ggml_time_us() - t_start_us) / 1000); // fprintf(stderr, "\nmodel eval time: %ldms\n", (ggml_time_us() - t_start_us) / 1000);
fflush(stderr); // fflush(stderr);
return true; return true;
} }
@ -2399,7 +2401,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
if (nthread <= 0) { if (nthread <= 0) {
nthread = std::thread::hardware_concurrency(); nthread = std::thread::hardware_concurrency();
} }
// multithreaded QX_0 quantization is not compatible with the current multithreaded quantization impl. // multithreaded QX_0 quantization is not compatible with the current multithreaded quantization impl.
// because, since blocks have an unknown size in bytes, we cannot section the output data in exact // because, since blocks have an unknown size in bytes, we cannot section the output data in exact
// chunks assigned to 1 thread. Multithreading would technically only be possible if we quantize // chunks assigned to 1 thread. Multithreading would technically only be possible if we quantize
@ -2558,7 +2560,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
if (local_hist.empty()) { if (local_hist.empty()) {
local_hist.resize(hist_cur.size(), 0); local_hist.resize(hist_cur.size(), 0);
} }
// pass in NULL for extra_data, since it's only required for QX_0, which doesn't support quantized threading // pass in NULL for extra_data, since it's only required for QX_0, which doesn't support quantized threading
local_size += ggml_quantize_chunk(new_type, f32_data, new_data, first, last - first, local_hist.data(), NULL, 0); local_size += ggml_quantize_chunk(new_type, f32_data, new_data, first, last - first, local_hist.data(), NULL, 0);
} }