mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-11-11 21:39:52 +00:00
refact : fix convert script + zero out KV cache to avoid nans (#3523)
* refact : fix convert script + zero out KV cache to avoid nans * ggml : silu(-inf) should never happen * metal : assert various kernel requirements
This commit is contained in:
parent
dcc09d2596
commit
fcca0a7004
@ -17,33 +17,6 @@ if "NO_LOCAL_GGUF" not in os.environ:
|
||||
sys.path.insert(1, str(Path(__file__).parent / "gguf-py" / "gguf"))
|
||||
import gguf
|
||||
|
||||
|
||||
def bytes_to_unicode():
|
||||
# ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py
|
||||
"""
|
||||
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
||||
The reversible bpe codes work on unicode strings.
|
||||
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
||||
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
||||
This is a significant percentage of your normal, say, 32K bpe vocab.
|
||||
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
||||
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
||||
"""
|
||||
bs = (
|
||||
list(range(ord("!"), ord("~") + 1))
|
||||
+ list(range(ord("¡"), ord("¬") + 1))
|
||||
+ list(range(ord("®"), ord("ÿ") + 1))
|
||||
)
|
||||
cs = bs[:]
|
||||
n = 0
|
||||
for b in range(2**8):
|
||||
if b not in bs:
|
||||
bs.append(b)
|
||||
cs.append(2**8 + n)
|
||||
n += 1
|
||||
return dict(zip(bs, (chr(n) for n in cs)))
|
||||
|
||||
|
||||
def count_model_parts(dir_model: Path) -> int:
|
||||
num_parts = 0
|
||||
for filename in os.listdir(dir_model):
|
||||
@ -153,53 +126,25 @@ tokens: list[bytearray] = []
|
||||
scores: list[float] = []
|
||||
toktypes: list[int] = []
|
||||
|
||||
tokenizer_json_file = dir_model / "tokenizer.json"
|
||||
if not tokenizer_json_file.is_file():
|
||||
print(f"Error: Missing {tokenizer_json_file}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# gpt2 tokenizer
|
||||
gguf_writer.add_tokenizer_model("gpt2")
|
||||
|
||||
with open(tokenizer_json_file, "r", encoding="utf-8") as f:
|
||||
tokenizer_json = json.load(f)
|
||||
|
||||
print("gguf: get gpt2 tokenizer vocab")
|
||||
|
||||
# ref: https://github.com/cmp-nct/ggllm.cpp/blob/master/falcon_convert.py
|
||||
tokenizer = AutoTokenizer.from_pretrained(dir_model)
|
||||
|
||||
# The number of tokens in tokenizer.json can differ from the expected vocab size.
|
||||
# This causes downstream issues with mismatched tensor sizes when running the inference
|
||||
vocab_size = (
|
||||
hparams["vocab_size"]
|
||||
if "vocab_size" in hparams
|
||||
else len(tokenizer_json["model"]["vocab"])
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
|
||||
vocab_size = hparams.get("vocab_size", len(tokenizer.vocab))
|
||||
assert max(tokenizer.vocab.values()) < vocab_size
|
||||
|
||||
reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.vocab.items()}
|
||||
byte_encoder = bytes_to_unicode()
|
||||
byte_decoder = {v: k for k, v in byte_encoder.items()}
|
||||
|
||||
for i in range(vocab_size):
|
||||
if i in reverse_vocab:
|
||||
text = reverse_vocab[i]
|
||||
try:
|
||||
text = bytearray([byte_decoder[c] for c in reverse_vocab[i]])
|
||||
except KeyError:
|
||||
text = bytearray()
|
||||
for c in reverse_vocab[i]:
|
||||
if ord(c) < 256: # single byte character
|
||||
text.append(byte_decoder[ord(c)])
|
||||
else: # multibyte special token character
|
||||
text.extend(c.encode("utf-8"))
|
||||
else:
|
||||
print(f"Key {i} not in tokenizer vocabulary. Padding with an arbitrary token.")
|
||||
pad_token = f"[PAD{i}]".encode("utf8")
|
||||
text = bytearray(pad_token)
|
||||
|
||||
tokens.append(text)
|
||||
scores.append(0.0) # dymmy
|
||||
toktypes.append(gguf.TokenType.NORMAL) # dummy
|
||||
tokens.append(reverse_vocab[i] if i in reverse_vocab else f"[PAD{i}]")
|
||||
scores.append(0.0) # dummy
|
||||
toktypes.append(gguf.TokenType.NORMAL)
|
||||
|
||||
gguf_writer.add_token_list(tokens)
|
||||
gguf_writer.add_token_scores(scores)
|
||||
|
@ -167,7 +167,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// the max batch size is as large as the context to handle cases where we get very long input prompt from multiple
|
||||
// users. regardless of the size, the main loop will chunk the batch into a maximum of params.n_batch tokens at a time
|
||||
llama_batch batch = llama_batch_init(params.n_ctx, 0);
|
||||
llama_batch batch = llama_batch_init(n_ctx, 0);
|
||||
|
||||
int32_t n_total_prompt = 0;
|
||||
int32_t n_total_gen = 0;
|
||||
|
20
ggml-metal.m
20
ggml-metal.m
@ -779,8 +779,8 @@ void ggml_metal_graph_compute(
|
||||
} break;
|
||||
case GGML_OP_CONCAT:
|
||||
{
|
||||
const int64_t nb = ne00;
|
||||
|
||||
int64_t nb = ne00;
|
||||
[encoder setComputePipelineState:ctx->pipeline_concat];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
@ -812,6 +812,7 @@ void ggml_metal_graph_compute(
|
||||
[encoder setBytes:&nb length:sizeof(nb) atIndex:27];
|
||||
|
||||
const int nth = MIN(1024, ne0);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_ADD:
|
||||
@ -909,9 +910,10 @@ void ggml_metal_graph_compute(
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
|
||||
|
||||
const int64_t n = ggml_nelements(dst)/4;
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
GGML_ASSERT(n % 4 == 0);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_UNARY:
|
||||
switch (ggml_get_unary_op(gf->nodes[i])) {
|
||||
@ -921,9 +923,10 @@ void ggml_metal_graph_compute(
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
|
||||
const int64_t n = ggml_nelements(dst)/4;
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
GGML_ASSERT(n % 4 == 0);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_UNARY_OP_RELU:
|
||||
{
|
||||
@ -941,9 +944,10 @@ void ggml_metal_graph_compute(
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
|
||||
const int64_t n = ggml_nelements(dst)/4;
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
GGML_ASSERT(n % 4 == 0);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
@ -1251,6 +1255,8 @@ void ggml_metal_graph_compute(
|
||||
} break;
|
||||
case GGML_OP_RMS_NORM:
|
||||
{
|
||||
GGML_ASSERT(ne00 % 4 == 0);
|
||||
|
||||
float eps;
|
||||
memcpy(&eps, dst->op_params, sizeof(float));
|
||||
|
||||
|
@ -345,10 +345,11 @@ kernel void kernel_rms_norm(
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint ntg[[threads_per_threadgroup]]) {
|
||||
device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
|
||||
device const float * x_scalar = (device const float *) x;
|
||||
float4 sumf=0;
|
||||
float all_sum=0;
|
||||
device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
|
||||
device const float * x_scalar = (device const float *) x;
|
||||
|
||||
float4 sumf = 0;
|
||||
float all_sum = 0;
|
||||
|
||||
// parallel sum
|
||||
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
||||
@ -361,6 +362,7 @@ kernel void kernel_rms_norm(
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// broadcast, simd group number is ntg / 32
|
||||
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
||||
if (tpitg < i) {
|
||||
@ -368,7 +370,9 @@ kernel void kernel_rms_norm(
|
||||
}
|
||||
}
|
||||
if (tpitg == 0) {
|
||||
for (int i = 4 * (ne00 / 4); i < ne00; i++) {sum[0] += x_scalar[i];}
|
||||
for (int i = 4 * (ne00 / 4); i < ne00; i++) {
|
||||
sum[0] += x_scalar[i];
|
||||
}
|
||||
sum[0] /= ne00;
|
||||
}
|
||||
|
||||
@ -383,7 +387,9 @@ kernel void kernel_rms_norm(
|
||||
y[i00] = x[i00] * scale;
|
||||
}
|
||||
if (tpitg == 0) {
|
||||
for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {y_scalar[i00] = x_scalar[i00] * scale;}
|
||||
for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {
|
||||
y_scalar[i00] = x_scalar[i00] * scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
27
ggml.c
27
ggml.c
@ -11233,7 +11233,7 @@ static void ggml_compute_forward_silu_f32(
|
||||
|
||||
#ifndef NDEBUG
|
||||
for (int k = 0; k < nc; k++) {
|
||||
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
||||
const float x = ((float *) ((char *) dst->data + i1*(dst->nb[1])))[k];
|
||||
UNUSED(x);
|
||||
assert(!isnan(x));
|
||||
assert(!isinf(x));
|
||||
@ -13066,17 +13066,17 @@ static void ggml_compute_forward_alibi_f32(
|
||||
|
||||
assert(n_past >= 0);
|
||||
|
||||
const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
|
||||
const int ne1 = src0->ne[1]; // seq_len_without_past
|
||||
const int ne2 = src0->ne[2]; // n_head -> this is k
|
||||
//const int ne3 = src0->ne[3]; // 1 -> bsz
|
||||
const int64_t ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
|
||||
const int64_t ne1 = src0->ne[1]; // seq_len_without_past
|
||||
const int64_t ne2 = src0->ne[2]; // n_head -> this is k
|
||||
//const int64_t ne3 = src0->ne[3]; // 1 -> bsz
|
||||
|
||||
const int n = ggml_nrows(src0);
|
||||
const int ne2_ne3 = n/ne1; // ne2*ne3
|
||||
const int64_t n = ggml_nrows(src0);
|
||||
const int64_t ne2_ne3 = n/ne1; // ne2*ne3
|
||||
|
||||
const int nb0 = src0->nb[0];
|
||||
const int nb1 = src0->nb[1];
|
||||
const int nb2 = src0->nb[2];
|
||||
const size_t nb0 = src0->nb[0];
|
||||
const size_t nb1 = src0->nb[1];
|
||||
const size_t nb2 = src0->nb[2];
|
||||
//const int nb3 = src0->nb[3];
|
||||
|
||||
GGML_ASSERT(nb0 == sizeof(float));
|
||||
@ -13088,9 +13088,9 @@ static void ggml_compute_forward_alibi_f32(
|
||||
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
|
||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
|
||||
|
||||
for (int i = 0; i < ne0; i++) {
|
||||
for (int j = 0; j < ne1; j++) {
|
||||
for (int k = 0; k < ne2_ne3; k++) {
|
||||
for (int64_t i = 0; i < ne0; i++) {
|
||||
for (int64_t j = 0; j < ne1; j++) {
|
||||
for (int64_t k = 0; k < ne2_ne3; k++) {
|
||||
float * const src = (float *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
|
||||
float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);
|
||||
|
||||
@ -13105,7 +13105,6 @@ static void ggml_compute_forward_alibi_f32(
|
||||
}
|
||||
|
||||
pdst[0] = i * m_k + src[0];
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1325,7 +1325,11 @@ static bool llama_kv_cache_init(
|
||||
cache.cells.clear();
|
||||
cache.cells.resize(n_ctx);
|
||||
|
||||
// TODO: this should be:
|
||||
// cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*ggml_tensor_overhead());
|
||||
// change it and test that it works
|
||||
cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*MB);
|
||||
memset(cache.buf.data, 0, cache.buf.size);
|
||||
|
||||
struct ggml_init_params params;
|
||||
params.mem_size = cache.buf.size;
|
||||
|
Loading…
Reference in New Issue
Block a user