llm : add Falcon support (#2717)

* llama : refactor GGUF constants into static maps

* llama : check if model architecture is known

* llama : refactor llama_model_load_internal()

* gguf : add KV constant maps

* llm : read arch-specific KVs

* convert : add dummy scores + types

* falcon : load tensor data (CPU only)

* llama : fix loading progress bar

* llama : add arch member to llama_model

* falcon : CPU inference working

* falcon : support non-40B models

* falcon : minor

* llama : minor updates

ggml-ci

* convert-falcon-hf-to-gguf.py : fix special token mapping

* llama.cpp : llama default UNK token = id 0

* llama.cpp : fix bpe tokenizer

* llama.cpp : fix the fix of bpe tokenizer

* ggml : pass eps to ggml_norm

* metal : implement RoPE (mode = 2) + avoid ggml_repeat

* ggml : ggml_repeat always creates new tensor

* falcon : copy-paste self-attention from LLaMA

* metal : print extra compute pipeline info

* falcon : minor changes (still chasing the Metal problem)

* llama.cpp : fix linefeed token

* metal : fix GELU kernel numerical stability by using precise::tanh

* metal : temporary workaround for the concurrency optimization bug

* falcon : add CUDA offloading (#2739)

* llama : better model naming and size reporting

* llama : prep new tokenizer support

* llama : advanced BPE tokenizer based on ggllm.cpp imlpementation

* llama : remove oboslete comment

ggml-ci

* common : remove obsolete BPE API + disable test-tokenizer-1

* llama : revert BPE special-case in llama_byte_to_token()

* cuda : add TODOs for RoPE NeoX implementation

* llama : default special tokens based on vocab type

* perplexity : add log for start of tokenization

---------

Co-authored-by: klosax <131523366+klosax@users.noreply.github.com>
Co-authored-by: slaren <slarengh@gmail.com>
This commit is contained in:
Georgi Gerganov 2023-08-23 23:08:04 +03:00 committed by GitHub
parent a192860cfe
commit cf658adc83
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 1596 additions and 668 deletions

View File

@ -744,35 +744,3 @@ std::string llama_token_to_str(const struct llama_context * ctx, llama_token tok
return std::string(result.data(), result.size()); return std::string(result.data(), result.size());
} }
std::vector<llama_token> llama_tokenize_bpe(
struct llama_context * ctx,
const std::string & text,
bool add_bos) {
int n_tokens = text.length() + add_bos;
std::vector<llama_token> result(n_tokens);
n_tokens = llama_tokenize_bpe(ctx, text.c_str(), result.data(), result.size(), add_bos);
if (n_tokens < 0) {
result.resize(-n_tokens);
int check = llama_tokenize_bpe(ctx, text.c_str(), result.data(), result.size(), add_bos);
GGML_ASSERT(check == -n_tokens);
} else {
result.resize(n_tokens);
}
return result;
}
std::string llama_token_to_str_bpe(const struct llama_context * ctx, llama_token token) {
std::vector<char> result(8, 0);
const int n_tokens = llama_token_to_str_bpe(ctx, token, result.data(), result.size());
if (n_tokens < 0) {
result.resize(-n_tokens);
const int check = llama_token_to_str_bpe(ctx, token, result.data(), result.size());
GGML_ASSERT(check == -n_tokens);
} else {
result.resize(n_tokens);
}
return std::string(result.data(), result.size());
}

View File

@ -120,15 +120,6 @@ std::vector<llama_token> llama_tokenize(
const std::string & text, const std::string & text,
bool add_bos); bool add_bos);
std::vector<llama_token> llama_tokenize_bpe(
struct llama_context * ctx,
const std::string & text,
bool add_bos);
std::string llama_token_to_str( std::string llama_token_to_str(
const struct llama_context * ctx, const struct llama_context * ctx,
llama_token token); llama_token token);
std::string llama_token_to_str_bpe(
const struct llama_context * ctx,
llama_token token);

View File

