mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 11:24:35 +00:00
llama : try to optimize offloading code
This commit is contained in:
parent
79617902ea
commit
b4ad03b3a7
335
llama.cpp
335
llama.cpp
@ -5267,6 +5267,90 @@ static struct ggml_cgraph * llm_build_mpt(
|
||||
return gf;
|
||||
}
|
||||
|
||||
enum offload_func_e {
|
||||
OFFLOAD_FUNC_NOP,
|
||||
OFFLOAD_FUNC,
|
||||
OFFLOAD_FUNC_KQ,
|
||||
OFFLOAD_FUNC_V,
|
||||
OFFLOAD_FUNC_NR,
|
||||
OFFLOAD_FUNC_EMB,
|
||||
OFFLOAD_FUNC_OUT,
|
||||
};
|
||||
|
||||
struct llama_offload_trie {
|
||||
struct node {
|
||||
~node() {
|
||||
for (int i = 0; i < 256; ++i) {
|
||||
if (children[i]) {
|
||||
delete children[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
node * children[256] = { nullptr };
|
||||
offload_func_e func = OFFLOAD_FUNC_NOP;
|
||||
};
|
||||
|
||||
llama_offload_trie() {
|
||||
root = new node;
|
||||
}
|
||||
|
||||
llama_offload_trie(const std::unordered_map<const char *, offload_func_e> & map) {
|
||||
root = new node;
|
||||
|
||||
for (const auto & kv : map) {
|
||||
add(kv.first, kv.second);
|
||||
}
|
||||
}
|
||||
|
||||
~llama_offload_trie() {
|
||||
delete root;
|
||||
}
|
||||
|
||||
void add(const char * name, offload_func_e func) {
|
||||
node * cur = root;
|
||||
|
||||
for (int i = 0; ; ++i) {
|
||||
const uint8_t c = name[i];
|
||||
|
||||
if (!c) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (!cur->children[c]) {
|
||||
cur->children[c] = new node;
|
||||
}
|
||||
|
||||
cur = cur->children[c];
|
||||
}
|
||||
|
||||
cur->func = func;
|
||||
}
|
||||
|
||||
offload_func_e find(const char * name) const {
|
||||
const node * cur = root;
|
||||
|
||||
for (int i = 0; ; ++i) {
|
||||
const uint8_t c = name[i];
|
||||
|
||||
if (!c) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (!cur->children[c]) {
|
||||
return OFFLOAD_FUNC_NOP;
|
||||
}
|
||||
|
||||
cur = cur->children[c];
|
||||
}
|
||||
|
||||
return cur->func;
|
||||
}
|
||||
|
||||
node * root = nullptr;
|
||||
};
|
||||
|
||||
|
||||
static void llama_build_graph_input(
|
||||
llama_context & lctx,
|
||||
const llama_batch & batch,
|
||||
@ -5441,6 +5525,8 @@ static struct ggml_cgraph * llama_build_graph(
|
||||
// allocate memory and set the values for the input tensors of the graph
|
||||
llama_build_graph_input(lctx, batch, result);
|
||||
|
||||
//auto t_start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
// offload layers
|
||||
// TODO: this code will be obsoleted with backend v2
|
||||
#ifdef GGML_USE_CUBLAS
|
||||
@ -5456,132 +5542,113 @@ static struct ggml_cgraph * llama_build_graph(
|
||||
const int i_gpu_start = n_layer - n_gpu_layers;
|
||||
|
||||
// should we offload the final norm? yes if we are not computing embeddings
|
||||
const bool off_res_norm = lctx.embedding.empty();
|
||||
|
||||
// offload functions set the tensor output backend to GPU
|
||||
// tensors are GPU-accelerated if any input or the output has been offloaded
|
||||
offload_func_t offload_func_nr = ggml_offload_nop; // nr = non-repeating
|
||||
offload_func_t offload_func_kq = ggml_offload_nop;
|
||||
offload_func_t offload_func_v = ggml_offload_nop;
|
||||
offload_func_t offload_func_emb = ggml_offload_nop;
|
||||
offload_func_t offload_func_out = ggml_offload_nop;
|
||||
offload_func_t offload_func = ggml_offload_nop;
|
||||
const bool offload_emb = lctx.embedding.empty();
|
||||
|
||||
static const std::unordered_map<offload_func_e, std::string> k_offload_func_name = {
|
||||
{ OFFLOAD_FUNC_NOP, "CPU" },
|
||||
#ifdef GGML_USE_CUBLAS
|
||||
if (n_gpu_layers > n_layer) {
|
||||
offload_func_nr = ggml_cuda_assign_buffers_no_alloc;
|
||||
}
|
||||
if (n_gpu_layers > n_layer + 1) {
|
||||
offload_func_v = ggml_cuda_assign_buffers_no_alloc;
|
||||
}
|
||||
if (n_gpu_layers > n_layer + 2) {
|
||||
offload_func_kq = ggml_cuda_assign_buffers_no_alloc;
|
||||
}
|
||||
|
||||
offload_func_emb = off_res_norm ? ggml_cuda_assign_buffers_no_alloc : ggml_offload_nop;
|
||||
offload_func_out = ggml_offload_nop;
|
||||
|
||||
offload_func = ggml_cuda_assign_buffers_no_alloc;
|
||||
{ OFFLOAD_FUNC, "GPU (CUDA)" },
|
||||
{ OFFLOAD_FUNC_KQ, "GPU (CUDA) KQ" },
|
||||
{ OFFLOAD_FUNC_V, "GPU (CUDA) V" },
|
||||
{ OFFLOAD_FUNC_NR, "GPU (CUDA) NR" },
|
||||
{ OFFLOAD_FUNC_EMB, "GPU (CUDA) EMB" },
|
||||
{ OFFLOAD_FUNC_OUT, "GPU (CUDA) OUT" },
|
||||
#endif // GGML_USE_CUBLAS
|
||||
|
||||
static const std::unordered_map<offload_func_t, std::string> k_offload_func_name = {
|
||||
{ ggml_offload_nop, "CPU" },
|
||||
#ifdef GGML_USE_CUBLAS
|
||||
{ ggml_cuda_assign_buffers_no_alloc, "GPU (CUDA)" },
|
||||
#endif
|
||||
};
|
||||
|
||||
const std::unordered_map<std::string, offload_func_t> k_offload_func = {
|
||||
{ "KQ_mask", offload_func_kq },
|
||||
{ "KQ_pos", offload_func_kq },
|
||||
{ "K_shift", offload_func_kq },
|
||||
{ "K_shifted", offload_func_kq },
|
||||
static const std::unordered_map<const char *, offload_func_e> k_offload_func = {
|
||||
{ "KQ_mask", OFFLOAD_FUNC_KQ },
|
||||
{ "KQ_pos", OFFLOAD_FUNC_KQ },
|
||||
{ "K_shift", OFFLOAD_FUNC_KQ },
|
||||
{ "K_shifted", OFFLOAD_FUNC_KQ },
|
||||
|
||||
{ "inp_norm", offload_func_nr },
|
||||
{ "inp_norm_w", offload_func_nr },
|
||||
{ "inp_norm_wb", offload_func_nr },
|
||||
{ "inp_norm", OFFLOAD_FUNC_NR },
|
||||
{ "inp_norm_w", OFFLOAD_FUNC_NR },
|
||||
{ "inp_norm_wb", OFFLOAD_FUNC_NR },
|
||||
|
||||
{ "rms_norm_0", offload_func },
|
||||
{ "rms_norm_0", OFFLOAD_FUNC },
|
||||
|
||||
{ "attn_norm_0", offload_func },
|
||||
{ "attn_norm_0_w", offload_func },
|
||||
{ "attn_norm_0_wb", offload_func },
|
||||
{ "attn_norm_0", OFFLOAD_FUNC },
|
||||
{ "attn_norm_0_w", OFFLOAD_FUNC },
|
||||
{ "attn_norm_0_wb", OFFLOAD_FUNC },
|
||||
|
||||
{ "attn_norm_2", offload_func },
|
||||
{ "attn_norm_2_w", offload_func },
|
||||
{ "attn_norm_2_wb", offload_func },
|
||||
{ "attn_norm_2", OFFLOAD_FUNC },
|
||||
{ "attn_norm_2_w", OFFLOAD_FUNC },
|
||||
{ "attn_norm_2_wb", OFFLOAD_FUNC },
|
||||
|
||||
{ "wqkv", offload_func_kq },
|
||||
{ "bqkv", offload_func_kq },
|
||||
{ "wqkv_clamped", offload_func_kq },
|
||||
{ "wqkv", OFFLOAD_FUNC_KQ },
|
||||
{ "bqkv", OFFLOAD_FUNC_KQ },
|
||||
{ "wqkv_clamped", OFFLOAD_FUNC_KQ },
|
||||
|
||||
{ "tmpk", offload_func_kq },
|
||||
{ "tmpq", offload_func_kq },
|
||||
{ "tmpv", offload_func_v },
|
||||
{ "tmpkqv", offload_func_kq }, // ??
|
||||
{ "Kcur", offload_func_kq },
|
||||
{ "Qcur", offload_func_kq },
|
||||
{ "Vcur", offload_func_v },
|
||||
{ "Vcur_0", offload_func_v },
|
||||
{ "Vcur_1", offload_func_v },
|
||||
{ "tmpk", OFFLOAD_FUNC_KQ },
|
||||
{ "tmpq", OFFLOAD_FUNC_KQ },
|
||||
{ "tmpv", OFFLOAD_FUNC_V },
|
||||
{ "tmpkqv", OFFLOAD_FUNC_KQ }, // ??
|
||||
{ "Kcur", OFFLOAD_FUNC_KQ },
|
||||
{ "Qcur", OFFLOAD_FUNC_KQ },
|
||||
{ "Vcur", OFFLOAD_FUNC_V },
|
||||
{ "Vcur_0", OFFLOAD_FUNC_V },
|
||||
{ "Vcur_1", OFFLOAD_FUNC_V },
|
||||
|
||||
{ "krot", offload_func_kq },
|
||||
{ "qrot", offload_func_kq },
|
||||
{ "kpass", offload_func_kq },
|
||||
{ "qpass", offload_func_kq },
|
||||
{ "krotated", offload_func_kq },
|
||||
{ "qrotated", offload_func_kq },
|
||||
{ "krot", OFFLOAD_FUNC_KQ },
|
||||
{ "qrot", OFFLOAD_FUNC_KQ },
|
||||
{ "kpass", OFFLOAD_FUNC_KQ },
|
||||
{ "qpass", OFFLOAD_FUNC_KQ },
|
||||
{ "krotated", OFFLOAD_FUNC_KQ },
|
||||
{ "qrotated", OFFLOAD_FUNC_KQ },
|
||||
|
||||
{ "k", offload_func_kq },
|
||||
{ "v", offload_func_v },
|
||||
{ "k", OFFLOAD_FUNC_KQ },
|
||||
{ "v", OFFLOAD_FUNC_V },
|
||||
|
||||
{ "Q", offload_func_kq },
|
||||
{ "K", offload_func_kq },
|
||||
{ "KQ", offload_func_kq },
|
||||
{ "KQ_scaled", offload_func_kq },
|
||||
{ "KQ_scaled_alibi", offload_func_kq },
|
||||
{ "KQ_masked", offload_func_kq },
|
||||
{ "KQ_soft_max", offload_func_v },
|
||||
{ "V", offload_func_v },
|
||||
{ "KQV", offload_func_v },
|
||||
{ "KQV_merged", offload_func_v },
|
||||
{ "KQV_merged_contiguous", offload_func_v },
|
||||
{ "Q", OFFLOAD_FUNC_KQ },
|
||||
{ "K", OFFLOAD_FUNC_KQ },
|
||||
{ "KQ", OFFLOAD_FUNC_KQ },
|
||||
{ "KQ_scaled", OFFLOAD_FUNC_KQ },
|
||||
{ "KQ_scaled_alibi", OFFLOAD_FUNC_KQ },
|
||||
{ "KQ_masked", OFFLOAD_FUNC_KQ },
|
||||
{ "KQ_soft_max", OFFLOAD_FUNC_V },
|
||||
{ "V", OFFLOAD_FUNC_V },
|
||||
{ "KQV", OFFLOAD_FUNC_V },
|
||||
{ "KQV_merged", OFFLOAD_FUNC_V },
|
||||
{ "KQV_merged_contiguous", OFFLOAD_FUNC_V },
|
||||
|
||||
{ "result_wo", offload_func },
|
||||
{ "result_wo_b", offload_func },
|
||||
{ "inpL_+_result_wo", offload_func },
|
||||
{ "result_wo", OFFLOAD_FUNC },
|
||||
{ "result_wo_b", OFFLOAD_FUNC },
|
||||
{ "inpL_+_result_wo", OFFLOAD_FUNC },
|
||||
|
||||
{ "inpFF", offload_func },
|
||||
{ "inpFF", OFFLOAD_FUNC },
|
||||
|
||||
{ "rms_norm_1", offload_func },
|
||||
{ "ffn_norm", offload_func },
|
||||
{ "ffn_norm_0", offload_func },
|
||||
{ "ffn_norm_0_w", offload_func },
|
||||
{ "ffn_norm_0_wb", offload_func },
|
||||
{ "rms_norm_1", OFFLOAD_FUNC },
|
||||
{ "ffn_norm", OFFLOAD_FUNC },
|
||||
{ "ffn_norm_0", OFFLOAD_FUNC },
|
||||
{ "ffn_norm_0_w", OFFLOAD_FUNC },
|
||||
{ "ffn_norm_0_wb", OFFLOAD_FUNC },
|
||||
|
||||
{ "result_w3", offload_func },
|
||||
{ "result_w3_b", offload_func },
|
||||
{ "result_w2", offload_func },
|
||||
{ "result_w2_b", offload_func },
|
||||
{ "result_w1", offload_func },
|
||||
{ "result_w3", OFFLOAD_FUNC },
|
||||
{ "result_w3_b", OFFLOAD_FUNC },
|
||||
{ "result_w2", OFFLOAD_FUNC },
|
||||
{ "result_w2_b", OFFLOAD_FUNC },
|
||||
{ "result_w1", OFFLOAD_FUNC },
|
||||
|
||||
{ "silu", offload_func },
|
||||
{ "gelu", offload_func },
|
||||
{ "relu", offload_func },
|
||||
{ "sqr(relu)", offload_func },
|
||||
{ "silu", OFFLOAD_FUNC },
|
||||
{ "gelu", OFFLOAD_FUNC },
|
||||
{ "relu", OFFLOAD_FUNC },
|
||||
{ "sqr(relu)", OFFLOAD_FUNC },
|
||||
|
||||
{ "silu_x_result_w3", offload_func },
|
||||
{ "inpFF_+_result_w2", offload_func },
|
||||
{ "inpL_+_inpFF_+_result_w2", offload_func },
|
||||
{ "silu_x_result_w3", OFFLOAD_FUNC },
|
||||
{ "inpFF_+_result_w2", OFFLOAD_FUNC },
|
||||
{ "inpL_+_inpFF_+_result_w2", OFFLOAD_FUNC },
|
||||
|
||||
{ "rms_norm_2", offload_func_nr },
|
||||
{ "out_norm_0", offload_func_nr },
|
||||
{ "out_norm_0_w", offload_func_nr },
|
||||
{ "rms_norm_2", OFFLOAD_FUNC_NR },
|
||||
{ "out_norm_0", OFFLOAD_FUNC_NR },
|
||||
{ "out_norm_0_w", OFFLOAD_FUNC_NR },
|
||||
|
||||
{ "result_norm", offload_func_emb },
|
||||
{ "result_output", offload_func_out },
|
||||
{ "result_norm", OFFLOAD_FUNC_EMB },
|
||||
{ "result_output", OFFLOAD_FUNC_OUT },
|
||||
};
|
||||
|
||||
static llama_offload_trie k_offload_func_trie(k_offload_func);
|
||||
|
||||
std::unordered_map<std::string, int> ofn;
|
||||
|
||||
for (int i = 0; i < result->n_nodes; ++i) {
|
||||
@ -5592,36 +5659,78 @@ static struct ggml_cgraph * llama_build_graph(
|
||||
continue;
|
||||
}
|
||||
|
||||
const std::string name = cur->name;
|
||||
offload_func_e func_e = k_offload_func_trie.find(cur->name);
|
||||
|
||||
const auto it = k_offload_func.find(name);
|
||||
if (it == k_offload_func.end()) {
|
||||
if (func_e == OFFLOAD_FUNC_NOP) {
|
||||
// if a tensor hasn't been offloaded, we warn the user
|
||||
if (worst_case) {
|
||||
LLAMA_LOG_WARN("%s: node %4d %32s: not offloaded (ref: %s)\n", __func__,
|
||||
i, name.c_str(), "https://github.com/ggerganov/llama.cpp/pull/3837");
|
||||
i, cur->name, "https://github.com/ggerganov/llama.cpp/pull/3837");
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
// count the number of layers and respect the provided n_gpu_layers
|
||||
offload_func_t f = it->second;
|
||||
if (n_gpu_layers < n_layer && f == offload_func) {
|
||||
if (ofn[name]++ < i_gpu_start) {
|
||||
f = ggml_offload_nop;
|
||||
}
|
||||
switch (func_e) {
|
||||
case OFFLOAD_FUNC_NOP:
|
||||
case OFFLOAD_FUNC_OUT: break;
|
||||
case OFFLOAD_FUNC:
|
||||
if (n_gpu_layers < n_layer) {
|
||||
if (ofn[cur->name]++ < i_gpu_start) {
|
||||
func_e = OFFLOAD_FUNC_NOP;
|
||||
}
|
||||
}
|
||||
break;
|
||||
case OFFLOAD_FUNC_NR:
|
||||
if (n_gpu_layers <= n_layer + 0) {
|
||||
func_e = OFFLOAD_FUNC_NOP;
|
||||
}
|
||||
break;
|
||||
case OFFLOAD_FUNC_V:
|
||||
if (n_gpu_layers <= n_layer + 1) {
|
||||
func_e = OFFLOAD_FUNC_NOP;
|
||||
}
|
||||
break;
|
||||
case OFFLOAD_FUNC_KQ:
|
||||
if (n_gpu_layers <= n_layer + 2) {
|
||||
func_e = OFFLOAD_FUNC_NOP;
|
||||
}
|
||||
break;
|
||||
case OFFLOAD_FUNC_EMB:
|
||||
if (!offload_emb || n_gpu_layers < n_layer) {
|
||||
func_e = OFFLOAD_FUNC_NOP;
|
||||
}
|
||||
break;
|
||||
default: GGML_ASSERT(false);
|
||||
}
|
||||
|
||||
offload_func_t func = ggml_offload_nop;
|
||||
|
||||
switch (func_e) {
|
||||
case OFFLOAD_FUNC_NOP:
|
||||
case OFFLOAD_FUNC_OUT: func = ggml_offload_nop; break;
|
||||
case OFFLOAD_FUNC:
|
||||
case OFFLOAD_FUNC_KQ:
|
||||
case OFFLOAD_FUNC_V:
|
||||
case OFFLOAD_FUNC_NR:
|
||||
case OFFLOAD_FUNC_EMB: func = ggml_cuda_assign_buffers_no_alloc; break;
|
||||
default: GGML_ASSERT(false);
|
||||
}
|
||||
|
||||
// apply offload function to the tensor
|
||||
f(cur);
|
||||
func(cur);
|
||||
|
||||
if (worst_case) {
|
||||
LLAMA_LOG_INFO("%s: node %4d %32s: %s\n", __func__, i, name.c_str(), k_offload_func_name.at(f).c_str());
|
||||
LLAMA_LOG_INFO("%s: node %4d %32s: %s\n", __func__, i, cur->name, k_offload_func_name.at(func_e).c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//auto t_end = std::chrono::high_resolution_clock::now();
|
||||
|
||||
//printf("offload time: %f ms\n", std::chrono::duration<double, std::milli>(t_end - t_start).count());
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user