Add JAIS model(s) (#8118)

* Add `JAIS` model(s)

* cleanup

* address review comments

* remove hack

* un-hardcode max-alibi-bias

* minor tweaks

---------

Co-authored-by: fmz <quic_fzaghlou@quic.com>
This commit is contained in:
Faisal Zaghloul 2024-07-02 10:36:00 -04:00 committed by GitHub
parent 023b8807e1
commit 968967376d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 288 additions and 9 deletions

View File

@ -86,6 +86,7 @@ models = [
{"name": "poro-chat", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LumiOpen/Poro-34B-chat", },
{"name": "jina-v2-code", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-code", },
{"name": "viking", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LumiOpen/Viking-7B", }, # Also used for Viking 13B and 33B
{"name": "jais", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/core42/jais-13b", },
]

View File

@ -490,6 +490,9 @@ class Model:
if chkhsh == "7fc505bd3104ca1083b150b17d088b59534ede9bde81f0dd2090967d7fe52cee":
# ref: https://huggingface.co/LumiOpen/Viking-7B
res = "viking"
if chkhsh == "b53802fb28e26d645c3a310b34bfe07da813026ec7c7716883404d5e0f8b1901":
# ref: https://huggingface.co/core42/jais-13b
res = "jais"
if res is None:
logger.warning("\n")
@ -2965,6 +2968,96 @@ class T5Model(Model):
return [(self.map_tensor_name(name), data_torch)]
@Model.register("JAISLMHeadModel")
class JaisModel(Model):
model_arch = gguf.MODEL_ARCH.JAIS
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# SwigLU activation
assert self.hparams["activation_function"] == "swiglu"
# ALiBi position embedding
assert self.hparams["position_embedding_type"] == "alibi"
# Embeddings scale
self.embeddings_scale = 1.0
# note: For some JAIS flavors, output is tied to (same as) wte in original model
self.output_is_wte = False
if 'mup_embeddings_scale' in self.hparams:
self.output_is_wte = True # Hack (?)
self.embeddings_scale = self.hparams['mup_embeddings_scale']
elif 'embeddings_scale' in self.hparams:
self.embeddings_scale = self.hparams['embeddings_scale']
else:
assert False
self.width_scale = 1.0
if 'mup_output_alpha' in self.hparams:
assert 'mup_width_scale' in self.hparams
self.width_scale = self.hparams['mup_output_alpha'] * self.hparams['mup_width_scale']
elif 'width_scale' in self.hparams:
self.width_scale = self.hparams['width_scale']
else:
assert False
self.max_alibi_bias = 8.0
def set_vocab(self):
self._set_vocab_gpt2()
def set_gguf_parameters(self):
self.gguf_writer.add_name(self.dir_model.name)
self.gguf_writer.add_block_count(self.hparams["n_layer"])
self.gguf_writer.add_context_length(self.hparams["n_positions"])
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
self.gguf_writer.add_feed_forward_length(self.hparams["n_inner"])
self.gguf_writer.add_head_count(self.hparams["n_head"])
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
self.gguf_writer.add_file_type(self.ftype)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
tensors: list[tuple[str, Tensor]] = []
# we don't need these
if name.endswith((".attn.bias")):
return tensors
if name.endswith(("relative_pe.slopes")):
# Calculate max ALiBi bias (this is the inverse of the ALiBi calculation)
# Some other models has max_alibi_bias spelled out explicitly in the hyperparams,
# but Jais's PyTorch model simply precalculates the slope values and places them
# in relative_pes.slopes
n_head_closest_log2 = 2 ** math.floor(math.log2(self.hparams["n_head"]))
first_val = float(data_torch._data[0])
self.max_alibi_bias = -round(math.log2(first_val) * n_head_closest_log2)
return tensors
if name.endswith((".c_attn.weight", ".c_proj.weight", ".c_fc.weight", ".c_fc2.weight")):
data_torch = data_torch.transpose(1, 0)
new_name = self.map_tensor_name(name)
if new_name == self.format_tensor_name(gguf.MODEL_TENSOR.TOKEN_EMBD):
tensors.append((new_name, data_torch * self.embeddings_scale))
if self.output_is_wte:
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT), data_torch * self.width_scale))
elif new_name == self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT):
assert not self.output_is_wte
tensors.append((new_name, data_torch * self.width_scale))
else:
tensors.append((new_name, data_torch))
return tensors
def write_tensors(self):
super().write_tensors()
self.gguf_writer.add_max_alibi_bias(self.max_alibi_bias)
###### CONVERSION LOGIC ######

View File

@ -164,6 +164,7 @@ class MODEL_ARCH(IntEnum):
DEEPSEEK2 = auto()
BITNET = auto()
T5 = auto()
JAIS = auto()
class MODEL_TENSOR(IntEnum):
@ -288,6 +289,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.DEEPSEEK2: "deepseek2",
MODEL_ARCH.BITNET: "bitnet",
MODEL_ARCH.T5: "t5",
MODEL_ARCH.JAIS: "jais",
}
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@ -954,6 +956,18 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.ENC_FFN_UP,
MODEL_TENSOR.ENC_OUTPUT_NORM,
],
MODEL_ARCH.JAIS: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_UP,
],
# TODO
}

