llama : enable flash attn automatically when supported

This commit is contained in:
slaren 2024-10-30 23:30:04 +01:00
parent b9e02e8184
commit afc4a7de65
2 changed files with 109 additions and 55 deletions

View File

@ -3148,6 +3148,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_RWKV_WKV:
return true;
case GGML_OP_FLASH_ATTN_EXT: {
// FIXME: this is not accurate, the flash attn implementation only has kernels for a limited number of configurations,
// which varies depending on too many factors to duplicate here.
#ifndef FLASH_ATTN_AVAILABLE
return false;
#endif

View File

@ -2777,7 +2777,7 @@ struct llama_kv_cache {
bool has_shift = false;
bool do_defrag = false;
bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token
bool v_trans = true; // the value tensor is transposed
std::vector<bool> v_trans_l; // the value tensor is transposed (per layer)
// Note: The value of head isn't only used to optimize searching
// for a free KV slot. llama_decode_internal also uses it, so it
@ -3395,7 +3395,7 @@ static int llama_get_device_count(const llama_model & model) {
}
template<typename F>
static bool buft_supported(ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev, F & fn) {
static bool buft_supported(ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev, const F & fn) {
ggml_init_params params = {
/*.mem_size =*/ ggml_tensor_overhead()*8,
/*.mem_buffer =*/ NULL,
@ -3446,16 +3446,12 @@ static bool llama_kv_cache_init(
uint32_t kv_size,
bool offload) {
const llama_model & model = ctx->model;
const llama_cparams & cparams = ctx->cparams;
const struct llama_hparams & hparams = model.hparams;
const int64_t n_layer = hparams.n_layer;
const int64_t n_layer = hparams.n_layer;
cache.has_shift = false;
cache.recurrent = llama_model_is_recurrent(&model);
cache.v_trans = !cache.recurrent && !cparams.flash_attn;
cache.head = 0;
cache.size = kv_size;
@ -9699,10 +9695,7 @@ static struct ggml_tensor * llm_build_kqv(
struct ggml_tensor * cur;
if (cparams.flash_attn) {
GGML_UNUSED(model);
GGML_UNUSED(n_ctx);
if (kv.v_trans_l[il]) {
// split cached v into n_head heads (not transposed)
struct ggml_tensor * v =
ggml_view_3d(ctx, kv.v_l[il],
@ -19400,6 +19393,10 @@ struct llama_context * llama_new_context_with_model(
params.flash_attn = false;
}
if (llama_model_is_recurrent(model)) {
params.flash_attn = false;
}
if (ggml_is_quantized(params.type_v) && !params.flash_attn) {
LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
return nullptr;
@ -19495,23 +19492,98 @@ struct llama_context * llama_new_context_with_model(
// build worst-case graph for encoder if a model contains encoder
ctx->is_encoding = llama_model_has_encoder(model);
uint32_t kv_size = cparams.n_ctx;
ggml_type type_k = params.type_k;
ggml_type type_v = params.type_v;
// Mamba only needs a constant number of KV cache cells per sequence
if (llama_model_is_recurrent(model)) {
// Mamba needs at least as many KV cells as there are sequences kept at any time
kv_size = std::max((uint32_t) 1, params.n_seq_max);
// it's probably best to keep as much precision as possible for the states
type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states
type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states
}
GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0);
GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0);
if (!hparams.vocab_only) {
uint32_t kv_size = cparams.n_ctx;
ggml_type type_k = params.type_k;
ggml_type type_v = params.type_v;
// Mamba only needs a constant number of KV cache cells per sequence
if (llama_model_is_recurrent(model)) {
// Mamba needs at least as many KV cells as there are sequences kept at any time
kv_size = std::max((uint32_t) 1, params.n_seq_max);
// it's probably best to keep as much precision as possible for the states
type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states
type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states
}
GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0);
GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0);
if (!llama_kv_cache_init(ctx->kv_self, ctx, type_k, type_v, kv_size, cparams.offload_kqv)) {
LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
llama_free(ctx);
return nullptr;
}
// find which layers can use flash attention
std::vector<bool> & flash_attn_layers = ctx->kv_self.v_trans_l;
flash_attn_layers.resize(hparams.n_layer, false);
if (cparams.flash_attn) {
for (uint32_t il = 0; il < hparams.n_layer; ++il) {
ggml_backend_dev_t layer_dev = model->dev_layer.at(il).dev;
ggml_backend_buffer_type_t layer_buft = ggml_backend_dev_buffer_type(layer_dev);
auto & kv = ctx->kv_self;
const int64_t n_ctx = cparams.n_ctx;
const int64_t n_head = hparams.n_head(il);
const int64_t n_head_kv = hparams.n_head_kv(il);
const int64_t n_embd_head_k = hparams.n_embd_head_k;
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
const int64_t n_embd_head_v = hparams.n_embd_head_v;
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
const int64_t n_embd_head = hparams.n_embd_head_v;
int n_kv = 128;
int n_tokens = 128;
bool supported = buft_supported(layer_buft, layer_dev, [&](ggml_context * ctx) -> ggml_tensor * {
ggml_tensor * kq_mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
float kq_scale = 1.0f;
ggml_tensor * q_cur = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd_head, n_head, n_tokens);
ggml_tensor * q = ggml_permute(ctx, q_cur, 0, 2, 1, 3);
ggml_tensor * k = ggml_view_3d(ctx, kv.k_l[il],
n_embd_head_k, n_kv, n_head_kv,
ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa),
ggml_row_size(kv.k_l[il]->type, n_embd_head_k),
0);
// split cached v into n_head heads (not transposed)
ggml_tensor * v = ggml_view_3d(ctx, kv.v_l[il],
n_embd_head_v, n_kv, n_head_kv,
ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa),
ggml_row_size(kv.v_l[il]->type, n_embd_head_v),
0);
ggml_tensor * cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
return cur;
});
LLAMA_LOG_INFO("%s: layer %2d %s flash_attn %s\n", __func__, il, ggml_backend_dev_name(layer_dev), supported ? "supported" : "not supported");
flash_attn_layers[il] = supported;
}
}
{
size_t memory_size_k = 0;
size_t memory_size_v = 0;
for (auto & k : ctx->kv_self.k_l) {
memory_size_k += ggml_nbytes(k);
}
for (auto & v : ctx->kv_self.v_l) {
memory_size_v += ggml_nbytes(v);
}
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
}
// GPU backends
for (auto * dev : model->devices) {
ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
@ -19558,30 +19630,6 @@ struct llama_context * llama_new_context_with_model(
}
}
if (!llama_kv_cache_init(ctx->kv_self, ctx, type_k, type_v, kv_size, cparams.offload_kqv)) {
LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
llama_free(ctx);
return nullptr;
}
{
size_t memory_size_k = 0;
size_t memory_size_v = 0;
for (auto & k : ctx->kv_self.k_l) {
memory_size_k += ggml_nbytes(k);
}
for (auto & v : ctx->kv_self.v_l) {
memory_size_v += ggml_nbytes(v);
}
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
}
// graph outputs buffer
{
// resized during inference when a batch uses more outputs
@ -20330,7 +20378,8 @@ struct llama_data_write {
const struct llama_kv_cache & kv_self = ctx->kv_self;
const struct llama_hparams & hparams = ctx->model.hparams;
const uint32_t v_trans = kv_self.v_trans ? 1 : 0;
// FIXME
const uint32_t v_trans = kv_self.v_trans_l.at(0) ? 1 : 0;
const uint32_t n_layer = hparams.n_layer;
write(&v_trans, sizeof(v_trans));
@ -20359,7 +20408,8 @@ struct llama_data_write {
}
}
if (!kv_self.v_trans) {
// FIXME
if (!kv_self.v_trans_l.at(0)) {
for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
@ -20652,7 +20702,8 @@ struct llama_data_read {
LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, kv_self.size);
return false;
}
if (kv_self.v_trans != (bool) v_trans) {
// FIXME
if (kv_self.v_trans_l.at(0) != (bool) v_trans) {
LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
return false;
}
@ -20685,7 +20736,8 @@ struct llama_data_read {
}
}
if (!kv_self.v_trans) {
// FIXME
if (!kv_self.v_trans_l.at(0)) {
for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();