From 4bd0f93e4ab4fe6682e7d0241c1bdec1397e954a Mon Sep 17 00:00:00 2001 From: Pierrick Hymbert Date: Sat, 13 Apr 2024 11:33:52 +0200 Subject: [PATCH] model: support arch `DbrxForCausalLM` (#6515) * model: dbrx convert to gguf #6344 * llama: support dbrx #6344 * doc: dbrx: add the model as supported * scripts: get-wikitext-2 add unzip * llama: increase maximum experts allowed * llama: factorize moe graph implementation between grok, mixtral and dbrx --------- Co-authored-by: Megha Agarwal <16129366+megha95@users.noreply.github.com> --- README.md | 1 + convert-hf-to-gguf.py | 96 ++++++ examples/eval-callback/eval-callback.cpp | 26 +- gguf-py/gguf/constants.py | 15 + gguf-py/gguf/tensor_mapping.py | 58 ++-- llama.cpp | 377 ++++++++++++++++------- scripts/get-wikitext-2.sh | 3 +- 7 files changed, 428 insertions(+), 148 deletions(-) diff --git a/README.md b/README.md index 00a487fc6..fa2a4db83 100644 --- a/README.md +++ b/README.md @@ -94,6 +94,7 @@ Typically finetunes of the base models below are supported as well. - [x] LLaMA 2 🦙🦙 - [X] [Mistral 7B](https://huggingface.co/mistralai/Mistral-7B-v0.1) - [x] [Mixtral MoE](https://huggingface.co/models?search=mistral-ai/Mixtral) +- [x] [DBRX](https://huggingface.co/databricks/dbrx-instruct) - [X] Falcon - [X] [Chinese LLaMA / Alpaca](https://github.com/ymcui/Chinese-LLaMA-Alpaca) and [Chinese LLaMA-2 / Alpaca-2](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2) - [X] [Vigogne (French)](https://github.com/bofenghuang/vigogne) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 63710676b..e1ac09e02 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1427,6 +1427,102 @@ class GrokModel(Model): self.gguf_writer.add_tensor(new_name, data) +@Model.register("DbrxForCausalLM") +class DbrxModel(Model): + model_arch = gguf.MODEL_ARCH.DBRX + + def set_gguf_parameters(self): + ffn_config = self.hparams["ffn_config"] + attn_config = self.hparams["attn_config"] + self.gguf_writer.add_name(self.hparams["model_type"]) + self.gguf_writer.add_block_count(self.hparams["n_layers"]) + + self.gguf_writer.add_context_length(self.hparams["max_seq_len"]) + self.gguf_writer.add_embedding_length(self.hparams["d_model"]) + self.gguf_writer.add_feed_forward_length(ffn_config["ffn_hidden_size"]) + + self.gguf_writer.add_head_count(self.hparams["n_heads"]) + self.gguf_writer.add_head_count_kv(attn_config["kv_n_heads"]) + + self.gguf_writer.add_rope_freq_base(attn_config["rope_theta"]) + + self.gguf_writer.add_clamp_kqv(attn_config["clip_qkv"]) + self.gguf_writer.add_file_type(self.ftype) + + self.gguf_writer.add_expert_count(ffn_config["moe_num_experts"]) + self.gguf_writer.add_expert_used_count(ffn_config["moe_top_k"]) + + self.gguf_writer.add_layer_norm_eps(1e-5) + + self.gguf_writer.add_file_type(self.ftype) + print(f"gguf: file type = {self.ftype}") + + def write_tensors(self): + block_count = self.hparams.get("n_layers") + tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count) + for name, data_torch in self.get_tensors(): + n_expert = self.hparams["ffn_config"]["moe_num_experts"] + n_ff = self.hparams["ffn_config"]["ffn_hidden_size"] + n_embd = self.hparams["d_model"] + + # Specific behavior for experts tensors: suffix .weight, view as 3D and transpose + # original implementation expects (n_expert, n_ff, n_embd) for all experts weights + # But llama.cpp moe graph works differently + # AND the dimensions in ggml are typically in the reverse order of the pytorch dimensions + # so (n_expert, n_ff, n_embd) in pytorch is {n_embd, n_ff, n_expert} in ggml_tensor + exp_tensor_names = {"ffn.experts.mlp.w1": None, # LLM_TENSOR_FFN_GATE_EXPS ggml_tensor->ne{n_embd, n_ff, n_expert} + "ffn.experts.mlp.w2": (0, 2, 1), # LLM_TENSOR_FFN_DOWN_EXPS ggml_tensor->ne{n_ff, n_embd, n_expert} + "ffn.experts.mlp.v1": None} # LLM_TENSOR_FFN_UP_EXPS ggml_tensor->ne{n_embd, n_ff, n_expert} + experts = False + for exp_tensor_name in exp_tensor_names.keys(): + if name.find(exp_tensor_name) != -1 and name.find(".weight") == -1: + experts = True + data_torch = data_torch.view(n_expert, n_ff, n_embd) + if (permute_tensor := exp_tensor_names[exp_tensor_name]) is not None: + data_torch = data_torch.permute(*permute_tensor) + break + + old_dtype = data_torch.dtype + + # convert any unsupported data types to float32 + if data_torch.dtype not in (torch.float16, torch.float32): + data_torch = data_torch.to(torch.float32) + + data = data_torch.squeeze().numpy() + + # map tensor names + # In MoE models the ffn tensors are typically most of the model weights, + # and need to be quantizable. Quantize expects tensor names to be suffixed by .weight. + # Every other model has the weight names ending in .weight, + # let's assume that is the convention which is not the case for dbrx: + # https://huggingface.co/databricks/dbrx-instruct/blob/main/model.safetensors.index.json#L15 + new_name = tensor_map.get_name(name if not experts else name + ".weight", try_suffixes=(".weight",)) + if new_name is None: + print(f"Can not map tensor {name!r}") + sys.exit() + + n_dims = len(data.shape) + data_dtype = data.dtype + + # Most of the codebase that takes in 1D tensors only handles F32 tensors + # and most of the outputs tensors are F32. + if data_dtype != np.float32 and n_dims == 1: + print(f"Can not map tensor {name!r}: all 1D tensors must be F32") + sys.exit() + + # if f32 desired, convert any float16 to float32 + if self.ftype == 0 and data_dtype == np.float16: + data = data.astype(np.float32) + + # if f16 desired, convert any float32 2-dim weight tensors to float16 + if self.ftype == 1 and data_dtype == np.float32 and n_dims > 1: + data = data.astype(np.float16) + + print(f"{new_name}, n_dims = {n_dims}, shape = {data.shape}, {old_dtype} --> {data.dtype}") + + self.gguf_writer.add_tensor(new_name, data) + + @Model.register("MiniCPMForCausalLM") class MiniCPMModel(Model): model_arch = gguf.MODEL_ARCH.MINICPM diff --git a/examples/eval-callback/eval-callback.cpp b/examples/eval-callback/eval-callback.cpp index 05f7d6ab1..29b5f3b3c 100644 --- a/examples/eval-callback/eval-callback.cpp +++ b/examples/eval-callback/eval-callback.cpp @@ -28,14 +28,27 @@ static std::string ggml_ne_string(const ggml_tensor * t) { } static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne, const size_t * nb, int64_t n) { + GGML_ASSERT(n > 0); float sum = 0; for (int64_t i3 = 0; i3 < ne[3]; i3++) { printf(" [\n"); - for (int64_t i2 = 0; i2 < ne[2] && i2 < n; i2++) { + for (int64_t i2 = 0; i2 < ne[2]; i2++) { + if (i2 == n && ne[2] > 2*n) { + printf(" ..., \n"); + i2 = ne[2] - n; + } printf(" [\n"); - for (int64_t i1 = 0; i1 < ne[1] && i1 < n; i1++) { + for (int64_t i1 = 0; i1 < ne[1]; i1++) { + if (i1 == n && ne[1] > 2*n) { + printf(" ..., \n"); + i1 = ne[1] - n; + } printf(" ["); - for (int64_t i0 = 0; i0 < ne[0] && i0 < n; i0++) { + for (int64_t i0 = 0; i0 < ne[0]; i0++) { + if (i0 == n && ne[0] > 2*n) { + printf("..., "); + i0 = ne[0] - n; + } size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0]; float v; if (type == GGML_TYPE_F16) { @@ -51,17 +64,14 @@ static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne } else { GGML_ASSERT(false); } - printf("%8.4f", v); + printf("%12.4f", v); sum += v; - if (i0 < ne[0] - 1 && i0 < n - 1) printf(", "); + if (i0 < ne[0] - 1) printf(", "); } - if (ne[0] > n) printf(", ..."); printf("],\n"); } - if (ne[1] > n) printf(" ...\n"); printf(" ],\n"); } - if (ne[2] > n) printf(" ...\n"); printf(" ]\n"); printf(" sum = %f\n", sum); } diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index a6454a10e..2566b2fb8 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -126,6 +126,7 @@ class MODEL_ARCH(IntEnum): MAMBA = auto() XVERSE = auto() COMMAND_R = auto() + DBRX = auto() class MODEL_TENSOR(IntEnum): @@ -195,6 +196,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.MAMBA: "mamba", MODEL_ARCH.XVERSE: "xverse", MODEL_ARCH.COMMAND_R: "command-r", + MODEL_ARCH.DBRX: "dbrx", } TENSOR_NAMES: dict[MODEL_TENSOR, str] = { @@ -642,6 +644,19 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.ATTN_K_NORM, MODEL_TENSOR.ATTN_Q_NORM, ], + MODEL_ARCH.DBRX: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_OUT_NORM, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + ], # TODO } diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 4f02d298e..ec6fcbb83 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -10,7 +10,7 @@ class TensorNameMap: # Token embeddings MODEL_TENSOR.TOKEN_EMBD: ( "gpt_neox.embed_in", # gptneox - "transformer.wte", # gpt2 gpt-j mpt refact qwen + "transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx "transformer.word_embeddings", # falcon "word_embeddings", # bloom "model.embed_tokens", # llama-hf @@ -48,7 +48,7 @@ class TensorNameMap: # Output MODEL_TENSOR.OUTPUT: ( "embed_out", # gptneox - "lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba + "lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx "output", # llama-pth bloom internlm2 "word_embeddings_for_head", # persimmon "lm_head.linear", # phi2 @@ -60,7 +60,7 @@ class TensorNameMap: "transformer.ln_f", # gpt2 gpt-j falcon "model.norm", # llama-hf baichuan internlm2 "norm", # llama-pth - "transformer.norm_f", # mpt + "transformer.norm_f", # mpt dbrx "ln_f", # refact bloom qwen gpt2 "language_model.encoder.final_layernorm", # persimmon "model.final_layernorm", # persimmon @@ -96,6 +96,7 @@ class TensorNameMap: "model.layers.{bid}.norm", # mamba-qbert "backbone.layers.{bid}.norm", # mamba "transformer.decoder_layer.{bid}.rms_norm", # Grok + "transformer.blocks.{bid}.norm_attn_norm.norm_1", # dbrx ), # Attention norm 2 @@ -108,6 +109,7 @@ class TensorNameMap: "gpt_neox.layers.{bid}.attention.query_key_value", # gptneox "transformer.h.{bid}.attn.c_attn", # gpt2 qwen "transformer.blocks.{bid}.attn.Wqkv", # mpt + "transformer.blocks.{bid}.norm_attn_norm.attn.Wqkv", # dbrx "transformer.h.{bid}.self_attention.query_key_value", # falcon "h.{bid}.self_attention.query_key_value", # bloom "language_model.encoder.layers.{bid}.self_attention.query_key_value", # persimmon @@ -152,23 +154,24 @@ class TensorNameMap: # Attention output MODEL_TENSOR.ATTN_OUT: ( - "gpt_neox.layers.{bid}.attention.dense", # gptneox - "transformer.h.{bid}.attn.c_proj", # gpt2 refact qwen - "transformer.blocks.{bid}.attn.out_proj", # mpt - "transformer.h.{bid}.self_attention.dense", # falcon - "h.{bid}.self_attention.dense", # bloom - "model.layers.{bid}.self_attn.o_proj", # llama-hf - "layers.{bid}.attention.wo", # llama-pth - "encoder.layer.{bid}.attention.output.dense", # bert - "transformer.h.{bid}.attn.out_proj", # gpt-j - "language_model.encoder.layers.{bid}.self_attention.dense", # persimmon - "model.layers.{bid}.self_attn.dense", # persimmon - "h.{bid}.attn.c_proj", # gpt2 - "transformer.h.{bid}.mixer.out_proj", # phi2 - "model.layers.layers.{bid}.self_attn.o_proj", # plamo - "model.layers.{bid}.attention.wo", # internlm2 - "encoder.layers.{bid}.attn.out_proj", # nomic-bert - "transformer.decoder_layer.{bid}.multi_head_attention.linear"# Grok + "gpt_neox.layers.{bid}.attention.dense", # gptneox + "transformer.h.{bid}.attn.c_proj", # gpt2 refact qwen + "transformer.blocks.{bid}.attn.out_proj", # mpt + "transformer.h.{bid}.self_attention.dense", # falcon + "h.{bid}.self_attention.dense", # bloom + "model.layers.{bid}.self_attn.o_proj", # llama-hf + "layers.{bid}.attention.wo", # llama-pth + "encoder.layer.{bid}.attention.output.dense", # bert + "transformer.h.{bid}.attn.out_proj", # gpt-j + "language_model.encoder.layers.{bid}.self_attention.dense", # persimmon + "model.layers.{bid}.self_attn.dense", # persimmon + "h.{bid}.attn.c_proj", # gpt2 + "transformer.h.{bid}.mixer.out_proj", # phi2 + "model.layers.layers.{bid}.self_attn.o_proj", # plamo + "model.layers.{bid}.attention.wo", # internlm2 + "encoder.layers.{bid}.attn.out_proj", # nomic-bert + "transformer.decoder_layer.{bid}.multi_head_attention.linear", # Grok + "transformer.blocks.{bid}.norm_attn_norm.attn.out_proj", # dbrx ), # Attention output norm @@ -176,6 +179,7 @@ class TensorNameMap: "encoder.layer.{bid}.attention.output.LayerNorm", # bert "encoder.layers.{bid}.norm1", # nomic-bert "transformer.decoder_layer.{bid}.rms_norm_1", # Grok + "transformer.blocks.{bid}.norm_attn_norm.norm_2", # dbrx ), # Rotary embeddings @@ -202,9 +206,10 @@ class TensorNameMap: ), MODEL_TENSOR.FFN_GATE_INP: ( - "layers.{bid}.feed_forward.gate", # mixtral - "model.layers.{bid}.block_sparse_moe.gate", # mixtral - "transformer.decoder_layer.{bid}.router" # Grok + "layers.{bid}.feed_forward.gate", # mixtral + "model.layers.{bid}.block_sparse_moe.gate", # mixtral + "transformer.decoder_layer.{bid}.router", # Grok + "transformer.blocks.{bid}.ffn.router.layer", # dbrx ), # Feed-forward up @@ -233,6 +238,7 @@ class TensorNameMap: MODEL_TENSOR.FFN_UP_EXP: ( "layers.{bid}.feed_forward.experts.w3", # mixtral (merged) "transformer.decoder_layer.{bid}.moe.linear_v", # Grok (merged) + "transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx ), # AWQ-activation gate @@ -251,8 +257,9 @@ class TensorNameMap: ), MODEL_TENSOR.FFN_GATE_EXP: ( - "layers.{bid}.feed_forward.experts.w1", # mixtral (merged) - "transformer.decoder_layer.{bid}.moe.linear" # Grok (merged) + "layers.{bid}.feed_forward.experts.w1", # mixtral (merged) + "transformer.decoder_layer.{bid}.moe.linear", # Grok (merged) + "transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx ), # Feed-forward down @@ -280,6 +287,7 @@ class TensorNameMap: MODEL_TENSOR.FFN_DOWN_EXP: ( "layers.{bid}.feed_forward.experts.w2", # mixtral (merged) "transformer.decoder_layer.{bid}.moe.linear_1", # Grok (merged) + "transformer.blocks.{bid}.ffn.experts.mlp.w2", # dbrx ), MODEL_TENSOR.ATTN_Q_NORM: ( diff --git a/llama.cpp b/llama.cpp index 83dd55efa..b93c1abcd 100644 --- a/llama.cpp +++ b/llama.cpp @@ -105,7 +105,7 @@ #endif #define LLAMA_MAX_NODES 8192 -#define LLAMA_MAX_EXPERTS 8 +#define LLAMA_MAX_EXPERTS 16 // @@ -220,6 +220,7 @@ enum llm_arch { LLM_ARCH_MAMBA, LLM_ARCH_XVERSE, LLM_ARCH_COMMAND_R, + LLM_ARCH_DBRX, LLM_ARCH_UNKNOWN, }; @@ -252,6 +253,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_MAMBA, "mamba" }, { LLM_ARCH_XVERSE, "xverse" }, { LLM_ARCH_COMMAND_R, "command-r" }, + { LLM_ARCH_DBRX, "dbrx" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -934,6 +936,22 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, }, }, + { + LLM_ARCH_DBRX, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, { LLM_ARCH_UNKNOWN, { @@ -1707,6 +1725,7 @@ enum e_model { MODEL_XL, MODEL_8x7B, MODEL_8x22B, + MODEL_16x12B, }; static const size_t kiB = 1024; @@ -3562,6 +3581,7 @@ static const char * llama_model_type_name(e_model type) { case MODEL_XL: return "1.5B"; case MODEL_8x7B: return "8x7B"; case MODEL_8x22B: return "8x22B"; + case MODEL_16x12B: return "16x12B"; default: return "?B"; } } @@ -3983,6 +4003,16 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_DBRX: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv); + + switch (hparams.n_layer) { + case 40: model.type = e_model::MODEL_16x12B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; default: (void)0; } @@ -4671,6 +4701,39 @@ static bool llm_load_tensors( layer.layer_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}); } } break; + case LLM_ARCH_DBRX: + { + if (n_expert == 0) { + throw std::runtime_error("DBRX model cannot have zero experts"); + } + + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // output + { + model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + } + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + + layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + + layer.attn_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}); + + layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}); + layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}); + layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}); + layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}); + } + } break; case LLM_ARCH_BAICHUAN: { model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); @@ -6433,62 +6496,7 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "ffn_norm", il); - ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts] - cb(logits, "ffn_moe_logits", il); - - ggml_tensor * probs = ggml_soft_max(ctx0, logits); // [n_tokens, num_experts] - cb(probs, "ffn_moe_probs", il); - - // select experts - ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_expert_used); // [n_tokens, num_experts_per_tok] - cb(selected_experts->src[0], "ffn_moe_argsort", il); - - ggml_tensor * weights = ggml_get_rows(ctx0, - ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); - cb(weights, "ffn_moe_weights", il); - - weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); // [n_tokens, num_experts_per_tok] - - ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); - cb(weights_sum, "ffn_moe_weights_sum", il); - - weights = ggml_div(ctx0, weights, weights_sum); // [n_tokens, num_experts_per_tok] - cb(weights, "ffn_moe_weights_norm", il); - - // compute expert outputs - ggml_tensor * moe_out = nullptr; - - for (int i = 0; i < n_expert_used; ++i) { - ggml_tensor * cur_expert; - - ggml_tensor * cur_up = ggml_mul_mat_id(ctx0, model.layers[il].ffn_up_exps, selected_experts, i, cur); - cb(cur_up, "ffn_moe_up", il); - - ggml_tensor * cur_gate = ggml_mul_mat_id(ctx0, model.layers[il].ffn_gate_exps, selected_experts, i, cur); - cb(cur_gate, "ffn_moe_gate", il); - - cur_gate = ggml_silu(ctx0, cur_gate); - cb(cur_gate, "ffn_moe_silu", il); - - cur_expert = ggml_mul(ctx0, cur_up, cur_gate); - cb(cur_expert, "ffn_moe_gate_par", il); - - cur_expert = ggml_mul_mat_id(ctx0, model.layers[il].ffn_down_exps, selected_experts, i, cur_expert); // [n_tokens, n_embd] - cb(cur_expert, "ffn_moe_down", il); - - cur_expert = ggml_mul(ctx0, cur_expert, - ggml_view_2d(ctx0, weights, 1, n_tokens, weights->nb[1], i*weights->nb[0])); - cb(cur_expert, "ffn_moe_weighted", il); - - if (i == 0) { - moe_out = cur_expert; - } else { - moe_out = ggml_add(ctx0, moe_out, cur_expert); - cb(moe_out, "ffn_moe_out", il); - } - } - - cur = moe_out; + cur = build_moe_ffn(cur, n_tokens, LLM_FFN_SILU, il); } cur = ggml_add(ctx0, cur, ffn_inp); @@ -6520,6 +6528,78 @@ struct llm_build_context { return gf; } + // REVIEW: will be replaced by https://github.com/ggerganov/llama.cpp/pull/6505 + ggml_tensor * build_moe_ffn(ggml_tensor * cur, int32_t n_tokens, llm_ffn_op_type type_op, int il) { + ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts] + cb(logits, "ffn_moe_logits", il); + + ggml_tensor * probs = ggml_soft_max(ctx0, logits); // [n_tokens, num_experts] + cb(probs, "ffn_moe_probs", il); + + // select experts + ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_expert_used); // [n_tokens, num_experts_per_tok] + cb(selected_experts->src[0], "ffn_moe_argsort", il); + + ggml_tensor * weights = ggml_get_rows(ctx0, + ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); + cb(weights, "ffn_moe_weights", il); + + weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); // [n_tokens, num_experts_per_tok] + + ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); + cb(weights_sum, "ffn_moe_weights_sum", il); + + weights = ggml_div(ctx0, weights, weights_sum); // [n_tokens, num_experts_per_tok] + cb(weights, "ffn_moe_weights_norm", il); + + // compute expert outputs + ggml_tensor * moe_out = nullptr; + + for (int i = 0; i < n_expert_used; ++i) { + ggml_tensor * cur_expert; + + ggml_tensor * cur_up = ggml_mul_mat_id(ctx0, model.layers[il].ffn_up_exps, selected_experts, i, cur); + cb(cur_up, "ffn_moe_up", il); + + ggml_tensor * gate = ggml_mul_mat_id(ctx0, model.layers[il].ffn_gate_exps, selected_experts, i, cur); + cb(gate, "ffn_moe_gate", il); + + switch (type_op) { + case LLM_FFN_SILU: + { + gate = ggml_silu(ctx0, gate); + cb(gate, "ffn_moe_silu", il); + } break; + case LLM_FFN_GELU: + { + gate = ggml_gelu(ctx0, gate); + cb(gate, "ffn_moe_gelu", il); + } break; + default: + GGML_ASSERT(false); + } + + cur_expert = ggml_mul(ctx0, cur_up, gate); + cb(cur_expert, "ffn_moe_gate_par", il); + + cur_expert = ggml_mul_mat_id(ctx0, model.layers[il].ffn_down_exps, selected_experts, i, cur_expert); // [n_tokens, n_embd] + cb(cur_expert, "ffn_moe_down", il); + + cur_expert = ggml_mul(ctx0, cur_expert, + ggml_view_2d(ctx0, weights, 1, n_tokens, weights->nb[1], i*weights->nb[0])); + cb(cur_expert, "ffn_moe_weighted", il); + + if (i == 0) { + moe_out = cur_expert; + } else { + moe_out = ggml_add(ctx0, moe_out, cur_expert); + cb(moe_out, "ffn_moe_out", il); + } + } + + return moe_out; + } + struct ggml_cgraph * build_baichuan() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); @@ -6967,63 +7047,7 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "ffn_norm", il); - ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts] - cb(logits, "ffn_moe_logits", il); - - ggml_tensor * probs = ggml_soft_max(ctx0, logits); // [n_tokens, num_experts] - cb(probs, "ffn_moe_probs", il); - - // select experts - ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_expert_used); // [n_tokens, num_experts_per_tok] - cb(selected_experts->src[0], "ffn_moe_argsort", il); - - ggml_tensor * weights = ggml_get_rows(ctx0, - ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); - cb(weights, "ffn_moe_weights", il); - - weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); // [n_tokens, num_experts_per_tok] - - ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); - cb(weights_sum, "ffn_moe_weights_sum", il); - - weights = ggml_div(ctx0, weights, weights_sum); // [n_tokens, num_experts_per_tok] - cb(weights, "ffn_moe_weights_norm", il); - - // compute expert outputs - ggml_tensor * moe_out = nullptr; - - for (int i = 0; i < n_expert_used; ++i) { - ggml_tensor * cur_expert; - - ggml_tensor * cur_up = ggml_mul_mat_id(ctx0, model.layers[il].ffn_up_exps, selected_experts, i, cur); - cb(cur_up, "ffn_moe_up", il); - - ggml_tensor * cur_gate = ggml_mul_mat_id(ctx0, model.layers[il].ffn_gate_exps, selected_experts, i, cur); - cb(cur_gate, "ffn_moe_gate", il); - - //GeLU - cur_gate = ggml_gelu(ctx0, cur_gate); - cb(cur_gate, "ffn_moe_gelu", il); - - cur_expert = ggml_mul(ctx0, cur_up, cur_gate); - cb(cur_expert, "ffn_moe_gate_par", il); - - cur_expert = ggml_mul_mat_id(ctx0, model.layers[il].ffn_down_exps, selected_experts, i, cur_expert); // [n_tokens, n_embd] - cb(cur_expert, "ffn_moe_down", il); - - cur_expert = ggml_mul(ctx0, cur_expert, - ggml_view_2d(ctx0, weights, 1, n_tokens, weights->nb[1], i*weights->nb[0])); - cb(cur_expert, "ffn_moe_weighted", il); - - if (i == 0) { - moe_out = cur_expert; - } else { - moe_out = ggml_add(ctx0, moe_out, cur_expert); - cb(moe_out, "ffn_moe_out", il); - } - } - - cur = moe_out; + cur = build_moe_ffn(cur, n_tokens, LLM_FFN_GELU, il); // Grok // if layer_out_norm is present then apply it before adding the input @@ -7071,6 +7095,126 @@ struct llm_build_context { return gf; } + struct ggml_cgraph * build_dbrx() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + + // mutable variable, needed during the last layer of the computation to skip unused tokens + int32_t n_tokens = this->n_tokens; + + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + // norm + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + struct ggml_tensor * Qcur = nullptr; + struct ggml_tensor * Kcur = nullptr; + struct ggml_tensor * Vcur = nullptr; + + cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + + cur = ggml_clamp(ctx0, cur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); + cb(cur, "wqkv_clamped", il); + + Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Qcur = ggml_rope_custom( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + Kcur = ggml_rope_custom( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + + cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + model.layers[il].wo, NULL, + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + n_tokens = n_outputs; + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + // MoE branch + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].attn_out_norm, NULL, + LLM_NORM, cb, il); + cb(cur, "attn_out_norm", il); + + cur = build_moe_ffn(cur, n_tokens, LLM_FFN_SILU, il); + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + ggml_tensor * layer_dir = lctx.cvec.tensor_for(il); + if (layer_dir != nullptr) { + cur = ggml_add(ctx0, cur, layer_dir); + } + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, NULL, + LLM_NORM, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = ggml_mul_mat(ctx0, model.output, cur); + + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + struct ggml_cgraph * build_starcoder() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); @@ -9785,6 +9929,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_command_r(); } break; + case LLM_ARCH_DBRX: + { + result = llm.build_dbrx(); + } break; default: GGML_ASSERT(false); } @@ -14638,6 +14786,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { // the pairs of head values are offset by n_rot/2 case LLM_ARCH_FALCON: case LLM_ARCH_GROK: + case LLM_ARCH_DBRX: case LLM_ARCH_PERSIMMON: case LLM_ARCH_BERT: case LLM_ARCH_NOMIC_BERT: diff --git a/scripts/get-wikitext-2.sh b/scripts/get-wikitext-2.sh index 7ca760fa6..b01476a46 100755 --- a/scripts/get-wikitext-2.sh +++ b/scripts/get-wikitext-2.sh @@ -1,10 +1,11 @@ #!/bin/bash wget https://huggingface.co/datasets/ggml-org/ci/resolve/main/wikitext-2-raw-v1.zip +unzip wikitext-2-raw-v1.zip echo "Usage:" echo "" -echo " ./perplexity -m model.gguf -f wiki.test.raw [other params]" +echo " ./perplexity -m model.gguf -f wikitext-2-raw/wiki.test.raw [other params]" echo "" exit 0