View File

@ -10,7 +10,7 @@ class TensorNameMap:
# Token embeddings
MODEL_TENSOR.TOKEN_EMBD: (
"gpt_neox.embed_in", # gptneox
"transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx
"transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais
"transformer.word_embeddings", # falcon
"word_embeddings", # bloom
"model.embed_tokens", # llama-hf
@ -49,7 +49,7 @@ class TensorNameMap:
# Output
MODEL_TENSOR.OUTPUT: (
"embed_out", # gptneox
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais
"output", # llama-pth bloom internlm2
"word_embeddings_for_head", # persimmon
"lm_head.linear", # phi2
@ -58,7 +58,7 @@ class TensorNameMap:
# Output norm
MODEL_TENSOR.OUTPUT_NORM: (
"gpt_neox.final_layer_norm", # gptneox
"transformer.ln_f", # gpt2 gpt-j falcon
"transformer.ln_f", # gpt2 gpt-j falcon jais
"model.norm", # llama-hf baichuan internlm2
"norm", # llama-pth
"transformer.norm_f", # mpt dbrx
@ -81,7 +81,7 @@ class TensorNameMap:
# Attention norm
MODEL_TENSOR.ATTN_NORM: (
"gpt_neox.layers.{bid}.input_layernorm", # gptneox
"transformer.h.{bid}.ln_1", # gpt2 gpt-j refact qwen
"transformer.h.{bid}.ln_1", # gpt2 gpt-j refact qwen jais
"transformer.blocks.{bid}.norm_1", # mpt
"transformer.h.{bid}.input_layernorm", # falcon7b
"h.{bid}.input_layernorm", # bloom
@ -109,7 +109,7 @@ class TensorNameMap:
# Attention query-key-value
MODEL_TENSOR.ATTN_QKV: (
"gpt_neox.layers.{bid}.attention.query_key_value", # gptneox
"transformer.h.{bid}.attn.c_attn", # gpt2 qwen
"transformer.h.{bid}.attn.c_attn", # gpt2 qwen jais
"transformer.blocks.{bid}.attn.Wqkv", # mpt
"transformer.blocks.{bid}.norm_attn_norm.attn.Wqkv", # dbrx
"transformer.h.{bid}.self_attention.query_key_value", # falcon
@ -160,7 +160,7 @@ class TensorNameMap:
# Attention output
MODEL_TENSOR.ATTN_OUT: (
"gpt_neox.layers.{bid}.attention.dense", # gptneox
"transformer.h.{bid}.attn.c_proj", # gpt2 refact qwen
"transformer.h.{bid}.attn.c_proj", # gpt2 refact qwen jais
"transformer.blocks.{bid}.attn.out_proj", # mpt
"transformer.h.{bid}.self_attention.dense", # falcon
"h.{bid}.self_attention.dense", # bloom
@ -202,7 +202,7 @@ class TensorNameMap:
# Feed-forward norm
MODEL_TENSOR.FFN_NORM: (
"gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox
"transformer.h.{bid}.ln_2", # gpt2 refact qwen
"transformer.h.{bid}.ln_2", # gpt2 refact qwen jais
"h.{bid}.post_attention_layernorm", # bloom
"transformer.blocks.{bid}.norm_2", # mpt
"model.layers.{bid}.post_attention_layernorm", # llama-hf
@ -239,7 +239,7 @@ class TensorNameMap:
# Feed-forward up
MODEL_TENSOR.FFN_UP: (
"gpt_neox.layers.{bid}.mlp.dense_h_to_4h", # gptneox
"transformer.h.{bid}.mlp.c_fc", # gpt2
"transformer.h.{bid}.mlp.c_fc", # gpt2 jais
"transformer.blocks.{bid}.ffn.up_proj", # mpt
"transformer.h.{bid}.mlp.dense_h_to_4h", # falcon
"h.{bid}.mlp.dense_h_to_4h", # bloom
@ -285,6 +285,7 @@ class TensorNameMap:
"model.layers.{bid}.mlp.gate_proj", # llama-hf refact
"layers.{bid}.feed_forward.w1", # llama-pth
"transformer.h.{bid}.mlp.w2", # qwen
"transformer.h.{bid}.mlp.c_fc2", # jais
"model.layers.layers.{bid}.mlp.gate_proj", # plamo
"model.layers.{bid}.feed_forward.w1", # internlm2
"encoder.layers.{bid}.mlp.fc12", # nomic-bert
@ -308,7 +309,7 @@ class TensorNameMap:
# Feed-forward down
MODEL_TENSOR.FFN_DOWN: (
"gpt_neox.layers.{bid}.mlp.dense_4h_to_h", # gptneox
"transformer.h.{bid}.mlp.c_proj", # gpt2 refact qwen
"transformer.h.{bid}.mlp.c_proj", # gpt2 refact qwen jais
"transformer.blocks.{bid}.ffn.down_proj", # mpt
"transformer.h.{bid}.mlp.dense_4h_to_h", # falcon
"h.{bid}.mlp.dense_4h_to_h", # bloom

View File

@ -89,6 +89,7 @@ extern "C" {
LLAMA_VOCAB_PRE_TYPE_SMAUG = 14,
LLAMA_VOCAB_PRE_TYPE_PORO = 15,
LLAMA_VOCAB_PRE_TYPE_VIKING = 16,
LLAMA_VOCAB_PRE_TYPE_JAIS = 17,
};
// note: these values should be synchronized with ggml_rope

View File

@ -228,6 +228,7 @@ enum llm_arch {
LLM_ARCH_DEEPSEEK2,
LLM_ARCH_BITNET,
LLM_ARCH_T5,
LLM_ARCH_JAIS,
LLM_ARCH_UNKNOWN,
};
@ -269,6 +270,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_DEEPSEEK2, "deepseek2" },
{ LLM_ARCH_BITNET, "bitnet" },
{ LLM_ARCH_T5, "t5" },
{ LLM_ARCH_JAIS, "jais" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};
@ -1236,6 +1238,21 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
{ LLM_TENSOR_ENC_FFN_UP, "enc.blk.%d.ffn_up" },
},
},
{
LLM_ARCH_JAIS,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
},
},
{
LLM_ARCH_UNKNOWN,
{
@ -2035,6 +2052,7 @@ enum e_model {
MODEL_410M,
MODEL_0_5B,
MODEL_1B,
MODEL_1_3B,
MODEL_1_4B,
MODEL_2B,
MODEL_2_8B,
@ -4276,6 +4294,7 @@ static const char * llama_model_type_name(e_model type) {
case MODEL_410M: return "410M";
case MODEL_0_5B: return "0.5B";
case MODEL_1B: return "1B";
case MODEL_1_3B: return "1.3B";
case MODEL_1_4B: return "1.4B";
case MODEL_2B: return "2B";
case MODEL_2_8B: return "2.8B";
@ -4898,6 +4917,18 @@ static void llm_load_hparams(
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
case LLM_ARCH_JAIS:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
ml.get_key(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias);
switch (hparams.n_layer) {
case 24: model.type = e_model::MODEL_1_3B; break;
case 40: model.type = e_model::MODEL_13B; break;
/* TODO: add variants */
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
default: (void)0;
}
@ -5129,6 +5160,9 @@ static void llm_load_vocab(
} else if (
tokenizer_pre == "viking") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_VIKING;
} else if (
tokenizer_pre == "jais") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_JAIS;
} else {
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
}
@ -6962,6 +6996,44 @@ static bool llm_load_tensors(
layer.ffn_up_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "scale", i), {1});
}
} break;
case LLM_ARCH_JAIS:
{
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
// Output
{
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd});
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab});
}
for (int i = 0; i < n_layer; ++i) {
ggml_context * ctx_layer = ctx_for_layer(i);
ggml_context * ctx_split = ctx_for_layer_split(i);
auto & layer = model.layers[i];
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd});
layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa});
layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa});
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd});
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd});
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd});
layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
layer.ffn_gate_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff});
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff});
}
} break;
default:
throw std::runtime_error("unknown architecture");
}
@ -12354,6 +12426,97 @@ struct llm_build_context {
return gf;
}
struct ggml_cgraph * build_jais() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
struct ggml_tensor * cur;
struct ggml_tensor * inpL;
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
for (int il = 0; il < n_layer; ++il) {
cur = llm_build_norm(ctx0, inpL, hparams,
model.layers[il].attn_norm,
model.layers[il].attn_norm_b,
LLM_NORM, cb, il);
cb(cur, "attn_norm", il);
// self-attention
{
cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
cb(cur, "wqkv", il);
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
cb(cur, "bqkv", il);
struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*cur->nb[0]*(n_embd)));
struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*cur->nb[0]*(n_embd)));
struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*cur->nb[0]*(n_embd + n_embd_gqa)));
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/float(n_embd_head), cb, il);
}
if (il == n_layer - 1) {
// skip computing output for unused tokens
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
}
// add the input
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
cb(ffn_inp, "ffn_inp", il);
// FF
{
cur = llm_build_norm(ctx0, ffn_inp, hparams,
model.layers[il].ffn_norm,
model.layers[il].ffn_norm_b,
LLM_NORM, cb, il);
cb(cur, "ffn_norm", il);
cur = llm_build_ffn(ctx0, cur,
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
cb(cur, "ffn_out", il);
}
inpL = ggml_add(ctx0, cur, ffn_inp);
cb(inpL, "l_out", il);
}
cur = llm_build_norm(ctx0, inpL, hparams,
model.output_norm,
model.output_norm_b,
LLM_NORM, cb, -1);
cb(cur, "result_norm", -1);
cur = ggml_mul_mat(ctx0, model.output, cur);
cb(cur, "result_output", -1);
ggml_build_forward_expand(gf, cur);
return gf;
}
};
static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
@ -12585,6 +12748,10 @@ static struct ggml_cgraph * llama_build_graph(
{
result = llm.build_bitnet();
} break;
case LLM_ARCH_JAIS:
{
result = llm.build_jais();
} break;
default:
GGML_ASSERT(false);
}
@ -13947,6 +14114,7 @@ struct llm_tokenizer_bpe {
break;
case LLAMA_VOCAB_PRE_TYPE_GPT2:
case LLAMA_VOCAB_PRE_TYPE_OLMO:
case LLAMA_VOCAB_PRE_TYPE_JAIS:
regex_exprs = {
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
};
@ -17826,6 +17994,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
case LLM_ARCH_MAMBA:
case LLM_ARCH_JINA_BERT_V2:
case LLM_ARCH_T5:
case LLM_ARCH_JAIS:
return LLAMA_ROPE_TYPE_NONE;
// use what we call a normal RoPE, operating on pairs of consecutive head values