mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 20:04:35 +00:00
minor fixes
This commit is contained in:
parent
5765d7a587
commit
24cc6f008f
@ -14,7 +14,7 @@ struct ggml_buffer ggml_backend_alloc_buffer(struct ggml_backend * backend, size
|
||||
buffer.mem_size = ggml_tensor_overhead() * max_tensors;
|
||||
buffer.mem_buffer = malloc(buffer.mem_size);
|
||||
buffer.backend = backend;
|
||||
// size += 128 * max_tensors; // alignment overhead
|
||||
size += 128 * max_tensors; // alignment overhead
|
||||
buffer.backend_buffer = backend->interface->alloc_buffer(backend->context, size);
|
||||
return buffer;
|
||||
}
|
||||
@ -172,7 +172,7 @@ static void ggml_backend_cpu_cpy_tensor_from(ggml_backend_context_t ctx, struct
|
||||
}
|
||||
|
||||
static void ggml_backend_cpu_cpy_tensor_to(ggml_backend_context_t ctx, struct ggml_tensor * src, struct ggml_tensor * dst) {
|
||||
ggml_backend_set_tensor(dst, src->data, 0, ggml_nbytes(src));
|
||||
ggml_backend_set_tensor_async(dst, src->data, 0, ggml_nbytes(src));
|
||||
|
||||
UNUSED(ctx);
|
||||
}
|
||||
@ -409,7 +409,7 @@ void ggml_graph_splits_compute(struct ggml_graph_splits * splits) {
|
||||
ggml_backend_cpy_tensor(split->dst_inputs[j], split->src_inputs[j]);
|
||||
}
|
||||
}
|
||||
ggml_backend_synchronize(split->dst_inputs[0]->backend);
|
||||
// ggml_backend_synchronize(split->dst_inputs[0]->backend);
|
||||
copy_us += ggml_time_us() - copy_start_us;
|
||||
|
||||
#if 0
|
||||
@ -419,7 +419,7 @@ void ggml_graph_splits_compute(struct ggml_graph_splits * splits) {
|
||||
#endif
|
||||
uint64_t start = ggml_time_us();
|
||||
ggml_backend_graph_compute(split->dst_inputs[0]->backend, split->graph);
|
||||
ggml_backend_synchronize(split->dst_inputs[0]->backend);
|
||||
//ggml_backend_synchronize(split->dst_inputs[0]->backend);
|
||||
uint64_t end = ggml_time_us();
|
||||
if (strcmp(ggml_backend_name(split->dst_inputs[0]->backend), "CPU") == 0) {
|
||||
compute_cpu_us += end - start;
|
||||
|
44
llama.cpp
44
llama.cpp
@ -624,8 +624,9 @@ struct llama_model_loader {
|
||||
}
|
||||
LLAMA_ASSERT(lt.ggml_tensor); // unused tensors should have been caught by load_data already
|
||||
|
||||
bool is_cpu = lt.ggml_tensor->backend == &model->backend_cpu; // TODO
|
||||
bool is_cpu = lt.ggml_tensor->backend == &model->backend_cpu;
|
||||
|
||||
// select buffer to load data into
|
||||
if (!use_mmap) {
|
||||
if (is_cpu) {
|
||||
lt.data = (uint8_t *) lt.ggml_tensor->data;
|
||||
@ -641,7 +642,7 @@ struct llama_model_loader {
|
||||
if (is_cpu) {
|
||||
if (use_mmap) {
|
||||
lt.ggml_tensor->data = lt.data;
|
||||
// TODO: this assumes that the data is contiguous, which may not always be the case
|
||||
// TODO: this assumes that the data to lock is contiguous, which may not always be the case
|
||||
if (lmlock) {
|
||||
lock_size += lt.size;
|
||||
lmlock->grow_to(lock_size);
|
||||
@ -1227,6 +1228,10 @@ static ggml_graph_splits llama_build_graph(
|
||||
inpL = ggml_get_rows(ctx_i, model.tok_embeddings, token_in);
|
||||
}
|
||||
|
||||
// reuse the scale tensor for all layers since it requires a memory transfer
|
||||
struct ggml_tensor * KQ_scale = ggml_new_f32(ctx_kv, 1.0f/sqrtf(float(n_embd)/n_head));
|
||||
ggml_set_name(KQ_scale, "1/sqrt(n_embd/n_head)");
|
||||
|
||||
struct ggml_tensor * cur = nullptr;
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_context * ctx_l = ctx_ls[il];
|
||||
@ -1267,9 +1272,6 @@ static ggml_graph_splits llama_build_graph(
|
||||
struct ggml_tensor * Vcur = ggml_transpose(ctx_l, ggml_reshape_2d(ctx_l, tmpv, n_embd, N));
|
||||
ggml_set_name(Vcur, "Vcur");
|
||||
|
||||
//ggml_graph_splits_add(&splits, &Kcur, ctx_kv, "Kcur");
|
||||
//ggml_graph_splits_add(&splits, &Vcur, ctx_kv, "Vcur");
|
||||
//ggml_graph_splits_add(&splits, &Qcur, ctx_kv, "Qcur");
|
||||
ggml_tensor ** attn_inputs[] = {&Kcur, &Vcur, &Qcur, NULL};
|
||||
ggml_graph_splits_add_n(&splits, attn_inputs, ctx_kv, "l%d_attn", il);
|
||||
|
||||
@ -1316,9 +1318,6 @@ static ggml_graph_splits llama_build_graph(
|
||||
ggml_set_name(KQ, "KQ");
|
||||
|
||||
// KQ_scaled = KQ / sqrt(n_embd/n_head)
|
||||
struct ggml_tensor * KQ_scale = ggml_new_f32(ctx_kv, 1.0f/sqrtf(float(n_embd)/n_head));
|
||||
ggml_set_name(KQ_scale, "1/sqrt(n_embd/n_head)");
|
||||
|
||||
// KQ_scaled shape [n_past + N, N, n_head, 1]
|
||||
struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx_kv, KQ, KQ_scale);
|
||||
ggml_set_name(KQ_scaled, "KQ_scaled");
|
||||
@ -1395,7 +1394,7 @@ static ggml_graph_splits llama_build_graph(
|
||||
cur = ggml_mul_mat(ctx_l,
|
||||
model.layers[il].w1,
|
||||
cur);
|
||||
ggml_set_name(cur, "result_w2");
|
||||
ggml_set_name(cur, "result_w1");
|
||||
|
||||
// SILU activation
|
||||
cur = ggml_silu(ctx_l, cur);
|
||||
@ -1531,6 +1530,12 @@ static bool llama_eval_internal(
|
||||
|
||||
LLAMA_ASSERT(lctx.graph_logits != nullptr);
|
||||
|
||||
|
||||
// for big prompts, if BLAS is enabled, it is better to use only one thread
|
||||
// otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance
|
||||
n_threads = N >= 32 && ggml_cpu_has_blas() ? 1 : n_threads;
|
||||
ggml_backend_cpu_set_n_threads(const_cast<ggml_backend*>(&model.backend_cpu), n_threads);
|
||||
|
||||
struct ggml_graph_splits splits = llama_build_graph(lctx, N, n_past, embd_input);
|
||||
|
||||
// TODO: use backend functions
|
||||
@ -1542,11 +1547,7 @@ static bool llama_eval_internal(
|
||||
ggml_backend_set_tensor(lctx.graph_embeddings_in, embd, 0, N*n_embd*ggml_element_size(lctx.graph_embeddings_in));
|
||||
}
|
||||
|
||||
// for big prompts, if BLAS is enabled, it is better to use only one thread
|
||||
// otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance
|
||||
n_threads = N >= 32 && ggml_cpu_has_blas() ? 1 : n_threads;
|
||||
|
||||
ggml_backend_cpu_set_n_threads(const_cast<ggml_backend*>(&model.backend_cpu), n_threads);
|
||||
|
||||
// run the computation
|
||||
ggml_graph_splits_compute(&splits);
|
||||
@ -1573,11 +1574,11 @@ static bool llama_eval_internal(
|
||||
|
||||
if (lctx.logits_all) {
|
||||
logits_out.resize(n_vocab * N);
|
||||
ggml_backend_get_tensor(lctx.graph_logits, logits_out.data(), 0, N*n_vocab*sizeof(float));
|
||||
ggml_backend_get_tensor_async(lctx.graph_logits, logits_out.data(), 0, N*n_vocab*sizeof(float));
|
||||
} else {
|
||||
// return result for just the last token
|
||||
logits_out.resize(n_vocab);
|
||||
ggml_backend_get_tensor(lctx.graph_logits, logits_out.data(), 0, n_vocab*sizeof(float));
|
||||
ggml_backend_get_tensor_async(lctx.graph_logits, logits_out.data(), 0, n_vocab*sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
@ -1585,9 +1586,16 @@ static bool llama_eval_internal(
|
||||
if (!lctx.embedding.empty()) {
|
||||
auto & embedding_out = lctx.embedding;
|
||||
embedding_out.resize(n_embd);
|
||||
ggml_backend_get_tensor(lctx.graph_embeddings_out, embedding_out.data(), 0, n_embd*sizeof(float));
|
||||
ggml_backend_get_tensor_async(lctx.graph_embeddings_out, embedding_out.data(), 0, n_embd*sizeof(float));
|
||||
}
|
||||
|
||||
#ifdef GGML_USE_CUDA
|
||||
// wait for the async copy to finish
|
||||
if (lctx.model.n_gpu_layers > 0) {
|
||||
ggml_backend_synchronize(const_cast<ggml_backend*>(&lctx.model.backend_cuda));
|
||||
}
|
||||
#endif
|
||||
|
||||
// measure the performance only for the single-token evals
|
||||
if (N == 1) {
|
||||
lctx.t_eval_us += ggml_time_us() - t_start_us;
|
||||
@ -2638,7 +2646,7 @@ struct llama_context * llama_new_context_with_model(
|
||||
// initialize the graph input/output buffers
|
||||
// input buffer
|
||||
{
|
||||
size_t buf_input_size = 1024;
|
||||
size_t buf_input_size = 0;
|
||||
buf_input_size += hparams.n_ctx * ggml_type_size(GGML_TYPE_F32); // input tokens
|
||||
// TODO: input embeddings should be optional to save memory
|
||||
buf_input_size += hparams.n_embd * hparams.n_ctx * ggml_type_size(GGML_TYPE_F32); // input embeddings
|
||||
@ -2657,7 +2665,7 @@ struct llama_context * llama_new_context_with_model(
|
||||
}
|
||||
// output buffer
|
||||
{
|
||||
size_t buf_output_size = 1024;
|
||||
size_t buf_output_size = 0;
|
||||
if (params.logits_all) {
|
||||
buf_output_size += hparams.n_ctx * hparams.n_vocab * ggml_type_size(GGML_TYPE_F32);
|
||||
} else {
|
||||
|
Loading…
Reference in New Issue
Block a user