@ -95,14 +95,17 @@ print("gguf: get model metadata")
block_count = hparams["n_layer"] block_count = hparams["n_layer"]
gguf_writer.add_name(last_dir) gguf_writer.add_name("Falcon")
gguf_writer.add_context_length(2048) # not in config.json gguf_writer.add_context_length(2048) # not in config.json
gguf_writer.add_tensor_data_layout("jploski") # qkv tensor transform gguf_writer.add_tensor_data_layout("jploski") # qkv tensor transform
gguf_writer.add_embedding_length(hparams["hidden_size"]) gguf_writer.add_embedding_length(hparams["hidden_size"])
gguf_writer.add_feed_forward_length(4 * hparams["hidden_size"]) gguf_writer.add_feed_forward_length(4 * hparams["hidden_size"])
gguf_writer.add_block_count(block_count) gguf_writer.add_block_count(block_count)
gguf_writer.add_head_count(hparams["n_head"]) gguf_writer.add_head_count(hparams["n_head"])
if "n_head_kv" in hparams: gguf_writer.add_head_count_kv(hparams["n_head_kv"]) if "n_head_kv" in hparams:
gguf_writer.add_head_count_kv(hparams["n_head_kv"])
else:
gguf_writer.add_head_count_kv(1)
gguf_writer.add_layer_norm_eps(hparams["layer_norm_epsilon"]) gguf_writer.add_layer_norm_eps(hparams["layer_norm_epsilon"])
# TOKENIZATION # TOKENIZATION
@ -110,6 +113,8 @@ gguf_writer.add_layer_norm_eps(hparams["layer_norm_epsilon"])
print("gguf: get tokenizer metadata") print("gguf: get tokenizer metadata")
tokens: List[str] = [] tokens: List[str] = []
scores: List[float] = []
toktypes: List[int] = []
merges: List[str] = [] merges: List[str] = []
@ -153,41 +158,30 @@ if Path(dir_model + "/tokenizer.json").is_file():
text = bytearray(pad_token) text = bytearray(pad_token)
tokens.append(text) tokens.append(text)
scores.append(0.0) # dymmy
toktypes.append(gguf.TokenType.NORMAL) # dummy
gguf_writer.add_token_list(tokens) gguf_writer.add_token_list(tokens)
gguf_writer.add_token_scores(scores)
gguf_writer.add_token_types(toktypes)
if "added_tokens" in tokenizer_json and Path(dir_model + "/tokenizer_config.json").is_file(): print("gguf: get special token ids")
print("gguf: get special token ids") # Look for special tokens in config.json
with open(dir_model + "/tokenizer_config.json", "r", encoding="utf-8") as f: if "bos_token_id" in hparams and hparams["bos_token_id"] != None:
tokenizer_config = json.load(f) gguf_writer.add_bos_token_id(hparams["bos_token_id"])
# find special token ids if "eos_token_id" in hparams and hparams["eos_token_id"] != None:
gguf_writer.add_eos_token_id(hparams["eos_token_id"])
if "bos_token" in tokenizer_config: if "unk_token_id" in hparams and hparams["unk_token_id"] != None:
for key in tokenizer_json["added_tokens"]: gguf_writer.add_unk_token_id(hparams["unk_token_id"])
if key["content"] == tokenizer_config["bos_token"]:
gguf_writer.add_bos_token_id(key["id"])
if "eos_token" in tokenizer_config: if "sep_token_id" in hparams and hparams["sep_token_id"] != None:
for key in tokenizer_json["added_tokens"]: gguf_writer.add_sep_token_id(hparams["sep_token_id"])
if key["content"] == tokenizer_config["eos_token"]:
gguf_writer.add_eos_token_id(key["id"])
if "unk_token" in tokenizer_config: if "pad_token_id" in hparams and hparams["pad_token_id"] != None:
for key in tokenizer_json["added_tokens"]: gguf_writer.add_pad_token_id(hparams["pad_token_id"])
if key["content"] == tokenizer_config["unk_token"]:
gguf_writer.add_unk_token_id(key["id"])
if "sep_token" in tokenizer_config:
for key in tokenizer_json["added_tokens"]:
if key["content"] == tokenizer_config["sep_token"]:
gguf_writer.add_sep_token_id(key["id"])
if "pad_token" in tokenizer_config:
for key in tokenizer_json["added_tokens"]:
if key["content"] == tokenizer_config["pad_token"]:
gguf_writer.add_pad_token_id(key["id"])
# TENSORS # TENSORS
@ -197,6 +191,7 @@ tensor_map = gguf.get_tensor_name_map(ARCH,block_count)
# params for qkv transform # params for qkv transform
n_head = hparams["n_head"] n_head = hparams["n_head"]
n_head_kv = hparams["n_head_kv"] if "n_head_kv" in hparams else 1 n_head_kv = hparams["n_head_kv"] if "n_head_kv" in hparams else 1
head_dim = hparams["hidden_size"] // n_head head_dim = hparams["hidden_size"] // n_head
# tensor info # tensor info

View File

@ -733,7 +733,11 @@ class OutputFile:
self.gguf = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH]) self.gguf = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH])
def add_meta_arch(self, params: Params) -> None: def add_meta_arch(self, params: Params) -> None:
self.gguf.add_name ("LLaMA") ver = None
if (params.n_ctx == 4096):
ver = "v2"
self.gguf.add_name ("LLaMA" if ver == None else "LLaMA " + ver)
self.gguf.add_context_length (params.n_ctx) self.gguf.add_context_length (params.n_ctx)
self.gguf.add_embedding_length (params.n_embd) self.gguf.add_embedding_length (params.n_embd)
self.gguf.add_block_count (params.n_layer) self.gguf.add_block_count (params.n_layer)

