mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-09 18:21:45 +00:00
llama : enable flash attn automatically when supported
This commit is contained in:
parent
b9e02e8184
commit
afc4a7de65
@ -3148,6 +3148,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||||||
case GGML_OP_RWKV_WKV:
|
case GGML_OP_RWKV_WKV:
|
||||||
return true;
|
return true;
|
||||||
case GGML_OP_FLASH_ATTN_EXT: {
|
case GGML_OP_FLASH_ATTN_EXT: {
|
||||||
|
// FIXME: this is not accurate, the flash attn implementation only has kernels for a limited number of configurations,
|
||||||
|
// which varies depending on too many factors to duplicate here.
|
||||||
#ifndef FLASH_ATTN_AVAILABLE
|
#ifndef FLASH_ATTN_AVAILABLE
|
||||||
return false;
|
return false;
|
||||||
#endif
|
#endif
|
||||||
|
130
src/llama.cpp
130
src/llama.cpp
@ -2777,7 +2777,7 @@ struct llama_kv_cache {
|
|||||||
bool has_shift = false;
|
bool has_shift = false;
|
||||||
bool do_defrag = false;
|
bool do_defrag = false;
|
||||||
bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token
|
bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token
|
||||||
bool v_trans = true; // the value tensor is transposed
|
std::vector<bool> v_trans_l; // the value tensor is transposed (per layer)
|
||||||
|
|
||||||
// Note: The value of head isn't only used to optimize searching
|
// Note: The value of head isn't only used to optimize searching
|
||||||
// for a free KV slot. llama_decode_internal also uses it, so it
|
// for a free KV slot. llama_decode_internal also uses it, so it
|
||||||
@ -3395,7 +3395,7 @@ static int llama_get_device_count(const llama_model & model) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template<typename F>
|
template<typename F>
|
||||||
static bool buft_supported(ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev, F & fn) {
|
static bool buft_supported(ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev, const F & fn) {
|
||||||
ggml_init_params params = {
|
ggml_init_params params = {
|
||||||
/*.mem_size =*/ ggml_tensor_overhead()*8,
|
/*.mem_size =*/ ggml_tensor_overhead()*8,
|
||||||
/*.mem_buffer =*/ NULL,
|
/*.mem_buffer =*/ NULL,
|
||||||
@ -3446,16 +3446,12 @@ static bool llama_kv_cache_init(
|
|||||||
uint32_t kv_size,
|
uint32_t kv_size,
|
||||||
bool offload) {
|
bool offload) {
|
||||||
const llama_model & model = ctx->model;
|
const llama_model & model = ctx->model;
|
||||||
const llama_cparams & cparams = ctx->cparams;
|
|
||||||
|
|
||||||
const struct llama_hparams & hparams = model.hparams;
|
const struct llama_hparams & hparams = model.hparams;
|
||||||
|
|
||||||
const int64_t n_layer = hparams.n_layer;
|
const int64_t n_layer = hparams.n_layer;
|
||||||
|
|
||||||
cache.has_shift = false;
|
cache.has_shift = false;
|
||||||
|
|
||||||
cache.recurrent = llama_model_is_recurrent(&model);
|
cache.recurrent = llama_model_is_recurrent(&model);
|
||||||
cache.v_trans = !cache.recurrent && !cparams.flash_attn;
|
|
||||||
|
|
||||||
cache.head = 0;
|
cache.head = 0;
|
||||||
cache.size = kv_size;
|
cache.size = kv_size;
|
||||||
@ -9699,10 +9695,7 @@ static struct ggml_tensor * llm_build_kqv(
|
|||||||
|
|
||||||
struct ggml_tensor * cur;
|
struct ggml_tensor * cur;
|
||||||
|
|
||||||
if (cparams.flash_attn) {
|
if (kv.v_trans_l[il]) {
|
||||||
GGML_UNUSED(model);
|
|
||||||
GGML_UNUSED(n_ctx);
|
|
||||||
|
|
||||||
// split cached v into n_head heads (not transposed)
|
// split cached v into n_head heads (not transposed)
|
||||||
struct ggml_tensor * v =
|
struct ggml_tensor * v =
|
||||||
ggml_view_3d(ctx, kv.v_l[il],
|
ggml_view_3d(ctx, kv.v_l[il],
|
||||||
@ -19400,6 +19393,10 @@ struct llama_context * llama_new_context_with_model(
|
|||||||
params.flash_attn = false;
|
params.flash_attn = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (llama_model_is_recurrent(model)) {
|
||||||
|
params.flash_attn = false;
|
||||||
|
}
|
||||||
|
|
||||||
if (ggml_is_quantized(params.type_v) && !params.flash_attn) {
|
if (ggml_is_quantized(params.type_v) && !params.flash_attn) {
|
||||||
LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
|
LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
@ -19495,6 +19492,7 @@ struct llama_context * llama_new_context_with_model(
|
|||||||
// build worst-case graph for encoder if a model contains encoder
|
// build worst-case graph for encoder if a model contains encoder
|
||||||
ctx->is_encoding = llama_model_has_encoder(model);
|
ctx->is_encoding = llama_model_has_encoder(model);
|
||||||
|
|
||||||
|
if (!hparams.vocab_only) {
|
||||||
uint32_t kv_size = cparams.n_ctx;
|
uint32_t kv_size = cparams.n_ctx;
|
||||||
ggml_type type_k = params.type_k;
|
ggml_type type_k = params.type_k;
|
||||||
ggml_type type_v = params.type_v;
|
ggml_type type_v = params.type_v;
|
||||||
@ -19511,7 +19509,81 @@ struct llama_context * llama_new_context_with_model(
|
|||||||
GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0);
|
GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0);
|
||||||
GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0);
|
GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0);
|
||||||
|
|
||||||
if (!hparams.vocab_only) {
|
if (!llama_kv_cache_init(ctx->kv_self, ctx, type_k, type_v, kv_size, cparams.offload_kqv)) {
|
||||||
|
LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
|
||||||
|
llama_free(ctx);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// find which layers can use flash attention
|
||||||
|
std::vector<bool> & flash_attn_layers = ctx->kv_self.v_trans_l;
|
||||||
|
flash_attn_layers.resize(hparams.n_layer, false);
|
||||||
|
if (cparams.flash_attn) {
|
||||||
|
for (uint32_t il = 0; il < hparams.n_layer; ++il) {
|
||||||
|
ggml_backend_dev_t layer_dev = model->dev_layer.at(il).dev;
|
||||||
|
ggml_backend_buffer_type_t layer_buft = ggml_backend_dev_buffer_type(layer_dev);
|
||||||
|
|
||||||
|
auto & kv = ctx->kv_self;
|
||||||
|
const int64_t n_ctx = cparams.n_ctx;
|
||||||
|
const int64_t n_head = hparams.n_head(il);
|
||||||
|
const int64_t n_head_kv = hparams.n_head_kv(il);
|
||||||
|
const int64_t n_embd_head_k = hparams.n_embd_head_k;
|
||||||
|
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
||||||
|
const int64_t n_embd_head_v = hparams.n_embd_head_v;
|
||||||
|
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
||||||
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||||
|
int n_kv = 128;
|
||||||
|
int n_tokens = 128;
|
||||||
|
|
||||||
|
bool supported = buft_supported(layer_buft, layer_dev, [&](ggml_context * ctx) -> ggml_tensor * {
|
||||||
|
ggml_tensor * kq_mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||||
|
float kq_scale = 1.0f;
|
||||||
|
ggml_tensor * q_cur = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd_head, n_head, n_tokens);
|
||||||
|
ggml_tensor * q = ggml_permute(ctx, q_cur, 0, 2, 1, 3);
|
||||||
|
ggml_tensor * k = ggml_view_3d(ctx, kv.k_l[il],
|
||||||
|
n_embd_head_k, n_kv, n_head_kv,
|
||||||
|
ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa),
|
||||||
|
ggml_row_size(kv.k_l[il]->type, n_embd_head_k),
|
||||||
|
0);
|
||||||
|
|
||||||
|
// split cached v into n_head heads (not transposed)
|
||||||
|
ggml_tensor * v = ggml_view_3d(ctx, kv.v_l[il],
|
||||||
|
n_embd_head_v, n_kv, n_head_kv,
|
||||||
|
ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa),
|
||||||
|
ggml_row_size(kv.v_l[il]->type, n_embd_head_v),
|
||||||
|
0);
|
||||||
|
|
||||||
|
ggml_tensor * cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
|
||||||
|
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
|
||||||
|
|
||||||
|
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
|
||||||
|
return cur;
|
||||||
|
});
|
||||||
|
|
||||||
|
LLAMA_LOG_INFO("%s: layer %2d %s flash_attn %s\n", __func__, il, ggml_backend_dev_name(layer_dev), supported ? "supported" : "not supported");
|
||||||
|
|
||||||
|
flash_attn_layers[il] = supported;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
size_t memory_size_k = 0;
|
||||||
|
size_t memory_size_v = 0;
|
||||||
|
|
||||||
|
for (auto & k : ctx->kv_self.k_l) {
|
||||||
|
memory_size_k += ggml_nbytes(k);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto & v : ctx->kv_self.v_l) {
|
||||||
|
memory_size_v += ggml_nbytes(v);
|
||||||
|
}
|
||||||
|
|
||||||
|
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
|
||||||
|
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
|
||||||
|
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
|
||||||
|
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
|
||||||
|
}
|
||||||
|
|
||||||
// GPU backends
|
// GPU backends
|
||||||
for (auto * dev : model->devices) {
|
for (auto * dev : model->devices) {
|
||||||
ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
|
ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
|
||||||
@ -19558,30 +19630,6 @@ struct llama_context * llama_new_context_with_model(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!llama_kv_cache_init(ctx->kv_self, ctx, type_k, type_v, kv_size, cparams.offload_kqv)) {
|
|
||||||
LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
|
|
||||||
llama_free(ctx);
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
size_t memory_size_k = 0;
|
|
||||||
size_t memory_size_v = 0;
|
|
||||||
|
|
||||||
for (auto & k : ctx->kv_self.k_l) {
|
|
||||||
memory_size_k += ggml_nbytes(k);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (auto & v : ctx->kv_self.v_l) {
|
|
||||||
memory_size_v += ggml_nbytes(v);
|
|
||||||
}
|
|
||||||
|
|
||||||
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
|
|
||||||
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
|
|
||||||
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
|
|
||||||
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
|
|
||||||
}
|
|
||||||
|
|
||||||
// graph outputs buffer
|
// graph outputs buffer
|
||||||
{
|
{
|
||||||
// resized during inference when a batch uses more outputs
|
// resized during inference when a batch uses more outputs
|
||||||
@ -20330,7 +20378,8 @@ struct llama_data_write {
|
|||||||
const struct llama_kv_cache & kv_self = ctx->kv_self;
|
const struct llama_kv_cache & kv_self = ctx->kv_self;
|
||||||
const struct llama_hparams & hparams = ctx->model.hparams;
|
const struct llama_hparams & hparams = ctx->model.hparams;
|
||||||
|
|
||||||
const uint32_t v_trans = kv_self.v_trans ? 1 : 0;
|
// FIXME
|
||||||
|
const uint32_t v_trans = kv_self.v_trans_l.at(0) ? 1 : 0;
|
||||||
const uint32_t n_layer = hparams.n_layer;
|
const uint32_t n_layer = hparams.n_layer;
|
||||||
|
|
||||||
write(&v_trans, sizeof(v_trans));
|
write(&v_trans, sizeof(v_trans));
|
||||||
@ -20359,7 +20408,8 @@ struct llama_data_write {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!kv_self.v_trans) {
|
// FIXME
|
||||||
|
if (!kv_self.v_trans_l.at(0)) {
|
||||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
||||||
|
|
||||||
@ -20652,7 +20702,8 @@ struct llama_data_read {
|
|||||||
LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, kv_self.size);
|
LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, kv_self.size);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (kv_self.v_trans != (bool) v_trans) {
|
// FIXME
|
||||||
|
if (kv_self.v_trans_l.at(0) != (bool) v_trans) {
|
||||||
LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
|
LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -20685,7 +20736,8 @@ struct llama_data_read {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!kv_self.v_trans) {
|
// FIXME
|
||||||
|
if (!kv_self.v_trans_l.at(0)) {
|
||||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user