llama : fix integer overflow during quantization (#4284)

happens with multi-threaded quantization of Qwen-72B

ggml-ci
This commit is contained in:
Georgi Gerganov 2023-12-01 18:42:11 +02:00 committed by GitHub
parent 8d6d9f033b
commit 880f57973b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -7655,18 +7655,21 @@ static void llama_convert_tensor_internal(
return; return;
} }
auto block_size = tensor->type == GGML_TYPE_F16 ? 1 : (size_t)ggml_blck_size(tensor->type); size_t block_size = tensor->type == GGML_TYPE_F16 ? 1 : (size_t)ggml_blck_size(tensor->type);
auto block_size_bytes = ggml_type_size(tensor->type); size_t block_size_bytes = ggml_type_size(tensor->type);
GGML_ASSERT(nelements % block_size == 0); GGML_ASSERT(nelements % block_size == 0);
auto nblocks = nelements / block_size; size_t nblocks = nelements / block_size;
auto blocks_per_thread = nblocks / nthread; size_t blocks_per_thread = nblocks / nthread;
auto spare_blocks = nblocks - (blocks_per_thread * nthread); // if blocks aren't divisible by thread count size_t spare_blocks = nblocks - (blocks_per_thread * nthread); // if blocks aren't divisible by thread count
for (auto tnum = 0, in_buff_offs = 0, out_buff_offs = 0; tnum < nthread; tnum++) { size_t in_buff_offs = 0;
auto thr_blocks = blocks_per_thread + (tnum == nthread - 1 ? spare_blocks : 0); // num blocks for this thread size_t out_buff_offs = 0;
auto thr_elems = thr_blocks * block_size; // number of elements for this thread
auto thr_block_bytes = thr_blocks * block_size_bytes; // number of input bytes for this thread for (int tnum = 0; tnum < nthread; tnum++) {
size_t thr_blocks = blocks_per_thread + (tnum == nthread - 1 ? spare_blocks : 0); // num blocks for this thread
size_t thr_elems = thr_blocks * block_size; // number of elements for this thread
size_t thr_block_bytes = thr_blocks * block_size_bytes; // number of input bytes for this thread
auto compute = [qtype] (ggml_type typ, uint8_t * inbuf, float * outbuf, int nels) { auto compute = [qtype] (ggml_type typ, uint8_t * inbuf, float * outbuf, int nels) {
if (typ == GGML_TYPE_F16) { if (typ == GGML_TYPE_F16) {