View File

@ -43,7 +43,7 @@ static bool is_interacting = false;
void sigint_handler(int signo) { void sigint_handler(int signo) {
if (signo == SIGINT) { if (signo == SIGINT) {
if (!is_interacting) { if (!is_interacting) {
is_interacting=true; is_interacting = true;
} else { } else {
console::cleanup(); console::cleanup();
printf("\n"); printf("\n");
@ -189,10 +189,12 @@ int main(int argc, char ** argv) {
} }
} }
const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM;
// tokenize the prompt // tokenize the prompt
std::vector<llama_token> embd_inp; std::vector<llama_token> embd_inp;
if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) { if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) {
embd_inp = ::llama_tokenize(ctx, params.prompt, true); embd_inp = ::llama_tokenize(ctx, params.prompt, is_spm);
} else { } else {
embd_inp = session_tokens; embd_inp = session_tokens;
} }
@ -208,9 +210,9 @@ int main(int argc, char ** argv) {
int original_prompt_len = 0; int original_prompt_len = 0;
if (ctx_guidance) { if (ctx_guidance) {
params.cfg_negative_prompt.insert(0, 1, ' '); params.cfg_negative_prompt.insert(0, 1, ' ');
guidance_inp = ::llama_tokenize(ctx_guidance, params.cfg_negative_prompt, true); guidance_inp = ::llama_tokenize(ctx_guidance, params.cfg_negative_prompt, is_spm);
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, true); std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, is_spm);
original_prompt_len = original_inp.size(); original_prompt_len = original_inp.size();
guidance_offset = (int)guidance_inp.size() - original_prompt_len; guidance_offset = (int)guidance_inp.size() - original_prompt_len;
} }
@ -257,7 +259,7 @@ int main(int argc, char ** argv) {
} }
// prefix & suffix for instruct mode // prefix & suffix for instruct mode
const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", true); const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", is_spm);
const auto inp_sfx = ::llama_tokenize(ctx, "\n\n### Response:\n\n", false); const auto inp_sfx = ::llama_tokenize(ctx, "\n\n### Response:\n\n", false);
// in instruct mode, we inject a prefix and a suffix to each input by the user // in instruct mode, we inject a prefix and a suffix to each input by the user

View File

