refact : fix convert script + zero out KV cache to avoid nans

This commit is contained in:
Georgi Gerganov 2023-10-07 11:18:04 +03:00
parent 0e797c2fc5
commit bdbe11719d
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
4 changed files with 28 additions and 78 deletions

View File

@ -17,33 +17,6 @@ if "NO_LOCAL_GGUF" not in os.environ:
sys.path.insert(1, str(Path(__file__).parent / "gguf-py" / "gguf")) sys.path.insert(1, str(Path(__file__).parent / "gguf-py" / "gguf"))
import gguf import gguf
def bytes_to_unicode():
# ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a significant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = (
list(range(ord("!"), ord("~") + 1))
+ list(range(ord("¡"), ord("¬") + 1))
+ list(range(ord("®"), ord("ÿ") + 1))
)
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
return dict(zip(bs, (chr(n) for n in cs)))
def count_model_parts(dir_model: Path) -> int: def count_model_parts(dir_model: Path) -> int:
num_parts = 0 num_parts = 0
for filename in os.listdir(dir_model): for filename in os.listdir(dir_model):
@ -153,53 +126,25 @@ tokens: list[bytearray] = []
scores: list[float] = [] scores: list[float] = []
toktypes: list[int] = [] toktypes: list[int] = []
tokenizer_json_file = dir_model / "tokenizer.json"
if not tokenizer_json_file.is_file():
print(f"Error: Missing {tokenizer_json_file}", file=sys.stderr)
sys.exit(1)
# gpt2 tokenizer # gpt2 tokenizer
gguf_writer.add_tokenizer_model("gpt2") gguf_writer.add_tokenizer_model("gpt2")
with open(tokenizer_json_file, "r", encoding="utf-8") as f:
tokenizer_json = json.load(f)
print("gguf: get gpt2 tokenizer vocab") print("gguf: get gpt2 tokenizer vocab")
# ref: https://github.com/cmp-nct/ggllm.cpp/blob/master/falcon_convert.py
tokenizer = AutoTokenizer.from_pretrained(dir_model)
# The number of tokens in tokenizer.json can differ from the expected vocab size. # The number of tokens in tokenizer.json can differ from the expected vocab size.
# This causes downstream issues with mismatched tensor sizes when running the inference # This causes downstream issues with mismatched tensor sizes when running the inference
vocab_size = ( vocab_size = hparams.get("vocab_size", len(tokenizer.vocab))
hparams["vocab_size"] assert max(tokenizer.vocab.values()) < vocab_size
if "vocab_size" in hparams
else len(tokenizer_json["model"]["vocab"])
)
tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.vocab.items()} reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.vocab.items()}
byte_encoder = bytes_to_unicode()
byte_decoder = {v: k for k, v in byte_encoder.items()}
for i in range(vocab_size): for i in range(vocab_size):
if i in reverse_vocab: tokens.append(reverse_vocab[i] if i in reverse_vocab else f"[PAD{i}]")
text = reverse_vocab[i] scores.append(0.0) # dummy
try: toktypes.append(gguf.TokenType.NORMAL)
text = bytearray([byte_decoder[c] for c in reverse_vocab[i]])
except KeyError:
text = bytearray()
for c in reverse_vocab[i]:
if ord(c) < 256: # single byte character
text.append(byte_decoder[ord(c)])
else: # multibyte special token character
text.extend(c.encode("utf-8"))
else:
print(f"Key {i} not in tokenizer vocabulary. Padding with an arbitrary token.")
pad_token = f"[PAD{i}]".encode("utf8")
text = bytearray(pad_token)
tokens.append(text)
scores.append(0.0) # dymmy
toktypes.append(gguf.TokenType.NORMAL) # dummy
gguf_writer.add_token_list(tokens) gguf_writer.add_token_list(tokens)
gguf_writer.add_token_scores(scores) gguf_writer.add_token_scores(scores)

View File

@ -167,7 +167,7 @@ int main(int argc, char ** argv) {
// the max batch size is as large as the context to handle cases where we get very long input prompt from multiple // the max batch size is as large as the context to handle cases where we get very long input prompt from multiple
// users. regardless of the size, the main loop will chunk the batch into a maximum of params.n_batch tokens at a time // users. regardless of the size, the main loop will chunk the batch into a maximum of params.n_batch tokens at a time
llama_batch batch = llama_batch_init(params.n_ctx, 0); llama_batch batch = llama_batch_init(n_ctx, 0);
int32_t n_total_prompt = 0; int32_t n_total_prompt = 0;
int32_t n_total_gen = 0; int32_t n_total_gen = 0;

27
ggml.c
View File

@ -3923,6 +3923,8 @@ inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float *
// Sigmoid Linear Unit (SiLU) function // Sigmoid Linear Unit (SiLU) function
inline static float ggml_silu_f32(float x) { inline static float ggml_silu_f32(float x) {
if (x == -INFINITY) return 0.0f;
return x/(1.0f + expf(-x)); return x/(1.0f + expf(-x));
} }
@ -13089,17 +13091,17 @@ static void ggml_compute_forward_alibi_f32(
assert(n_past >= 0); assert(n_past >= 0);
const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1 const int64_t ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
const int ne1 = src0->ne[1]; // seq_len_without_past const int64_t ne1 = src0->ne[1]; // seq_len_without_past
const int ne2 = src0->ne[2]; // n_head -> this is k const int64_t ne2 = src0->ne[2]; // n_head -> this is k
//const int ne3 = src0->ne[3]; // 1 -> bsz //const int64_t ne3 = src0->ne[3]; // 1 -> bsz
const int n = ggml_nrows(src0); const int64_t n = ggml_nrows(src0);
const int ne2_ne3 = n/ne1; // ne2*ne3 const int64_t ne2_ne3 = n/ne1; // ne2*ne3
const int nb0 = src0->nb[0]; const size_t nb0 = src0->nb[0];
const int nb1 = src0->nb[1]; const size_t nb1 = src0->nb[1];
const int nb2 = src0->nb[2]; const size_t nb2 = src0->nb[2];
//const int nb3 = src0->nb[3]; //const int nb3 = src0->nb[3];
GGML_ASSERT(nb0 == sizeof(float)); GGML_ASSERT(nb0 == sizeof(float));
@ -13111,9 +13113,9 @@ static void ggml_compute_forward_alibi_f32(
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
for (int i = 0; i < ne0; i++) { for (int64_t i = 0; i < ne0; i++) {
for (int j = 0; j < ne1; j++) { for (int64_t j = 0; j < ne1; j++) {
for (int k = 0; k < ne2_ne3; k++) { for (int64_t k = 0; k < ne2_ne3; k++) {
float * const src = (float *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2); float * const src = (float *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2); float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);
@ -13128,7 +13130,6 @@ static void ggml_compute_forward_alibi_f32(
} }
pdst[0] = i * m_k + src[0]; pdst[0] = i * m_k + src[0];
} }
} }
} }

View File

@ -1325,7 +1325,11 @@ static bool llama_kv_cache_init(
cache.cells.clear(); cache.cells.clear();
cache.cells.resize(n_ctx); cache.cells.resize(n_ctx);
// TODO: this should be:
// cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*ggml_tensor_overhead());
// change it and test that it works
cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*MB); cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*MB);
memset(cache.buf.data, 0, cache.buf.size);
struct ggml_init_params params; struct ggml_init_params params;
params.mem_size = cache.buf.size; params.mem_size = cache.buf.size;