diff --git a/CMakeLists.txt b/CMakeLists.txt index 7bd640966..3fc65eaf2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -846,7 +846,7 @@ install(FILES ${CMAKE_CURRENT_BINARY_DIR}/LlamaConfig.cmake ${CMAKE_CURRENT_BINARY_DIR}/LlamaConfigVersion.cmake DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/Llama) -set(GGML_PUBLIC_HEADERS "ggml.h" +set(GGML_PUBLIC_HEADERS "ggml.h" "ggml-alloc.h" "ggml-backend.h" "${GGML_HEADERS_CUDA}" "${GGML_HEADERS_OPENCL}" "${GGML_HEADERS_METAL}" "${GGML_HEADERS_MPI}" "${GGML_HEADERS_EXTRA}") diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 1178d63a2..5cb3e63fb 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -189,6 +189,8 @@ class Model: return StableLMModel if model_architecture == "QWenLMHeadModel": return QwenModel + if model_architecture == "Qwen2ForCausalLM": + return Model if model_architecture == "MixtralForCausalLM": return MixtralModel if model_architecture == "GPT2LMHeadModel": @@ -197,6 +199,8 @@ class Model: return Phi2Model if model_architecture == "PlamoForCausalLM": return PlamoModel + if model_architecture == "CodeShellForCausalLM": + return CodeShellModel return Model def _is_model_safetensors(self) -> bool: @@ -234,6 +238,8 @@ class Model: return gguf.MODEL_ARCH.STABLELM if arch == "QWenLMHeadModel": return gguf.MODEL_ARCH.QWEN + if arch == "Qwen2ForCausalLM": + return gguf.MODEL_ARCH.QWEN2 if arch == "MixtralForCausalLM": return gguf.MODEL_ARCH.LLAMA if arch == "GPT2LMHeadModel": @@ -242,6 +248,8 @@ class Model: return gguf.MODEL_ARCH.PHI2 if arch == "PlamoForCausalLM": return gguf.MODEL_ARCH.PLAMO + if arch == "CodeShellForCausalLM": + return gguf.MODEL_ARCH.CODESHELL raise NotImplementedError(f'Architecture "{arch}" not supported!') @@ -1176,6 +1184,70 @@ class PlamoModel(Model): self.gguf_writer.add_tensor(new_name, data) +class CodeShellModel(Model): + def set_gguf_parameters(self): + block_count = self.hparams["n_layer"] + + self.gguf_writer.add_name("CodeShell") + self.gguf_writer.add_context_length(self.hparams["n_positions"]) + self.gguf_writer.add_embedding_length(self.hparams["n_embd"]) + self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"]) + self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_head_count(self.hparams["n_head"]) + self.gguf_writer.add_head_count_kv(self.hparams["num_query_groups"]) + self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"]) + self.gguf_writer.add_file_type(self.ftype) + self.gguf_writer.add_rope_freq_base(10000.0) + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) + self.gguf_writer.add_rope_scaling_factor(1.0) + + def write_tensors(self): + block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer"))) + tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count) + tensors = dict(self.get_tensors()) + has_lm_head = "lm_head.weight" in tensors.keys() or "output.weight" in tensors.keys() + for name, data_torch in tensors.items(): + # we don't need these + if name.endswith((".attn.rotary_emb.inv_freq")): + continue + + old_dtype = data_torch.dtype + + # convert any unsupported data types to float32 + if data_torch.dtype not in (torch.float16, torch.float32): + data_torch = data_torch.to(torch.float32) + + data = data_torch.squeeze().numpy() + + # map tensor names + new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias")) + if new_name is None: + print(f"Can not map tensor {name!r}") + sys.exit() + + n_dims = len(data.shape) + data_dtype = data.dtype + + # if f32 desired, convert any float16 to float32 + if self.ftype == 0 and data_dtype == np.float16: + data = data.astype(np.float32) + + # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32 + if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1: + data = data.astype(np.float32) + + # if f16 desired, convert any float32 2-dim weight tensors to float16 + if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: + data = data.astype(np.float16) + + print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") + + self.gguf_writer.add_tensor(new_name, data) + + if not has_lm_head and name == "transformer.wte.weight": + self.gguf_writer.add_tensor("output.weight", data) + print(name, f"=> output.weight, shape = {data.shape}, {old_dtype} --> {data.dtype}") + ###### CONVERSION LOGIC ###### diff --git a/convert.py b/convert.py index e38ee5315..980e6fc72 100755 --- a/convert.py +++ b/convert.py @@ -348,7 +348,7 @@ class Params: f_rope_freq_base = 1e6 return Params( - n_vocab=config.get("vocab_size", model["tok_embeddings.weight"].shape[0]), + n_vocab=model["tok_embeddings.weight"].shape[0], n_embd=config["dim"], n_layer=config["n_layers"], n_ctx=n_ctx, diff --git a/examples/imatrix/README.md b/examples/imatrix/README.md new file mode 100644 index 000000000..578e8fc27 --- /dev/null +++ b/examples/imatrix/README.md @@ -0,0 +1,32 @@ +# llama.cpp/examples/imatrix + +Compute an importance matrix for a model and given text dataset. Can be used during quantization to enchance the quality of the quantum models. +More information is available here: https://github.com/ggerganov/llama.cpp/pull/4861 + +## Usage + +``` +./imatrix -m -f [-o ] [--verbosity ] + [-ofreq num_chunks] [-ow <0 or 1>] [other common params] +``` + +Here `-m` with a model name and `-f` with a file containing training data (such as e.g. `wiki.train.raw`) are mandatory. +The parameters in square brackets are optional and have the following meaning: +* `-o` (or `--output-file`) specifies the name of the file where the computed data will be stored. If missing `imatrix.dat` is used. +* `--verbosity` specifies the verbosity level. If set to `0`, no output other than the perplexity of the processed chunks will be generated. If set to `1`, each time the results are saved a message is written to `stderr`. If `>=2`, a message is output each time data is collected for any tensor. Default verbosity level is `1`. +* `-ofreq` (or `--output-frequency`) specifies how often the so far computed result is saved to disk. Default is 10 (i.e., every 10 chunks) +* `-ow` (or `--output-weight`) specifies if data will be collected for the `output.weight` tensor. My experience is that it is better to not utilize the importance matrix when quantizing `output.weight`, so this is set to `false` by default. + +For faster computation, make sure to use GPU offloading via the `-ngl` argument + +## Example + +```bash +LLAMA_CUBLAS=1 make -j + +# generate importance matrix (imatrix.dat) +./imatrix -m ggml-model-f16.gguf -f train-data.txt -ngl 99 + +# use the imatrix to perform a Q4_K_M quantization +./quantize --imatrix imatrix.dat ggml-model-f16.gguf ./ggml-model-q4_k_m.gguf q4_k_m +``` diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index af78711c5..5a3d30b88 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -80,7 +80,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * // for simplicity, always copy src0 to host, because it is small // take into account that src0 is not contiguous! GGML_ASSERT(src0->ne[1] == src1->ne[1]); - GGML_ASSERT(n_as*ggml_nrows(src0)); + GGML_ASSERT(n_as*ggml_nrows(src0)*sizeof(int) == GGML_PAD(ggml_nbytes(src0), n_as*sizeof(int))); m_ids.resize(ggml_nbytes(src0)/sizeof(int)); ggml_backend_tensor_get(src0, m_ids.data(), 0, ggml_nbytes(src0)); diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index ea2c8026c..b07320190 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -324,6 +325,13 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par double nll = 0.0; double nll2 = 0.0; + const int num_batches = (n_ctx + n_batch - 1) / n_batch; + + std::vector logits; + if (num_batches > 1) { + logits.reserve((size_t)n_ctx * n_vocab); + } + fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch); std::vector workers(std::thread::hardware_concurrency() - 1); @@ -332,10 +340,6 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par const int start = i * n_ctx; const int end = start + n_ctx; - const int num_batches = (n_ctx + n_batch - 1) / n_batch; - - std::vector logits; - const auto t_start = std::chrono::high_resolution_clock::now(); // clear the KV cache @@ -361,8 +365,10 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par // restore the original token in case it was set to BOS tokens[batch_start] = token_org; - const auto * batch_logits = llama_get_logits(ctx); - logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab); + if (num_batches > 1) { + const auto * batch_logits = llama_get_logits(ctx); + logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab); + } } const auto t_end = std::chrono::high_resolution_clock::now(); @@ -391,7 +397,8 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par // last 256 tokens. Then, we split the input up into context window size chunks to // process the entire prompt. const int first = n_ctx/2; - process_logits(n_vocab, logits.data() + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first, + const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx); + process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first, workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first); count += n_ctx - first - 1; @@ -405,6 +412,8 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par printf("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2); } fflush(stdout); + + logits.clear(); } printf("\n"); @@ -422,26 +431,73 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par return {tokens, ppl, logit_history, prob_history}; } -static std::vector evaluate_tokens(llama_context * ctx, std::vector & tokens, - int n_past, int n_batch, int n_vocab) { - std::vector result; - result.reserve(tokens.size() * n_vocab); - size_t n_chunk = (tokens.size() + n_batch - 1)/n_batch; - for (size_t i_chunk = 0; i_chunk < n_chunk; ++i_chunk) { - size_t n_tokens = tokens.size() - i_chunk * n_batch; - n_tokens = std::min(n_tokens, size_t(n_batch)); - llama_kv_cache_seq_rm(ctx, 0, n_past, -1); - if (llama_decode(ctx, llama_batch_get_one(tokens.data() + i_chunk * n_batch, n_tokens, n_past, 0))) { - fprintf(stderr, "%s : failed to eval\n", __func__); - return {}; +static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector & batch_logits, int32_t n_batch, int32_t n_vocab) { + for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { + const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); + + llama_batch batch_view = { + n_tokens, + batch.token + i, + nullptr, + batch.pos + i, + batch.n_seq_id + i, + batch.seq_id + i, + batch.logits + i, + 0, 0, 0, // unused + }; + + const int ret = llama_decode(ctx, batch_view); + if (ret != 0) { + LOG_TEE("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret); + return false; } - const auto logits = llama_get_logits(ctx); - result.insert(result.end(), logits, logits + n_tokens * n_vocab); - - n_past += n_tokens; + memcpy(batch_logits.data() + i*n_vocab, llama_get_logits(ctx), n_tokens*n_vocab*sizeof(float)); } - return result; + + return true; +} + +static void compute_logprobs(const float * batch_logits, int n_vocab, std::vector& workers, + const std::vector>& eval_pairs, std::vector& eval_results) { + constexpr int k_token_chunk = 4; + if (eval_results.size() != eval_pairs.size()) { + eval_results.resize(eval_pairs.size()); + } + if (eval_pairs.empty()) return; + + size_t max_threads = std::min((eval_pairs.size() + k_token_chunk - 1)/k_token_chunk, workers.size()); + + std::atomic counter(0); + auto compute = [&counter, &eval_pairs, &eval_results, batch_logits, n_vocab] () { + float local_logprobs[k_token_chunk]; + while (true) { + size_t first = counter.fetch_add(k_token_chunk, std::memory_order_relaxed); + if (first >= eval_results.size()) break; + size_t last = std::min(first + k_token_chunk, eval_results.size()); + for (size_t i = first; i < last; ++i) { + auto logits = batch_logits + eval_pairs[i].first * n_vocab; + float max_logit = logits[0]; + for (int j = 1; j < n_vocab; ++j) { + max_logit = std::max(max_logit, logits[j]); + } + float sum_p = 0.f; + for (int j = 0; j < n_vocab; ++j) { + sum_p += expf(logits[j] - max_logit); + } + local_logprobs[i - first] = logits[eval_pairs[i].second] - max_logit - std::log(sum_p); + } + std::memcpy(eval_results.data() + first, local_logprobs, (last - first)*sizeof(float)); + } + }; + + for (size_t it = 0; it < max_threads; ++it) { + workers[it] = std::thread(compute); + } + for (size_t it = 0; it < max_threads; ++it) { + workers[it].join(); + } + } static void hellaswag_score(llama_context * ctx, const gpt_params & params) { @@ -533,7 +589,6 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { // determine the common prefix of the endings hs_cur.common_prefix = 0; - hs_cur.required_tokens = 0; for (size_t k = 0; k < hs_cur.seq_tokens[0].size(); k++) { if (hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[1][k] || hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[2][k] || @@ -566,40 +621,17 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { const int n_ctx = llama_n_ctx(ctx); const int n_batch = params.n_batch; - const int max_tasks_per_batch = params.n_parallel; + const int max_tasks_per_batch = 32; const int max_seq = 4*max_tasks_per_batch; llama_batch batch = llama_batch_init(n_ctx, 0, max_seq); std::vector tok_logits(n_vocab); - std::vector batch_logits(n_ctx*n_vocab); + std::vector batch_logits(n_vocab*n_ctx); - auto decode_helper = [&](llama_context * ctx, llama_batch & batch, int32_t n_batch) { - for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { - const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); - - llama_batch batch_view = { - n_tokens, - batch.token + i, - nullptr, - batch.pos + i, - batch.n_seq_id + i, - batch.seq_id + i, - batch.logits + i, - 0, 0, 0, // unused - }; - - const int ret = llama_decode(ctx, batch_view); - if (ret != 0) { - LOG_TEE("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret); - return false; - } - - memcpy(batch_logits.data() + i*n_vocab, llama_get_logits(ctx), n_tokens*n_vocab*sizeof(float)); - } - - return true; - }; + std::vector> eval_pairs; + std::vector eval_results; + std::vector workers(std::thread::hardware_concurrency()); for (size_t i0 = 0; i0 < hs_task_count; i0++) { int n_cur = 0; @@ -649,11 +681,29 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { llama_kv_cache_clear(ctx); // decode all tasks [i0, i1) - if (!decode_helper(ctx, batch, n_batch)) { + if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) { fprintf(stderr, "%s: llama_decode() failed\n", __func__); return; } + // Compute log-probs in parallel + // First we collect all tasks + eval_pairs.clear(); + for (size_t i = i0; i < i1; ++i) { + auto & hs_cur = hs_data[i]; + size_t li = hs_cur.common_prefix; + for (int s = 0; s < 4; ++s) { + for (size_t j = hs_cur.common_prefix; j < hs_cur.seq_tokens[s].size() - 1; j++) { + eval_pairs.push_back(std::make_pair(hs_cur.i_batch + li++, hs_cur.seq_tokens[s][j + 1])); + } + ++li; + } + } + // Then we do the actual calculation + compute_logprobs(batch_logits.data(), n_vocab, workers, eval_pairs, eval_results); + + size_t ir = 0; + // compute the logprobs for each ending of the decoded tasks for (size_t i = i0; i < i1; ++i) { auto & hs_cur = hs_data[i]; @@ -662,26 +712,13 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { const auto first_probs = softmax(tok_logits); - size_t li = hs_cur.common_prefix; // logits index in the batch - for (int s = 0; s < 4; ++s) { hs_cur.ending_logprob_count[s] = 1; hs_cur.ending_logprob[s] = std::log(first_probs[hs_cur.seq_tokens[s][hs_cur.common_prefix]]); - - // Calculate the logprobs over the ending for (size_t j = hs_cur.common_prefix; j < hs_cur.seq_tokens[s].size() - 1; j++) { - std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(hs_cur.i_batch + li++), n_vocab*sizeof(float)); - - const float prob = softmax(tok_logits)[hs_cur.seq_tokens[s][j + 1]]; - - hs_cur.ending_logprob[s] += std::log(prob); + hs_cur.ending_logprob[s] += eval_results[ir++]; hs_cur.ending_logprob_count[s]++; } - - // account that we skip the last token in the ending - ++li; - - // Calculate the mean token logprob for acc_norm hs_cur.ending_logprob[s] /= hs_cur.ending_logprob_count[s]; } @@ -720,6 +757,13 @@ struct winogrande_entry { std::string second; std::array choices; int answer; + + size_t i_batch; + size_t common_prefix; + size_t required_tokens; + size_t n_base1; // number of tokens for context + choice 1 + size_t n_base2; // number of tokens for context + choice 2 + std::vector seq_tokens[2]; }; static std::vector load_winogrande_from_csv(const std::string& prompt) { @@ -813,7 +857,7 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) { } float scale = 1/(1.f + (float)rng.max()); std::vector selected; - selected.reserve(params.winogrande_tasks); + selected.resize(params.winogrande_tasks); for (int i = 0; i < int(params.winogrande_tasks); ++i) { int j = int(scale*rng()*aux.size()); selected[i] = std::move(data[aux[j]]); @@ -823,115 +867,159 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) { data = std::move(selected); } + fprintf(stderr, "%s : tokenizing selected tasks\n", __func__); + // This is needed as usual for LLaMA models const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx)); + for (auto & task : data) { + task.seq_tokens[0] = ::llama_tokenize(ctx, task.first + task.choices[0] + task.second, add_bos); + task.seq_tokens[1] = ::llama_tokenize(ctx, task.first + task.choices[1] + task.second, add_bos); + + task.common_prefix = 0; + for (size_t k = 0; k < task.seq_tokens[0].size(); k++) { + if (task.seq_tokens[0][k] != task.seq_tokens[1][k]) { + break; + } + task.common_prefix++; + } + + task.required_tokens = task.common_prefix + + task.seq_tokens[0].size() - task.common_prefix + + task.seq_tokens[1].size() - task.common_prefix; + + task.n_base1 = ::llama_tokenize(ctx, task.first + task.choices[0], add_bos).size(); + task.n_base2 = ::llama_tokenize(ctx, task.first + task.choices[1], add_bos).size(); + } + fprintf(stderr, "%s : calculating winogrande score over selected tasks.\n", __func__); const int n_vocab = llama_n_vocab(llama_get_model(ctx)); - const int n_ctx = llama_n_ctx(ctx); + const int n_ctx = llama_n_ctx(ctx); + const int n_batch = params.n_batch; + + const int max_tasks_per_batch = 128; + const int max_seq = 2*max_tasks_per_batch; + + llama_batch batch = llama_batch_init(n_ctx, 0, max_seq); std::vector tok_logits(n_vocab); + std::vector batch_logits(n_vocab*n_ctx); + + std::vector> eval_pairs; + std::vector eval_results; + std::vector workers(std::thread::hardware_concurrency()); int n_correct = 0; int n_done = 0; - for (size_t task_idx = 0; task_idx < data.size(); task_idx++) { - const auto& task = data[task_idx]; + for (size_t i0 = 0; i0 < data.size(); i0++) { + int n_cur = 0; - auto base_context = ::llama_tokenize(ctx, task.first, add_bos); - auto base_ctx_1st = ::llama_tokenize(ctx, task.first + task.choices[0], add_bos); - auto base_ctx_2nd = ::llama_tokenize(ctx, task.first + task.choices[1], add_bos); + size_t i1 = i0; + size_t i_batch = 0; - auto sentence_1st = task.first + task.choices[0] + task.second; - auto sentence_2nd = task.first + task.choices[1] + task.second; - auto query_1st = ::llama_tokenize(ctx, sentence_1st, add_bos); - auto query_2nd = ::llama_tokenize(ctx, sentence_2nd, add_bos); + llama_batch_clear(batch); - if (query_1st.size() > (size_t)n_ctx || query_2nd.size() > (size_t)n_ctx) { - fprintf(stderr, "%s : number of tokens in queries %zu, %zu > n_ctxl\n", __func__, query_1st.size(), query_2nd.size()); + while (n_cur + (int) data[i1].required_tokens <= n_ctx) { + const int s0 = 2*(i1 - i0); + if (s0 + 2 > max_seq) { + break; + } + + for (size_t i = 0; i < data[i1].common_prefix; ++i) { + llama_batch_add(batch, data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1}, false); + } + batch.logits[batch.n_tokens - 1] = true; + + for (int s = 0; s < 2; ++s) { + for (size_t i = data[i1].common_prefix; i < data[i1].seq_tokens[s].size(); ++i) { + llama_batch_add(batch, data[i1].seq_tokens[s][i], i, { s0 + s }, true); + } + } + + data[i1].i_batch = i_batch; + i_batch += data[i1].required_tokens; + + n_cur += data[i1].required_tokens; + if (++i1 == data.size()) { + break; + } + } + + if (i0 == i1) { + fprintf(stderr, "%s : task %zu does not fit in the context window\n", __func__, i0); return; } - auto query_1st_size = query_1st.size(); - auto query_2nd_size = query_2nd.size(); - - // Speedup small evaluations by evaluating atleast 32 tokens - // For Winogrande this seems to slow it down rather than speed it up. - //if (query_1st.size() < 32) query_1st.resize(32); - //if (query_2nd.size() < 32) query_2nd.resize(32); - llama_kv_cache_clear(ctx); - auto logits_1st = evaluate_tokens(ctx, query_1st, 0, params.n_batch, n_vocab); - llama_kv_cache_clear(ctx); - auto logits_2nd = evaluate_tokens(ctx, query_2nd, 0, params.n_batch, n_vocab); - - if (logits_1st.empty() || logits_2nd.empty()) { - fprintf(stderr, "%s : failed to eval\n", __func__); + // decode all tasks [i0, i1) + if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) { + fprintf(stderr, "%s: llama_decode() failed\n", __func__); return; } - bool skip_choice = query_1st_size - base_ctx_1st.size() > k_min_trailing_ctx && - query_2nd_size - base_ctx_2nd.size() > k_min_trailing_ctx; + eval_pairs.clear(); + for (size_t i = i0; i < i1; ++i) { + auto & task = data[i]; - float score_1st = 0; - bool is_nan_1st = false; - const auto& base_1 = skip_choice ? base_ctx_1st : base_context; - const int last_1st = query_1st_size - base_1.size() > 1 ? 1 : 0; - for (size_t j = base_1.size()-1; j < query_1st_size-1-last_1st; ++j) { - std::memcpy(tok_logits.data(), logits_1st.data() + j*n_vocab, n_vocab*sizeof(float)); - const float prob = softmax(tok_logits)[query_1st[j+1]]; - if (std::isnan(prob) || !prob) { - fprintf(stderr, "%s: %g probability for token %zu when evaluating <%s>. Base context has %zu tokens\n", __func__, - prob, j, sentence_1st.c_str(), base_context.size()); - is_nan_1st = true; - break; + const bool skip_choice = + task.seq_tokens[0].size() - task.common_prefix > k_min_trailing_ctx && + task.seq_tokens[1].size() - task.common_prefix > k_min_trailing_ctx; + + const auto& n_base1 = skip_choice ? task.n_base1 : task.common_prefix; + const int last_1st = task.seq_tokens[0].size() - n_base1 > 1 ? 1 : 0; + size_t li = n_base1 - 1; + for (size_t j = n_base1-1; j < task.seq_tokens[0].size()-1-last_1st; ++j) { + eval_pairs.push_back(std::make_pair(task.i_batch + li++, task.seq_tokens[0][j+1])); } - score_1st += std::log(prob); - } - score_1st /= (query_1st_size - base_1.size() - last_1st); - - float score_2nd = 0; - bool is_nan_2nd = false; - const auto& base_2 = skip_choice ? base_ctx_2nd : base_context; - const int last_2nd = query_2nd_size - base_2.size() > 1 ? 1 : 0; - for (size_t j = base_2.size()-1; j < query_2nd_size-1-last_2nd; ++j) { - std::memcpy(tok_logits.data(), logits_2nd.data() + j*n_vocab, n_vocab*sizeof(float)); - const float prob = softmax(tok_logits)[query_2nd[j+1]]; - if (std::isnan(prob) || !prob) { - fprintf(stderr, "%s: %g probability for token %zu when evaluating <%s>. Base context has %zu tokens\n", __func__, - prob, j, sentence_2nd.c_str(), base_context.size()); - is_nan_2nd = true; - break; + const auto& n_base2 = skip_choice ? task.n_base2 : task.common_prefix; + const int last_2nd = task.seq_tokens[1].size() - n_base2 > 1 ? 1 : 0; + li = task.seq_tokens[0].size() - task.common_prefix + n_base2 - 1; + for (size_t j = n_base2-1; j < task.seq_tokens[1].size()-1-last_2nd; ++j) { + eval_pairs.push_back(std::make_pair(task.i_batch + li++, task.seq_tokens[1][j+1])); } - score_2nd += std::log(prob); } - score_2nd /= (query_2nd_size - base_2.size() - last_2nd); + compute_logprobs(batch_logits.data(), n_vocab, workers, eval_pairs, eval_results); - if (is_nan_1st || is_nan_2nd) { - continue; + size_t ir = 0; + for (size_t i = i0; i < i1; ++i) { + auto & task = data[i]; + + const bool skip_choice = + task.seq_tokens[0].size() - task.common_prefix > k_min_trailing_ctx && + task.seq_tokens[1].size() - task.common_prefix > k_min_trailing_ctx; + + float score_1st = 0; + const auto& n_base1 = skip_choice ? task.n_base1 : task.common_prefix; + const int last_1st = task.seq_tokens[0].size() - n_base1 > 1 ? 1 : 0; + for (size_t j = n_base1-1; j < task.seq_tokens[0].size()-1-last_1st; ++j) { + score_1st += eval_results[ir++]; + } + score_1st /= (task.seq_tokens[0].size() - n_base1 - last_1st); + + float score_2nd = 0; + const auto& n_base2 = skip_choice ? task.n_base2 : task.common_prefix; + const int last_2nd = task.seq_tokens[1].size() - n_base2 > 1 ? 1 : 0; + for (size_t j = n_base2-1; j < task.seq_tokens[1].size()-1-last_2nd; ++j) { + score_2nd += eval_results[ir++]; + } + score_2nd /= (task.seq_tokens[1].size() - n_base2 - last_2nd); + + int result = score_1st > score_2nd ? 1 : 2; + + if (result == task.answer) { + ++n_correct; + } + ++n_done; + + // print the accumulated accuracy mean x 100 + printf("%zu\t%.4lf\t%10.6f %10.6f %d %d\n", i+1, 100.0 * n_correct/n_done, score_1st, score_2nd, result, task.answer); + fflush(stdout); } - if (std::isnan(score_1st) || std::isnan(score_2nd)) { - printf("================== NaN score %g, %g) for:\n", score_1st, score_2nd); - printf("Q1: <%s> - %zu tokens\n", sentence_1st.c_str(), query_1st_size); - printf("Q2: <%s> - %zu tokens\n", sentence_2nd.c_str(), query_2nd_size); - printf("B : <%s> - %zu tokens\n", task.first.c_str(), base_context.size()); - printf("base_1 has %zu tokens, base_2 has %zu tokens, skip_choice = %d\n", base_1.size(), base_2.size(), skip_choice); - continue; - } - - int result = score_1st > score_2nd ? 1 : 2; - - if (result == task.answer) { - ++n_correct; - } - ++n_done; - - // Print the accumulated accuracy mean x 100 - printf("%zu\t%.4lf\t%10.6f %10.6f %d %d\n",task_idx+1, 100.0 * n_correct/n_done,score_1st,score_2nd,result,task.answer); - fflush(stdout); + i0 = i1 - 1; } printf("\n"); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 93f999298..0462fbd24 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1558,6 +1558,7 @@ struct llama_server_context void process_tasks() { std::unique_lock lock(mutex_tasks); + std::vector deferred_tasks; while (!queue_tasks.empty()) { task_server task = queue_tasks.front(); @@ -1568,9 +1569,8 @@ struct llama_server_context llama_client_slot *slot = get_slot(json_value(task.data, "slot_id", -1)); if (slot == nullptr) { - LOG_TEE("slot unavailable\n"); - // send error result - send_error(task, "slot unavailable"); + // if no slot is available, we defer this task for processing later + deferred_tasks.push_back(task); break; } @@ -1616,6 +1616,12 @@ struct llama_server_context } } + // add all the deferred tasks back the the queue + for (task_server &task : deferred_tasks) + { + queue_tasks.push_back(task); + } + // remove finished multitasks from the queue of multitasks, and add the corresponding result to the result queue std::vector agg_results; auto queue_iterator = queue_multitasks.begin(); diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 4a9a2340b..eee9d4de3 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -263,7 +263,6 @@ static void init_model(struct my_llama_model * model) { model->data.resize(size + tensor_alignment); alloc = ggml_allocr_new(model->data.data(), model->data.size(), tensor_alignment); alloc_model(alloc, model); - ggml_allocr_free(alloc); } static void randomize_model(struct my_llama_model * model, int seed, float mean, float std, float min, float max) { @@ -1102,7 +1101,6 @@ int main(int argc, char ** argv) { alloc = ggml_allocr_new(mem_input_data.data(), mem_input_data.size(), tensor_alignment); ggml_allocr_alloc(alloc, tokens_input); ggml_allocr_alloc(alloc, target_probs); - ggml_allocr_free(alloc); // context for compute tensors without their data const size_t estimated_compute_size_wo_data = ( @@ -1149,7 +1147,6 @@ int main(int argc, char ** argv) { best_compute_size = max_compute_size; best_order = gf->order; } - ggml_allocr_free(alloc); ggml_free(ctx_compute); } size_t max_compute_size = best_compute_size; @@ -1177,7 +1174,6 @@ int main(int argc, char ** argv) { params.common.use_flash, params.common.use_checkpointing ); - ggml_allocr_free(alloc); std::vector train_tokens; std::vector train_samples_begin; diff --git a/ggml-cuda.cu b/ggml-cuda.cu index b2211d858..ec3837fb8 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -12,9 +12,6 @@ #include #include #include -#include "ggml-cuda.h" -#include "ggml.h" -#include "ggml-backend-impl.h" #if defined(GGML_USE_HIPBLAS) #include @@ -118,6 +115,11 @@ #endif // defined(GGML_USE_HIPBLAS) +// ggml-cuda need half type so keep ggml headers include at last +#include "ggml-cuda.h" +#include "ggml.h" +#include "ggml-backend-impl.h" + #define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed) #define CC_PASCAL 600 diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 972b4e9a7..2d9c33c7d 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -97,8 +97,10 @@ class MODEL_ARCH(IntEnum): BLOOM = auto() STABLELM = auto() QWEN = auto() + QWEN2 = auto() PHI2 = auto() PLAMO = auto() + CODESHELL = auto() class MODEL_TENSOR(IntEnum): @@ -145,8 +147,10 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.BLOOM: "bloom", MODEL_ARCH.STABLELM: "stablelm", MODEL_ARCH.QWEN: "qwen", + MODEL_ARCH.QWEN2: "qwen2", MODEL_ARCH.PHI2: "phi2", MODEL_ARCH.PLAMO: "plamo", + MODEL_ARCH.CODESHELL: "codeshell", } TENSOR_NAMES: dict[MODEL_TENSOR, str] = { @@ -356,6 +360,20 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.QWEN2: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], MODEL_ARCH.PLAMO: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, @@ -396,6 +414,19 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_NORM, MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.CODESHELL: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.POS_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, ] # TODO } @@ -417,6 +448,10 @@ MODEL_TENSOR_SKIP: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_ROT_EMBD, ], + MODEL_ARCH.CODESHELL: [ + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_ROT_EMBD, + ], } # diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index e5b146106..de177af13 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -154,6 +154,7 @@ class TensorNameMap: "model.layers.{bid}.self_attn.rotary_emb.inv_freq", # llama-hf "layers.{bid}.attention.inner_attention.rope.freqs", # llama-pth "model.layers.layers.{bid}.self_attn.rotary_emb.inv_freq", # plamo + "transformer.h.{bid}.attn.rotary_emb.inv_freq", # codeshell ), # Feed-forward norm diff --git a/llama.cpp b/llama.cpp index d4bebe520..f0a63afef 100644 --- a/llama.cpp +++ b/llama.cpp @@ -192,8 +192,10 @@ enum llm_arch { LLM_ARCH_BLOOM, LLM_ARCH_STABLELM, LLM_ARCH_QWEN, + LLM_ARCH_QWEN2, LLM_ARCH_PHI2, LLM_ARCH_PLAMO, + LLM_ARCH_CODESHELL, LLM_ARCH_UNKNOWN, }; @@ -211,8 +213,10 @@ static std::map LLM_ARCH_NAMES = { { LLM_ARCH_BLOOM, "bloom" }, { LLM_ARCH_STABLELM, "stablelm" }, { LLM_ARCH_QWEN, "qwen" }, + { LLM_ARCH_QWEN2, "qwen2" }, { LLM_ARCH_PHI2, "phi2" }, { LLM_ARCH_PLAMO, "plamo" }, + { LLM_ARCH_CODESHELL, "codeshell" }, }; enum llm_kv { @@ -566,6 +570,23 @@ static std::map> LLM_TENSOR_NAMES = { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_QWEN2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_PHI2, { @@ -600,6 +621,26 @@ static std::map> LLM_TENSOR_NAMES = { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_CODESHELL, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_UNKNOWN, @@ -1599,7 +1640,7 @@ struct llama_model { std::unique_ptr mapping; // objects representing data potentially being locked in memory - llama_mlock mlock_buf; + std::vector> mlock_bufs; llama_mlock mlock_mmap; // for quantize-stats only @@ -2847,6 +2888,17 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_QWEN2: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 24: model.type = e_model::MODEL_1B; break; + case 32: model.type = e_model::MODEL_7B; break; + case 40: model.type = e_model::MODEL_13B; break; + case 80: model.type = e_model::MODEL_70B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; case LLM_ARCH_PHI2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -2877,6 +2929,14 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_CODESHELL: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + switch (hparams.n_layer) { + case 42: model.type = e_model::MODEL_SMALL; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; default: (void)0; } @@ -3438,7 +3498,12 @@ static bool llm_load_tensors( { model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + if (gguf_find_tensor(ml.ctx_gguf, tn(LLM_TENSOR_OUTPUT, "weight").c_str()) >= 0) { + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + } else { + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // needs to be on GPU + ml.n_created--; // artificial tensor + } } for (int i = 0; i < n_layer; ++i) { @@ -3669,6 +3734,41 @@ static bool llm_load_tensors( layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff/2}); } } break; + case LLM_ARCH_QWEN2: + { + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // output + { + model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + } + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + + layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + + // optional bias tensors + layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}); + layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}); + layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}); + + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + + layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + } + } break; case LLM_ARCH_PHI2: { model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); @@ -3779,6 +3879,42 @@ static bool llm_load_tensors( layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}); } } break; + case LLM_ARCH_CODESHELL: + { + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // output + { + model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + } + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); + + layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); + layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}); + + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); + + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}); + + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); + layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}); + + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}); + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -3815,8 +3951,10 @@ static bool llm_load_tensors( else { buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); if (buf != nullptr && use_mlock && ggml_backend_buffer_is_host(buf)) { - model.mlock_buf.init (ggml_backend_buffer_get_base(buf)); - model.mlock_buf.grow_to(ggml_backend_buffer_get_size(buf)); + model.mlock_bufs.emplace_back(new llama_mlock); + auto & mlock_buf = model.mlock_bufs.back(); + mlock_buf->init (ggml_backend_buffer_get_base(buf)); + mlock_buf->grow_to(ggml_backend_buffer_get_size(buf)); } } if (buf == nullptr) { @@ -5638,6 +5776,128 @@ struct llm_build_context { return gf; } + + struct ggml_cgraph * build_qwen2() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb); + cb(inpL, "inp_embd", -1); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + cb(inp_pos, "inp_pos", -1); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); + cb(KQ_mask, "KQ_mask", -1); + + // shift the entire K-cache if needed + if (do_rope_shift) { + llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb); + } + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + // norm + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + + struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + + struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + + // these nodes are added to the graph together so that they are not reordered + // by doing so, the number of splits in the graph is reduced + ggml_build_forward_expand(gf, Qcur); + ggml_build_forward_expand(gf, Kcur); + ggml_build_forward_expand(gf, Vcur); + + Qcur = ggml_rope_custom( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + Kcur = ggml_rope_custom( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + + llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il); + + cur = llm_build_kqv(ctx0, model, hparams, kv_self, + model.layers[il].wo, model.layers[il].bo, + Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il); + cb(cur, "kqv_out", il); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, NULL, + model.layers[il].ffn_gate, NULL, + model.layers[il].ffn_down, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + struct ggml_cgraph * build_phi2() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); @@ -5971,6 +6231,117 @@ struct llm_build_context { return gf; } + + struct ggml_cgraph * build_codeshell() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb); + cb(inpL, "inp_embd", -1); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + cb(inp_pos, "inp_pos", -1); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); + cb(KQ_mask, "KQ_mask", -1); + + // shift the entire K-cache if needed + if (do_rope_shift) { + llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, freq_base, freq_scale, cb); + } + + for (int il = 0; il < n_layer; ++il) { + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, + model.layers[il].attn_norm_b, + LLM_NORM, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + + cur = ggml_add(ctx0, cur, model.layers[il].bqkv); + cb(cur, "bqkv", il); + + struct ggml_tensor * tmpq = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + struct ggml_tensor * tmpk = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + + cb(tmpq, "tmpq", il); + cb(tmpk, "tmpk", il); + cb(Vcur, "Vcur", il); + + struct ggml_tensor * Qcur = ggml_rope_custom( + ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens), inp_pos, + hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + struct ggml_tensor * Kcur = ggml_rope_custom( + ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), inp_pos, + hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + + llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il); + + cur = llm_build_kqv(ctx0, model, hparams, kv_self, + model.layers[il].wo, model.layers[il].bo, + Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il); + cb(cur, "kqv_out", il); + } + + // add the input + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); + cb(ffn_inp, "ffn_inp", il); + + // FF + { + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, + model.layers[il].ffn_norm_b, + LLM_NORM, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, + NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, + NULL, + LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); + cb(cur, "ffn_out", il); + } + + inpL = ggml_add(ctx0, cur, ffn_inp); + cb(inpL, "l_out", il); + } + + cur = llm_build_norm(ctx0, inpL, hparams, + model.output_norm, + model.output_norm_b, + LLM_NORM, cb, -1); + cb(cur, "result_norm", -1); + + cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } }; static struct ggml_cgraph * llama_build_graph( @@ -6153,6 +6524,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_qwen(); } break; + case LLM_ARCH_QWEN2: + { + result = llm.build_qwen2(); + } break; case LLM_ARCH_PHI2: { result = llm.build_phi2(); @@ -6165,6 +6540,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_gpt2(); } break; + case LLM_ARCH_CODESHELL: + { + result = llm.build_codeshell(); + } break; default: GGML_ASSERT(false); } diff --git a/scripts/get-hellaswag.sh b/scripts/get-hellaswag.sh index ef8dcceb0..121979fe2 100755 --- a/scripts/get-hellaswag.sh +++ b/scripts/get-hellaswag.sh @@ -4,7 +4,7 @@ wget https://raw.githubusercontent.com/klosax/hellaswag_text_data/main/hellaswag echo "Usage:" echo "" -echo " ./perplexity --hellaswag --hellaswag-tasks N -f hellaswag_val_full.txt -m modelfile.gguf" +echo " ./perplexity -m model.gguf -f hellaswag_val_full.txt --hellaswag [--hellaswag-tasks N] [other params]" echo "" exit 0 diff --git a/scripts/get-wikitext-2.sh b/scripts/get-wikitext-2.sh index 98aec3e3e..ff96f331e 100755 --- a/scripts/get-wikitext-2.sh +++ b/scripts/get-wikitext-2.sh @@ -1,3 +1,10 @@ #!/bin/bash wget https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip + +echo "Usage:" +echo "" +echo " ./perplexity -m model.gguf -f wiki.test.raw [other params]" +echo "" + +exit 0 diff --git a/scripts/get-winogrande.sh b/scripts/get-winogrande.sh new file mode 100755 index 000000000..5f234468e --- /dev/null +++ b/scripts/get-winogrande.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +wget https://huggingface.co/datasets/ikawrakow/winogrande-eval-for-llama.cpp/raw/main/winogrande-debiased-eval.csv + +echo "Usage:" +echo "" +echo " ./perplexity -m model.gguf -f winogrande-debiased-eval.csv --winogrande [--winogrande-tasks N] [other params]" +echo "" + +exit 0