From afc4a7de65f416e9fe1cb1248b4d7e4c104100df Mon Sep 17 00:00:00 2001 From: slaren Date: Wed, 30 Oct 2024 23:30:04 +0100 Subject: [PATCH] llama : enable flash attn automatically when supported --- ggml/src/ggml-cuda.cu | 2 + src/llama.cpp | 162 ++++++++++++++++++++++++++++-------------- 2 files changed, 109 insertions(+), 55 deletions(-) diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 087091516..c2ccbb99f 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -3148,6 +3148,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_RWKV_WKV: return true; case GGML_OP_FLASH_ATTN_EXT: { + // FIXME: this is not accurate, the flash attn implementation only has kernels for a limited number of configurations, + // which varies depending on too many factors to duplicate here. #ifndef FLASH_ATTN_AVAILABLE return false; #endif diff --git a/src/llama.cpp b/src/llama.cpp index ef1b8ee59..b44f94e26 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2777,7 +2777,7 @@ struct llama_kv_cache { bool has_shift = false; bool do_defrag = false; bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token - bool v_trans = true; // the value tensor is transposed + std::vector v_trans_l; // the value tensor is transposed (per layer) // Note: The value of head isn't only used to optimize searching // for a free KV slot. llama_decode_internal also uses it, so it @@ -3395,7 +3395,7 @@ static int llama_get_device_count(const llama_model & model) { } template -static bool buft_supported(ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev, F & fn) { +static bool buft_supported(ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev, const F & fn) { ggml_init_params params = { /*.mem_size =*/ ggml_tensor_overhead()*8, /*.mem_buffer =*/ NULL, @@ -3446,16 +3446,12 @@ static bool llama_kv_cache_init( uint32_t kv_size, bool offload) { const llama_model & model = ctx->model; - const llama_cparams & cparams = ctx->cparams; - const struct llama_hparams & hparams = model.hparams; - - const int64_t n_layer = hparams.n_layer; + const int64_t n_layer = hparams.n_layer; cache.has_shift = false; cache.recurrent = llama_model_is_recurrent(&model); - cache.v_trans = !cache.recurrent && !cparams.flash_attn; cache.head = 0; cache.size = kv_size; @@ -9699,10 +9695,7 @@ static struct ggml_tensor * llm_build_kqv( struct ggml_tensor * cur; - if (cparams.flash_attn) { - GGML_UNUSED(model); - GGML_UNUSED(n_ctx); - + if (kv.v_trans_l[il]) { // split cached v into n_head heads (not transposed) struct ggml_tensor * v = ggml_view_3d(ctx, kv.v_l[il], @@ -19400,6 +19393,10 @@ struct llama_context * llama_new_context_with_model( params.flash_attn = false; } + if (llama_model_is_recurrent(model)) { + params.flash_attn = false; + } + if (ggml_is_quantized(params.type_v) && !params.flash_attn) { LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__); return nullptr; @@ -19495,23 +19492,98 @@ struct llama_context * llama_new_context_with_model( // build worst-case graph for encoder if a model contains encoder ctx->is_encoding = llama_model_has_encoder(model); - uint32_t kv_size = cparams.n_ctx; - ggml_type type_k = params.type_k; - ggml_type type_v = params.type_v; - - // Mamba only needs a constant number of KV cache cells per sequence - if (llama_model_is_recurrent(model)) { - // Mamba needs at least as many KV cells as there are sequences kept at any time - kv_size = std::max((uint32_t) 1, params.n_seq_max); - // it's probably best to keep as much precision as possible for the states - type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states - type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states - } - - GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0); - GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0); - if (!hparams.vocab_only) { + uint32_t kv_size = cparams.n_ctx; + ggml_type type_k = params.type_k; + ggml_type type_v = params.type_v; + + // Mamba only needs a constant number of KV cache cells per sequence + if (llama_model_is_recurrent(model)) { + // Mamba needs at least as many KV cells as there are sequences kept at any time + kv_size = std::max((uint32_t) 1, params.n_seq_max); + // it's probably best to keep as much precision as possible for the states + type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states + type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states + } + + GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0); + GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0); + + if (!llama_kv_cache_init(ctx->kv_self, ctx, type_k, type_v, kv_size, cparams.offload_kqv)) { + LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__); + llama_free(ctx); + return nullptr; + } + + // find which layers can use flash attention + std::vector & flash_attn_layers = ctx->kv_self.v_trans_l; + flash_attn_layers.resize(hparams.n_layer, false); + if (cparams.flash_attn) { + for (uint32_t il = 0; il < hparams.n_layer; ++il) { + ggml_backend_dev_t layer_dev = model->dev_layer.at(il).dev; + ggml_backend_buffer_type_t layer_buft = ggml_backend_dev_buffer_type(layer_dev); + + auto & kv = ctx->kv_self; + const int64_t n_ctx = cparams.n_ctx; + const int64_t n_head = hparams.n_head(il); + const int64_t n_head_kv = hparams.n_head_kv(il); + const int64_t n_embd_head_k = hparams.n_embd_head_k; + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); + const int64_t n_embd_head_v = hparams.n_embd_head_v; + const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); + const int64_t n_embd_head = hparams.n_embd_head_v; + int n_kv = 128; + int n_tokens = 128; + + bool supported = buft_supported(layer_buft, layer_dev, [&](ggml_context * ctx) -> ggml_tensor * { + ggml_tensor * kq_mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + float kq_scale = 1.0f; + ggml_tensor * q_cur = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd_head, n_head, n_tokens); + ggml_tensor * q = ggml_permute(ctx, q_cur, 0, 2, 1, 3); + ggml_tensor * k = ggml_view_3d(ctx, kv.k_l[il], + n_embd_head_k, n_kv, n_head_kv, + ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa), + ggml_row_size(kv.k_l[il]->type, n_embd_head_k), + 0); + + // split cached v into n_head heads (not transposed) + ggml_tensor * v = ggml_view_3d(ctx, kv.v_l[il], + n_embd_head_v, n_kv, n_head_kv, + ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa), + ggml_row_size(kv.v_l[il]->type, n_embd_head_v), + 0); + + ggml_tensor * cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias, + hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f); + + ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); + return cur; + }); + + LLAMA_LOG_INFO("%s: layer %2d %s flash_attn %s\n", __func__, il, ggml_backend_dev_name(layer_dev), supported ? "supported" : "not supported"); + + flash_attn_layers[il] = supported; + } + } + + { + size_t memory_size_k = 0; + size_t memory_size_v = 0; + + for (auto & k : ctx->kv_self.k_l) { + memory_size_k += ggml_nbytes(k); + } + + for (auto & v : ctx->kv_self.v_l) { + memory_size_v += ggml_nbytes(v); + } + + LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, + (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), + ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), + ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); + } + // GPU backends for (auto * dev : model->devices) { ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr); @@ -19558,30 +19630,6 @@ struct llama_context * llama_new_context_with_model( } } - if (!llama_kv_cache_init(ctx->kv_self, ctx, type_k, type_v, kv_size, cparams.offload_kqv)) { - LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__); - llama_free(ctx); - return nullptr; - } - - { - size_t memory_size_k = 0; - size_t memory_size_v = 0; - - for (auto & k : ctx->kv_self.k_l) { - memory_size_k += ggml_nbytes(k); - } - - for (auto & v : ctx->kv_self.v_l) { - memory_size_v += ggml_nbytes(v); - } - - LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, - (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), - ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), - ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); - } - // graph outputs buffer { // resized during inference when a batch uses more outputs @@ -20330,7 +20378,8 @@ struct llama_data_write { const struct llama_kv_cache & kv_self = ctx->kv_self; const struct llama_hparams & hparams = ctx->model.hparams; - const uint32_t v_trans = kv_self.v_trans ? 1 : 0; + // FIXME + const uint32_t v_trans = kv_self.v_trans_l.at(0) ? 1 : 0; const uint32_t n_layer = hparams.n_layer; write(&v_trans, sizeof(v_trans)); @@ -20359,7 +20408,8 @@ struct llama_data_write { } } - if (!kv_self.v_trans) { + // FIXME + if (!kv_self.v_trans_l.at(0)) { for (uint32_t il = 0; il < n_layer; ++il) { const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); @@ -20652,7 +20702,8 @@ struct llama_data_read { LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, kv_self.size); return false; } - if (kv_self.v_trans != (bool) v_trans) { + // FIXME + if (kv_self.v_trans_l.at(0) != (bool) v_trans) { LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__); return false; } @@ -20685,7 +20736,8 @@ struct llama_data_read { } } - if (!kv_self.v_trans) { + // FIXME + if (!kv_self.v_trans_l.at(0)) { for (uint32_t il = 0; il < n_layer; ++il) { const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();