@ -28,7 +28,6 @@ std::vector<float> softmax(const std::vector<float>& logits) {
} }
void perplexity_v2(llama_context * ctx, const gpt_params & params) { void perplexity_v2(llama_context * ctx, const gpt_params & params) {
// Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research // Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
// Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw` // Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
// Output: `perplexity: 13.5106 [114/114]` // Output: `perplexity: 13.5106 [114/114]`
@ -38,7 +37,13 @@ void perplexity_v2(llama_context * ctx, const gpt_params & params) {
fprintf(stderr, "%s: stride is %d but must be greater than zero!\n",__func__,params.ppl_stride); fprintf(stderr, "%s: stride is %d but must be greater than zero!\n",__func__,params.ppl_stride);
return; return;
} }
auto tokens = ::llama_tokenize(ctx, params.prompt, true);
const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM;
const bool add_bos = is_spm;
fprintf(stderr, "%s: tokenizing the input ..\n", __func__);
auto tokens = ::llama_tokenize(ctx, params.prompt, add_bos);
const int calc_chunk = params.n_ctx; const int calc_chunk = params.n_ctx;
@ -86,7 +91,7 @@ void perplexity_v2(llama_context * ctx, const gpt_params & params) {
const auto token_org = tokens[batch_start]; const auto token_org = tokens[batch_start];
// add BOS token for the first batch of each chunk // add BOS token for the first batch of each chunk
if (j == 0) { if (add_bos && j == 0) {
tokens[batch_start] = llama_token_bos(ctx); tokens[batch_start] = llama_token_bos(ctx);
} }
@ -136,7 +141,6 @@ void perplexity_v2(llama_context * ctx, const gpt_params & params) {
} }
void perplexity(llama_context * ctx, const gpt_params & params) { void perplexity(llama_context * ctx, const gpt_params & params) {
if (params.ppl_stride > 0) { if (params.ppl_stride > 0) {
perplexity_v2(ctx, params); perplexity_v2(ctx, params);
return; return;
@ -146,7 +150,13 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
// Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw` // Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
// Output: `perplexity: 13.5106 [114/114]` // Output: `perplexity: 13.5106 [114/114]`
// BOS tokens will be added for each chunk before eval // BOS tokens will be added for each chunk before eval
auto tokens = ::llama_tokenize(ctx, params.prompt, true);
const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM;
const bool add_bos = is_spm;
fprintf(stderr, "%s: tokenizing the input ..\n", __func__);
auto tokens = ::llama_tokenize(ctx, params.prompt, add_bos);
const int n_chunk_max = tokens.size() / params.n_ctx; const int n_chunk_max = tokens.size() / params.n_ctx;
@ -177,7 +187,7 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
const auto token_org = tokens[batch_start]; const auto token_org = tokens[batch_start];
// add BOS token for the first batch of each chunk // add BOS token for the first batch of each chunk
if (j == 0) { if (add_bos && j == 0) {
tokens[batch_start] = llama_token_bos(ctx); tokens[batch_start] = llama_token_bos(ctx);
} }
@ -295,8 +305,10 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
size_t hs_task_count = prompt_lines.size()/6; size_t hs_task_count = prompt_lines.size()/6;
fprintf(stderr, "%s : loaded %zu tasks from prompt.\n", __func__, hs_task_count); fprintf(stderr, "%s : loaded %zu tasks from prompt.\n", __func__, hs_task_count);
const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM;
// This is needed as usual for LLaMA models // This is needed as usual for LLaMA models
bool prepend_bos = true; const bool add_bos = is_spm;
// Number of tasks to use when computing the score // Number of tasks to use when computing the score
if ( params.hellaswag_tasks < hs_task_count ) { if ( params.hellaswag_tasks < hs_task_count ) {
@ -352,14 +364,13 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
std::vector<float> tok_logits(n_vocab); std::vector<float> tok_logits(n_vocab);
for (size_t task_idx = 0; task_idx < hs_task_count; task_idx++) { for (size_t task_idx = 0; task_idx < hs_task_count; task_idx++) {
// Tokenize the context to count tokens // Tokenize the context to count tokens
std::vector<int> context_embd = ::llama_tokenize(ctx, hs_data[task_idx].context, prepend_bos); std::vector<int> context_embd = ::llama_tokenize(ctx, hs_data[task_idx].context, add_bos);
size_t context_size = context_embd.size(); size_t context_size = context_embd.size();
// Do the 1st ending // Do the 1st ending
// In this case we include the context when evaluating // In this case we include the context when evaluating
auto query_embd = ::llama_tokenize(ctx, hs_data[task_idx].context + hs_data[task_idx].ending[0], prepend_bos); auto query_embd = ::llama_tokenize(ctx, hs_data[task_idx].context + hs_data[task_idx].ending[0], add_bos);
auto query_size = query_embd.size(); auto query_size = query_embd.size();
//printf("First query: %d\n",(int)query_size); //printf("First query: %d\n",(int)query_size);

View File

@ -238,7 +238,7 @@ static void ggml_allocator_free_tensor(struct ggml_allocr * alloc, struct ggml_t
alloc->n_free_blocks++; alloc->n_free_blocks++;
} }
void ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, int * list, int n) { void ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, const int * list, int n) {
int pos = 0; int pos = 0;
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
if (list[i] != -1) { if (list[i] != -1) {
@ -547,7 +547,7 @@ static size_t ggml_allocator_alloc_graph_tensors_n(
struct ggml_tensor * view_src = get_view_source(parent); struct ggml_tensor * view_src = get_view_source(parent);
struct hash_node * view_src_hn = hash_get(ht, view_src); struct hash_node * view_src_hn = hash_get(ht, view_src);
view_src_hn->n_views -= 1; view_src_hn->n_views -= 1;
AT_PRINTF("view_src %s: %d children, %d views\n", view_src->name, view_src->n_children, view_src->n_views); AT_PRINTF("view_src %s\n", view_src->name);
if (view_src_hn->n_views == 0 && view_src_hn->n_children == 0 && view_src->data != node->data) { if (view_src_hn->n_views == 0 && view_src_hn->n_children == 0 && view_src->data != node->data) {
ggml_allocator_free_tensor(alloc, view_src); ggml_allocator_free_tensor(alloc, view_src);
} }

View File

@ -12,7 +12,7 @@ GGML_API struct ggml_allocr * ggml_allocr_new_measure(size_t alignment);
// tell the allocator to parse nodes following the order described in the list // tell the allocator to parse nodes following the order described in the list
// you should call this if your graph are optimized to execute out-of-order // you should call this if your graph are optimized to execute out-of-order
GGML_API void ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, int * list, int n); GGML_API void ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, const int * list, int n);
GGML_API void ggml_allocr_free(struct ggml_allocr * alloc); GGML_API void ggml_allocr_free(struct ggml_allocr * alloc);
GGML_API bool ggml_allocr_is_measure(struct ggml_allocr * alloc); GGML_API bool ggml_allocr_is_measure(struct ggml_allocr * alloc);

View File

@ -3907,6 +3907,29 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
dst[i + 1] = x0*sin_theta + x1*cos_theta; dst[i + 1] = x0*sin_theta + x1*cos_theta;
} }
// TODO: this implementation is wrong!
//static __global__ void rope_neox_f32(const float * x, float * dst, const int ncols, const float p0,
// const float p_delta, const int p_delta_rows, const float theta_scale) {
// const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
//
// if (col >= ncols) {
// return;
// }
//
// const int row = blockDim.x*blockIdx.x + threadIdx.x;
// const int i = row*ncols + col/2;
//
// const float theta = (p0 + p_delta * (row/p_delta_rows))*powf(theta_scale, col/2);
// const float sin_theta = sinf(theta);
// const float cos_theta = cosf(theta);
//
// const float x0 = x[i + 0];
// const float x1 = x[i + ncols/2];
//
// dst[i + 0] = x0*cos_theta - x1*sin_theta;
// dst[i + ncols/2] = x0*sin_theta + x1*cos_theta;
//}
static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const float p, const float block_p, const float theta_scale) { static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const float p, const float block_p, const float theta_scale) {
const int col = blockDim.x*blockIdx.x + threadIdx.x; const int col = blockDim.x*blockIdx.x + threadIdx.x;
const int half_n_dims = ncols/4; const int half_n_dims = ncols/4;
@ -5515,6 +5538,7 @@ inline void ggml_cuda_op_rope(
const float theta_scale = powf(freq_base, -2.0f/n_dims); const float theta_scale = powf(freq_base, -2.0f/n_dims);
const bool is_neox = mode & 2;
const bool is_glm = mode & 4; const bool is_glm = mode & 4;
// compute // compute
@ -5523,6 +5547,9 @@ inline void ggml_cuda_op_rope(
const float id_p = min(p, n_ctx - 2.f); const float id_p = min(p, n_ctx - 2.f);
const float block_p = max(p - (n_ctx - 2.f), 0.f); const float block_p = max(p - (n_ctx - 2.f), 0.f);
rope_glm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, id_p, block_p, theta_scale, cudaStream_main); rope_glm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, id_p, block_p, theta_scale, cudaStream_main);
} else if (is_neox) {
GGML_ASSERT(false && "RoPE NeoX not implemented yet");
#pragma message("TODO: implement RoPE NeoX for CUDA")
} else { } else {
const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale; const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale;
rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, cudaStream_main); rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, cudaStream_main);

View File

@ -167,7 +167,9 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
#define GGML_METAL_ADD_KERNEL(name) \ #define GGML_METAL_ADD_KERNEL(name) \
ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \ ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \ ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name); \ fprintf(stderr, "%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
(int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \
(int) ctx->pipeline_##name.threadExecutionWidth); \
if (error) { \ if (error) { \
fprintf(stderr, "%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \ fprintf(stderr, "%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
return NULL; \ return NULL; \
@ -538,7 +540,7 @@ void ggml_metal_graph_compute(
id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
const int node_start = (cb_idx + 0) * n_nodes_per_cb; const int node_start = (cb_idx + 0) * n_nodes_per_cb;
const int node_end = (cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb; const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
for (int ind = node_start; ind < node_end; ++ind) { for (int ind = node_start; ind < node_end; ++ind) {
const int i = has_concur ? ctx->concur_list[ind] : ind; const int i = has_concur ? ctx->concur_list[ind] : ind;
@ -768,8 +770,7 @@ void ggml_metal_graph_compute(
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:10]; [encoder setBytes:&gqa length:sizeof(gqa) atIndex:10];
[encoder setThreadgroupMemoryLength:8192 atIndex:0]; [encoder setThreadgroupMemoryLength:8192 atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
} } else {
else {
int nth0 = 32; int nth0 = 32;
int nth1 = 1; int nth1 = 1;
@ -872,20 +873,20 @@ void ggml_metal_graph_compute(
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) { src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} }
else if (src0t == GGML_TYPE_Q3_K) { else if (src0t == GGML_TYPE_Q3_K) {
#ifdef GGML_QKK_64 #ifdef GGML_QKK_64
[encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
#else #else
[encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
#endif #endif
} }
else if (src0t == GGML_TYPE_Q5_K) { else if (src0t == GGML_TYPE_Q5_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} }
else if (src0t == GGML_TYPE_Q6_K) { else if (src0t == GGML_TYPE_Q6_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} else { } else {
[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0]; [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@ -938,7 +939,8 @@ void ggml_metal_graph_compute(
} break; } break;
case GGML_OP_NORM: case GGML_OP_NORM:
{ {
const float eps = 1e-5f; float eps;
memcpy(&eps, dst->op_params, sizeof(float));
const int nth = 256; const int nth = 256;
@ -990,7 +992,9 @@ void ggml_metal_graph_compute(
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
[encoder setBytes:&m0 length:sizeof( float) atIndex:18]; [encoder setBytes:&m0 length:sizeof( float) atIndex:18];
const int nth = 32; const int nth = 32;
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break; } break;
case GGML_OP_ROPE: case GGML_OP_ROPE:

View File

@ -87,7 +87,12 @@ kernel void kernel_gelu(
device float * dst, device float * dst,
uint tpig[[thread_position_in_grid]]) { uint tpig[[thread_position_in_grid]]) {
float x = src0[tpig]; float x = src0[tpig];
dst[tpig] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
// BEWARE !!!
// Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
// This was observed with Falcon 7B and 40B models
//
dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
} }
kernel void kernel_soft_max( kernel void kernel_soft_max(
@ -571,7 +576,25 @@ kernel void kernel_rope(
dst_data[1] = x0*sin_theta + x1*cos_theta; dst_data[1] = x0*sin_theta + x1*cos_theta;
} }
} else { } else {
// TODO: implement for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
for (int64_t ic = 0; ic < n_dims; ic += 2) {
const float cos_theta = cos(theta);
const float sin_theta = sin(theta);
theta *= theta_scale;
const int64_t i0 = ib*n_dims + ic/2;
device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
const float x0 = src[0];
const float x1 = src[n_dims/2];
dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
}
}
} }
} }

24
ggml.c
View File

@ -5555,10 +5555,6 @@ struct ggml_tensor * ggml_repeat(
is_node = true; is_node = true;
} }
if (ggml_are_same_shape(a, b) && !is_node) {
return a;
}
struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, b->n_dims, b->ne); struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, b->n_dims, b->ne);
result->op = GGML_OP_REPEAT; result->op = GGML_OP_REPEAT;
@ -5789,6 +5785,7 @@ struct ggml_tensor * ggml_silu_back(
static struct ggml_tensor * ggml_norm_impl( static struct ggml_tensor * ggml_norm_impl(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
float eps,
bool inplace) { bool inplace) {
bool is_node = false; bool is_node = false;
@ -5799,7 +5796,7 @@ static struct ggml_tensor * ggml_norm_impl(
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
// TODO: maybe store epsilon here? ggml_set_op_params(result, &eps, sizeof(eps));
result->op = GGML_OP_NORM; result->op = GGML_OP_NORM;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@ -5810,14 +5807,16 @@ static struct ggml_tensor * ggml_norm_impl(
struct ggml_tensor * ggml_norm( struct ggml_tensor * ggml_norm(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a) { struct ggml_tensor * a,
return ggml_norm_impl(ctx, a, false); float eps) {
return ggml_norm_impl(ctx, a, eps, false);
} }
struct ggml_tensor * ggml_norm_inplace( struct ggml_tensor * ggml_norm_inplace(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a) { struct ggml_tensor * a,
return ggml_norm_impl(ctx, a, true); float eps) {
return ggml_norm_impl(ctx, a, eps, true);
} }
// ggml_rms_norm // ggml_rms_norm
@ -10619,7 +10618,8 @@ static void ggml_compute_forward_norm_f32(
GGML_TENSOR_UNARY_OP_LOCALS; GGML_TENSOR_UNARY_OP_LOCALS;
const float eps = 1e-5f; // TODO: make this a parameter float eps;
memcpy(&eps, dst->op_params, sizeof(float));
// TODO: optimize // TODO: optimize
for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i03 = 0; i03 < ne03; i03++) {
@ -12537,7 +12537,7 @@ static void ggml_compute_forward_rope_f32(
dst_data[1] = x0*sin_theta*zeta + x1*cos_theta*zeta; dst_data[1] = x0*sin_theta*zeta + x1*cos_theta*zeta;
} }
} else { } else {
// TODO: this is probably wrong, but I can't figure it out .. // TODO: this might be wrong for ne0 != n_dims - need double check
// ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28 // ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
for (int64_t ic = 0; ic < n_dims; ic += 2) { for (int64_t ic = 0; ic < n_dims; ic += 2) {
@ -12666,7 +12666,7 @@ static void ggml_compute_forward_rope_f16(
dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
} }
} else { } else {
// TODO: this is probably wrong, but I can't figure it out .. // TODO: this might be wrong for ne0 != n_dims - need double check
// ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28 // ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
for (int64_t ic = 0; ic < n_dims; ic += 2) { for (int64_t ic = 0; ic < n_dims; ic += 2) {

7
ggml.h
View File

@ -909,14 +909,15 @@ extern "C" {
struct ggml_tensor * b); struct ggml_tensor * b);
// normalize along rows // normalize along rows
// TODO: eps is hardcoded to 1e-5 for now
GGML_API struct ggml_tensor * ggml_norm( GGML_API struct ggml_tensor * ggml_norm(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a); struct ggml_tensor * a,
float eps);
GGML_API struct ggml_tensor * ggml_norm_inplace( GGML_API struct ggml_tensor * ggml_norm_inplace(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a); struct ggml_tensor * a,
float eps);
GGML_API struct ggml_tensor * ggml_rms_norm( GGML_API struct ggml_tensor * ggml_rms_norm(
struct ggml_context * ctx, struct ggml_context * ctx,

26
gguf.py
View File

@ -30,12 +30,12 @@ KEY_GENERAL_SOURCE_HF_REPO = "general.source.hugginface.repository"
KEY_GENERAL_FILE_TYPE = "general.file_type" KEY_GENERAL_FILE_TYPE = "general.file_type"
# LLM # LLM
KEY_LLM_CONTEXT_LENGTH = "{arch}.context_length" KEY_CONTEXT_LENGTH = "{arch}.context_length"
KEY_LLM_EMBEDDING_LENGTH = "{arch}.embedding_length" KEY_EMBEDDING_LENGTH = "{arch}.embedding_length"
KEY_LLM_BLOCK_COUNT = "{arch}.block_count" KEY_BLOCK_COUNT = "{arch}.block_count"
KEY_LLM_FEED_FORWARD_LENGTH = "{arch}.feed_forward_length" KEY_FEED_FORWARD_LENGTH = "{arch}.feed_forward_length"
KEY_LLM_USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual" KEY_USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual"
KEY_LLM_TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout" KEY_TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout"
# attention # attention
KEY_ATTENTION_HEAD_COUNT = "{arch}.attention.head_count" KEY_ATTENTION_HEAD_COUNT = "{arch}.attention.head_count"
@ -583,7 +583,7 @@ class GGUFWriter:
self.add_string(KEY_GENERAL_AUTHOR, author) self.add_string(KEY_GENERAL_AUTHOR, author)
def add_tensor_data_layout(self, layout: str): def add_tensor_data_layout(self, layout: str):
self.add_string(KEY_LLM_TENSOR_DATA_LAYOUT.format(arch=self.arch), layout) self.add_string(KEY_TENSOR_DATA_LAYOUT.format(arch=self.arch), layout)
def add_url(self, url: str): def add_url(self, url: str):
self.add_string(KEY_GENERAL_URL, url) self.add_string(KEY_GENERAL_URL, url)
@ -613,27 +613,27 @@ class GGUFWriter:
def add_context_length(self, length: int): def add_context_length(self, length: int):
self.add_uint32( self.add_uint32(
KEY_LLM_CONTEXT_LENGTH.format(arch=self.arch), length) KEY_CONTEXT_LENGTH.format(arch=self.arch), length)
def add_embedding_length(self, length: int): def add_embedding_length(self, length: int):
self.add_uint32( self.add_uint32(
KEY_LLM_EMBEDDING_LENGTH.format(arch=self.arch), length) KEY_EMBEDDING_LENGTH.format(arch=self.arch), length)
def add_block_count(self, length: int): def add_block_count(self, length: int):
self.add_uint32( self.add_uint32(
KEY_LLM_BLOCK_COUNT.format(arch=self.arch), length) KEY_BLOCK_COUNT.format(arch=self.arch), length)
def add_feed_forward_length(self, length: int): def add_feed_forward_length(self, length: int):
self.add_uint32( self.add_uint32(
KEY_LLM_FEED_FORWARD_LENGTH.format(arch=self.arch), length) KEY_FEED_FORWARD_LENGTH.format(arch=self.arch), length)
def add_parallel_residual(self, use: bool): def add_parallel_residual(self, use: bool):
self.add_bool( self.add_bool(
KEY_LLM_USE_PARALLEL_RESIDUAL.format(arch=self.arch), use) KEY_USE_PARALLEL_RESIDUAL.format(arch=self.arch), use)
def add_tensor_data_layout(self, layout: str): def add_tensor_data_layout(self, layout: str):
self.add_string( self.add_string(
KEY_LLM_TENSOR_DATA_LAYOUT.format(arch=self.arch), layout) KEY_TENSOR_DATA_LAYOUT.format(arch=self.arch), layout)
def add_head_count(self, count: int): def add_head_count(self, count: int):
self.add_uint32( self.add_uint32(

1570
llama.cpp

File diff suppressed because it is too large Load Diff

15
llama.h
View File

@ -247,6 +247,8 @@ extern "C" {
LLAMA_API int llama_n_ctx (const struct llama_context * ctx); LLAMA_API int llama_n_ctx (const struct llama_context * ctx);
LLAMA_API int llama_n_embd (const struct llama_context * ctx); LLAMA_API int llama_n_embd (const struct llama_context * ctx);
LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_context * ctx);
LLAMA_API int llama_model_n_vocab(const struct llama_model * model); LLAMA_API int llama_model_n_vocab(const struct llama_model * model);
LLAMA_API int llama_model_n_ctx (const struct llama_model * model); LLAMA_API int llama_model_n_ctx (const struct llama_model * model);
LLAMA_API int llama_model_n_embd (const struct llama_model * model); LLAMA_API int llama_model_n_embd (const struct llama_model * model);
@ -368,13 +370,6 @@ extern "C" {
int n_max_tokens, int n_max_tokens,
bool add_bos); bool add_bos);
LLAMA_API int llama_tokenize_bpe(
struct llama_context * ctx,
const char * text,
llama_token * tokens,
int n_max_tokens,
bool add_bos);
LLAMA_API int llama_tokenize_with_model( LLAMA_API int llama_tokenize_with_model(
const struct llama_model * model, const struct llama_model * model,
const char * text, const char * text,
@ -390,12 +385,6 @@ extern "C" {
char * buf, char * buf,
int length); int length);
LLAMA_API int llama_token_to_str_bpe(
const struct llama_context * ctx,
llama_token token,
char * buf,
int length);
LLAMA_API int llama_token_to_str_with_model( LLAMA_API int llama_token_to_str_with_model(
const struct llama_model * model, const struct llama_model * model,
llama_token token, llama_token token,

View File

@ -28,7 +28,8 @@ llama_build_and_test_executable(test-sampling.cpp)
llama_build_executable(test-tokenizer-0.cpp) llama_build_executable(test-tokenizer-0.cpp)
llama_test_executable (test-tokenizer-0.llama test-tokenizer-0.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.gguf) llama_test_executable (test-tokenizer-0.llama test-tokenizer-0.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.gguf)
llama_build_executable(test-tokenizer-1.cpp) llama_build_executable(test-tokenizer-1.cpp)
llama_test_executable (test-tokenizer-1.llama test-tokenizer-1.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.gguf) # test-tokenizer-1 requires a BPE vocab. re-enable when we have one.
#llama_test_executable (test-tokenizer-1.llama test-tokenizer-1.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-falcon.gguf)
#llama_test_executable(test-tokenizer-1.aquila test-tokenizer-1.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-aquila.gguf) #llama_test_executable(test-tokenizer-1.aquila test-tokenizer-1.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-aquila.gguf)
llama_build_and_test_executable(test-grammar-parser.cpp) llama_build_and_test_executable(test-grammar-parser.cpp)
llama_build_and_test_executable(test-llama-grammar.cpp) llama_build_and_test_executable(test-llama-grammar.cpp)

View File

@ -67,11 +67,13 @@ int main(int argc, char **argv) {
} }
} }
GGML_ASSERT(llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_BPE);
const int n_vocab = llama_n_vocab(ctx); const int n_vocab = llama_n_vocab(ctx);
for (int i = 0; i < n_vocab; ++i) { for (int i = 0; i < n_vocab; ++i) {
std::string forward = llama_token_to_str_bpe(ctx, i); std::string forward = llama_token_to_str(ctx, i);
std::vector<llama_token> tokens = llama_tokenize_bpe(ctx, forward, false); std::vector<llama_token> tokens = llama_tokenize(ctx, forward, false);
if (tokens.size() == 1) { if (tokens.size() == 1) {
if (i != tokens[0]) { if (i != tokens[0]) {
std::string backward = llama_token_to_str(ctx, tokens[0]); std::string backward = llama_token_to_str(ctx, tokens[0]);
@ -79,16 +81,6 @@ int main(int argc, char **argv) {
__func__, i, llama_token_to_str(ctx, i).c_str(), tokens[0], backward.c_str()); __func__, i, llama_token_to_str(ctx, i).c_str(), tokens[0], backward.c_str());
return 2; return 2;
} }
} else {
llama_token_type type = llama_token_get_type(ctx, i);
if (type == LLAMA_TOKEN_TYPE_UNKNOWN || type == LLAMA_TOKEN_TYPE_CONTROL || type == LLAMA_TOKEN_TYPE_BYTE) {
fprintf(stderr, "%s : info: token %d is string %s and bpe returns tokens %s\n",
__func__, i, llama_token_to_str(ctx, i).c_str(), unescape_whitespace(ctx, tokens).c_str());
} else {
fprintf(stderr, "%s : error: token %d is string %s but bpe returns tokens %s\n",
__func__, i, llama_token_to_str(ctx, i).c_str(), unescape_whitespace(ctx, tokens).c_str());
return 2;
}
} }
} }