mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-07 17:21:46 +00:00
llama : enable flash attn automatically when supported
This commit is contained in:
parent
b9e02e8184
commit
afc4a7de65
@ -3148,6 +3148,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_OP_RWKV_WKV:
|
||||
return true;
|
||||
case GGML_OP_FLASH_ATTN_EXT: {
|
||||
// FIXME: this is not accurate, the flash attn implementation only has kernels for a limited number of configurations,
|
||||
// which varies depending on too many factors to duplicate here.
|
||||
#ifndef FLASH_ATTN_AVAILABLE
|
||||
return false;
|
||||
#endif
|
||||
|
162
src/llama.cpp
162
src/llama.cpp
@ -2777,7 +2777,7 @@ struct llama_kv_cache {
|
||||
bool has_shift = false;
|
||||
bool do_defrag = false;
|
||||
bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token
|
||||
bool v_trans = true; // the value tensor is transposed
|
||||
std::vector<bool> v_trans_l; // the value tensor is transposed (per layer)
|
||||
|
||||
// Note: The value of head isn't only used to optimize searching
|
||||
// for a free KV slot. llama_decode_internal also uses it, so it
|
||||
@ -3395,7 +3395,7 @@ static int llama_get_device_count(const llama_model & model) {
|
||||
}
|
||||
|
||||
template<typename F>
|
||||
static bool buft_supported(ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev, F & fn) {
|
||||
static bool buft_supported(ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev, const F & fn) {
|
||||
ggml_init_params params = {
|
||||
/*.mem_size =*/ ggml_tensor_overhead()*8,
|
||||
/*.mem_buffer =*/ NULL,
|
||||
@ -3446,16 +3446,12 @@ static bool llama_kv_cache_init(
|
||||
uint32_t kv_size,
|
||||
bool offload) {
|
||||
const llama_model & model = ctx->model;
|
||||
const llama_cparams & cparams = ctx->cparams;
|
||||
|
||||
const struct llama_hparams & hparams = model.hparams;
|
||||
|
||||
const int64_t n_layer = hparams.n_layer;
|
||||
const int64_t n_layer = hparams.n_layer;
|
||||
|
||||
cache.has_shift = false;
|
||||
|
||||
cache.recurrent = llama_model_is_recurrent(&model);
|
||||
cache.v_trans = !cache.recurrent && !cparams.flash_attn;
|
||||
|
||||
cache.head = 0;
|
||||
cache.size = kv_size;
|
||||
@ -9699,10 +9695,7 @@ static struct ggml_tensor * llm_build_kqv(
|
||||
|
||||
struct ggml_tensor * cur;
|
||||
|
||||
if (cparams.flash_attn) {
|
||||
GGML_UNUSED(model);
|
||||
GGML_UNUSED(n_ctx);
|
||||
|
||||
if (kv.v_trans_l[il]) {
|
||||
// split cached v into n_head heads (not transposed)
|
||||
struct ggml_tensor * v =
|
||||
ggml_view_3d(ctx, kv.v_l[il],
|
||||
@ -19400,6 +19393,10 @@ struct llama_context * llama_new_context_with_model(
|
||||
params.flash_attn = false;
|
||||
}
|
||||
|
||||
if (llama_model_is_recurrent(model)) {
|
||||
params.flash_attn = false;
|
||||
}
|
||||
|
||||
if (ggml_is_quantized(params.type_v) && !params.flash_attn) {
|
||||
LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
|
||||
return nullptr;
|
||||
@ -19495,23 +19492,98 @@ struct llama_context * llama_new_context_with_model(
|
||||
// build worst-case graph for encoder if a model contains encoder
|
||||
ctx->is_encoding = llama_model_has_encoder(model);
|
||||
|
||||
uint32_t kv_size = cparams.n_ctx;
|
||||
ggml_type type_k = params.type_k;
|
||||
ggml_type type_v = params.type_v;
|
||||
|
||||
// Mamba only needs a constant number of KV cache cells per sequence
|
||||
if (llama_model_is_recurrent(model)) {
|
||||
// Mamba needs at least as many KV cells as there are sequences kept at any time
|
||||
kv_size = std::max((uint32_t) 1, params.n_seq_max);
|
||||
// it's probably best to keep as much precision as possible for the states
|
||||
type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states
|
||||
type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states
|
||||
}
|
||||
|
||||
GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0);
|
||||
GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0);
|
||||
|
||||
if (!hparams.vocab_only) {
|
||||
uint32_t kv_size = cparams.n_ctx;
|
||||
ggml_type type_k = params.type_k;
|
||||
ggml_type type_v = params.type_v;
|
||||
|
||||
// Mamba only needs a constant number of KV cache cells per sequence
|
||||
if (llama_model_is_recurrent(model)) {
|
||||
// Mamba needs at least as many KV cells as there are sequences kept at any time
|
||||
kv_size = std::max((uint32_t) 1, params.n_seq_max);
|
||||
// it's probably best to keep as much precision as possible for the states
|
||||
type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states
|
||||
type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states
|
||||
}
|
||||
|
||||
GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0);
|
||||
GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0);
|
||||
|
||||
if (!llama_kv_cache_init(ctx->kv_self, ctx, type_k, type_v, kv_size, cparams.offload_kqv)) {
|
||||
LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
|
||||
llama_free(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// find which layers can use flash attention
|
||||
std::vector<bool> & flash_attn_layers = ctx->kv_self.v_trans_l;
|
||||
flash_attn_layers.resize(hparams.n_layer, false);
|
||||
if (cparams.flash_attn) {
|
||||
for (uint32_t il = 0; il < hparams.n_layer; ++il) {
|
||||
ggml_backend_dev_t layer_dev = model->dev_layer.at(il).dev;
|
||||
ggml_backend_buffer_type_t layer_buft = ggml_backend_dev_buffer_type(layer_dev);
|
||||
|
||||
auto & kv = ctx->kv_self;
|
||||
const int64_t n_ctx = cparams.n_ctx;
|
||||
const int64_t n_head = hparams.n_head(il);
|
||||
const int64_t n_head_kv = hparams.n_head_kv(il);
|
||||
const int64_t n_embd_head_k = hparams.n_embd_head_k;
|
||||
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
||||
const int64_t n_embd_head_v = hparams.n_embd_head_v;
|
||||
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
int n_kv = 128;
|
||||
int n_tokens = 128;
|
||||
|
||||
bool supported = buft_supported(layer_buft, layer_dev, [&](ggml_context * ctx) -> ggml_tensor * {
|
||||
ggml_tensor * kq_mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
float kq_scale = 1.0f;
|
||||
ggml_tensor * q_cur = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd_head, n_head, n_tokens);
|
||||
ggml_tensor * q = ggml_permute(ctx, q_cur, 0, 2, 1, 3);
|
||||
ggml_tensor * k = ggml_view_3d(ctx, kv.k_l[il],
|
||||
n_embd_head_k, n_kv, n_head_kv,
|
||||
ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa),
|
||||
ggml_row_size(kv.k_l[il]->type, n_embd_head_k),
|
||||
0);
|
||||
|
||||
// split cached v into n_head heads (not transposed)
|
||||
ggml_tensor * v = ggml_view_3d(ctx, kv.v_l[il],
|
||||
n_embd_head_v, n_kv, n_head_kv,
|
||||
ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa),
|
||||
ggml_row_size(kv.v_l[il]->type, n_embd_head_v),
|
||||
0);
|
||||
|
||||
ggml_tensor * cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
|
||||
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
|
||||
|
||||
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
|
||||
return cur;
|
||||
});
|
||||
|
||||
LLAMA_LOG_INFO("%s: layer %2d %s flash_attn %s\n", __func__, il, ggml_backend_dev_name(layer_dev), supported ? "supported" : "not supported");
|
||||
|
||||
flash_attn_layers[il] = supported;
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
size_t memory_size_k = 0;
|
||||
size_t memory_size_v = 0;
|
||||
|
||||
for (auto & k : ctx->kv_self.k_l) {
|
||||
memory_size_k += ggml_nbytes(k);
|
||||
}
|
||||
|
||||
for (auto & v : ctx->kv_self.v_l) {
|
||||
memory_size_v += ggml_nbytes(v);
|
||||
}
|
||||
|
||||
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
|
||||
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
|
||||
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
|
||||
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
|
||||
}
|
||||
|
||||
// GPU backends
|
||||
for (auto * dev : model->devices) {
|
||||
ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
|
||||
@ -19558,30 +19630,6 @@ struct llama_context * llama_new_context_with_model(
|
||||
}
|
||||
}
|
||||
|
||||
if (!llama_kv_cache_init(ctx->kv_self, ctx, type_k, type_v, kv_size, cparams.offload_kqv)) {
|
||||
LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
|
||||
llama_free(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
{
|
||||
size_t memory_size_k = 0;
|
||||
size_t memory_size_v = 0;
|
||||
|
||||
for (auto & k : ctx->kv_self.k_l) {
|
||||
memory_size_k += ggml_nbytes(k);
|
||||
}
|
||||
|
||||
for (auto & v : ctx->kv_self.v_l) {
|
||||
memory_size_v += ggml_nbytes(v);
|
||||
}
|
||||
|
||||
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
|
||||
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
|
||||
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
|
||||
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
|
||||
}
|
||||
|
||||
// graph outputs buffer
|
||||
{
|
||||
// resized during inference when a batch uses more outputs
|
||||
@ -20330,7 +20378,8 @@ struct llama_data_write {
|
||||
const struct llama_kv_cache & kv_self = ctx->kv_self;
|
||||
const struct llama_hparams & hparams = ctx->model.hparams;
|
||||
|
||||
const uint32_t v_trans = kv_self.v_trans ? 1 : 0;
|
||||
// FIXME
|
||||
const uint32_t v_trans = kv_self.v_trans_l.at(0) ? 1 : 0;
|
||||
const uint32_t n_layer = hparams.n_layer;
|
||||
|
||||
write(&v_trans, sizeof(v_trans));
|
||||
@ -20359,7 +20408,8 @@ struct llama_data_write {
|
||||
}
|
||||
}
|
||||
|
||||
if (!kv_self.v_trans) {
|
||||
// FIXME
|
||||
if (!kv_self.v_trans_l.at(0)) {
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
||||
|
||||
@ -20652,7 +20702,8 @@ struct llama_data_read {
|
||||
LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, kv_self.size);
|
||||
return false;
|
||||
}
|
||||
if (kv_self.v_trans != (bool) v_trans) {
|
||||
// FIXME
|
||||
if (kv_self.v_trans_l.at(0) != (bool) v_trans) {
|
||||
LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
|
||||
return false;
|
||||
}
|
||||
@ -20685,7 +20736,8 @@ struct llama_data_read {
|
||||
}
|
||||
}
|
||||
|
||||
if (!kv_self.v_trans) {
|
||||
// FIXME
|
||||
if (!kv_self.v_trans_l.at(0)) {
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user