mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-11-11 21:39:52 +00:00
llm : add Falcon support (#2717)
* llama : refactor GGUF constants into static maps * llama : check if model architecture is known * llama : refactor llama_model_load_internal() * gguf : add KV constant maps * llm : read arch-specific KVs * convert : add dummy scores + types * falcon : load tensor data (CPU only) * llama : fix loading progress bar * llama : add arch member to llama_model * falcon : CPU inference working * falcon : support non-40B models * falcon : minor * llama : minor updates ggml-ci * convert-falcon-hf-to-gguf.py : fix special token mapping * llama.cpp : llama default UNK token = id 0 * llama.cpp : fix bpe tokenizer * llama.cpp : fix the fix of bpe tokenizer * ggml : pass eps to ggml_norm * metal : implement RoPE (mode = 2) + avoid ggml_repeat * ggml : ggml_repeat always creates new tensor * falcon : copy-paste self-attention from LLaMA * metal : print extra compute pipeline info * falcon : minor changes (still chasing the Metal problem) * llama.cpp : fix linefeed token * metal : fix GELU kernel numerical stability by using precise::tanh * metal : temporary workaround for the concurrency optimization bug * falcon : add CUDA offloading (#2739) * llama : better model naming and size reporting * llama : prep new tokenizer support * llama : advanced BPE tokenizer based on ggllm.cpp imlpementation * llama : remove oboslete comment ggml-ci * common : remove obsolete BPE API + disable test-tokenizer-1 * llama : revert BPE special-case in llama_byte_to_token() * cuda : add TODOs for RoPE NeoX implementation * llama : default special tokens based on vocab type * perplexity : add log for start of tokenization --------- Co-authored-by: klosax <131523366+klosax@users.noreply.github.com> Co-authored-by: slaren <slarengh@gmail.com>
This commit is contained in:
parent
a192860cfe
commit
cf658adc83
@ -744,35 +744,3 @@ std::string llama_token_to_str(const struct llama_context * ctx, llama_token tok
|
|||||||
|
|
||||||
return std::string(result.data(), result.size());
|
return std::string(result.data(), result.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<llama_token> llama_tokenize_bpe(
|
|
||||||
struct llama_context * ctx,
|
|
||||||
const std::string & text,
|
|
||||||
bool add_bos) {
|
|
||||||
int n_tokens = text.length() + add_bos;
|
|
||||||
std::vector<llama_token> result(n_tokens);
|
|
||||||
n_tokens = llama_tokenize_bpe(ctx, text.c_str(), result.data(), result.size(), add_bos);
|
|
||||||
if (n_tokens < 0) {
|
|
||||||
result.resize(-n_tokens);
|
|
||||||
int check = llama_tokenize_bpe(ctx, text.c_str(), result.data(), result.size(), add_bos);
|
|
||||||
GGML_ASSERT(check == -n_tokens);
|
|
||||||
} else {
|
|
||||||
result.resize(n_tokens);
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string llama_token_to_str_bpe(const struct llama_context * ctx, llama_token token) {
|
|
||||||
std::vector<char> result(8, 0);
|
|
||||||
const int n_tokens = llama_token_to_str_bpe(ctx, token, result.data(), result.size());
|
|
||||||
if (n_tokens < 0) {
|
|
||||||
result.resize(-n_tokens);
|
|
||||||
const int check = llama_token_to_str_bpe(ctx, token, result.data(), result.size());
|
|
||||||
GGML_ASSERT(check == -n_tokens);
|
|
||||||
} else {
|
|
||||||
result.resize(n_tokens);
|
|
||||||
}
|
|
||||||
|
|
||||||
return std::string(result.data(), result.size());
|
|
||||||
}
|
|
||||||
|
|
||||||
|
@ -120,15 +120,6 @@ std::vector<llama_token> llama_tokenize(
|
|||||||
const std::string & text,
|
const std::string & text,
|
||||||
bool add_bos);
|
bool add_bos);
|
||||||
|
|
||||||
std::vector<llama_token> llama_tokenize_bpe(
|
|
||||||
struct llama_context * ctx,
|
|
||||||
const std::string & text,
|
|
||||||
bool add_bos);
|
|
||||||
|
|
||||||
std::string llama_token_to_str(
|
std::string llama_token_to_str(
|
||||||
const struct llama_context * ctx,
|
const struct llama_context * ctx,
|
||||||
llama_token token);
|
llama_token token);
|
||||||
|
|
||||||
std::string llama_token_to_str_bpe(
|
|
||||||
const struct llama_context * ctx,
|
|
||||||
llama_token token);
|
|
||||||
|
@ -95,14 +95,17 @@ print("gguf: get model metadata")
|
|||||||
|
|
||||||
block_count = hparams["n_layer"]
|
block_count = hparams["n_layer"]
|
||||||
|
|
||||||
gguf_writer.add_name(last_dir)
|
gguf_writer.add_name("Falcon")
|
||||||
gguf_writer.add_context_length(2048) # not in config.json
|
gguf_writer.add_context_length(2048) # not in config.json
|
||||||
gguf_writer.add_tensor_data_layout("jploski") # qkv tensor transform
|
gguf_writer.add_tensor_data_layout("jploski") # qkv tensor transform
|
||||||
gguf_writer.add_embedding_length(hparams["hidden_size"])
|
gguf_writer.add_embedding_length(hparams["hidden_size"])
|
||||||
gguf_writer.add_feed_forward_length(4 * hparams["hidden_size"])
|
gguf_writer.add_feed_forward_length(4 * hparams["hidden_size"])
|
||||||
gguf_writer.add_block_count(block_count)
|
gguf_writer.add_block_count(block_count)
|
||||||
gguf_writer.add_head_count(hparams["n_head"])
|
gguf_writer.add_head_count(hparams["n_head"])
|
||||||
if "n_head_kv" in hparams: gguf_writer.add_head_count_kv(hparams["n_head_kv"])
|
if "n_head_kv" in hparams:
|
||||||
|
gguf_writer.add_head_count_kv(hparams["n_head_kv"])
|
||||||
|
else:
|
||||||
|
gguf_writer.add_head_count_kv(1)
|
||||||
gguf_writer.add_layer_norm_eps(hparams["layer_norm_epsilon"])
|
gguf_writer.add_layer_norm_eps(hparams["layer_norm_epsilon"])
|
||||||
|
|
||||||
# TOKENIZATION
|
# TOKENIZATION
|
||||||
@ -110,6 +113,8 @@ gguf_writer.add_layer_norm_eps(hparams["layer_norm_epsilon"])
|
|||||||
print("gguf: get tokenizer metadata")
|
print("gguf: get tokenizer metadata")
|
||||||
|
|
||||||
tokens: List[str] = []
|
tokens: List[str] = []
|
||||||
|
scores: List[float] = []
|
||||||
|
toktypes: List[int] = []
|
||||||
merges: List[str] = []
|
merges: List[str] = []
|
||||||
|
|
||||||
|
|
||||||
@ -153,41 +158,30 @@ if Path(dir_model + "/tokenizer.json").is_file():
|
|||||||
text = bytearray(pad_token)
|
text = bytearray(pad_token)
|
||||||
|
|
||||||
tokens.append(text)
|
tokens.append(text)
|
||||||
|
scores.append(0.0) # dymmy
|
||||||
|
toktypes.append(gguf.TokenType.NORMAL) # dummy
|
||||||
|
|
||||||
gguf_writer.add_token_list(tokens)
|
gguf_writer.add_token_list(tokens)
|
||||||
|
gguf_writer.add_token_scores(scores)
|
||||||
|
gguf_writer.add_token_types(toktypes)
|
||||||
|
|
||||||
if "added_tokens" in tokenizer_json and Path(dir_model + "/tokenizer_config.json").is_file():
|
print("gguf: get special token ids")
|
||||||
print("gguf: get special token ids")
|
# Look for special tokens in config.json
|
||||||
|
|
||||||
with open(dir_model + "/tokenizer_config.json", "r", encoding="utf-8") as f:
|
if "bos_token_id" in hparams and hparams["bos_token_id"] != None:
|
||||||
tokenizer_config = json.load(f)
|
gguf_writer.add_bos_token_id(hparams["bos_token_id"])
|
||||||
|
|
||||||
# find special token ids
|
if "eos_token_id" in hparams and hparams["eos_token_id"] != None:
|
||||||
|
gguf_writer.add_eos_token_id(hparams["eos_token_id"])
|
||||||
|
|
||||||
if "bos_token" in tokenizer_config:
|
if "unk_token_id" in hparams and hparams["unk_token_id"] != None:
|
||||||
for key in tokenizer_json["added_tokens"]:
|
gguf_writer.add_unk_token_id(hparams["unk_token_id"])
|
||||||
if key["content"] == tokenizer_config["bos_token"]:
|
|
||||||
gguf_writer.add_bos_token_id(key["id"])
|
|
||||||
|
|
||||||
if "eos_token" in tokenizer_config:
|
if "sep_token_id" in hparams and hparams["sep_token_id"] != None:
|
||||||
for key in tokenizer_json["added_tokens"]:
|
gguf_writer.add_sep_token_id(hparams["sep_token_id"])
|
||||||
if key["content"] == tokenizer_config["eos_token"]:
|
|
||||||
gguf_writer.add_eos_token_id(key["id"])
|
|
||||||
|
|
||||||
if "unk_token" in tokenizer_config:
|
if "pad_token_id" in hparams and hparams["pad_token_id"] != None:
|
||||||
for key in tokenizer_json["added_tokens"]:
|
gguf_writer.add_pad_token_id(hparams["pad_token_id"])
|
||||||
if key["content"] == tokenizer_config["unk_token"]:
|
|
||||||
gguf_writer.add_unk_token_id(key["id"])
|
|
||||||
|
|
||||||
if "sep_token" in tokenizer_config:
|
|
||||||
for key in tokenizer_json["added_tokens"]:
|
|
||||||
if key["content"] == tokenizer_config["sep_token"]:
|
|
||||||
gguf_writer.add_sep_token_id(key["id"])
|
|
||||||
|
|
||||||
if "pad_token" in tokenizer_config:
|
|
||||||
for key in tokenizer_json["added_tokens"]:
|
|
||||||
if key["content"] == tokenizer_config["pad_token"]:
|
|
||||||
gguf_writer.add_pad_token_id(key["id"])
|
|
||||||
|
|
||||||
|
|
||||||
# TENSORS
|
# TENSORS
|
||||||
@ -197,6 +191,7 @@ tensor_map = gguf.get_tensor_name_map(ARCH,block_count)
|
|||||||
# params for qkv transform
|
# params for qkv transform
|
||||||
n_head = hparams["n_head"]
|
n_head = hparams["n_head"]
|
||||||
n_head_kv = hparams["n_head_kv"] if "n_head_kv" in hparams else 1
|
n_head_kv = hparams["n_head_kv"] if "n_head_kv" in hparams else 1
|
||||||
|
|
||||||
head_dim = hparams["hidden_size"] // n_head
|
head_dim = hparams["hidden_size"] // n_head
|
||||||
|
|
||||||
# tensor info
|
# tensor info
|
||||||
|
@ -733,7 +733,11 @@ class OutputFile:
|
|||||||
self.gguf = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH])
|
self.gguf = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH])
|
||||||
|
|
||||||
def add_meta_arch(self, params: Params) -> None:
|
def add_meta_arch(self, params: Params) -> None:
|
||||||
self.gguf.add_name ("LLaMA")
|
ver = None
|
||||||
|
if (params.n_ctx == 4096):
|
||||||
|
ver = "v2"
|
||||||
|
|
||||||
|
self.gguf.add_name ("LLaMA" if ver == None else "LLaMA " + ver)
|
||||||
self.gguf.add_context_length (params.n_ctx)
|
self.gguf.add_context_length (params.n_ctx)
|
||||||
self.gguf.add_embedding_length (params.n_embd)
|
self.gguf.add_embedding_length (params.n_embd)
|
||||||
self.gguf.add_block_count (params.n_layer)
|
self.gguf.add_block_count (params.n_layer)
|
||||||
|
@ -43,7 +43,7 @@ static bool is_interacting = false;
|
|||||||
void sigint_handler(int signo) {
|
void sigint_handler(int signo) {
|
||||||
if (signo == SIGINT) {
|
if (signo == SIGINT) {
|
||||||
if (!is_interacting) {
|
if (!is_interacting) {
|
||||||
is_interacting=true;
|
is_interacting = true;
|
||||||
} else {
|
} else {
|
||||||
console::cleanup();
|
console::cleanup();
|
||||||
printf("\n");
|
printf("\n");
|
||||||
@ -189,10 +189,12 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM;
|
||||||
|
|
||||||
// tokenize the prompt
|
// tokenize the prompt
|
||||||
std::vector<llama_token> embd_inp;
|
std::vector<llama_token> embd_inp;
|
||||||
if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) {
|
if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) {
|
||||||
embd_inp = ::llama_tokenize(ctx, params.prompt, true);
|
embd_inp = ::llama_tokenize(ctx, params.prompt, is_spm);
|
||||||
} else {
|
} else {
|
||||||
embd_inp = session_tokens;
|
embd_inp = session_tokens;
|
||||||
}
|
}
|
||||||
@ -208,9 +210,9 @@ int main(int argc, char ** argv) {
|
|||||||
int original_prompt_len = 0;
|
int original_prompt_len = 0;
|
||||||
if (ctx_guidance) {
|
if (ctx_guidance) {
|
||||||
params.cfg_negative_prompt.insert(0, 1, ' ');
|
params.cfg_negative_prompt.insert(0, 1, ' ');
|
||||||
guidance_inp = ::llama_tokenize(ctx_guidance, params.cfg_negative_prompt, true);
|
guidance_inp = ::llama_tokenize(ctx_guidance, params.cfg_negative_prompt, is_spm);
|
||||||
|
|
||||||
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, true);
|
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, is_spm);
|
||||||
original_prompt_len = original_inp.size();
|
original_prompt_len = original_inp.size();
|
||||||
guidance_offset = (int)guidance_inp.size() - original_prompt_len;
|
guidance_offset = (int)guidance_inp.size() - original_prompt_len;
|
||||||
}
|
}
|
||||||
@ -257,7 +259,7 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// prefix & suffix for instruct mode
|
// prefix & suffix for instruct mode
|
||||||
const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", true);
|
const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", is_spm);
|
||||||
const auto inp_sfx = ::llama_tokenize(ctx, "\n\n### Response:\n\n", false);
|
const auto inp_sfx = ::llama_tokenize(ctx, "\n\n### Response:\n\n", false);
|
||||||
|
|
||||||
// in instruct mode, we inject a prefix and a suffix to each input by the user
|
// in instruct mode, we inject a prefix and a suffix to each input by the user
|
||||||
|
@ -28,7 +28,6 @@ std::vector<float> softmax(const std::vector<float>& logits) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void perplexity_v2(llama_context * ctx, const gpt_params & params) {
|
void perplexity_v2(llama_context * ctx, const gpt_params & params) {
|
||||||
|
|
||||||
// Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
|
// Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
|
||||||
// Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
|
// Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
|
||||||
// Output: `perplexity: 13.5106 [114/114]`
|
// Output: `perplexity: 13.5106 [114/114]`
|
||||||
@ -38,7 +37,13 @@ void perplexity_v2(llama_context * ctx, const gpt_params & params) {
|
|||||||
fprintf(stderr, "%s: stride is %d but must be greater than zero!\n",__func__,params.ppl_stride);
|
fprintf(stderr, "%s: stride is %d but must be greater than zero!\n",__func__,params.ppl_stride);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
auto tokens = ::llama_tokenize(ctx, params.prompt, true);
|
|
||||||
|
const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM;
|
||||||
|
const bool add_bos = is_spm;
|
||||||
|
|
||||||
|
fprintf(stderr, "%s: tokenizing the input ..\n", __func__);
|
||||||
|
|
||||||
|
auto tokens = ::llama_tokenize(ctx, params.prompt, add_bos);
|
||||||
|
|
||||||
const int calc_chunk = params.n_ctx;
|
const int calc_chunk = params.n_ctx;
|
||||||
|
|
||||||
@ -86,7 +91,7 @@ void perplexity_v2(llama_context * ctx, const gpt_params & params) {
|
|||||||
const auto token_org = tokens[batch_start];
|
const auto token_org = tokens[batch_start];
|
||||||
|
|
||||||
// add BOS token for the first batch of each chunk
|
// add BOS token for the first batch of each chunk
|
||||||
if (j == 0) {
|
if (add_bos && j == 0) {
|
||||||
tokens[batch_start] = llama_token_bos(ctx);
|
tokens[batch_start] = llama_token_bos(ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -136,7 +141,6 @@ void perplexity_v2(llama_context * ctx, const gpt_params & params) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void perplexity(llama_context * ctx, const gpt_params & params) {
|
void perplexity(llama_context * ctx, const gpt_params & params) {
|
||||||
|
|
||||||
if (params.ppl_stride > 0) {
|
if (params.ppl_stride > 0) {
|
||||||
perplexity_v2(ctx, params);
|
perplexity_v2(ctx, params);
|
||||||
return;
|
return;
|
||||||
@ -146,7 +150,13 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
|
|||||||
// Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
|
// Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
|
||||||
// Output: `perplexity: 13.5106 [114/114]`
|
// Output: `perplexity: 13.5106 [114/114]`
|
||||||
// BOS tokens will be added for each chunk before eval
|
// BOS tokens will be added for each chunk before eval
|
||||||
auto tokens = ::llama_tokenize(ctx, params.prompt, true);
|
|
||||||
|
const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM;
|
||||||
|
const bool add_bos = is_spm;
|
||||||
|
|
||||||
|
fprintf(stderr, "%s: tokenizing the input ..\n", __func__);
|
||||||
|
|
||||||
|
auto tokens = ::llama_tokenize(ctx, params.prompt, add_bos);
|
||||||
|
|
||||||
const int n_chunk_max = tokens.size() / params.n_ctx;
|
const int n_chunk_max = tokens.size() / params.n_ctx;
|
||||||
|
|
||||||
@ -177,7 +187,7 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
|
|||||||
const auto token_org = tokens[batch_start];
|
const auto token_org = tokens[batch_start];
|
||||||
|
|
||||||
// add BOS token for the first batch of each chunk
|
// add BOS token for the first batch of each chunk
|
||||||
if (j == 0) {
|
if (add_bos && j == 0) {
|
||||||
tokens[batch_start] = llama_token_bos(ctx);
|
tokens[batch_start] = llama_token_bos(ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -295,8 +305,10 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|||||||
size_t hs_task_count = prompt_lines.size()/6;
|
size_t hs_task_count = prompt_lines.size()/6;
|
||||||
fprintf(stderr, "%s : loaded %zu tasks from prompt.\n", __func__, hs_task_count);
|
fprintf(stderr, "%s : loaded %zu tasks from prompt.\n", __func__, hs_task_count);
|
||||||
|
|
||||||
|
const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM;
|
||||||
|
|
||||||
// This is needed as usual for LLaMA models
|
// This is needed as usual for LLaMA models
|
||||||
bool prepend_bos = true;
|
const bool add_bos = is_spm;
|
||||||
|
|
||||||
// Number of tasks to use when computing the score
|
// Number of tasks to use when computing the score
|
||||||
if ( params.hellaswag_tasks < hs_task_count ) {
|
if ( params.hellaswag_tasks < hs_task_count ) {
|
||||||
@ -352,14 +364,13 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|||||||
std::vector<float> tok_logits(n_vocab);
|
std::vector<float> tok_logits(n_vocab);
|
||||||
|
|
||||||
for (size_t task_idx = 0; task_idx < hs_task_count; task_idx++) {
|
for (size_t task_idx = 0; task_idx < hs_task_count; task_idx++) {
|
||||||
|
|
||||||
// Tokenize the context to count tokens
|
// Tokenize the context to count tokens
|
||||||
std::vector<int> context_embd = ::llama_tokenize(ctx, hs_data[task_idx].context, prepend_bos);
|
std::vector<int> context_embd = ::llama_tokenize(ctx, hs_data[task_idx].context, add_bos);
|
||||||
size_t context_size = context_embd.size();
|
size_t context_size = context_embd.size();
|
||||||
|
|
||||||
// Do the 1st ending
|
// Do the 1st ending
|
||||||
// In this case we include the context when evaluating
|
// In this case we include the context when evaluating
|
||||||
auto query_embd = ::llama_tokenize(ctx, hs_data[task_idx].context + hs_data[task_idx].ending[0], prepend_bos);
|
auto query_embd = ::llama_tokenize(ctx, hs_data[task_idx].context + hs_data[task_idx].ending[0], add_bos);
|
||||||
auto query_size = query_embd.size();
|
auto query_size = query_embd.size();
|
||||||
//printf("First query: %d\n",(int)query_size);
|
//printf("First query: %d\n",(int)query_size);
|
||||||
|
|
||||||
|
@ -238,7 +238,7 @@ static void ggml_allocator_free_tensor(struct ggml_allocr * alloc, struct ggml_t
|
|||||||
alloc->n_free_blocks++;
|
alloc->n_free_blocks++;
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, int * list, int n) {
|
void ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, const int * list, int n) {
|
||||||
int pos = 0;
|
int pos = 0;
|
||||||
for (int i = 0; i < n; i++) {
|
for (int i = 0; i < n; i++) {
|
||||||
if (list[i] != -1) {
|
if (list[i] != -1) {
|
||||||
@ -547,7 +547,7 @@ static size_t ggml_allocator_alloc_graph_tensors_n(
|
|||||||
struct ggml_tensor * view_src = get_view_source(parent);
|
struct ggml_tensor * view_src = get_view_source(parent);
|
||||||
struct hash_node * view_src_hn = hash_get(ht, view_src);
|
struct hash_node * view_src_hn = hash_get(ht, view_src);
|
||||||
view_src_hn->n_views -= 1;
|
view_src_hn->n_views -= 1;
|
||||||
AT_PRINTF("view_src %s: %d children, %d views\n", view_src->name, view_src->n_children, view_src->n_views);
|
AT_PRINTF("view_src %s\n", view_src->name);
|
||||||
if (view_src_hn->n_views == 0 && view_src_hn->n_children == 0 && view_src->data != node->data) {
|
if (view_src_hn->n_views == 0 && view_src_hn->n_children == 0 && view_src->data != node->data) {
|
||||||
ggml_allocator_free_tensor(alloc, view_src);
|
ggml_allocator_free_tensor(alloc, view_src);
|
||||||
}
|
}
|
||||||
|
@ -12,7 +12,7 @@ GGML_API struct ggml_allocr * ggml_allocr_new_measure(size_t alignment);
|
|||||||
|
|
||||||
// tell the allocator to parse nodes following the order described in the list
|
// tell the allocator to parse nodes following the order described in the list
|
||||||
// you should call this if your graph are optimized to execute out-of-order
|
// you should call this if your graph are optimized to execute out-of-order
|
||||||
GGML_API void ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, int * list, int n);
|
GGML_API void ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, const int * list, int n);
|
||||||
|
|
||||||
GGML_API void ggml_allocr_free(struct ggml_allocr * alloc);
|
GGML_API void ggml_allocr_free(struct ggml_allocr * alloc);
|
||||||
GGML_API bool ggml_allocr_is_measure(struct ggml_allocr * alloc);
|
GGML_API bool ggml_allocr_is_measure(struct ggml_allocr * alloc);
|
||||||
|
27
ggml-cuda.cu
27
ggml-cuda.cu
@ -3907,6 +3907,29 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
|
|||||||
dst[i + 1] = x0*sin_theta + x1*cos_theta;
|
dst[i + 1] = x0*sin_theta + x1*cos_theta;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: this implementation is wrong!
|
||||||
|
//static __global__ void rope_neox_f32(const float * x, float * dst, const int ncols, const float p0,
|
||||||
|
// const float p_delta, const int p_delta_rows, const float theta_scale) {
|
||||||
|
// const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
|
||||||
|
//
|
||||||
|
// if (col >= ncols) {
|
||||||
|
// return;
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// const int row = blockDim.x*blockIdx.x + threadIdx.x;
|
||||||
|
// const int i = row*ncols + col/2;
|
||||||
|
//
|
||||||
|
// const float theta = (p0 + p_delta * (row/p_delta_rows))*powf(theta_scale, col/2);
|
||||||
|
// const float sin_theta = sinf(theta);
|
||||||
|
// const float cos_theta = cosf(theta);
|
||||||
|
//
|
||||||
|
// const float x0 = x[i + 0];
|
||||||
|
// const float x1 = x[i + ncols/2];
|
||||||
|
//
|
||||||
|
// dst[i + 0] = x0*cos_theta - x1*sin_theta;
|
||||||
|
// dst[i + ncols/2] = x0*sin_theta + x1*cos_theta;
|
||||||
|
//}
|
||||||
|
|
||||||
static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const float p, const float block_p, const float theta_scale) {
|
static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const float p, const float block_p, const float theta_scale) {
|
||||||
const int col = blockDim.x*blockIdx.x + threadIdx.x;
|
const int col = blockDim.x*blockIdx.x + threadIdx.x;
|
||||||
const int half_n_dims = ncols/4;
|
const int half_n_dims = ncols/4;
|
||||||
@ -5515,6 +5538,7 @@ inline void ggml_cuda_op_rope(
|
|||||||
|
|
||||||
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
||||||
|
|
||||||
|
const bool is_neox = mode & 2;
|
||||||
const bool is_glm = mode & 4;
|
const bool is_glm = mode & 4;
|
||||||
|
|
||||||
// compute
|
// compute
|
||||||
@ -5523,6 +5547,9 @@ inline void ggml_cuda_op_rope(
|
|||||||
const float id_p = min(p, n_ctx - 2.f);
|
const float id_p = min(p, n_ctx - 2.f);
|
||||||
const float block_p = max(p - (n_ctx - 2.f), 0.f);
|
const float block_p = max(p - (n_ctx - 2.f), 0.f);
|
||||||
rope_glm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, id_p, block_p, theta_scale, cudaStream_main);
|
rope_glm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, id_p, block_p, theta_scale, cudaStream_main);
|
||||||
|
} else if (is_neox) {
|
||||||
|
GGML_ASSERT(false && "RoPE NeoX not implemented yet");
|
||||||
|
#pragma message("TODO: implement RoPE NeoX for CUDA")
|
||||||
} else {
|
} else {
|
||||||
const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale;
|
const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale;
|
||||||
rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, cudaStream_main);
|
rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, cudaStream_main);
|
||||||
|
24
ggml-metal.m
24
ggml-metal.m
@ -167,7 +167,9 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||||||
#define GGML_METAL_ADD_KERNEL(name) \
|
#define GGML_METAL_ADD_KERNEL(name) \
|
||||||
ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
|
ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
|
||||||
ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
|
ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
|
||||||
fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name); \
|
fprintf(stderr, "%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
|
||||||
|
(int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \
|
||||||
|
(int) ctx->pipeline_##name.threadExecutionWidth); \
|
||||||
if (error) { \
|
if (error) { \
|
||||||
fprintf(stderr, "%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
|
fprintf(stderr, "%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
|
||||||
return NULL; \
|
return NULL; \
|
||||||
@ -538,7 +540,7 @@ void ggml_metal_graph_compute(
|
|||||||
id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
||||||
|
|
||||||
const int node_start = (cb_idx + 0) * n_nodes_per_cb;
|
const int node_start = (cb_idx + 0) * n_nodes_per_cb;
|
||||||
const int node_end = (cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb;
|
const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
|
||||||
|
|
||||||
for (int ind = node_start; ind < node_end; ++ind) {
|
for (int ind = node_start; ind < node_end; ++ind) {
|
||||||
const int i = has_concur ? ctx->concur_list[ind] : ind;
|
const int i = has_concur ? ctx->concur_list[ind] : ind;
|
||||||
@ -768,8 +770,7 @@ void ggml_metal_graph_compute(
|
|||||||
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:10];
|
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:10];
|
||||||
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
||||||
}
|
} else {
|
||||||
else {
|
|
||||||
int nth0 = 32;
|
int nth0 = 32;
|
||||||
int nth1 = 1;
|
int nth1 = 1;
|
||||||
|
|
||||||
@ -872,20 +873,20 @@ void ggml_metal_graph_compute(
|
|||||||
|
|
||||||
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
|
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
|
||||||
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
|
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
}
|
}
|
||||||
else if (src0t == GGML_TYPE_Q3_K) {
|
else if (src0t == GGML_TYPE_Q3_K) {
|
||||||
#ifdef GGML_QKK_64
|
#ifdef GGML_QKK_64
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
#else
|
#else
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
else if (src0t == GGML_TYPE_Q5_K) {
|
else if (src0t == GGML_TYPE_Q5_K) {
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
}
|
}
|
||||||
else if (src0t == GGML_TYPE_Q6_K) {
|
else if (src0t == GGML_TYPE_Q6_K) {
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
} else {
|
} else {
|
||||||
[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
|
[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
@ -938,7 +939,8 @@ void ggml_metal_graph_compute(
|
|||||||
} break;
|
} break;
|
||||||
case GGML_OP_NORM:
|
case GGML_OP_NORM:
|
||||||
{
|
{
|
||||||
const float eps = 1e-5f;
|
float eps;
|
||||||
|
memcpy(&eps, dst->op_params, sizeof(float));
|
||||||
|
|
||||||
const int nth = 256;
|
const int nth = 256;
|
||||||
|
|
||||||
@ -990,7 +992,9 @@ void ggml_metal_graph_compute(
|
|||||||
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
||||||
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
||||||
[encoder setBytes:&m0 length:sizeof( float) atIndex:18];
|
[encoder setBytes:&m0 length:sizeof( float) atIndex:18];
|
||||||
|
|
||||||
const int nth = 32;
|
const int nth = 32;
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_ROPE:
|
case GGML_OP_ROPE:
|
||||||
|
@ -87,7 +87,12 @@ kernel void kernel_gelu(
|
|||||||
device float * dst,
|
device float * dst,
|
||||||
uint tpig[[thread_position_in_grid]]) {
|
uint tpig[[thread_position_in_grid]]) {
|
||||||
float x = src0[tpig];
|
float x = src0[tpig];
|
||||||
dst[tpig] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
|
|
||||||
|
// BEWARE !!!
|
||||||
|
// Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
|
||||||
|
// This was observed with Falcon 7B and 40B models
|
||||||
|
//
|
||||||
|
dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_soft_max(
|
kernel void kernel_soft_max(
|
||||||
@ -571,7 +576,25 @@ kernel void kernel_rope(
|
|||||||
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// TODO: implement
|
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
|
||||||
|
for (int64_t ic = 0; ic < n_dims; ic += 2) {
|
||||||
|
const float cos_theta = cos(theta);
|
||||||
|
const float sin_theta = sin(theta);
|
||||||
|
|
||||||
|
theta *= theta_scale;
|
||||||
|
|
||||||
|
const int64_t i0 = ib*n_dims + ic/2;
|
||||||
|
|
||||||
|
device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||||
|
device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||||
|
|
||||||
|
const float x0 = src[0];
|
||||||
|
const float x1 = src[n_dims/2];
|
||||||
|
|
||||||
|
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
||||||
|
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
24
ggml.c
24
ggml.c
@ -5555,10 +5555,6 @@ struct ggml_tensor * ggml_repeat(
|
|||||||
is_node = true;
|
is_node = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ggml_are_same_shape(a, b) && !is_node) {
|
|
||||||
return a;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, b->n_dims, b->ne);
|
struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, b->n_dims, b->ne);
|
||||||
|
|
||||||
result->op = GGML_OP_REPEAT;
|
result->op = GGML_OP_REPEAT;
|
||||||
@ -5789,6 +5785,7 @@ struct ggml_tensor * ggml_silu_back(
|
|||||||
static struct ggml_tensor * ggml_norm_impl(
|
static struct ggml_tensor * ggml_norm_impl(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
|
float eps,
|
||||||
bool inplace) {
|
bool inplace) {
|
||||||
bool is_node = false;
|
bool is_node = false;
|
||||||
|
|
||||||
@ -5799,7 +5796,7 @@ static struct ggml_tensor * ggml_norm_impl(
|
|||||||
|
|
||||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||||
|
|
||||||
// TODO: maybe store epsilon here?
|
ggml_set_op_params(result, &eps, sizeof(eps));
|
||||||
|
|
||||||
result->op = GGML_OP_NORM;
|
result->op = GGML_OP_NORM;
|
||||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||||
@ -5810,14 +5807,16 @@ static struct ggml_tensor * ggml_norm_impl(
|
|||||||
|
|
||||||
struct ggml_tensor * ggml_norm(
|
struct ggml_tensor * ggml_norm(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a) {
|
struct ggml_tensor * a,
|
||||||
return ggml_norm_impl(ctx, a, false);
|
float eps) {
|
||||||
|
return ggml_norm_impl(ctx, a, eps, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * ggml_norm_inplace(
|
struct ggml_tensor * ggml_norm_inplace(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a) {
|
struct ggml_tensor * a,
|
||||||
return ggml_norm_impl(ctx, a, true);
|
float eps) {
|
||||||
|
return ggml_norm_impl(ctx, a, eps, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ggml_rms_norm
|
// ggml_rms_norm
|
||||||
@ -10619,7 +10618,8 @@ static void ggml_compute_forward_norm_f32(
|
|||||||
|
|
||||||
GGML_TENSOR_UNARY_OP_LOCALS;
|
GGML_TENSOR_UNARY_OP_LOCALS;
|
||||||
|
|
||||||
const float eps = 1e-5f; // TODO: make this a parameter
|
float eps;
|
||||||
|
memcpy(&eps, dst->op_params, sizeof(float));
|
||||||
|
|
||||||
// TODO: optimize
|
// TODO: optimize
|
||||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||||
@ -12537,7 +12537,7 @@ static void ggml_compute_forward_rope_f32(
|
|||||||
dst_data[1] = x0*sin_theta*zeta + x1*cos_theta*zeta;
|
dst_data[1] = x0*sin_theta*zeta + x1*cos_theta*zeta;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// TODO: this is probably wrong, but I can't figure it out ..
|
// TODO: this might be wrong for ne0 != n_dims - need double check
|
||||||
// ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
|
// ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
|
||||||
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
|
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
|
||||||
for (int64_t ic = 0; ic < n_dims; ic += 2) {
|
for (int64_t ic = 0; ic < n_dims; ic += 2) {
|
||||||
@ -12666,7 +12666,7 @@ static void ggml_compute_forward_rope_f16(
|
|||||||
dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// TODO: this is probably wrong, but I can't figure it out ..
|
// TODO: this might be wrong for ne0 != n_dims - need double check
|
||||||
// ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
|
// ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
|
||||||
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
|
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
|
||||||
for (int64_t ic = 0; ic < n_dims; ic += 2) {
|
for (int64_t ic = 0; ic < n_dims; ic += 2) {
|
||||||
|
7
ggml.h
7
ggml.h
@ -909,14 +909,15 @@ extern "C" {
|
|||||||
struct ggml_tensor * b);
|
struct ggml_tensor * b);
|
||||||
|
|
||||||
// normalize along rows
|
// normalize along rows
|
||||||
// TODO: eps is hardcoded to 1e-5 for now
|
|
||||||
GGML_API struct ggml_tensor * ggml_norm(
|
GGML_API struct ggml_tensor * ggml_norm(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a);
|
struct ggml_tensor * a,
|
||||||
|
float eps);
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_norm_inplace(
|
GGML_API struct ggml_tensor * ggml_norm_inplace(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a);
|
struct ggml_tensor * a,
|
||||||
|
float eps);
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_rms_norm(
|
GGML_API struct ggml_tensor * ggml_rms_norm(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
|
26
gguf.py
26
gguf.py
@ -30,12 +30,12 @@ KEY_GENERAL_SOURCE_HF_REPO = "general.source.hugginface.repository"
|
|||||||
KEY_GENERAL_FILE_TYPE = "general.file_type"
|
KEY_GENERAL_FILE_TYPE = "general.file_type"
|
||||||
|
|
||||||
# LLM
|
# LLM
|
||||||
KEY_LLM_CONTEXT_LENGTH = "{arch}.context_length"
|
KEY_CONTEXT_LENGTH = "{arch}.context_length"
|
||||||
KEY_LLM_EMBEDDING_LENGTH = "{arch}.embedding_length"
|
KEY_EMBEDDING_LENGTH = "{arch}.embedding_length"
|
||||||
KEY_LLM_BLOCK_COUNT = "{arch}.block_count"
|
KEY_BLOCK_COUNT = "{arch}.block_count"
|
||||||
KEY_LLM_FEED_FORWARD_LENGTH = "{arch}.feed_forward_length"
|
KEY_FEED_FORWARD_LENGTH = "{arch}.feed_forward_length"
|
||||||
KEY_LLM_USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual"
|
KEY_USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual"
|
||||||
KEY_LLM_TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout"
|
KEY_TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout"
|
||||||
|
|
||||||
# attention
|
# attention
|
||||||
KEY_ATTENTION_HEAD_COUNT = "{arch}.attention.head_count"
|
KEY_ATTENTION_HEAD_COUNT = "{arch}.attention.head_count"
|
||||||
@ -583,7 +583,7 @@ class GGUFWriter:
|
|||||||
self.add_string(KEY_GENERAL_AUTHOR, author)
|
self.add_string(KEY_GENERAL_AUTHOR, author)
|
||||||
|
|
||||||
def add_tensor_data_layout(self, layout: str):
|
def add_tensor_data_layout(self, layout: str):
|
||||||
self.add_string(KEY_LLM_TENSOR_DATA_LAYOUT.format(arch=self.arch), layout)
|
self.add_string(KEY_TENSOR_DATA_LAYOUT.format(arch=self.arch), layout)
|
||||||
|
|
||||||
def add_url(self, url: str):
|
def add_url(self, url: str):
|
||||||
self.add_string(KEY_GENERAL_URL, url)
|
self.add_string(KEY_GENERAL_URL, url)
|
||||||
@ -613,27 +613,27 @@ class GGUFWriter:
|
|||||||
|
|
||||||
def add_context_length(self, length: int):
|
def add_context_length(self, length: int):
|
||||||
self.add_uint32(
|
self.add_uint32(
|
||||||
KEY_LLM_CONTEXT_LENGTH.format(arch=self.arch), length)
|
KEY_CONTEXT_LENGTH.format(arch=self.arch), length)
|
||||||
|
|
||||||
def add_embedding_length(self, length: int):
|
def add_embedding_length(self, length: int):
|
||||||
self.add_uint32(
|
self.add_uint32(
|
||||||
KEY_LLM_EMBEDDING_LENGTH.format(arch=self.arch), length)
|
KEY_EMBEDDING_LENGTH.format(arch=self.arch), length)
|
||||||
|
|
||||||
def add_block_count(self, length: int):
|
def add_block_count(self, length: int):
|
||||||
self.add_uint32(
|
self.add_uint32(
|
||||||
KEY_LLM_BLOCK_COUNT.format(arch=self.arch), length)
|
KEY_BLOCK_COUNT.format(arch=self.arch), length)
|
||||||
|
|
||||||
def add_feed_forward_length(self, length: int):
|
def add_feed_forward_length(self, length: int):
|
||||||
self.add_uint32(
|
self.add_uint32(
|
||||||
KEY_LLM_FEED_FORWARD_LENGTH.format(arch=self.arch), length)
|
KEY_FEED_FORWARD_LENGTH.format(arch=self.arch), length)
|
||||||
|
|
||||||
def add_parallel_residual(self, use: bool):
|
def add_parallel_residual(self, use: bool):
|
||||||
self.add_bool(
|
self.add_bool(
|
||||||
KEY_LLM_USE_PARALLEL_RESIDUAL.format(arch=self.arch), use)
|
KEY_USE_PARALLEL_RESIDUAL.format(arch=self.arch), use)
|
||||||
|
|
||||||
def add_tensor_data_layout(self, layout: str):
|
def add_tensor_data_layout(self, layout: str):
|
||||||
self.add_string(
|
self.add_string(
|
||||||
KEY_LLM_TENSOR_DATA_LAYOUT.format(arch=self.arch), layout)
|
KEY_TENSOR_DATA_LAYOUT.format(arch=self.arch), layout)
|
||||||
|
|
||||||
def add_head_count(self, count: int):
|
def add_head_count(self, count: int):
|
||||||
self.add_uint32(
|
self.add_uint32(
|
||||||
|
15
llama.h
15
llama.h
@ -247,6 +247,8 @@ extern "C" {
|
|||||||
LLAMA_API int llama_n_ctx (const struct llama_context * ctx);
|
LLAMA_API int llama_n_ctx (const struct llama_context * ctx);
|
||||||
LLAMA_API int llama_n_embd (const struct llama_context * ctx);
|
LLAMA_API int llama_n_embd (const struct llama_context * ctx);
|
||||||
|
|
||||||
|
LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_context * ctx);
|
||||||
|
|
||||||
LLAMA_API int llama_model_n_vocab(const struct llama_model * model);
|
LLAMA_API int llama_model_n_vocab(const struct llama_model * model);
|
||||||
LLAMA_API int llama_model_n_ctx (const struct llama_model * model);
|
LLAMA_API int llama_model_n_ctx (const struct llama_model * model);
|
||||||
LLAMA_API int llama_model_n_embd (const struct llama_model * model);
|
LLAMA_API int llama_model_n_embd (const struct llama_model * model);
|
||||||
@ -368,13 +370,6 @@ extern "C" {
|
|||||||
int n_max_tokens,
|
int n_max_tokens,
|
||||||
bool add_bos);
|
bool add_bos);
|
||||||
|
|
||||||
LLAMA_API int llama_tokenize_bpe(
|
|
||||||
struct llama_context * ctx,
|
|
||||||
const char * text,
|
|
||||||
llama_token * tokens,
|
|
||||||
int n_max_tokens,
|
|
||||||
bool add_bos);
|
|
||||||
|
|
||||||
LLAMA_API int llama_tokenize_with_model(
|
LLAMA_API int llama_tokenize_with_model(
|
||||||
const struct llama_model * model,
|
const struct llama_model * model,
|
||||||
const char * text,
|
const char * text,
|
||||||
@ -390,12 +385,6 @@ extern "C" {
|
|||||||
char * buf,
|
char * buf,
|
||||||
int length);
|
int length);
|
||||||
|
|
||||||
LLAMA_API int llama_token_to_str_bpe(
|
|
||||||
const struct llama_context * ctx,
|
|
||||||
llama_token token,
|
|
||||||
char * buf,
|
|
||||||
int length);
|
|
||||||
|
|
||||||
LLAMA_API int llama_token_to_str_with_model(
|
LLAMA_API int llama_token_to_str_with_model(
|
||||||
const struct llama_model * model,
|
const struct llama_model * model,
|
||||||
llama_token token,
|
llama_token token,
|
||||||
|
@ -28,7 +28,8 @@ llama_build_and_test_executable(test-sampling.cpp)
|
|||||||
llama_build_executable(test-tokenizer-0.cpp)
|
llama_build_executable(test-tokenizer-0.cpp)
|
||||||
llama_test_executable (test-tokenizer-0.llama test-tokenizer-0.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.gguf)
|
llama_test_executable (test-tokenizer-0.llama test-tokenizer-0.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.gguf)
|
||||||
llama_build_executable(test-tokenizer-1.cpp)
|
llama_build_executable(test-tokenizer-1.cpp)
|
||||||
llama_test_executable (test-tokenizer-1.llama test-tokenizer-1.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.gguf)
|
# test-tokenizer-1 requires a BPE vocab. re-enable when we have one.
|
||||||
|
#llama_test_executable (test-tokenizer-1.llama test-tokenizer-1.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-falcon.gguf)
|
||||||
#llama_test_executable(test-tokenizer-1.aquila test-tokenizer-1.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-aquila.gguf)
|
#llama_test_executable(test-tokenizer-1.aquila test-tokenizer-1.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-aquila.gguf)
|
||||||
llama_build_and_test_executable(test-grammar-parser.cpp)
|
llama_build_and_test_executable(test-grammar-parser.cpp)
|
||||||
llama_build_and_test_executable(test-llama-grammar.cpp)
|
llama_build_and_test_executable(test-llama-grammar.cpp)
|
||||||
|
@ -67,11 +67,13 @@ int main(int argc, char **argv) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
GGML_ASSERT(llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_BPE);
|
||||||
|
|
||||||
const int n_vocab = llama_n_vocab(ctx);
|
const int n_vocab = llama_n_vocab(ctx);
|
||||||
|
|
||||||
for (int i = 0; i < n_vocab; ++i) {
|
for (int i = 0; i < n_vocab; ++i) {
|
||||||
std::string forward = llama_token_to_str_bpe(ctx, i);
|
std::string forward = llama_token_to_str(ctx, i);
|
||||||
std::vector<llama_token> tokens = llama_tokenize_bpe(ctx, forward, false);
|
std::vector<llama_token> tokens = llama_tokenize(ctx, forward, false);
|
||||||
if (tokens.size() == 1) {
|
if (tokens.size() == 1) {
|
||||||
if (i != tokens[0]) {
|
if (i != tokens[0]) {
|
||||||
std::string backward = llama_token_to_str(ctx, tokens[0]);
|
std::string backward = llama_token_to_str(ctx, tokens[0]);
|
||||||
@ -79,16 +81,6 @@ int main(int argc, char **argv) {
|
|||||||
__func__, i, llama_token_to_str(ctx, i).c_str(), tokens[0], backward.c_str());
|
__func__, i, llama_token_to_str(ctx, i).c_str(), tokens[0], backward.c_str());
|
||||||
return 2;
|
return 2;
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
llama_token_type type = llama_token_get_type(ctx, i);
|
|
||||||
if (type == LLAMA_TOKEN_TYPE_UNKNOWN || type == LLAMA_TOKEN_TYPE_CONTROL || type == LLAMA_TOKEN_TYPE_BYTE) {
|
|
||||||
fprintf(stderr, "%s : info: token %d is string %s and bpe returns tokens %s\n",
|
|
||||||
__func__, i, llama_token_to_str(ctx, i).c_str(), unescape_whitespace(ctx, tokens).c_str());
|
|
||||||
} else {
|
|
||||||
fprintf(stderr, "%s : error: token %d is string %s but bpe returns tokens %s\n",
|
|
||||||
__func__, i, llama_token_to_str(ctx, i).c_str(), unescape_whitespace(ctx, tokens).c_str());
|
|
||||||
return 2;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user